我有一些基本的功能,取入的圖像的URL,並通過VGG-16 CNN其轉換:tensorflow多對圖像特徵提取
def convert_url(_id, url):
im = get_image(url)
return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}))
我有一個大組URL(〜60,000)我想在其上執行此功能。每次迭代需要一秒以上的時間,這太慢了。我想通過並行使用多個進程來加速它。沒有共享狀態可以擔心,所以多線程的通常陷阱不是問題。
但是,我不確定如何實際使tensorflow與多處理程序包一起工作。我知道你不能將tensorflow session
傳遞給Pool變量。所以不是,我想初始化的session
多個實例:
def init():
global sess;
sess = tf.Session()
但是當我真正啓動過程中,它只是掛起無限期:
with Pool(processes=3,initializer=init) as pool:
results = pool.starmap(convert_url, list(id_img_dict.items())[0:5])
注意,tensorflow圖被全局定義。我認爲這是正確的做法,但我不確定:
input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
_, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
任何人都可以幫助我實現這個工作嗎?多謝。
如果你在使用在線文件,你應該看看使用異步,這應該會產生很大的加速。 –