使用torch.serialize两次时线程内存不足

Posted

技术标签:

【中文标题】使用torch.serialize两次时线程内存不足【英文标题】:Torch out of memory in thread when using torch.serialize twice 【发布时间】:2016-11-20 03:42:54 【问题描述】:

我正在尝试将并行数据加载器添加到 torch-dataframe 以添加 torchnet compatibility。我使用了tnt.ParallelDatasetIterator 和changed it 以便:

    在线程外加载一个基本批处理 批处理被序列化并发送到线程 在线程中,批处理被反序列化并将批处理数据转换为张量 张量返回到具有inputtarget 键的表中,以匹配tnt.Engine 设置。

问题在第二次调用 enque 时出现错误:.../torch_distro/install/bin/luajit: not enough memory。我目前只使用mnist 和改编的mnist-example。 enque 循环现在看起来像这样(带有调试内存输出):

-- `samplePlaceholder` stands in for samples which have been
-- filtered out by the `filter` function
local samplePlaceholder = 

-- The enque does the main loop
local idx = 1
local function enqueue()
  while idx <= size and threads:acceptsjob() do
    local batch, reset = self.dataset:get_batch(batch_size)

    if (reset) then
      idx = size + 1
    else
      idx = idx + 1
    end

    if (batch) then
      local serialized_batch = torch.serialize(batch)

      -- In the parallel section only the to_tensor is run in parallel
      --  this should though be the computationally expensive operation
      threads:addjob(
        function(argList)
          io.stderr:write("\n Start");
          io.stderr:write("\n 1: " ..tostring(collectgarbage("count")))
          local origIdx, serialized_batch, samplePlaceholder = unpack(argList)

          io.stderr:write("\n 2: " ..tostring(collectgarbage("count")))
          local batch = torch.deserialize(serialized_batch)
          serialized_batch = nil

          collectgarbage()
          collectgarbage()

          io.stderr:write("\n 3: " .. tostring(collectgarbage("count")))
          batch = transform(batch)

          io.stderr:write("\n 4: " .. tostring(collectgarbage("count")))
          local sample = samplePlaceholder
          if (filter(batch)) then
            sample = 
            sample.input, sample.target = batch:to_tensor()
          end
          io.stderr:write("\n 5: " ..tostring(collectgarbage("count")))

          collectgarbage()
          collectgarbage()
          io.stderr:write("\n 6: " ..tostring(collectgarbage("count")))

          io.stderr:write("\n End \n");
          return 
            sample,
            origIdx
          
        end,
        function(argList)
          sample, sampleOrigIdx = unpack(argList)
        end,
        idx, serialized_batch, samplePlaceholder
      )
    end
  end
end

我已经洒了collectgarbage 并尝试移除任何不需要的对象。内存输出相当直接:

 Start
 1: 374840.87695312
 2: 374840.94433594
 3: 372023.79101562
 4: 372023.85839844
 5: 372075.41308594
 6: 372023.73632812
 End 

循环enque的函数是微不足道的无序函数(内存错误抛出第二个enque和):

iterFunction = function()
  while threads:hasjob() do
    enqueue()
    threads:dojob()
    if threads:haserror() then
      threads:synchronize()
    end
    enqueue()

    if table.exact_length(sample) > 0 then
      return sample
    end
  end
end

【问题讨论】:

【参考方案1】:

所以问题在于torch.serialize,其中设置中的函数将整个数据集耦合到函数。添加时:

serialized_batch = nil
collectgarbage()
collectgarbage()

问题已解决。我进一步想知道是什么占用了这么多空间,而罪魁祸首是我在一个环境中定义了这个函数,其中包含一个与函数交织在一起的大型数据集,大大增加了大小。这里数据本地的原始定义

mnist = require 'mnist'
local dataset = mnist[mode .. 'dataset']()

-- PROBLEMATIC LINE BELOW --
local ext_resource = dataset.data:reshape(dataset.data:size(1),
  dataset.data:size(2) * dataset.data:size(3)):double()

-- Create a Dataframe with the label. The actual images will be loaded
--  as an external resource
local df = Dataframe(
  Df_Dict
    label = dataset.label:totable(),
    row_id = torch.range(1, dataset.data:size(1)):totable()
  )

-- Since the mnist package already has taken care of the data
--  splitting we create a single subsetter
df:create_subsets
  subsets = Df_Dictcore = 1,
  class_args = Df_Tbl(
    batch_args = Df_Tbl(
      label = Df_Array("label"),
      data = function(row)
        return ext_resource[row.row_id]
      end
    )
  )

事实证明,删除我突出显示的行会将内存使用量从 358 Mb 降低到 0.0008 Mb!我用来测试性能的代码是:

local mem = 
table.insert(mem, collectgarbage("count"))

local ser_data = torch.serialize(batch.dataset)
table.insert(mem, collectgarbage("count"))

local ser_retriever = torch.serialize(batch.batchframe_defaults.data)
table.insert(mem, collectgarbage("count"))

local ser_raw_retriever = torch.serialize(function(row)
  return ext_resource[row.row_id]
end)
table.insert(mem, collectgarbage("count"))

local serialized_batch = torch.serialize(batch)
table.insert(mem, collectgarbage("count"))

for i=2,#mem do
  print(i-1, (mem[i] - mem[i-1])/1024)
end

最初产生的输出:

1   0.0082607269287109  
2   358.23344707489 
3   0.0017471313476562  
4   358.90182781219 

修复后:

1   0.0094480514526367  
2   0.00080204010009766 
3   0.00090408325195312 
4   0.010146141052246

我尝试将setfenv 用于该功能,但没有解决问题。将序列化数据发送到线程仍然存在性能损失,但主要问题已解决,并且如果没有昂贵的数据检索器,函数会小得多。

【讨论】:

以上是关于使用torch.serialize两次时线程内存不足的主要内容,如果未能解决你的问题,请参考以下文章

复制相同值两次时,可写计算变量不读取值

Mongoose Pagination工作不正常。当我刷新页面两次时,它可以工作

尝试计算相同的值两次时,HashMap 的循环不正确且超出范围

当标签栏点击两次时禁用自动弹出到根视图控制器

当我尝试在 wpf 中打开它两次时,控制台崩溃了

recv 函数在使用两次时失败