保留MLIR的TF方言中的输入控制结构。

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了保留MLIR的TF方言中的输入控制结构。相关的知识,希望对你有一定的参考价值。

我正在尝试使用Tensorflow(2.2.0)作为前端生成MLIR,我想澄清以下问题。让我们考虑下面的例子,它实现了两个2x2矩阵的直接矩阵乘法。

import tensorflow as tf
import tensorflow.mlir as mlir

with tf.Graph().as_default() as g:
    with tf.device('/cpu:0'):

        @tf.function
        def mymatmul(A, B, C):
            for i in range(2):
                for j in range(2):
                    cij = 0.0
                    for k in range(2):
                        cij += A[i, k]*B[i, j]
                    C[i, j].assign(cij)


        A = tf.constant([[1., 2.], [3., 4.]])
        B = tf.constant([[2., 1.], [4., 3.]])
        C = tf.Variable([[0., 0.], [0., 0.]])
        mymatmul(A, B, C)

tf_mlir_graph = mlir.experimental.convert_graph_def(g.as_graph_def())
print(tf_mlir_graph)

这段代码发出了如下的MLIR。

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 175 : i32}} {
  func @main() {
    %0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %1 = "tf.Const"() {value = dense<[[2.000000e+00, 1.000000e+00], [4.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %2 = "tf.Const"() {value = dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %3 = "tf.VarHandleOp"() {_class = ["loc:@Variable"], container = "", device = "/device:CPU:0", dtype = f32, shape = "tfshape$dim { size: 2 } dim { size: 2 }", shared_name = "Variable"} : () -> tensor<!tf.resource<tensor<2x2xf32>>>
    "tf.StatefulPartitionedCall"(%2, %1, %3) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_RESOURCE"], Tout = [], _read_only_resource_inputs = [], config = "", config_proto = "A7A3CPU101A7A3GPU10022J081", device = "/device:CPU:0", executor_type = "", f = @__inference_mymatmul_1160} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<!tf.resource<tensor<2x2xf32>>>) -> ()
    %4 = "tf.VarIsInitializedOp"(%3) {device = "/device:CPU:0"} : (tensor<!tf.resource<tensor<2x2xf32>>>) -> tensor<i1>
    %5 = "tf.ReadVariableOp"(%3) {device = "/device:CPU:0", dtype = f32} : (tensor<!tf.resource<tensor<2x2xf32>>>) -> tensor<2x2xf32>
    "tf.AssignVariableOp"(%3, %0) {device = "/device:CPU:0", dtype = f32} : (tensor<!tf.resource<tensor<2x2xf32>>>, tensor<2x2xf32>) -> ()
    return
  }
  func @__inference_mymatmul_1160(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<!tf.resource>) attributes {tf.signature.is_stateful} {
    %0 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
    %1 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
    %2 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
    %3 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
    %4 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32>
    %5 = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
    %6 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
    %7 = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
    %8 = "tf.ReadVariableOp"(%arg2) {device = "", dtype = f32} : (tensor<!tf.resource>) -> tensor<*xf32>
    %9 = "tf.StridedSlice"(%arg0, %3, %0, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %10 = "tf.StridedSlice"(%arg1, %3, %0, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %11 = "tf.Mul"(%9, %10) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %12 = "tf.AddV2"(%11, %7) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %13 = "tf.StridedSlice"(%arg0, %6, %5, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %14 = "tf.StridedSlice"(%arg1, %6, %5, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %15 = "tf.Mul"(%13, %14) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %16 = "tf.AddV2"(%15, %7) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %17 = "tf.StridedSlice"(%arg0, %0, %4, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %18 = "tf.Mul"(%17, %14) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %19 = "tf.AddV2"(%16, %18) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %20 = "tf.StridedSlice"(%arg1, %0, %4, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %21 = "tf.Mul"(%13, %20) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %22 = "tf.AddV2"(%21, %7) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %23 = "tf.Mul"(%17, %20) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %24 = "tf.AddV2"(%22, %23) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %25 = "tf.StridedSlice"(%arg0, %2, %1, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %26 = "tf.Mul"(%25, %10) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %27 = "tf.AddV2"(%12, %26) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "tf.ResourceStridedSliceAssign"(%arg2, %3, %0, %0, %27) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<!tf.resource>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<f32>) -> ()
    %28 = "tf.ReadVariableOp"(%arg2) {device = "", dtype = f32} : (tensor<!tf.resource>) -> tensor<*xf32>
    %29 = "tf.StridedSlice"(%arg1, %2, %1, %0) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<2x2xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
    %30 = "tf.Mul"(%9, %29) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %31 = "tf.AddV2"(%30, %7) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %32 = "tf.Mul"(%25, %29) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %33 = "tf.AddV2"(%31, %32) {T = f32, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "tf.ResourceStridedSliceAssign"(%arg2, %2, %1, %0, %33) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<!tf.resource>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<f32>) -> ()
    %34 = "tf.ReadVariableOp"(%arg2) {device = "", dtype = f32} : (tensor<!tf.resource>) -> tensor<*xf32>
    "tf.ResourceStridedSliceAssign"(%arg2, %6, %5, %0, %19) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<!tf.resource>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<f32>) -> ()
    %35 = "tf.ReadVariableOp"(%arg2) {device = "", dtype = f32} : (tensor<!tf.resource>) -> tensor<*xf32>
    "tf.ResourceStridedSliceAssign"(%arg2, %0, %4, %0, %24) {Index = i32, T = f32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 3 : i64} : (tensor<!tf.resource>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<f32>) -> ()
    return
  }
}

有趣的是,至少对我的目的来说,是计算的循环结构的损失。 在 tf 方言,循环结构是扁平化的,但我希望输出的MLIR能够反映出TF运算符图中表达的原始循环结构。我想,这个问题的另一种表达方式是问TensorFlow方言是否支持控制结构(我相信它通过 "控制结构 "来实现)。tf.IfOptf.WhileOp),如果有任何特殊的语法限制,输入应该遵守,以保留循环结构。

P.S.我怀疑这可能与以下问题有关 急于求成 这是tf =>2.0的默认行为。也许有人能验证这一点?

谢谢。

答案

将输入计算修改为下面的内容,果然成功了。我相信问题出在(至少部分地)使用python变量和 tf 变量。下面有效地保留了计算的符号结构。

with tf.Graph().as_default() as g:
    with tf.device('/cpu:0'):
        @tf.function
        def mymatmul(A, B, C, m, n):
            for i in range(m):
                for j in range(m):
                    for k in range(n):
                        C[i,j].assign(tf.math.add(C[i, j], tf.math.multiply(A[i, k], B[k, j])))

            return C

        A = tf.constant([[1., 2.], [3., 4.]])
        B = tf.constant([[2., 1.], [4., 3.]])
        C = tf.Variable((tf.zeros((2, 2), dtype=tf.float32)))
        m = tf.constant(2)
        n = tf.constant(2)
        mymatmul(A, B, C, m, n)

这将产生以下MLIR tf.While.

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 412 : i32}} {
  func @main() {
    %0 = "tf.Const"() {value = dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %1 = "tf.Const"() {value = dense<[[2.000000e+00, 1.000000e+00], [4.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
    %3 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %4 = "tf.VarHandleOp"() {_class = ["loc:@Variable"], allowed_devices = [], container = "", device = "/device:CPU:0", shared_name = "Variable"} : () -> tensor<!tf.resource<tensor<2x2xf32>>>
    %5 = "tf.StatefulPartitionedCall"(%0, %1, %4, %2, %2) {_collective_manager_ids = [], _read_only_resource_inputs = [], config = "", config_proto = "A7A3CPU101A7A3GPU10022J081", device = "/device:CPU:0", executor_type = "", f = @__inference_mymatmul_3650} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<!tf.resource<tensor<2x2xf32>>>, tensor<i32>, tensor<i32>) -> tensor<2x2xf32>
    %6 = "tf.VarIsInitializedOp"(%4) {device = "/device:CPU:0"} : (tensor<!tf.resource<tensor<2x2xf32>>>) -> tensor<i1>
    %7 = "tf.ReadVariableOp"(%4) {device = "/device:CPU:0"} : (tensor<!tf.resource<tensor<2x2xf32>>>) -> tensor<2x2xf32>
    "tf.AssignVariableOp"(%4, %3) {device = "/device:CPU:0"} : (tensor<!tf.resource<tensor<2x2xf32>>>, tensor<2x2xf32>) -> ()
    return
  }
  func @__inference_mymatmul_3650(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<!tf.resource<tensor<2x2xf32>>>, %arg3: tensor<i32>, %arg4: tensor<i32>) -> tensor<2x2xf32> attributes {tf.signature.is_stateful} {
    %0 = "tf.Const"() {value = dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %1 = "tf.Const"() {value = dense<[[2.000000e+00, 1.000000e+00], [4.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
    %2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
    %3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
    %4 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
    %5:10 = "tf.While"(%3, %4, %3, %2, %4, %4, %4, %arg2, %0, %1) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_1410, cond = @while_cond_1400, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<2x2>, #tf.shape<2x2>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource<tensor<2x2xf32>>>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource<tensor<2x2xf32>>>, tensor<2x2xf32>, tensor<2x2xf32>)
    %6 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<2x2xf32>>>) -> tensor<2x2xf32>
    %7 = "tf.Identity"(%6) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
    return %7 : tensor<2x2xf32>
  }
  func @while_body_1410(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>, %arg7: tensor<!tf.resource<tensor<2x2xf32>>>, %arg8: tensor<2x2xf32>, %arg9: tensor<2x2xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource<tensor<2x2xf32>>>, tensor<2x2xf32>, tensor<2x2xf32>) attributes {tf.signature.is_stateful} {
    %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
    %1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
    %2 = "tf.Maximum"(%arg5, %1) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %3 = "tf.FloorDiv"(%2, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %4 = "tf.FloorMod"(%2, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %5 = "tf.AddV2"(%arg2, %arg3) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %6 = "tf.AddV2"(%arg0, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %7 = "tf.NotEqual"(%4, %1) {device = "", incompatible_shape_error = true} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<i1>) -> tensor<i32>
    %9 = "tf.AddV2"(%3, %8) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %10 = "tf.Maximum"(%9, %1) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %11:10 = "tf.While"(%1, %10, %1, %0, %2, %arg6, %arg7, %arg2, %arg8, %arg9) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_1830, cond = @while_cond_1820, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<2x2>, #tf.shape<2x2>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource<tensor<2x2xf32>>>, tensor<i32>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource<tensor<2x2xf32>>>, tensor<i32>, tensor<2x2xf32>, tensor<2x2xf32>)
    %12 = "tf.Identity"(%6) {device = ""} : (tensor<i32>) -> tensor<i32>
    %13 = "tf.Identity"(%arg1) {device = ""} : (tensor<i32>) -> tensor<i32>
    %14 = "tf.Identity"(%5) {device = ""} : (tensor<i32>) -> tensor<i32>
    return %12, %13, %14, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.res

以上是关于保留MLIR的TF方言中的输入控制结构。的主要内容,如果未能解决你的问题,请参考以下文章

MLIR:摩尔定律终结的编译器基础结构 论文解读

如何将百万美元html中的多个参数传递给百里香方言处理器

MLIR:摩尔定律终结的编译器基础结构 论文解读

Buddy-MLIR 项目详解(入门 MLIR 极佳选择)

Buddy-MLIR 项目详解(入门 MLIR 极佳选择)

“不要保留活动” - 当应用程序恢复时,片段仅可见一秒钟