我想在Android應用程序中使用我的Tensorflow算法。 Tensorflow Android示例首先下載一個包含模型定義和權重的GraphDef(在* .pb文件中)。現在這應該來自我的Scikit Flow算法(Tensorflow的一部分)。Tensorflow Scikit Flow獲取GraphDef for Android(保存* .pb文件)
第一眼看起來很簡單,你只需要說classifier.save('model /'),但保存到該文件夾的文件不是* .ckpt,* .def,當然不是* .pb。相反,你必須處理一個* .pbtxt和一個檢查點(沒有結束)文件。
我在那裏呆了很長一段時間。下面的代碼示例,出口的東西:
#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
的文件,你得到的是:
- 關卡
- graph.pbtxt
- model.ckpt-1.meta
- model.ckpt -1-00000-of-00001
- model.ckpt-200.meta
- model.ckpt-200-00 000-of-00001
我發現的許多可能的解決方法都需要在變量中使用GraphDef(不知道如何使用Scikit Flow)。或者使用Scikit Flow似乎不需要的Tensorflow會話。
您是否設法找到解決方案? – idoshamun
我決定使用Scikit Flow進行實驗(我的NN需要多少層),然後用純張量流重建模型。然後,我通過創建第二個模型,將已經訓練的權重作爲常量(切換到iOS,但可能與Android相同)來避免整個freeze_graph bazel的內容。這不是一個真正的建議,只是我採取的路徑 – CodingYourLife