2017-10-20 108 views
0

我有以下代碼:Tensorflow Dataset.from_tensor_slices時間太長

data = np.load("data.npy") 
print(data) # Makes sure the array gets loaded in memory 
dataset = tf.contrib.data.Dataset.from_tensor_slices((data)) 

文件"data.npy"爲3.3 GB。用numpy讀取文件需要幾秒鐘,但是接下來創建tensorflow數據集對象的那一行需要很長時間才能執行。這是爲什麼?它在底下做了什麼?

回答

2

引用此answer

一個npznp.load只返回一個文件加載器,而不是實際的數據。這是一個'懶惰的加載程序',只有在訪問時加載特定的數組。

這就是爲什麼它很快。

編輯1:擴大多一點這樣的回答,從tensorflow's documentation另一句名言:

如果所有輸入數據存放在內存中,最簡單的方法來創建他們Dataset是轉換他們到tf.Tensor對象並使用Dataset.from_tensor_slices()

這適用於小數據集,但浪費內存---因爲數組內容將被複制多次---並可能運行到tf.GraphDef協議緩衝區的2GB限制。

該鏈接還顯示如何有效地做到這一點。

+0

如果我嘗試打印'data',以便確保它實際上被加載,它仍然需要幾秒鐘,而'Dataset'需要幾分鐘。 – niko

+0

它不一定打印所有數據。打印與否並不能確保它「實際上被加載」。我不是tensorflow方面的專家,只是看着'from_tensor_slices'循環遍歷整個數據集的代碼(並且速度相當慢),這肯定會*加載所有數據。海事組織這可能可以加快,但公平我沒有嘗試過。在某些情況下,如果您的計算機在內存中佔用3.3GB的空間,您可能只需要投入更多硬件。 – Iguananaut

+0

我更新了我的答案,給你更多的細節。 –