tensorflow添加新操作(Op)
Posted inshallah
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow添加新操作(Op)相关的知识,希望对你有一定的参考价值。
参考:https://tensorflow.juejin.im/extend/adding_an_op.html
https://zhuanlan.zhihu.com/p/34168765
为了加入一个定制操作,你需要:
- 在 C++ 文件中注册这个新操作。操作的注册为此操作的功能定义了一个接口(规范)。比如,操作的注册定义了此操作的名称和它的输入输出。它还定义了 shape 函数,用于获取张量的形状。
- 在 C++ 中实现这个操作。操作的实现称为内核,它是你在步骤 1 中注册的规范的具体实现。对于不同的输入输出类型或架构(比如不同的 CPUs 或 GPUs),可能有多个内核。
- 创建一个 Python 包装器(可选)。这个包装器是用于在 Python 中创建操作的公共 API。操作的注册可以产生一个默认的包装器,它可以直接使用,或添加。
- 为操作编写一个函数来计算梯度(可选)。
tf.test.compute_gradient_error
1. 定义接口:
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" using namespace tensorflow; REGISTER_OP("ZeroOut") .Input("to_zero: int32") .Output("zeroed: int32") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); });
关于命名的备注:操作名称必须首字母大写,而且不能和库中已经注册的其它操作重名。
2. 实现操作的内核
定义接口后,接下来就需要为此操作提供一个或多个内核实现了。
为了实现这些内核,创建一个继承自 OpKernel
的类,并重载 Compute
方法。
Compute
方法有一个类型为 OpKernelContext*
的参数 context
,从中可以访问输入和输出张量等有用的信息。
将你的内核加到上面创建的文件中。这个内核的代码形如:
#include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // 得到输入张量 const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat<int32>(); // 创建输出张量 Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat<int32>(); // 除第一个元素外,输出张量的其它所有元素都设置为 0 const int N = input.size(); for (int i = 1; i < N; i++) { output_flat(i) = 0; } // 如果可能的话,保留第一个输入值 if (N > 0) output_flat(0) = input(0); } };
给 ZeroOut
操作加上约束条件:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
这里注册的操作名是ZeroOut,通过上面的语句和ZeroOutOp对应吧,
输入和输出
下面对前面的示例做个总结,一个操作注册可以指定多个输入输出:
REGISTER_OP("MultipleInsAndOuts") .Input("y: int32") .Input("z: float") .Output("a: string") .Output("b: int32");
以上是关于tensorflow添加新操作(Op)的主要内容,如果未能解决你的问题,请参考以下文章