diff --git a/mace/kernels/expand_dims.h b/mace/kernels/expand_dims.h new file mode 100644 index 0000000000000000000000000000000000000000..94386f3532406905e7e9425e4b15f4d0093259ef --- /dev/null +++ b/mace/kernels/expand_dims.h @@ -0,0 +1,72 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_EXPAND_DIMS_H_ +#define MACE_KERNELS_EXPAND_DIMS_H_ + +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" + +#ifdef MACE_ENABLE_OPENCL +#include "mace/core/runtime/opencl/cl2_header.h" +#endif // MACE_ENABLE_OPENCL + +namespace mace { +namespace kernels { + +struct ExpandDimsBase { + explicit ExpandDimsBase(int axis) : axis_(axis) {} + + int axis_; +}; + +template +struct ExpandDimsFunctor; + +template +struct ExpandDimsFunctor : ExpandDimsBase { + explicit ExpandDimsFunctor(int axis) : ExpandDimsBase(axis) {} + + MaceStatus operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + index_t input_dims_size = input->dim_size(); + if ( axis_ < 0 ) { + axis_ += input_dims_size + 1; + } + MACE_CHECK(axis_ >= 0 && axis_ <= input_dims_size, + "axis is out of bound: ", axis_); + const std::vector input_shape = input->shape(); + std::vector output_shape; + output_shape.insert(output_shape.end(), input_shape.begin(), + input_shape.begin() + axis_); + output_shape.insert(output_shape.end(), 1); + output_shape.insert(output_shape.end(), input_shape.begin() + axis_, + input_shape.end()); + + output->ReuseTensorBuffer(*input); + output->Reshape(output_shape); + + return MACE_SUCCESS; + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_EXPAND_DIMS_H_ diff --git a/mace/kernels/reverse.h b/mace/kernels/reverse.h new file mode 100644 index 0000000000000000000000000000000000000000..60cc44eba7a1b1a390b82e3100021e0868e27367 --- /dev/null +++ b/mace/kernels/reverse.h @@ -0,0 +1,81 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_REVERSE_H_ +#define MACE_KERNELS_REVERSE_H_ + +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" + +#ifdef MACE_ENABLE_OPENCL +#include "mace/core/runtime/opencl/cl2_header.h" +#endif // MACE_ENABLE_OPENCL + +namespace mace { +namespace kernels { + +template +struct ReverseFunctor; + +template +struct ReverseFunctor { + MaceStatus operator()(const Tensor *input, + const Tensor *axis, + Tensor *output, + StatsFuture *future) { + MACE_CHECK(axis->dim_size() == 1, "Only support reverse in one axis now"); + + const int32_t *axis_data = axis->data(); + const index_t reverse_dim = *axis_data >= 0 ? + *axis_data : *axis_data + input->dim_size(); + MACE_CHECK(reverse_dim >= 0 && reverse_dim < input->dim_size(), + "axis must be in the range [-rank(input), rank(input))"); + + const std::vector input_shape = input->shape(); + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + + index_t high_dim_elem_size = + std::accumulate(input_shape.begin(), input_shape.begin() + reverse_dim, + 1, std::multiplies()); + index_t low_dim_elem_size = + std::accumulate(input_shape.begin() + reverse_dim + 1, + input_shape.end(), 1, std::multiplies()); + + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + + index_t reverse_size = input_shape[reverse_dim] * low_dim_elem_size; + for (index_t h = 0; h < high_dim_elem_size; ++h) { + int input_idx = h * reverse_size; + int output_idx = input_idx + reverse_size; + for (index_t i = 0; i < input_shape[reverse_dim]; ++i) { + output_idx -= low_dim_elem_size; + memcpy(output_data + output_idx, input_data + input_idx, + sizeof(T) * low_dim_elem_size); + input_idx += low_dim_elem_size; + } + } + + SetFutureDefaultWaitFn(future); + return MACE_SUCCESS; + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_REVERSE_H_ diff --git a/mace/ops/expand_dims.cc b/mace/ops/expand_dims.cc new file mode 100644 index 0000000000000000000000000000000000000000..07d940c6320f77fe2e6efbf16eba9d8a414a99c3 --- /dev/null +++ b/mace/ops/expand_dims.cc @@ -0,0 +1,34 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/expand_dims.h" + +namespace mace { +namespace ops { + +void Register_ExpandDims(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ExpandDimsOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ExpandDimsOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/expand_dims.h b/mace/ops/expand_dims.h new file mode 100644 index 0000000000000000000000000000000000000000..b466c7c41072023025e29ba247c0582da7ac37ab --- /dev/null +++ b/mace/ops/expand_dims.h @@ -0,0 +1,50 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_EXPAND_DIMS_H_ +#define MACE_OPS_EXPAND_DIMS_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/expand_dims.h" + +namespace mace { +namespace ops { + +template +class ExpandDimsOp : public Operator { + public: + ExpandDimsOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(OperatorBase::GetOptionalArg("axis", 0)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + + return functor_(input, output, future); + } + + private: + kernels::ExpandDimsFunctor functor_; + + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_EXPAND_DIMS_H_ diff --git a/mace/ops/expand_dims_test.cc b/mace/ops/expand_dims_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5650c9cd141514d4fe47167e72f48cc79ad1646 --- /dev/null +++ b/mace/ops/expand_dims_test.cc @@ -0,0 +1,65 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ExpandDimsTest : public OpsTestBase {}; + +namespace { +template +void TestExpandDims(const std::vector &input_shape, + const int &axis, + const std::vector &output_shape) { + // Construct graph + OpsTestNet net; + OpDefBuilder("ExpandDims", "ExpandDimsTest") + .Input("Input") + .AddIntArg("axis", static_cast(axis)) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", input_shape); + + // Run + net.RunOp(); + + auto input = net.GetTensor("Input"); + auto output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(output_shape)); + + const T *input_ptr = input->data(); + const T *output_ptr = output->data(); + const int size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(input_ptr[i], output_ptr[i]); + } +} +} // namespace + +TEST_F(ExpandDimsTest, SimpleCPU) { + TestExpandDims({3, 2, 1}, 1, {3, 1, 2, 1}); + TestExpandDims({1, 2, 3}, -1, {1, 2, 3, 1}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index 8db515578e2efd437a1d68af321cb28ca05ab516..dddc032614e37a43cc4968e969ee43bafa1b30f8 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); extern void Register_Dequantize(OperatorRegistryBase *op_registry); extern void Register_Eltwise(OperatorRegistryBase *op_registry); +extern void Register_ExpandDims(OperatorRegistryBase *op_registry); extern void Register_Fill(OperatorRegistryBase *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry); @@ -50,6 +51,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); +extern void Register_Reverse(OperatorRegistryBase *op_registry); extern void Register_ScalarMath(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry); @@ -90,6 +92,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_DepthwiseConv2d(this); ops::Register_Dequantize(this); ops::Register_Eltwise(this); + ops::Register_ExpandDims(this); ops::Register_Fill(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); @@ -106,6 +109,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Reshape(this); ops::Register_ResizeBicubic(this); ops::Register_ResizeBilinear(this); + ops::Register_Reverse(this); ops::Register_ScalarMath(this); ops::Register_Shape(this); ops::Register_Split(this); diff --git a/mace/ops/reverse.cc b/mace/ops/reverse.cc new file mode 100644 index 0000000000000000000000000000000000000000..4660fba7dd9f73de277aa0893585c639f85578de --- /dev/null +++ b/mace/ops/reverse.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/reverse.h" + +namespace mace { +namespace ops { + +void Register_Reverse(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reverse") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ReverseOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/reverse.h b/mace/ops/reverse.h new file mode 100644 index 0000000000000000000000000000000000000000..8f60dc66a7a089dd0e54ba6541e32734ff9893e1 --- /dev/null +++ b/mace/ops/reverse.h @@ -0,0 +1,50 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_REVERSE_H_ +#define MACE_OPS_REVERSE_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/reverse.h" + +namespace mace { +namespace ops { + +template +class ReverseOp : public Operator { + public: + ReverseOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *axis = this->Input(AXIS); + Tensor *output = this->Output(OUTPUT); + return functor_(input, axis, output, future); + } + + private: + kernels::ReverseFunctor functor_; + + protected: + MACE_OP_INPUT_TAGS(INPUT, AXIS); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_REVERSE_H_ diff --git a/mace/ops/reverse_test.cc b/mace/ops/reverse_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..afa17e502e9800556eacc7eebd0700c7e429a58f --- /dev/null +++ b/mace/ops/reverse_test.cc @@ -0,0 +1,62 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ReverseOpTest : public OpsTestBase {}; + +namespace { + +void TestReverse(const std::vector &input_shape, + const std::vector &input, + const std::vector &axis_shape, + const std::vector &axis, + const std::vector &outputs) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray("Axis", axis_shape, axis); + + OpDefBuilder("Reverse", "ReverseOpTest") + .Input("Input") + .Input("Axis") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", input_shape, + outputs); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + +} // namespace + +TEST_F(ReverseOpTest, SimpleCPU) { + TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {0}, + {6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}); + TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {1}, + {4, 5, 2, 3, 0, 1, 10, 11, 8, 9, 6, 7}); + TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {2}, + {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index a46af22440ad2d38233a98527eaeaa16fd791a83..faebe494d321c11675bacaa24316a4741e183861 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -87,6 +87,7 @@ MaceSupportedOps = [ 'DepthwiseConv2d', 'Dequantize', 'Eltwise', + 'ExpandDims', 'FoldedBatchNorm', 'Fill', 'FullyConnected', @@ -104,6 +105,7 @@ MaceSupportedOps = [ 'Reshape', 'ResizeBicubic', 'ResizeBilinear', + 'Reverse', 'ScalarMath', 'Slice', 'Split', diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 24799631bacfd4fea83e46b7fc4fcc6660b57cc7..08baa46208f1ab831c06c500a2ce1ca21e674a28 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -73,6 +73,7 @@ TFSupportedOps = [ 'FusedBatchNorm', 'AvgPool', 'MaxPool', + 'ExpandDims', 'Squeeze', 'MatMul', 'BatchMatMul', @@ -95,6 +96,7 @@ TFSupportedOps = [ 'Gather', 'StridedSlice', 'Slice', + 'ReverseV2', 'Stack', 'Pack', 'Unstack', @@ -181,6 +183,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Identity.name: self.convert_identity, TFOpType.Reshape.name: self.convert_reshape, TFOpType.Shape.name: self.convert_shape, + TFOpType.ExpandDims.name: self.convert_expand_dims, TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Transpose.name: self.convert_transpose, TFOpType.Softmax.name: self.convert_softmax, @@ -198,6 +201,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Gather.name: self.convert_gather, TFOpType.StridedSlice.name: self.convert_stridedslice, TFOpType.Slice.name: self.convert_slice, + TFOpType.ReverseV2.name: self.convert_reverse, TFOpType.Pack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack, TFOpType.Unpack.name: self.convert_unstack, @@ -225,9 +229,12 @@ class TensorflowConverter(base_converter.ConverterInterface): self._skip_tensor = set() + self._output_shape_list = [] + self._output_shape_op_list = [] + def run(self): with tf.Session() as session: - self.convert_ops() + self.convert_ops(session) self.replace_input_output_tensor_name() return self._mace_net_def @@ -267,12 +274,19 @@ class TensorflowConverter(base_converter.ConverterInterface): else: return tensor_name[:idx] - def convert_ops(self): + def update_output_shapes(self, sess): + output_shapes = sess.run(self._output_shape_op_list, + feed_dict=self._placeholders) + for i in range(len(self._output_shape_list)): + self._output_shape_list[i].dims.extend(output_shapes[i]) + + def convert_ops(self, sess): for tf_op in self._tf_graph.get_operations(): mace_check(tf_op.type in self._op_converters, "Mace does not support tensorflow op type %s yet" % tf_op.type) self._op_converters[tf_op.type](tf_op) + self.update_output_shapes(sess) self.convert_tensors() def convert_tensors(self): @@ -306,7 +320,7 @@ class TensorflowConverter(base_converter.ConverterInterface): # this function tries to infer tensor shape, but some dimension shape # may be undefined due to variance of input length - def infer_tensor_shape(self, tensor): + def infer_tensor_shape(self, output_shape, tensor): inferred_tensor_shape = tensor.shape.as_list() inferred_success = True for _, dim in enumerate(inferred_tensor_shape): @@ -314,10 +328,10 @@ class TensorflowConverter(base_converter.ConverterInterface): inferred_success = False break if inferred_success: - return inferred_tensor_shape - - tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders) - return tensor_shape + output_shape.dims.extend(inferred_tensor_shape) + else: + self._output_shape_list.append(output_shape) + self._output_shape_op_list.append(tf.shape(tensor)) def convert_nop(self, tf_op): pass @@ -330,7 +344,7 @@ class TensorflowConverter(base_converter.ConverterInterface): op.output.extend([tf_output.name for tf_output in tf_op.outputs]) for tf_output in tf_op.outputs: output_shape = op.output_shape.add() - output_shape.dims.extend(self.infer_tensor_shape(tf_output)) + self.infer_tensor_shape(output_shape, tf_output) data_type_arg = op.arg.add() data_type_arg.name = 'T' @@ -678,6 +692,21 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.Reshape.name + def convert_expand_dims(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.ExpandDims.name + + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + try: + axis_value = tf_op.get_attr('dim') + except ValueError: + try: + axis_value = tf_op.get_attr('axis') + except ValueError: + axis_value = 0 + axis_arg.i = axis_value + def convert_squeeze(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.Squeeze.name @@ -783,6 +812,10 @@ class TensorflowConverter(base_converter.ConverterInterface): arg.name = 'slice' arg.i = 1 + def convert_reverse(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Reverse.name + def convert_stack(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.Stack.name