2016-09-25 37 views
2

我想在Android應用程序中使用我的Tensorflow算法。 Tensorflow Android示例首先下載一個包含模型定義和權重的GraphDef(在* .pb文件中)。現在這應該來自我的Scikit Flow算法(Tensorflow的一部分)。Tensorflow Scikit Flow獲取GraphDef for Android(保存* .pb文件)

第一眼看起來很簡單,你只需要說classifier.save('model /'),但保存到該文件夾​​的文件不是* .ckpt,* .def,當然不是* .pb。相反,你必須處理一個* .pbtxt和一個檢查點(沒有結束)文件。

我在那裏呆了很長一段時間。下面的代碼示例,出口的東西:

#imports 
import tensorflow as tf 
import tensorflow.contrib.learn as skflow 
import tensorflow.contrib.learn.python.learn as learn 
from sklearn import datasets, metrics 

#skflow example 
iris = datasets.load_iris() 
feature_columns = learn.infer_real_valued_columns_from_input(iris.data) 
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest") 
classifier.fit(iris.data, iris.target, steps=200, batch_size=32) 
iris_predictions = list(classifier.predict(iris.data, as_iterable=True)) 
score = metrics.accuracy_score(iris.target, iris_predictions) 
print("Accuracy: %f" % score) 

的文件,你得到的是:

  • 關卡
  • graph.pbtxt
  • model.ckpt-1.meta
  • model.ckpt -1-00000-of-00001
  • model.ckpt-200.meta
  • model.ckpt-200-00 000-of-00001

我發現的許多可能的解決方法都需要在變量中使用GraphDef(不知道如何使用Scikit Flow)。或者使用Scikit Flow似乎不需要的Tensorflow會話。

+0

您是否設法找到解決方案? – idoshamun

+0

我決定使用Scikit Flow進行實驗(我的NN需要多少層),然後用純張量流重建模型。然後,我通過創建第二個模型,將已經訓練的權重作爲常量(切換到iOS,但可能與Android相同)來避免整個freeze_graph bazel的內容。這不是一個真正的建議,只是我採取的路徑 – CodingYourLife

回答

1

要保存爲pb文件,您需要從構建的圖中提取graph_def。你可以做到這一點as--

from tensorflow.python.framework import tensor_shape, graph_util 
from tensorflow.python.platform import gfile 
sess = tf.Session() 
final_tensor_name = 'results:0'  #Replace final_tensor_name with name of the final tensor in your graph 
#########Build your graph and train######## 
## Your tensorflow code to build the graph 
########################################### 

outpt_filename = 'output_graph.pb' 
output_graph_def = sess.graph.as_graph_def() 
with gfile.FastGFile(outpt_filename, 'wb') as f: 
    f.write(output_graph_def.SerializeToString()) 

如果你想你的訓練有素的變量轉化爲常量(避免使用CKPT文件加載權重),你可以使用:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name]) 

希望這有助於!

+0

說「名稱」旗幟'沒有定義「。它似乎工作,如果我把目標路徑改爲* .pb文件。另外我不知道你的意思是[network_prefix + final_tensor_name]。看來我可以把輸出佔位符(因爲字符串似乎需要一個名字) – CodingYourLife

+0

@CodingYourLife:是的。我編輯了這個改變。用張量或張量名稱(輸出節點)替換final_tensor_name,它將正確地輸出輸出圖形 –

相關問題