2017-08-27 89 views
1

我有一些基本的功能,取入的圖像的URL,並通過VGG-16 CNN其轉換:tensorflow多對圖像特徵提取

def convert_url(_id, url): 
    im = get_image(url) 
    return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im})) 

我有一個大組URL(〜60,000)我想在其上執行此功能。每次迭代需要一秒以上的時間,這太慢了。我想通過並行使用多個進程來加速它。沒有共享狀態可以擔心,所以多線程的通常陷阱不是問題。

但是,我不確定如何實際使tensorflow與多處理程序包一起工作。我知道你不能將tensorflow session傳遞給Pool變量。所以不是,我想初始化的session多個實例:

def init(): 
    global sess; 
    sess = tf.Session() 

但是當我真正啓動過程中,它只是掛起無限期:

with Pool(processes=3,initializer=init) as pool: 
    results = pool.starmap(convert_url, list(id_img_dict.items())[0:5]) 

注意,tensorflow圖被全局定義。我認爲這是正確的做法,但我不確定:

input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image') 
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor) 
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5) 
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0) 

arg_scope = vgg_arg_scope() 
with slim.arg_scope(arg_scope): 
    _, end_points = vgg_16(scaled_input_tensor, is_training=False) 
saver = tf.train.Saver() 
saver.restore(sess, checkpoint_file) 

任何人都可以幫助我實現這個工作嗎?多謝。

+0

如果你在使用在線文件,你應該看看使用異步,這應該會產生很大的加速。 –

回答

1

忘掉python的正常多線程工具並使用tensorflow.contrib.data.Dataset。嘗試如下所示。

urls = ['img1.jpg', 'img2.jpg', ...] 
batch_size = 16 
n_batches = len(urls) // batch_size # do something more elegant for remainder 


def load_img(url): 
    image = tf.read_file(url, name='image_data') 
    image = tf.image.decode_jpeg(image, channels=3, name='image') 
    return image 


def preprocess(img_tensor): 
    img_tensor = (tf.cast(img_tensor, tf.float32)/255 - 0.5)*2 
    img_tensor.set_shape((256, 256, 3)) # whatever shape 
    return img_tensor 


dataset = tf.contrib.data.Dataset.from_tensor_slices(urls) 
dataset = dataset.map(load_img).map(preprocess) 

preprocessed_images = dataset.batch(
    batch_size).make_one_shot_iterator().get_next() 


arg_scope = vgg_arg_scope() 
with slim.arg_scope(arg_scope): 
    _, end_points = vgg_16(preprocessed_images, is_training=False) 
    output = end_points['vgg_16/fc7'] 


results = [] 

with tf.Session() as sess: 
    tf.train.Saver().restore(sess, checkpoint_file) 
    for i in range(n_batches): 
     batch_results = sess.run(output) 
     results.extend(batch_results) 
     print('Done batch %d/%d' % (i+1, n_batches)) 
+0

欣賞響應!如果我在本地保存所有文件,這似乎是一個好方法。但是,我只在線存儲URL到jpeg或png文件。顯然,實際使用'urllib'或'requests'來獲取圖像文件本身是微不足道的,但是這會削弱使用'Datasets'來並行化圖形的價值嗎? – anon

+0

我明白了。在這種情況下,你可以很大程度上忽略這個答案。我沒有做過很多正常的python多線程,但我想象如果你並行化數據提取(單獨從tensorflow),你應該能夠使用上面的代碼,批量大於1來顯着提高速度,例如,獲取16張圖片,當所有圖片加載時,都會提供給tf.sess。 – DomJack

+0

我寧願讓每個線程都負責數據獲取和tensorflow轉換,因爲這是一個更清晰的解決方案。你的解決方案是可行的,但我想知道通過多處理可以一次提取16張圖像多快。也就是說,即使我連續抓住這16個圖像,我也可以通過並行運行VGG,這樣就可以工作。儘管如此,我仍然堅持有人知道如何做我最初想要的東西! – anon