tf.device()出现异常
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.device()出现异常相关的知识,希望对你有一定的参考价值。
当我使用tf.device()分配GPU编号时,它似乎是一个例外。这是我第一次在Stack Overflow中提问,如果有错误,请原谅我,并告诉我。
当我在代码中放入allow_soft_placement = True时,它可以工作。
答案
def init_graph(self):
"""
init bert graph
"""
with self.tf_instance.device('device:GPU:{}'.format(str(self.gpu_no))):
# add tokenizer
from bert import tokenization
self.tokenizer = tokenization.FullTokenizer(self.args.vocab_file)
from bert import modeling
bert_config = modeling.BertConfig.from_json_file(self.args.config_file)
self.model = modeling.BertModel(config=bert_config,
is_training=False,
input_ids=self.input_ids,
input_mask=self.input_mask,
token_type_ids=self.input_type_ids,
use_one_hot_embeddings=False)
# get output weights and output bias
reader = self.tf_instance.train.NewCheckpointReader(self.args.ckpt_file)
output_weights = reader.get_tensor('output_weights')
output_bias = reader.get_tensor('output_bias')
# get result op
output_layer = self.model.get_pooled_output()
logits = self.tf_instance.matmul(output_layer, output_weights, transpose_b=True)
logits = self.tf_instance.nn.bias_add(logits, output_bias)
self.probabilities = self.tf_instance.nn.softmax(logits, axis=-1)
sess_config = self.tf_instance.ConfigProto()
sess_config.gpu_options.allow_growth = True
graph = self.probabilities.graph
saver = self.tf_instance.train.Saver()
self.sess = self.tf_instance.Session(config=sess_config, graph=graph)
self.sess.run(self.tf_instance.global_variables_initializer())
self.tf_instance.reset_default_graph()
saver.restore(self.sess, self.args.ckpt_file)
以上是关于tf.device()出现异常的主要内容,如果未能解决你的问题,请参考以下文章
活动到片段通信:当我尝试从活动更新片段中的文本视图时,出现空指针异常