diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 9273f156438f4e753f1c4abaac3d252e8400d2df..9b0762a8bc8f640abfff4246ad3e73282f4582cb 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -714,3 +714,17 @@ class ConverterUtil(object): return DataFormat.OIHW else: return None + + @staticmethod + def set_framework_type(net, framework_type): + framework_type_arg = net.arg.add() + framework_type_arg.name = MaceKeyword.mace_framework_type_str + framework_type_arg.i = framework_type + + @staticmethod + def framework_type(net): + framework_type_arg = ConverterUtil.get_arg( + net, MaceKeyword.mace_framework_type_str) + if framework_type_arg is None: + return None + return framework_type_arg.i diff --git a/tools/python/transform/caffe_converter.py b/tools/python/transform/caffe_converter.py index 0eddb5f958871004aff28bc279a95fbb6dc36529..b9a0777ddfb8dda9a381558460ec6ec55cdb4910 100644 --- a/tools/python/transform/caffe_converter.py +++ b/tools/python/transform/caffe_converter.py @@ -209,6 +209,8 @@ class CaffeConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.CAFFE.value) self._caffe_net = CaffeNet() self._caffe_layers = caffe_pb2.NetParameter() caffe_weights = caffe_pb2.NetParameter() diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index 041b625c822a54d58459af38dcb78e488682154b..4a4662af2453d89dc710e3646251cffdf9bb03a2 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -415,6 +415,8 @@ class OnnxConverter(base_converter.ConverterInterface): ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.add_data_format_arg(self._mace_net_def, self._data_format) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.ONNX.value) onnx_model = onnx.load(src_model_file) ir_version = onnx_model.ir_version diff --git a/tools/python/transform/pytorch_converter.py b/tools/python/transform/pytorch_converter.py index 019cce83e65abae9b1c6a814d6e05ed45c1634f0..5ed65c30b7c0c96252c28096153b9c3fb226d950 100644 --- a/tools/python/transform/pytorch_converter.py +++ b/tools/python/transform/pytorch_converter.py @@ -204,6 +204,8 @@ class PytorchConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.PYTORCH.value) self._op_converters = { NodeKind.AdaptiveAvgPool2D: self.convert_pool, NodeKind.Add: self.convert_add, diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index 1ee69b0dc3889a63e77cdf8c8c8bd567e94cb5ff..67d72ab0c90445efc395693f2e435369ae0e084e 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -306,6 +306,8 @@ class TensorflowConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC) + ConverterUtil.set_framework_type( + self._mace_net_def, FrameworkType.TENSORFLOW.value) # import tensorflow graph tf_graph_def = tf.GraphDef() diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 31b4dcac6d00f70d3b01344831d8b6cebb659731..ff9788dfc409d83874a530a8b6809afc33b38d1d 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -1322,8 +1322,7 @@ class Transformer(base_converter.ConverterInterface): # transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)` # fc output is 2D in transformer, using as 4D in op kernel # work for TensorFlow/PyTorch/ONNX - framework = ConverterUtil.get_arg( - op, MaceKeyword.mace_framework_type_str).i + framework = ConverterUtil.framework_type(net) is_torch = framework == FrameworkType.PYTORCH.value is_tf = framework == FrameworkType.TENSORFLOW.value is_onnx = framework == FrameworkType.ONNX.value @@ -1333,7 +1332,8 @@ class Transformer(base_converter.ConverterInterface): op.input[1] in self._consts and \ len(op.output_shape[0].dims) == 2 and \ (is_tf or is_torch or is_onnx) and \ - op.input[0] in self._producer: + op.input[0] in self._producer and \ + op.output[0] in self._consumers: input_op = self._producer[op.input[0]] input_shape = input_op.output_shape[0].dims # check input op