diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 6183179e04694b600ae079c313d5b5161c4f7108..9321cd5e5570707c091f831a552a91fb07f11679 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -169,7 +169,7 @@ void RegisterReshape(OpRegistry *op_registry) { int has_data_format = ProtoArgHelper::GetOptionalArg( *op, "has_data_format", 0); - if (has_data_format) { + if (has_data_format && op->input_size() == 1) { return {DeviceType::CPU, DeviceType::GPU}; } @@ -183,7 +183,8 @@ void RegisterReshape(OpRegistry *op_registry) { op->output_shape(0).dims_size(); if (op_data_format == DataFormat::NHWC && 4 == tensor_shape_info->at(input_0).size() && - (out_dims_size == 4 || out_dims_size == 2)) { + (out_dims_size == 4 || out_dims_size == 2) && + op->input_size() == 1) { return {DeviceType::CPU, DeviceType::GPU}; } diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index ca384e59a18053968bd46344f92b4c3a815d110f..57e50dca16394de4411ffa67da7eb3a29c79e7fa 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -604,8 +604,8 @@ class OnnxConverter(base_converter.ConverterInterface): for output in node.outputs: op.output.append(output) if with_shape: + output_shape = op.output_shape.add() if output in self._graph_shapes_dict: - output_shape = op.output_shape.add() shape_info = self._graph_shapes_dict[output] output_shape.dims.extend(shape_info) diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 459299bc3d623a57527fefbd7e7df867e575a9c3..a15a93bd4e2158053dba094591328ed615f2d0a4 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -867,7 +867,9 @@ class Transformer(base_converter.ConverterInterface): if (((op.type == MaceOp.Conv2D.name or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.FullyConnected.name - or op.type == MaceOp.MatMul.name) + or (op.type == MaceOp.MatMul.name + and self._option.device == DeviceType.CPU.value + and not self._option.quantize)) and len(op.input) == 2) or (op.type == MaceOp.Deconv2D.name and ((ConverterUtil.get_arg(