diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite index 2cb33152c33ca7603cd015dfc0d16f71c1fb5e18..78614e5be67f1aeae9899266a56219f4fe39e70f 100644 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite index 0fc30506d0ded7f8d117fdc91cecd5c3a649a912..4817b9ea41669f18d524e57be96bb876ce261208 100644 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite index d820f874f942c1e29219657293cb3911f5041b92..d658f804ccb9e9ea102b86e27bb97197db88318a 100644 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite index 9baa40effdd1924ae945bfcee0a4706f631adf3b..56190f5ee8e4d682abe011ef41fca7cb9ce999ce 100644 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc index dc2a763d499645c6f3c8b38edbf55b0f78d6e87f..202079abc6483831cd68a3661654c5fefe992b01 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc @@ -28,13 +28,12 @@ TEST_F(TestTfliteParserConv, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserConv, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->group, 1); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc index 9ff4704c7c2d43c1323cdd103eba8d6a0dee2fa9..8badaa1d334ab4f5f135987ee471ff0b5d01f9c2 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc @@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDeConv, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserDeConv, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDeConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsDeConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->group, 1); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc index 6ba9c4d1e68af75d2b0d93beb8d07723ea2e4a13..243bbcc867a9d0d17d41299a8aab0ca0b144ee7f 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc @@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDepthwiseConv1, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->group, 0); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); @@ -64,13 +63,12 @@ TEST_F(TestTfliteParserDepthwiseConv2, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); ASSERT_EQ(val->hasBias, true); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index b9990d0ace7b45dea1182d2367faf54f22df2de1..6c25cc895b9cdcb2290b2f87b7b71b0d6ac06c1f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -222,12 +222,29 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT* sub_graph) { if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { auto attr = op->primitive->value.AsDepthwiseConv2D(); if (attr->channelMultiplier > 1) { - // update attr std::unique_ptr conv_attr(new schema::Conv2DT); + // get channel attr + if (op->inputIndex.empty()) { + MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; + return RET_NULL_PTR; + } + auto data_id = op->inputIndex[0]; + if (sub_graph->allTensors.size() <= data_id) { + MS_LOG(ERROR) << "the number of allTensors is less than " << data_id; + return RET_ERROR; + } + auto &data_tensor = sub_graph->allTensors.at(data_id); + if (data_tensor == nullptr) { + MS_LOG(ERROR) << "the data tensor is null"; + return RET_NULL_PTR; + } + auto data_shape = data_tensor->dims; + conv_attr->channelIn = data_shape[3]; + conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; + + // update attr conv_attr->group = 0; conv_attr->format = attr->format; - conv_attr->channelIn = attr->channelIn; - conv_attr->channelOut = attr->channelIn * attr->channelMultiplier; conv_attr->kernelH = attr->kernelH; conv_attr->kernelW = attr->kernelW; conv_attr->strideH = attr->strideH;