提交 cccb99d5 编写于 作者: 卢旭辉

Merge branch 'framework_type_in_net' into 'master'

fix: Add framework_type arg in Net so that we can get it from Net besides Op

See merge request applied-machine-learning/sysml/mace!1312
......@@ -714,3 +714,17 @@ class ConverterUtil(object):
return DataFormat.OIHW
else:
return None
@staticmethod
def set_framework_type(net, framework_type):
framework_type_arg = net.arg.add()
framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = framework_type
@staticmethod
def framework_type(net):
framework_type_arg = ConverterUtil.get_arg(
net, MaceKeyword.mace_framework_type_str)
if framework_type_arg is None:
return None
return framework_type_arg.i
......@@ -209,6 +209,8 @@ class CaffeConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.CAFFE.value)
self._caffe_net = CaffeNet()
self._caffe_layers = caffe_pb2.NetParameter()
caffe_weights = caffe_pb2.NetParameter()
......
......@@ -415,6 +415,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def,
self._data_format)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.ONNX.value)
onnx_model = onnx.load(src_model_file)
ir_version = onnx_model.ir_version
......
......@@ -204,6 +204,8 @@ class PytorchConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.PYTORCH.value)
self._op_converters = {
NodeKind.AdaptiveAvgPool2D: self.convert_pool,
NodeKind.Add: self.convert_add,
......
......@@ -306,6 +306,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.TENSORFLOW.value)
# import tensorflow graph
tf_graph_def = tf.GraphDef()
......
......@@ -1322,8 +1322,7 @@ class Transformer(base_converter.ConverterInterface):
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow/PyTorch/ONNX
framework = ConverterUtil.get_arg(
op, MaceKeyword.mace_framework_type_str).i
framework = ConverterUtil.framework_type(net)
is_torch = framework == FrameworkType.PYTORCH.value
is_tf = framework == FrameworkType.TENSORFLOW.value
is_onnx = framework == FrameworkType.ONNX.value
......@@ -1333,7 +1332,8 @@ class Transformer(base_converter.ConverterInterface):
op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \
(is_tf or is_torch or is_onnx) and \
op.input[0] in self._producer:
op.input[0] in self._producer and \
op.output[0] in self._consumers:
input_op = self._producer[op.input[0]]
input_shape = input_op.output_shape[0].dims
# check input op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册