From 49659ddb06fbd2298cb59e5926179098fc58ad32 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Tue, 20 Oct 2020 12:21:05 +0800 Subject: [PATCH] fix: Fix folding MatMul BiasAdd, Reshape --- mace/ops/reshape.cc | 5 +++-- tools/python/transform/onnx_converter.py | 2 +- tools/python/transform/transformer.py | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 6183179e..9321cd5e 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -169,7 +169,7 @@ void RegisterReshape(OpRegistry *op_registry) { int has_data_format = ProtoArgHelper::GetOptionalArg( *op, "has_data_format", 0); - if (has_data_format) { + if (has_data_format && op->input_size() == 1) { return {DeviceType::CPU, DeviceType::GPU}; } @@ -183,7 +183,8 @@ void RegisterReshape(OpRegistry *op_registry) { op->output_shape(0).dims_size(); if (op_data_format == DataFormat::NHWC && 4 == tensor_shape_info->at(input_0).size() && - (out_dims_size == 4 || out_dims_size == 2)) { + (out_dims_size == 4 || out_dims_size == 2) && + op->input_size() == 1) { return {DeviceType::CPU, DeviceType::GPU}; } diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index ca384e59..57e50dca 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -604,8 +604,8 @@ class OnnxConverter(base_converter.ConverterInterface): for output in node.outputs: op.output.append(output) if with_shape: + output_shape = op.output_shape.add() if output in self._graph_shapes_dict: - output_shape = op.output_shape.add() shape_info = self._graph_shapes_dict[output] output_shape.dims.extend(shape_info) diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 459299bc..a15a93bd 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -867,7 +867,9 @@ class Transformer(base_converter.ConverterInterface): if (((op.type == MaceOp.Conv2D.name or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.FullyConnected.name - or op.type == MaceOp.MatMul.name) + or (op.type == MaceOp.MatMul.name + and self._option.device == DeviceType.CPU.value + and not self._option.quantize)) and len(op.input) == 2) or (op.type == MaceOp.Deconv2D.name and ((ConverterUtil.get_arg( -- GitLab