embedding层报错

Posted yangxiaoling

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了embedding层报错相关的知识,希望对你有一定的参考价值。

相关代码

...
TRG_VOCAB_SIZE = 4000  # 目标语言词汇表大小
...
self.trg_embedding = tf.get_variable(trg_emb, [TRG_VOCAB_SIZE, HIDDEN_SIZE])
...
# softmax层的变量
if SHARE_EMB_AND_SOFTMAX:
    self.softmax_weight = tf.transpose(self.trg_embedding)
else:
    self.softmax_weight = tf.get_variable(weight, [HIDDEN_SIZE, TRG_VOCAB_SIZE])
self.softmax_bias = tf.get_variable(softmax_bias, [TRG_VOCAB_SIZE])
...
trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)

报错:

 1 Traceback (most recent call last):
 2   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_call
 3     return fn(*args)
 4   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1329, in _run_fn
 5     status, run_metadata)
 6   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
 7     c_api.TF_GetCode(self.status.status))
 8 tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,2] = 5545 is not in [0, 4000)
 9      [[Node: embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@nmt_model/trg_emb"], validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](nmt_model/trg_emb/read, IteratorGetNext:2)]]
10 
11 During handling of the above exception, another exception occurred:
12 
13 Traceback (most recent call last):
14   File "/home/error/PycharmProjects/tensortry/lang.py", line 170, in <module>
15     main()
16   File "/home/error/PycharmProjects/tensortry/lang.py", line 165, in main
17     trg_emb = sess.run(trg_emb)
18   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run
19     run_metadata_ptr)
20   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1128, in _run
21     feed_dict_tensor, options, run_metadata)
22   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1344, in _do_run
23     options, run_metadata)
24   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1363, in _do_call
25     raise type(e)(node_def, op, message)
26 tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,2] = 5545 is not in [0, 4000)
27      [[Node: embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@nmt_model/trg_emb"], validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](nmt_model/trg_emb/read, IteratorGetNext:2)]]
28 
29 Caused by op embedding_lookup_1, defined at:
30   File "/home/error/PycharmProjects/tensortry/lang.py", line 170, in <module>
31     main()
32   File "/home/error/PycharmProjects/tensortry/lang.py", line 155, in main
33     trg_emb = train_model.forward(src, src_size, trg_input, trg_label, trg_size)
34   File "/home/error/PycharmProjects/tensortry/lang.py", line 109, in forward
35     trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)  # 解码器的输入trg_input
36   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/embedding_ops.py", line 325, in embedding_lookup
37     transform_fn=None)
38   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/embedding_ops.py", line 150, in _embedding_lookup_and_transform
39     result = _clip(_gather(params[0], ids, name=name), ids, max_norm)
40   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/embedding_ops.py", line 54, in _gather
41     return array_ops.gather(params, ids, name=name)
42   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 2585, in gather
43     params, indices, validate_indices=validate_indices, name=name)
44   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1864, in gather
45     validate_indices=validate_indices, name=name)
46   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
47     op_def=op_def)
48   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3160, in create_op
49     op_def=op_def)
50   File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1625, in __init__
51     self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
52 
53 InvalidArgumentError (see above for traceback): indices[0,2] = 5545 is not in [0, 4000)
54      [[Node: embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@nmt_model/trg_emb"], validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](nmt_model/trg_emb/read, IteratorGetNext:2)]]

报错信息中的`4000`是目标语言的词汇量,可能是设置的值过小导致,改成比5545大的数就OK了。

以上是关于embedding层报错的主要内容,如果未能解决你的问题,请参考以下文章

Debug 路漫漫-10:AttributeError: 'Embedding' object has no attribute 'get_shape'

SpringBoot整合WebSocket时,自动注入Service层报错空指针异常的解决方案

解决spring-boot启动中碰到的问题:Cannot determine embedded database driver class for database type NONE(转)(代码片段

SAP报表修改-WBS销售订单汇总层报表

关于递归最大层数

springboot启动时报错Cannot determine embedded database driver class for database type NONE解决办法