embedding层报错
Posted yangxiaoling
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了embedding层报错相关的知识,希望对你有一定的参考价值。
相关代码
... TRG_VOCAB_SIZE = 4000 # 目标语言词汇表大小 ... self.trg_embedding = tf.get_variable(‘trg_emb‘, [TRG_VOCAB_SIZE, HIDDEN_SIZE]) ... # softmax层的变量 if SHARE_EMB_AND_SOFTMAX: self.softmax_weight = tf.transpose(self.trg_embedding) else: self.softmax_weight = tf.get_variable(‘weight‘, [HIDDEN_SIZE, TRG_VOCAB_SIZE]) self.softmax_bias = tf.get_variable(‘softmax_bias‘, [TRG_VOCAB_SIZE]) ... trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)
报错:
1 Traceback (most recent call last): 2 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_call 3 return fn(*args) 4 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1329, in _run_fn 5 status, run_metadata) 6 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__ 7 c_api.TF_GetCode(self.status.status)) 8 tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,2] = 5545 is not in [0, 4000) 9 [[Node: embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@nmt_model/trg_emb"], validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](nmt_model/trg_emb/read, IteratorGetNext:2)]] 10 11 During handling of the above exception, another exception occurred: 12 13 Traceback (most recent call last): 14 File "/home/error/PycharmProjects/tensortry/lang.py", line 170, in <module> 15 main() 16 File "/home/error/PycharmProjects/tensortry/lang.py", line 165, in main 17 trg_emb = sess.run(trg_emb) 18 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run 19 run_metadata_ptr) 20 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1128, in _run 21 feed_dict_tensor, options, run_metadata) 22 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1344, in _do_run 23 options, run_metadata) 24 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1363, in _do_call 25 raise type(e)(node_def, op, message) 26 tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,2] = 5545 is not in [0, 4000) 27 [[Node: embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@nmt_model/trg_emb"], validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](nmt_model/trg_emb/read, IteratorGetNext:2)]] 28 29 Caused by op ‘embedding_lookup_1‘, defined at: 30 File "/home/error/PycharmProjects/tensortry/lang.py", line 170, in <module> 31 main() 32 File "/home/error/PycharmProjects/tensortry/lang.py", line 155, in main 33 trg_emb = train_model.forward(src, src_size, trg_input, trg_label, trg_size) 34 File "/home/error/PycharmProjects/tensortry/lang.py", line 109, in forward 35 trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input) # 解码器的输入trg_input 36 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/embedding_ops.py", line 325, in embedding_lookup 37 transform_fn=None) 38 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/embedding_ops.py", line 150, in _embedding_lookup_and_transform 39 result = _clip(_gather(params[0], ids, name=name), ids, max_norm) 40 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/embedding_ops.py", line 54, in _gather 41 return array_ops.gather(params, ids, name=name) 42 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 2585, in gather 43 params, indices, validate_indices=validate_indices, name=name) 44 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1864, in gather 45 validate_indices=validate_indices, name=name) 46 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper 47 op_def=op_def) 48 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3160, in create_op 49 op_def=op_def) 50 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1625, in __init__ 51 self._traceback = self._graph._extract_stack() # pylint: disable=protected-access 52 53 InvalidArgumentError (see above for traceback): indices[0,2] = 5545 is not in [0, 4000) 54 [[Node: embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@nmt_model/trg_emb"], validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](nmt_model/trg_emb/read, IteratorGetNext:2)]]
报错信息中的`4000`是目标语言的词汇量,可能是设置的值过小导致,改成比5545大的数就OK了。
以上是关于embedding层报错的主要内容,如果未能解决你的问题,请参考以下文章
Debug 路漫漫-10:AttributeError: 'Embedding' object has no attribute 'get_shape'
SpringBoot整合WebSocket时,自动注入Service层报错空指针异常的解决方案
解决spring-boot启动中碰到的问题:Cannot determine embedded database driver class for database type NONE(转)(代码片段
springboot启动时报错Cannot determine embedded database driver class for database type NONE解决办法