tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB

Posted gongxijun

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB相关的知识,希望对你有一定的参考价值。

      这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个

variable都在2G以下; 

  

 1 # !/usr/bin/env/python
 2 # coding=utf-8
 3 import tensorflow as tf
 4 import numpy as np
 5 
 6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None])
 7 
 8 num_shards = 3
 9 weights = []
10 weights_shape = np.arange(27).reshape(9, 3)
11 # assert weights_shape[0] % num_shards == 0
12 num_shards_len = (weights_shape.shape[0]) / num_shards
13 assert  (weights_shape.shape[0]) % num_shards ==0
14 begin_ = 0
15 ends_ = num_shards_len
16 for i in range(0, num_shards):
17     if (i + 1) * num_shards_len < weights_shape.shape[0]:
18         begin_ = i * num_shards_len
19         if i + 1 == num_shards:
20             ends_ = weights_shape.shape[0]
21         else:
22             ends_ = (i + 1) * num_shards_len
23     else:
24         begin_ = i * num_shards_len
25         ends_ = weights_shape.shape[0]
26     weights_i = tf.get_variable("words-%02d" % i,
27                                 initializer=tf.constant(weights_shape[begin_: ends_, ]))
28     weights.append(weights_i)
29 
30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")
31 
32 sess = tf.InteractiveSession()
33 sess.run(tf.global_variables_initializer())
34 print(sess.run(weights))
35 
36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))

 结果为:

    

[array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]]), array([[ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17]]), array([[18, 19, 20],
       [21, 22, 23],
       [24, 25, 26]])]
[[[ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [ 0  1  2]]

 [[24 25 26]
  [ 6  7  8]]

 [[15 16 17]
  [ 3  4  5]]]

 

以上是关于tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB的主要内容,如果未能解决你的问题,请参考以下文章

如何为 .Net Embedding (Github Embeddinator-4000) 构建 objcgen 工具?

在 Tensorflow 中使用来自大型 numpy 数组的数据集

Tensorflow:加载大数据的现代方式

如何在 tensorflow 2.0 w/keras 中保存/恢复大型模型?

带你轻松使用 TensorFlow 创建大型线性模型

TensorFlow - tf.data.Dataset 读取大型 HDF5 文件