tf.device()出现异常

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.device()出现异常相关的知识,希望对你有一定的参考价值。

当我使用tf.device()分配GPU编号时,它似乎是一个例外。这是我第一次在Stack Overflow中提问,如果有错误,请原谅我,并告诉我。

当我在代码中放入allow_soft_placement = True时,它可以工作。

答案
def init_graph(self):
    """
    init bert graph
    """
    with self.tf_instance.device('device:GPU:{}'.format(str(self.gpu_no))):
        # add tokenizer
        from bert import tokenization
        self.tokenizer = tokenization.FullTokenizer(self.args.vocab_file)
        from bert import modeling
        bert_config = modeling.BertConfig.from_json_file(self.args.config_file)
        self.model = modeling.BertModel(config=bert_config,
                                        is_training=False,
                                        input_ids=self.input_ids,
                                        input_mask=self.input_mask,
                                        token_type_ids=self.input_type_ids,
                                        use_one_hot_embeddings=False)

        # get output weights and output bias
        reader = self.tf_instance.train.NewCheckpointReader(self.args.ckpt_file)
        output_weights = reader.get_tensor('output_weights')
        output_bias = reader.get_tensor('output_bias')

        # get result op
        output_layer = self.model.get_pooled_output()
        logits = self.tf_instance.matmul(output_layer, output_weights, transpose_b=True)
        logits = self.tf_instance.nn.bias_add(logits, output_bias)
        self.probabilities = self.tf_instance.nn.softmax(logits, axis=-1)

        sess_config = self.tf_instance.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        graph = self.probabilities.graph
        saver = self.tf_instance.train.Saver()
        self.sess = self.tf_instance.Session(config=sess_config, graph=graph)
        self.sess.run(self.tf_instance.global_variables_initializer())
        self.tf_instance.reset_default_graph()
        saver.restore(self.sess, self.args.ckpt_file)

以上是关于tf.device()出现异常的主要内容,如果未能解决你的问题,请参考以下文章

出现异常:片段已激活

活动到片段通信:当我尝试从活动更新片段中的文本视图时,出现空指针异常

异常和TCP通讯

以编程方式创建 MapView 并添加标记导致片段中出现空指针异常

片段中的 Listview 适配器给出空指针异常

TensorFlow如何通过tf.device函数来指定运行每一个操作的设备?