我想部署一個簡單的TensorFlow模型,並在像Flask這樣的REST服務中運行它。 在github上或這裏找不到目前爲止的好例子。TensorFlow REST前端,但不是TensorFlow服務
我不準備使用TF擔任其他職位的建議,這是谷歌完美的解決方案,但它矯枉過正我與GRPC,巴澤爾,C++編碼,protobuf的任務......
我想部署一個簡單的TensorFlow模型,並在像Flask這樣的REST服務中運行它。 在github上或這裏找不到目前爲止的好例子。TensorFlow REST前端,但不是TensorFlow服務
我不準備使用TF擔任其他職位的建議,這是谷歌完美的解決方案,但它矯枉過正我與GRPC,巴澤爾,C++編碼,protobuf的任務......
有不同的方法來做到這一點。純粹地,使用張量流並不是非常靈活,然而相對簡單。這種方法的缺點是您必須重新生成圖形並在恢復模型的代碼中初始化變量。有一種方法顯示在tensorflow skflow/contrib learn這是更優雅,但是這似乎並沒有功能的時刻和文檔已過時。
我在github here上放了一個簡短的例子,它展示瞭如何將GET或POST參數命名爲瓶式REST部署的tensorflow模型。
主要代碼然後在需要立足崗位字典的功能/ GET數據:
@app.route('/model', methods=['GET', 'POST'])
@parse_postget
def apply_model(d):
tf.reset_default_graph()
with tf.Session() as session:
n = 1
x = tf.placeholder(tf.float32, [n], name='x')
y = tf.placeholder(tf.float32, [n], name='y')
m = tf.Variable([1.0], name='m')
b = tf.Variable([1.0], name='b')
y = tf.add(tf.mul(m, x), b) # fit y_i = m * x_i + b
y_act = tf.placeholder(tf.float32, [n], name='y_')
error = tf.sqrt((y - y_act) * (y - y_act))
train_step = tf.train.AdamOptimizer(0.05).minimize(error)
feed_dict = {x: np.array([float(d['x_in'])]), y_act: np.array([float(d['y_star'])])}
saver = tf.train.Saver()
saver.restore(session, 'linear.chk')
y_i, _, _ = session.run([y, m, b], feed_dict)
return jsonify(output=float(y_i))
這github project顯示恢復模型檢查點並使用Flask的工作示例。
@app.route('/api/mnist', methods=['POST'])
def mnist():
input = ((255 - np.array(request.json, dtype=np.uint8))/255.0).reshape(1, 784)
output1 = simple(input)
output2 = convolutional(input)
return jsonify(results=[output1, output2])
在線demo看起來很快。
我不喜歡把數據/模型處理太多的代碼在燒瓶寧靜的文件。我通常有tf模型課等。 即它可能是這樣的:
# model init, loading data
cifar10_recognizer = Cifar10_Recognizer()
cifar10_recognizer.load('data/c10_model.ckpt')
@app.route('/tf/api/v1/SomePath', methods=['GET', 'POST'])
def upload():
X = []
if request.method == 'POST':
if 'photo' in request.files:
# place for uploading process workaround, obtaining input for tf
X = generate_X_c10(f)
if len(X) != 0:
# designing desired result here
answer = np.squeeze(cifar10_recognizer.predict(X))
top3 = (-answer).argsort()[:3]
res = ([cifar10_labels[i] for i in top3], [answer[i] for i in top3])
# you can simply print this to console
# return 'Prediction answer: {}'.format(res)
# or generate some html with result
return fk.render_template('demos/c10_show_result.html',
name=file,
result=res)
if request.method == 'GET':
# in html I have simple form to upload img file
return fk.render_template('demos/c10_classifier.html')
cifar10_recognizer.predict(X)是簡單FUNC,在TF會話中運行預測操作:
def predict(self, image):
logits = self.sess.run(self.model, feed_dict={self.input: image})
return logits
P.S.從文件中保存/恢復模型是一個非常漫長的過程,儘量避免這種情況,同時服務發佈/獲取請求
縮小問題只是想在保存器加載模型後返回結果的Flask示例 – chro