提交 cccb99d5 编写于 作者: 卢旭辉

Merge branch 'framework_type_in_net' into 'master'

fix: Add framework_type arg in Net so that we can get it from Net besides Op

See merge request applied-machine-learning/sysml/mace!1312
...@@ -714,3 +714,17 @@ class ConverterUtil(object): ...@@ -714,3 +714,17 @@ class ConverterUtil(object):
return DataFormat.OIHW return DataFormat.OIHW
else: else:
return None return None
@staticmethod
def set_framework_type(net, framework_type):
framework_type_arg = net.arg.add()
framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = framework_type
@staticmethod
def framework_type(net):
framework_type_arg = ConverterUtil.get_arg(
net, MaceKeyword.mace_framework_type_str)
if framework_type_arg is None:
return None
return framework_type_arg.i
...@@ -209,6 +209,8 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -209,6 +209,8 @@ class CaffeConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.CAFFE.value)
self._caffe_net = CaffeNet() self._caffe_net = CaffeNet()
self._caffe_layers = caffe_pb2.NetParameter() self._caffe_layers = caffe_pb2.NetParameter()
caffe_weights = caffe_pb2.NetParameter() caffe_weights = caffe_pb2.NetParameter()
......
...@@ -415,6 +415,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -415,6 +415,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, ConverterUtil.add_data_format_arg(self._mace_net_def,
self._data_format) self._data_format)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.ONNX.value)
onnx_model = onnx.load(src_model_file) onnx_model = onnx.load(src_model_file)
ir_version = onnx_model.ir_version ir_version = onnx_model.ir_version
......
...@@ -204,6 +204,8 @@ class PytorchConverter(base_converter.ConverterInterface): ...@@ -204,6 +204,8 @@ class PytorchConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.PYTORCH.value)
self._op_converters = { self._op_converters = {
NodeKind.AdaptiveAvgPool2D: self.convert_pool, NodeKind.AdaptiveAvgPool2D: self.convert_pool,
NodeKind.Add: self.convert_add, NodeKind.Add: self.convert_add,
......
...@@ -306,6 +306,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -306,6 +306,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC) ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
ConverterUtil.set_framework_type(
self._mace_net_def, FrameworkType.TENSORFLOW.value)
# import tensorflow graph # import tensorflow graph
tf_graph_def = tf.GraphDef() tf_graph_def = tf.GraphDef()
......
...@@ -1322,8 +1322,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1322,8 +1322,7 @@ class Transformer(base_converter.ConverterInterface):
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)` # transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel # fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow/PyTorch/ONNX # work for TensorFlow/PyTorch/ONNX
framework = ConverterUtil.get_arg( framework = ConverterUtil.framework_type(net)
op, MaceKeyword.mace_framework_type_str).i
is_torch = framework == FrameworkType.PYTORCH.value is_torch = framework == FrameworkType.PYTORCH.value
is_tf = framework == FrameworkType.TENSORFLOW.value is_tf = framework == FrameworkType.TENSORFLOW.value
is_onnx = framework == FrameworkType.ONNX.value is_onnx = framework == FrameworkType.ONNX.value
...@@ -1333,7 +1332,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1333,7 +1332,8 @@ class Transformer(base_converter.ConverterInterface):
op.input[1] in self._consts and \ op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \ len(op.output_shape[0].dims) == 2 and \
(is_tf or is_torch or is_onnx) and \ (is_tf or is_torch or is_onnx) and \
op.input[0] in self._producer: op.input[0] in self._producer and \
op.output[0] in self._consumers:
input_op = self._producer[op.input[0]] input_op = self._producer[op.input[0]]
input_shape = input_op.output_shape[0].dims input_shape = input_op.output_shape[0].dims
# check input op # check input op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册