2016-12-01 18 views
1

我爲了自己的目的移植了Tensorflow Cifar-10教程文件,並遇到了一個有趣的問題,由於Tensorflow的圖形和會話體系結構,我無法輕鬆進行概念化。如何複製Tensorflow隊列中張量屬性(「過採樣」)條件下的輸入張量?

問題在於我的輸入數據集高度不平衡,因此我需要在輸入管道中「過度採樣」(並擴充)輸入管道中的某些圖像,並以其標籤爲條件。在一個正常的Python環境中,我可以設置一個簡單的控制流程語句,格式爲if label then duplicate,但由於控制流操作存在於正在運行的會話之外,因此我無法在Tensorflow中編寫相同的語法,在此情況下,此操作不支持label返回一個值。

我的問題是,對Tensorflow隊列內張量進行過採樣的最簡單方法是什麼?

我知道我可以在輸入操作之前簡單地複製感興趣的數據,但是這顯然消除了運行時過採樣導致的存儲節省。

我想要做的是評估張量的標籤(在Cifar-10的情況下,通過檢查1D image.label屬性),然後通過固定因子複製該張量(例如,如果標籤是「狗」 「)並將所有的張量發送到配料操作。我最初的做法是在讀取操作之後和批處理操作之前嘗試複製步驟,但這也是在運行會話之外。我正在考慮使用TF的while控制流量表,但我不確定這個功能是否能夠執行除輸入張量外的其他任何操作。你怎麼看?取決於的值


更新#1

基本上我試圖創建了在扁平圖像字節和標籤字節的py_func(),並垂直堆疊在同一圖像字節N次標籤,然後將其作爲(N x image_bytes)張量(py_func()自動將輸入張量轉換爲numpy並返回)返回。我試圖從形狀報告爲(?,image_bytes)的可變高度張量創建一個input_queue,然後實例化一個閱讀器來撕掉image_byte大小記錄。好吧,你似乎無法建立未知數據大小的隊列,所以這種方法對我來說並不合適,這在後見之明是有道理的,但我仍然無法概念化一種方法來識別隊列中的記錄,並重復記錄具體次數。


更新#2

48小時我終於想出了一個解決辦法,這要歸功於this SO thread,我能挖後嘛。該線程中概述的解決方案僅假設2類數據,因此如果pred爲True,則tf.cond()函數可以對一個類進行過採樣,如果pred爲False,則可以對另一個類進行過採樣。爲了有一個n路有條件的,我試圖建立一個tf.case()功能,導致ValueError: Cannot infer Tensor's rank。原因是tf.case()函數不保留shape屬性,並且由於輸入流水線末端的任何批處理操作必須採用形狀參數,或者採用定義形狀的張量,因此按照documentation中的註釋:

注意:您必須確保(i)形狀參數已通過,或(ii)張量中的所有張量都必須具有完全定義的形狀。如果這兩個條件都不成立,ValueError將會被提升。

進一步深挖表明,這是一個known issuetf.case(),目前尚未得到解決的2016年十二月在Tensorflow許多控制流頭刮管的只有一個問題。無論如何,我的簡裝解決n路的過採樣的問題是這樣的:

# Initiate a queue of "raw" input data with embedded Queue Runner. 
queue = tf.train.string_input_producer(rawdata_filename) 

# Instantiate Reader Op to read examples from files in the filename queue. 
reader = tf.FixedLengthRecordReader(record_bytes) 

# Pull off one instance, decode and cast image and label to 3D, 1D Tensors. 
result.key, value = reader.read(queue) 
image_raw, label_raw = decode(value) 
image = tf.cast(image_raw, dtype) #3D tensor 
label = tf.cast(label_raw, dtype) #1D tensor 

# Assume your oversampling factors per class are fixed 
# and you have 4 classes. 
OVERSAMPLE_FACTOR = [1,2,4,10] 

# Now we need to reshape input image tensors to 4D, where the 
# first dimension is the image number in a batch of oversampled tensors. 
# images = tf.expand_dims(image, 0) # so, (*,height,width,channels) in 4D 

# Set up your predicates, which are 1D boolean tensors. 
# Note you will have to squash the boolean tensors to 0-dimension. 
# This seems illogical to me, but it is what it is. 
pred0 = tf.reshape(tf.equal(label, tf.convert_to_tensor([0])), []) #0D tf.bool 
pred1 = tf.reshape(tf.equal(label, tf.convert_to_tensor([1])), []) #0D tf.bool 
pred2 = tf.reshape(tf.equal(label, tf.convert_to_tensor([2])), []) #0D tf.bool 
pred3 = tf.reshape(tf.equal(label, tf.convert_to_tensor([3])), []) #0D tf.bool 

# Build your callables (functions) that vertically stack an input image and 
# label tensors X times depending on the accompanying oversample factor. 
def f0(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[0]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[0]) 
def f1(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[1]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[1]) 
def f2(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[2]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[2]) 
def f3(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[3]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[3]) 

# Here we have N conditionals, one for each class. These are exclusive 
# but due to tf.case() not behaving every conditional gets evaluated. 
[images, label] = tf.cond(pred0, f0, lambda: [images,label]) 
[images, label] = tf.cond(pred1, f1, lambda: [images,label]) 
[images, label] = tf.cond(pred2, f2, lambda: [images,label]) 
[images, label] = tf.cond(pred3, f3, lambda: [images,label]) 

# Pass the 4D batch of oversampled tensors to a batching op at the end 
# of the input data queue. The batching op must be set up to accept 
# batches of tensors (4D) as opposed to individual tensors (in our case, 3D). 
images, label_batch = tf.train.batch([images, label], 
            batch_size=batch_size, 
            num_threads=num_threads, 
            capacity=capacity, 
            enqueue_many = True) #accept batches 
+0

多煩人。我遇到了同樣的問題,期望'tf.case'能夠像開關盒那樣工作,但是另外就像'tf.cond'一樣。 –

回答

1

我的問題的解決方案是一個解決辦法,並在「更新2」在原來的問題概述。