2017-10-09 54 views
3

我試圖使用tensorflow.contrib.learn.KMeansClustering作爲Tensorflow中圖的一部分。我想用它作爲圖的組成部分,給我預測和中心。代碼的相關部分如下:在張量流中使用KMeans tflearn估計器作爲圖的一部分

with tf.variable_scope('kmeans'): 
    kmeans = KMeansClustering(num_clusters=num_clusters, 
           relative_tolerance=0.0001) 
    kmeans.fit(input_fn= (lambda : [X, None])) 
    clusters = kmeans.clusters() 

init_vars = tf.global_variables_initializer() 
sess = tf.Session() 
sess.run(init_vars, feed_dict={X: full_data_x}) 
clusters_np = sess.run(clusters, feed_dict={X: full_data_x}) 

不過,我得到以下錯誤:

ValueError: Tensor("kmeans/strided_slice:0", shape=(), dtype=int32) must be from the same graph as Tensor("sub:0", shape=(), dtype=int32). 

我相信這是因爲k-平均算法是TFLearn估計;這將比單個模塊更類似於整個圖。那是對的嗎?我可以將它轉換爲默認圖形的模塊嗎?如果沒有,是否有一個函數可以在另一個圖表中執行KMeans?

謝謝!

回答

1

KMeansClustering估計器使用tf.contrib.factorization中的ops。 factorization MNIST example使用沒有Estimator的KMeans。

+0

我已經查看了您的建議和該函數的[documentation](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/factorization/KMeans)。如果我沒有弄錯,他們只提供從輸入到每個輸入的中心和標籤的距離。但是,我正在尋找這些中心,想知道如何獲得它們?謝謝! – etal

+0

看起來不像公共訪問者。您可以使用集羣賦值對點進行平均,也可以用'reuse = True'調用'variable_scople'中的'_create_variables'(您也必須將原始變量創建封裝在變量範圍內)。 –

+0

我可以看到_create_variables輸出中心,但我看不到輸入數據如何/何時受到影響。你如何使用variable_scope來做到這一點? – etal

1

的k-平均算法估算API建立自己的tf.Graph並自行管理tf.Session,這樣你就不會需要運行一個tf.Session養活值(由input_fn完成),這就是爲什麼ValueError出現。

k-平均算法估計的正確用法只是:

kmeans = KMeansClustering(num_clusters=num_clusters, 
          relative_tolerance=0.0001) 
kmeans.fit(input_fn=(lambda: [X, None])) 
clusters = kmeans.clusters() 

其中X是保持的值的輸入tf.constant張量(例如定義爲Xnp.array不是使用tf.convert_to_tensor)。這裏X不是tf.placeholder,需要以tf.Session運行。

更新TensorFlow 1.4:

使用tf.contrib.factorization.KMeansClustering API找到聚類中心:

kmeans=tf.contrib.factorization.KMeansClustering(num_clusters=num_clusters) 
kmeans.train(input_fn=(lambda: [X, None])) 
centers = kmeans.cluster_centers() 

來預測給定功能中心只需使用:

predictions = kmeans.predict(input_fn=(lambda:[another_X, None])) 
0

這裏link是使用方法KMeans聚類 u唱着tf.contrib.factorization.KMeansClustering。它告訴解決方案是通過將輸入Tensor(X)放入input_fn lambda中來延遲創建輸入,該輸入將在train()內調用。那麼你將不會得到上面提到的錯誤。

相關問題