2016-05-04 68 views
0

我正在實施球面上的k-means,從@ dga的gist開始。 單位範數約束基本上意味着使用內積而不是成對距離,使用argmax而不是argmin和sum +歸一化來代替平均來更新質心。使用布爾值掩碼和索引列表替換Tensorflow中的某些行的一些行

現在我試圖用最少的表示數據點來替換死亡的質心。 unsorted_segment_sum將返回0之死重心:

total = tf.unsorted_segment_sum(points, best_centroids, K) 

從這個我弄不死質心的布爾掩碼:

deads = tf.equal(total, 0) 

...死了質心的計數:

dead_count = tf.reduce_sum(tf.as_type(deads, 'int64')) 

...並且最後列出了當前模型中表示最差的數據點的索引:

_, dead_replacement_idx = tf.nn.top_k(-assignment_qualities, 
             k=dead_count, sorted=False) 

現在我該如何更換死質心? 在numpy的這個現在將回落到這一點:

means[deads] = points[dead_replacement_idx] 

我怎樣才能做到在Tensorflow類似的東西?

+0

https://github.com/tensorflow/tensorflow/issues/206 ...我覺得有人正在開始工作 –

回答

0

如果您存儲手段Variables可以使用scatter_update

tf.reset_default_graph() 

means = tf.Variable(np.array([[1,1],[2,2],[3,3]]), dtype=np.float32) 
indices = tf.constant([0, 2]) 
new_mean = tf.constant([-1, -1], dtype=np.float32) 

new_mean_matrix = tf.reshape(new_mean, [1, -1]) 
tile_shape = tf.pack([tf.size(indices), 1]) 
new_mean_matrix_tiled = tf.tile(new_mean_matrix, tile_shape) 
update_op = tf.scatter_update(means, indices, new_mean_matrix_tiled) 

sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 
print "Before update" 
print sess.run(means) 
print "Updating rows", indices.eval(), "to", new_mean.eval() 
sess.run(update_op) 
print "After update" 
print sess.run(means) 

結果

Before update 
[[ 1. 1.] 
[ 2. 2.] 
[ 3. 3.]] 
Updating rows [0 2] to [-1. -1.] 
After update 
[[-1. -1.] 
[ 2. 2.] 
[-1. -1.]] 
+0

在我的情況下,'means'是一個來自* centroids變量的張量。如果賦值改變了,'centroids'和'cluster_assignments'在control_dependencies():... tf.group()'中被賦值''assign。我想我可以在之後做'scatter_update()',但是我需要另外一個'control_dependencies()'以確保在'assign'之後死亡質心被替換*。 –

+0

確保一組計算髮生在別人之後的最簡單方法是使用兩個不同的運行調用。 IE,'sess.run(update_variables); sess.run(do_calculation)'。如果使用'with_dependencies',確保所有相關操作都在'with_dependencies'塊內創建。 –

相關問題