在 Python 中编写和注册自定义 TensorFlow 操作
Posted
技术标签:
【中文标题】在 Python 中编写和注册自定义 TensorFlow 操作【英文标题】:Writing and Registering a Custom Tensorflow Op in Python 【发布时间】:2017-05-01 18:24:04 【问题描述】:我想在 Python 中编写一个自定义的 Tensorflow 操作并将其注册到 Protobuf 注册表中,以进行类似here 解释的操作。 Protobuf 注册是关键,因为我不会直接从 Python 使用此操作,但如果它像 C++ 操作一样注册并加载到 Python 运行时环境中,那么我可以在我的环境中运行它。
我希望代码看起来像,
import tensorflow as tf
from google.protobuf import json_format
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list
""" Missing the Python equivalent of,
class HDF5QueueOp : public ResourceOpKernel<QueueInterface>
public:
// Implementation
;
REGISTER_OP("HDF5Queue")
.Output("handle: resource")
.Attr("filename: string")
.Attr("datasets: list(string)")
.Attr("overwrite: bool = false")
.Attr("component_types: list(type) >= 0 = []")
.Attr("shapes: list(shape) >= 0 = []")
.Attr("shared_name: string = ''")
.Attr("container: string = ''")
.Attr("capacity: int = -1")
.SetIsStateful()
.SetShapeFn(TwoElementOutput);
"""
class HDF5Queue(QueueBase):
def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100,
shapes=None, names=None, name="hdf5_queue"):
if not dtypes:
dtypes = [tf.int64, tf.float32]
if not shapes:
shapes = [[1], [1]]
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id,
stream_columns=stream_columns, capacity=capacity,
component_types=dtypes, shapes=shapes,
name=name, container=None, shared_name=None)
super(HDF5Queue, self).__init__(dtypes, shapes,
names, queue_ref)
以上是 TF 的标准。例如,可以看到 FIFOQueue。 Python Wrapper、Protobuf Registration、C++ Implementation。在编译过程中生成了一个我不喜欢的 Python 包装器,但是您可以通过运行 grep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null
看到它在哪里使用
下面将以 JSON 格式转储 TF Graph 的 Protobuf 消息。我希望这会与 HDF5Queue 操作的块一起转储,就像我编写 C++ 操作一样。
with tf.Session() as sess:
queue = HDF5Queue(stream_id=0xa)
write = queue.enqueue([[1], [1.2]])
read = queue.dequeue()
print json_format.MessageToJson(tf.train.export_meta_graph())
【问题讨论】:
【参考方案1】:这可以使用py_func
来完成。这是一个例子。
import tensorflow as tf
from google.protobuf import json_format
import sys, json, base64, numpy
from tensorflow.python.ops.script_ops import _py_funcs as py_func_registry
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
graph = tf.Graph()
graph2 = tf.Graph()
def f(x):
return x
def g(x):
return 2*x
with graph.as_default():
x = tf.placeholder(tf.float32, shape=(3,), name='x')
y = tf.py_func(f, [x], tf.float32, name='y')
# py_func_registry._funcs.clear() # Optional line to clear the Python function registry
msg = json.loads(json_format.MessageToJson(tf.train.export_meta_graph()))
# Change the function being used by py_func
msg['graphDef']['node'][1]['attr']['token']['s'] = base64.b64encode(py_func_registry.insert(g))
with graph2.as_default():
# Load graph
meta_graph_def = MetaGraphDef()
json_format.Parse(json.dumps(msg), meta_graph_def)
tf.train.import_meta_graph(meta_graph_def)
sess = tf.Session(graph=graph2)
print sess.run('y:0', feed_dict='x:0':numpy.array([1, 2, 3]))
print g(numpy.array([1, 2, 3]))
【讨论】:
以上是关于在 Python 中编写和注册自定义 TensorFlow 操作的主要内容,如果未能解决你的问题,请参考以下文章
mindspore的tensor与numpy数据类型转换问题?
AttributeError:“Tensor”对象在自定义损失函数中没有属性“numpy”(Tensorflow 2.1.0)