From b0d9a3aa526303c4f276e7ef1044ab2913f40aea Mon Sep 17 00:00:00 2001 From: like15 Date: Tue, 27 Oct 2020 18:01:02 +0800 Subject: [PATCH] fix: Add framework_type arg in Net so that we can get it from Net besides Op --- tools/python/transform/base_converter.py | 14 ++++++++++++++ tools/python/transform/caffe_converter.py | 2 ++ tools/python/transform/onnx_converter.py | 2 ++ tools/python/transform/pytorch_converter.py | 2 ++ tools/python/transform/tensorflow_converter.py | 2 ++ tools/python/transform/transformer.py | 6 +++--- 6 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 9273f156..9b0762a8 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -714,3 +714,17 @@ class ConverterUtil(object): return DataFormat.OIHW else: return None + + @staticmethod + def set_framework_type(net, framework_type): + framework_type_arg = net.arg.add() + framework_type_arg.name = MaceKeyword.mace_framework_type_str + framework_type_arg.i = framework_type + + @staticmethod + def framework_type(net): + framework_type_arg = ConverterUtil.get_arg( + net, MaceKeyword.mace_framework_type_str) + if framework_type_arg is None: + return None + return framework_type_arg.i diff --git a/tools/python/transform/caffe_converter.py b/tools/python/transform/caffe_converter.py index 0eddb5f9..b9a0777d 100644 --- a/tools/python/transform/caffe_converter.py +++ b/tools/python/transform/caffe_converter.py @@ -209,6 +209,8 @@ class CaffeConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.CAFFE.value) self._caffe_net = CaffeNet() self._caffe_layers = caffe_pb2.NetParameter() caffe_weights = caffe_pb2.NetParameter() diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index 041b625c..4a4662af 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -415,6 +415,8 @@ class OnnxConverter(base_converter.ConverterInterface): ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.add_data_format_arg(self._mace_net_def, self._data_format) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.ONNX.value) onnx_model = onnx.load(src_model_file) ir_version = onnx_model.ir_version diff --git a/tools/python/transform/pytorch_converter.py b/tools/python/transform/pytorch_converter.py index 019cce83..5ed65c30 100644 --- a/tools/python/transform/pytorch_converter.py +++ b/tools/python/transform/pytorch_converter.py @@ -204,6 +204,8 @@ class PytorchConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.PYTORCH.value) self._op_converters = { NodeKind.AdaptiveAvgPool2D: self.convert_pool, NodeKind.Add: self.convert_add, diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index 1ee69b0d..67d72ab0 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -306,6 +306,8 @@ class TensorflowConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.TENSORFLOW.value) # import tensorflow graph tf_graph_def = tf.GraphDef() diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 31b4dcac..ff9788df 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -1322,8 +1322,7 @@ class Transformer(base_converter.ConverterInterface): # transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)` # fc output is 2D in transformer, using as 4D in op kernel # work for TensorFlow/PyTorch/ONNX - framework = ConverterUtil.get_arg( - op, MaceKeyword.mace_framework_type_str).i + framework = ConverterUtil.framework_type(net) is_torch = framework == FrameworkType.PYTORCH.value is_tf = framework == FrameworkType.TENSORFLOW.value is_onnx = framework == FrameworkType.ONNX.value @@ -1333,7 +1332,8 @@ class Transformer(base_converter.ConverterInterface): op.input[1] in self._consts and \ len(op.output_shape[0].dims) == 2 and \ (is_tf or is_torch or is_onnx) and \ - op.input[0] in self._producer: + op.input[0] in self._producer and \ + op.output[0] in self._consumers: input_op = self._producer[op.input[0]] input_shape = input_op.output_shape[0].dims # check input op -- GitLab