2017-09-28 73 views
0

我通常在張量流中處理索引張量。 我有圖像數據和額外的標量數據。我只能使用一個佔位符將所有數據輸入到神經網絡。Tensorflow:通過單個佔位符輸入封裝數據批量

圖像(img)是形狀爲(84,84,3)的numpy陣列,我的數據a的形狀爲(2),b的形狀爲(1)

現在我創建一個單個樣本

sample = np.reshape(np.array([img,a,b]),(3,1)) #shape (3,1) 

佔位符是

input = tf.placeholder(dtype=tf.float32,shape=[None] + list(sample.shape)) 

現在,當TF讀取一批樣品,我想取回一批圖像,該批次的,和一批b,因爲它們需要在神經網絡中的不同位置輸入。

這裏是一個小例子:

import tensorflow as tf 
from tensorflow.contrib import layers 
import numpy as np 

#Numpy 
img = np.random.rand(84,84,3) 
a = np.random.rand(2) 
b = np.random.rand(1) 
sample = np.reshape(np.array([img,a,b]),(3,1)) #shape (3,1) 
batch = np.repeat(np.expand_dims(sample,axis=0),32,axis=0) #shape (32,3,1) 

#TF 
input = tf.placeholder(dtype=tf.float32,shape=[None] + list(sample.shape)) 

#TODO: 
tf_img = tf.#get image batch from input 
tf_a = tf.#get a batch from input 
tf_b = tf.#get b batch from input 

out = layers.convolution2d(tf_img,num_outputs=64,kernel_size=8,stride=2,activation_fn=tf.nn.relu) 
out = layers.flatten(out) 
out = tf.concat([out,tf_a,tf_b]) 
out = layers.fully_connected(out,10,activation_fn=tf.nn.relu) 

init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    sess.run(init) 
    _ = sess.run(out,feed_dict={input:batch}) 

我怎樣才能提取輸入的各個部分從與形狀(?,3,1)的張量,使用所述圖像數據以創建一個嵌入並連接其他兩個部分,以該輸出嵌入。

有沒有更好的方式輸入數據?我唯一的約束是它必須是一個單獨的佔位符。

+0

請說明爲什麼你只限於一個佔位符!如果輸入數據的所有維都是已知的,則可以將每個樣本換成單個形狀張量'(None,84 * 84 * 3 + 2 + 1)'並使用[slicing](https://www.tensorflow。 org/api_docs/python/tf/slice)和[重塑](https://www.tensorflow.org/api_docs/python/tf/reshape)。 – chrert

+0

感謝您的回答。我正在處理一個大型代碼模板,並且我想盡可能少地進行更改,因爲它非常複雜(並且假定只有一個圖像輸入)。我也想知道這是否可能。你能提供一個更大的例子嗎? – user3142067

回答

0

下面是我的上述評論的完整例子:

import numpy as np 
import tensorflow as tf 

im_height = 84 
im_width = 84 
im_channels = 3 
a_len = 2 
b_len = 1 

np_img = np.random.rand(im_height, im_width, im_channels) 
np_a = np.random.rand(a_len) 
np_b = np.random.rand(b_len) 

# flatten the input and concatenate to a single 1D numpy array 
np_sample = np.concatenate((np_img.reshape(-1), np_a.reshape(-1), np_b.reshape(-1)), axis=0) 
# construct a pseudo batch 
np_batch = np.repeat(np_sample[np.newaxis, :], 32, axis=0) 

tf_batch = tf.placeholder(shape=(None, im_height*im_width*im_channels + a_len + b_len), dtype=tf.float32) 

img_stop = im_height*im_width*im_channels 
a_stop = img_stop+a_len 

# you could also use tf.slice(...) here 
tf_img = tf.reshape(tf_batch[:, 0:img_stop], (-1, im_height, im_width, im_channels)) 
tf_a = tf.reshape(tf_batch[:, img_stop:a_stop], (-1, a_len)) 
tf_b = tf.reshape(tf_batch[:, a_stop:], (-1, b_len)) 

with tf.Session() as sess: 
    fetch_dict = {'img': tf_img, 'a': tf_a, 'b': tf_b} 
    feed_dict = {tf_batch: np_batch} 
    res = sess.run(fetch_dict, feed_dict=feed_dict) 

assert(np.isclose(res['img'][0, ...], np_img).all()) 
assert(np.isclose(res['a'][0, :], np_a).all()) 
assert(np.isclose(res['b'][0, :], np_b).all()) 

但是,這至少是添加適當的佔位符的代碼侵入。另外,在我看來,它的可讀性要差得多。

相關問題