diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 4b48ab9d71824e249a52600e47bd38e1cf37a534..a76fca486a210249e6eb5fd8866868d2f09930b3 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -885,10 +885,10 @@ class TensorflowConverter(base_converter.ConverterInterface): op.output_type.extend([mace_pb2.DT_INT32]) def convert_split(self, tf_op): - axis = tf_op.inputs[0].eval().astype(np.int32) - axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis op = self.convert_general_op(tf_op) op.type = MaceOp.Split.name + axis = tf_op.inputs[0].eval().astype(np.int32) + axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis del op.input[0] axis_arg = op.arg.add()