提交 701e86ca 编写于 作者: 卢旭辉

Merge branch 'fold' into 'master'

fix: Fix folding MatMul BiasAdd

See merge request applied-machine-learning/sysml/mace!1306
...@@ -169,7 +169,7 @@ void RegisterReshape(OpRegistry *op_registry) { ...@@ -169,7 +169,7 @@ void RegisterReshape(OpRegistry *op_registry) {
int has_data_format = int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0); *op, "has_data_format", 0);
if (has_data_format) { if (has_data_format && op->input_size() == 1) {
return {DeviceType::CPU, DeviceType::GPU}; return {DeviceType::CPU, DeviceType::GPU};
} }
...@@ -183,7 +183,8 @@ void RegisterReshape(OpRegistry *op_registry) { ...@@ -183,7 +183,8 @@ void RegisterReshape(OpRegistry *op_registry) {
op->output_shape(0).dims_size(); op->output_shape(0).dims_size();
if (op_data_format == DataFormat::NHWC && if (op_data_format == DataFormat::NHWC &&
4 == tensor_shape_info->at(input_0).size() && 4 == tensor_shape_info->at(input_0).size() &&
(out_dims_size == 4 || out_dims_size == 2)) { (out_dims_size == 4 || out_dims_size == 2) &&
op->input_size() == 1) {
return {DeviceType::CPU, DeviceType::GPU}; return {DeviceType::CPU, DeviceType::GPU};
} }
......
...@@ -604,8 +604,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -604,8 +604,8 @@ class OnnxConverter(base_converter.ConverterInterface):
for output in node.outputs: for output in node.outputs:
op.output.append(output) op.output.append(output)
if with_shape: if with_shape:
output_shape = op.output_shape.add()
if output in self._graph_shapes_dict: if output in self._graph_shapes_dict:
output_shape = op.output_shape.add()
shape_info = self._graph_shapes_dict[output] shape_info = self._graph_shapes_dict[output]
output_shape.dims.extend(shape_info) output_shape.dims.extend(shape_info)
......
...@@ -867,7 +867,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -867,7 +867,9 @@ class Transformer(base_converter.ConverterInterface):
if (((op.type == MaceOp.Conv2D.name if (((op.type == MaceOp.Conv2D.name
or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.DepthwiseConv2d.name
or op.type == MaceOp.FullyConnected.name or op.type == MaceOp.FullyConnected.name
or op.type == MaceOp.MatMul.name) or (op.type == MaceOp.MatMul.name
and self._option.device == DeviceType.CPU.value
and not self._option.quantize))
and len(op.input) == 2) and len(op.input) == 2)
or (op.type == MaceOp.Deconv2D.name or (op.type == MaceOp.Deconv2D.name
and ((ConverterUtil.get_arg( and ((ConverterUtil.get_arg(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册