2016-07-17 58 views
1

我正在嘗試向torch-dataframe添加一個並行數據採集器,以便添加torchnet compatibility。我用tnt.ParallelDatasetIteratorchanged it使:torch.serialize使用torch.serialize時出現內存不足的問題

  1. 基本批量加載線程
  2. 批次串行化之外,發送到線程
  3. 在線程批反序列化和轉換批量數據到張量
  4. 張量返回的表格中有inputtarget鍵以匹配tnt.Engine設置。

問題發生第二次enque被稱爲一個錯誤:.../torch_distro/install/bin/luajit: not enough memory。我目前只使用mnist與改編的mnist-example。該enque循環現在看起來像這樣(與調試內存輸出):

-- `samplePlaceholder` stands in for samples which have been 
-- filtered out by the `filter` function 
local samplePlaceholder = {} 

-- The enque does the main loop 
local idx = 1 
local function enqueue() 
    while idx <= size and threads:acceptsjob() do 
    local batch, reset = self.dataset:get_batch(batch_size) 

    if (reset) then 
     idx = size + 1 
    else 
     idx = idx + 1 
    end 

    if (batch) then 
     local serialized_batch = torch.serialize(batch) 

     -- In the parallel section only the to_tensor is run in parallel 
     -- this should though be the computationally expensive operation 
     threads:addjob(
     function(argList) 
      io.stderr:write("\n Start"); 
      io.stderr:write("\n 1: " ..tostring(collectgarbage("count"))) 
      local origIdx, serialized_batch, samplePlaceholder = unpack(argList) 

      io.stderr:write("\n 2: " ..tostring(collectgarbage("count"))) 
      local batch = torch.deserialize(serialized_batch) 
      serialized_batch = nil 

      collectgarbage() 
      collectgarbage() 

      io.stderr:write("\n 3: " .. tostring(collectgarbage("count"))) 
      batch = transform(batch) 

      io.stderr:write("\n 4: " .. tostring(collectgarbage("count"))) 
      local sample = samplePlaceholder 
      if (filter(batch)) then 
      sample = {} 
      sample.input, sample.target = batch:to_tensor() 
      end 
      io.stderr:write("\n 5: " ..tostring(collectgarbage("count"))) 

      collectgarbage() 
      collectgarbage() 
      io.stderr:write("\n 6: " ..tostring(collectgarbage("count"))) 

      io.stderr:write("\n End \n"); 
      return { 
      sample, 
      origIdx 
      } 
     end, 
     function(argList) 
      sample, sampleOrigIdx = unpack(argList) 
     end, 
     {idx, serialized_batch, samplePlaceholder} 
    ) 
    end 
    end 
end 

我撒collectgarbage並試圖刪除不需要的任何對象。存儲器輸出是相當直截了當:

Start 
1: 374840.87695312 
2: 374840.94433594 
3: 372023.79101562 
4: 372023.85839844 
5: 372075.41308594 
6: 372023.73632812 
End 

該循環的enque功能是所述非有序功能是微不足道的(在存儲器錯誤在第二enque拋出和):

iterFunction = function() 
    while threads:hasjob() do 
    enqueue() 
    threads:dojob() 
    if threads:haserror() then 
     threads:synchronize() 
    end 
    enqueue() 

    if table.exact_length(sample) > 0 then 
     return sample 
    end 
    end 
end 

回答

1

所以問題在於torch.serialize,其中設置中的函數將整個數據集與函數耦合。添加時:

serialized_batch = nil 
collectgarbage() 
collectgarbage() 

問題已解決。我還想知道什麼東西佔用了太多的空間,罪魁禍首竟然是我在一個大數據集的環境中定義了函數,這個數據集與函數相互交織,大大增加了大小。在這裏,數據本地

mnist = require 'mnist' 
local dataset = mnist[mode .. 'dataset']() 

-- PROBLEMATIC LINE BELOW -- 
local ext_resource = dataset.data:reshape(dataset.data:size(1), 
    dataset.data:size(2) * dataset.data:size(3)):double() 

-- Create a Dataframe with the label. The actual images will be loaded 
-- as an external resource 
local df = Dataframe(
    Df_Dict{ 
    label = dataset.label:totable(), 
    row_id = torch.range(1, dataset.data:size(1)):totable() 
    }) 

-- Since the mnist package already has taken care of the data 
-- splitting we create a single subsetter 
df:create_subsets{ 
    subsets = Df_Dict{core = 1}, 
    class_args = Df_Tbl({ 
    batch_args = Df_Tbl({ 
     label = Df_Array("label"), 
     data = function(row) 
     return ext_resource[row.row_id] 
     end 
    }) 
    }) 
} 

事實證明,除去我強調了線從358 MB減少了內存佔用降低到0.0008 MB的原始定義!我用於測試性能的代碼是:

local mem = {} 
table.insert(mem, collectgarbage("count")) 

local ser_data = torch.serialize(batch.dataset) 
table.insert(mem, collectgarbage("count")) 

local ser_retriever = torch.serialize(batch.batchframe_defaults.data) 
table.insert(mem, collectgarbage("count")) 

local ser_raw_retriever = torch.serialize(function(row) 
    return ext_resource[row.row_id] 
end) 
table.insert(mem, collectgarbage("count")) 

local serialized_batch = torch.serialize(batch) 
table.insert(mem, collectgarbage("count")) 

for i=2,#mem do 
    print(i-1, (mem[i] - mem[i-1])/1024) 
end 

它生產最初的輸出:

1 0.0094480514526367 
2 0.00080204010009766 
3 0.00090408325195312 
4 0.010146141052246 

我嘗試使用setfenv爲:

1 0.0082607269287109 
2 358.23344707489 
3 0.0017471313476562 
4 358.90182781219 

和修復後功能,但它沒有解決問題。將序列化數據發送到線程仍然存在性能損失,但是主要問題已得到解決,並且沒有昂貴的數據檢索器,功能就會小得多。