diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index c4ad7cdbfb4c2861ce846af9770daa5670767f8a..91ad51afc70777749088ea8fe48bcfa43ebc94d8 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -43,6 +43,11 @@ int AddN::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } for (int i = 1; i < inputs.size(); ++i) { if (inputs.at(i)->shape() != inputs.at(0)->shape()) { MS_LOG(ERROR) << "AddN inputs shape is not equal!"; @@ -53,9 +58,8 @@ int AddN::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); + output->set_shape(input->shape()); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index 208511c8fc2f6c8723f3b4a833094ea701857fcd..be50f47acbcc22af865abed94caec1fab364d67f 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } auto argmax_prim = this->primitive->value_as_ArgMax(); std::vector output_shape(input->shape()); auto input_shape_size = input->shape().size(); @@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector inputs_, std::vectortopK(); } - output->SetFormat(input->GetFormat()); + output->set_shape(output_shape); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index c6584d159404e63f2635ee7de1bbe00053815d15..0d9940f0414c84e72d4c4e7f6010c51ffaca8794 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } auto argmin_prim = this->primitive->value_as_ArgMin(); auto input_shape_size = input->shape().size(); int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis(); @@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector inputs_, std::vectortopK(); } - output->SetFormat(input->GetFormat()); + output->set_shape(output_shape); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 8ac0ce3b9f269526002ecb37d5d2dfa312296cbf..eb7ef3dc7b5dab9484f125318850db7fb418e36e 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector inputs, std::vec return 1; } auto input = inputs.at(0); + outputs[0]->SetFormat(input->GetFormat()); + outputs[0]->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } std::vector dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(), this->primitive->value_as_BroadcastTo()->dst_shape()->end()); auto input_shape = input->shape(); @@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector inputs, std::vec shape[i] = dst_shape[i]; --input_shape_index; } - outputs[0]->SetFormat(input->GetFormat()); outputs[0]->set_shape(shape); - outputs[0]->set_data_type(input->data_type()); - return 0; + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 0cc8cb8dc5715b0c6c918b57e6c7d497dd88ecd5..fa7850bccf945eb5b93a7deb27be3b0f401b8a7f 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -44,8 +44,14 @@ int Cast::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); auto cast_prim = this->primitive->value_as_Cast(); MS_ASSERT(cast_prim != nullptr); + output->set_data_type(static_cast(cast_prim->dstT())); + if (!GetInferFlag()) { + return RET_OK; + } + if (input->data_type() != cast_prim->srcT()) { MS_LOG(ERROR) << "input dataType is error"; return RET_INPUT_TENSOR_ERROR; @@ -54,13 +60,8 @@ int Cast::InferShape(std::vector inputs_, std::vectordata_type(); return RET_INPUT_TENSOR_ERROR; } - if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { - MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT(); - return RET_INPUT_TENSOR_ERROR; - } - output->SetFormat(input->GetFormat()); + output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeFloat32); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index 4fc49d8a120ec32bebb2cba02ae2cc424fb80a94..d214d72aa76309a85c4e79b5487b008fd2949641 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector inputs_, std::vect return RET_ERROR; } auto in_tensor = inputs_.front(); - auto in_data = reinterpret_cast(in_tensor->Data()); auto out_tensor = outputs_.front(); + out_tensor->set_data_type(kNumberTypeFloat32); + out_tensor->SetFormat(in_tensor->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto in_data = reinterpret_cast(in_tensor->Data()); int size = in_tensor->ElementsNum(); std::vector out_shape(size); for (int i = 0; i < size; ++i) { out_shape[i] = in_data[i]; } out_tensor->set_shape(out_shape); - out_tensor->set_data_type(kNumberTypeFloat32); - out_tensor->SetFormat(in_tensor->GetFormat()); return RET_OK; } diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index 8c5ddb7cb7cd96b816c9cdfcae0a0beeda298a25..71c07885564ae7fe2b2d8bec8426a99e711b5a46 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -46,9 +46,12 @@ int Crop::InferShape(std::vector inputs, std::vectorset_shape(inputs[1]->shape()); outputs[0]->SetFormat(inputs[0]->GetFormat()); outputs[0]->set_data_type(inputs[0]->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } + outputs[0]->set_shape(inputs[1]->shape()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 972c0a60f73f6e3192d4665325572352d7ac5af3..8aaa7d4135048ea1819aebcae3de57902e7880dc 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector inputs_, std::vecto MS_ASSERT(weight != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } int32_t input_h = input->Height(); int32_t input_w = input->Width(); @@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector inputs_, std::vecto std::vector out_shape = {output_n, output_h, output_w, output_c}; output->set_shape(out_shape); - output->SetFormat(input->GetFormat()); - output->set_data_type(input->data_type()); return 0; } } // namespace lite diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index 81548ea4621f0a5ffde765881949f6c706ac0f56..c07e521fe55cc9b338d9da1d64a4dbc39891d797 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector inputs_, MS_ASSERT(weight != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } auto in_shape = input->shape(); int input_h = in_shape.at(1); int input_w = in_shape.at(2); @@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector inputs_, out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel output->set_shape(out_shape); - output->SetFormat(input->GetFormat()); - output->set_data_type(input->data_type()); return 0; } } // namespace lite diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index 4bfcbf369e559f89793b8558e797fe508e843f71..ab0edb0f208501ab870a450e8faf036cc6021867 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector inputs, std::ve MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; return 1; } + outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto input_shape = input->shape(); if (input_shape.size() != kDimension_4d) { MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; @@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector inputs, std::ve output_shape[NHWC_W] = input_shape[NHWC_W] * block_size; output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size); outputs[0]->set_shape(output_shape); - outputs[0]->set_data_type(input->data_type()); - outputs[0]->SetFormat(input->GetFormat()); - - return 0; + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index 0cfb482bbd5a970fae96004c0da66e86f8468ded..de6c6acef3ffa5ff31dcd6106373183e3f7f3d28 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector inputs_, MS_ASSERT(weight != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } auto in_shape = input->shape(); int input_h = in_shape.at(1); int input_w = in_shape.at(2); @@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector inputs_, out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel output->set_shape(out_shape); - output->SetFormat(input->GetFormat()); - output->set_data_type(input->data_type()); return 0; } } // namespace lite diff --git a/mindspore/lite/src/ops/embedding_lookup.cc b/mindspore/lite/src/ops/embedding_lookup.cc index 0653f1d8c1d2b9e20732aa33ebc708ce06acd272..c1aa76a4bb0cadb0ef756a4d8c3ad5cfdcc60315 100644 --- a/mindspore/lite/src/ops/embedding_lookup.cc +++ b/mindspore/lite/src/ops/embedding_lookup.cc @@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector inputs_, std::vect MS_ASSERT(ids != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); + output->SetFormat(params_->GetFormat()); + output->set_data_type(params_->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } + auto embedding_shape = params_->shape(); embedding_shape.erase(embedding_shape.begin()); std::vector output_shape(ids->shape()); @@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector inputs_, std::vect } } output->set_shape(output_shape); - output->set_data_type(params_->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index f959d3c501d31512ccc7afbdecc16eff284cad45..0cdff13698b91a4c91833b17010a899740fde71f 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto expand_dims_prim = this->primitive->value_as_ExpandDims(); int dim = expand_dims_prim->dim(); if (dim < 0) { @@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector inputs_, std::vectorshape(); out_shape.insert(out_shape.begin() + dim, 1, 1); output->set_shape(out_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 9e0ed36f538d0c4a2bf735a4ffe80a0160c35df8..df6b1268811b69edc487ff7b8dfb65e8e457e1e9 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -45,6 +45,11 @@ int Fill::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto fill_prim = this->primitive->value_as_Fill(); if (fill_prim == nullptr) { MS_LOG(ERROR) << "Fill primitive is null!"; @@ -53,8 +58,6 @@ int Fill::InferShape(std::vector inputs_, std::vector output_shape; (void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end()); output->set_shape(output_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index fbc008f3119e4eff0811cc853e656f6cb8face66..f5a96068388b471bf3fc73d8c15e6f173eca1c30 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); std::vector output_shape(2); output_shape[0] = input_shape[0]; @@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/full_connection.cc b/mindspore/lite/src/ops/full_connection.cc index 69b5c4417e064e064389157ff198ec8246de093e..8f7bb738285d2088468de7f4000c394b74677a98 100644 --- a/mindspore/lite/src/ops/full_connection.cc +++ b/mindspore/lite/src/ops/full_connection.cc @@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector inputs_, MS_ASSERT(input1 != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - + output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) { MS_LOG(ERROR) << "Input tensors num error"; return 1; @@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector inputs_, out_shape.resize(GetAxis() + 1); out_shape[GetAxis()] = input1->shape()[0]; output->set_shape(out_shape); - output->set_data_type(input0->data_type()); - output->SetFormat(input0->GetFormat()); return 0; } diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index 36eeaacd1360c76e3cb0d9e77e981e8b4c346577..a6696bfefe8b6320593630f44f5bb83ecf001d50 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto in_shape = input->shape(); int in_rank = in_shape.size(); auto indices_shape = indices->shape(); @@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index 6983619636ad7c9f6bc8178cfd8a42f4d19da5b5..7c7ed2391d6c7a9d588bc39c6d5a10f28ddd843d 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + outputs_[i]->SetFormat(input->GetFormat()); + } + if (!GetInferFlag()) { + return RET_OK; + } + std::vector in_shape = input->shape(); std::vector w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size if (in_shape.size() != 3 || w_shape.size() != 3) { @@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector inputs_, std::vectorset_shape(state_shape); outputs_[2]->set_shape(state_shape); - for (int i = 0; i < kLstmOutputNum; i++) { - outputs_[i]->set_data_type(input->data_type()); - outputs_[i]->SetFormat(input->GetFormat()); - } + return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 7a6d7e145223f8018231c00ba94b6ff16ea53384..ba19adb38ccd6133f58b2c7787a0234bc02a37e1 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector inputs_, std::vectorset_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + std::vector a_shape = input0->shape(); std::vector b_shape = input1->shape(); if (a_shape.size() < 2 || b_shape.size() < 2) { @@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector inputs_, std::vector c_shape(a_shape); c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; output->set_shape(c_shape); - output->set_data_type(input0->data_type()); - output->SetFormat(input0->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/mean.cc b/mindspore/lite/src/ops/mean.cc index f7f2b487e269a4cd1c582bcaaa0f8ed44b41decc..a2901b8890e4b2557f847322dcc41de985f9854f 100644 --- a/mindspore/lite/src/ops/mean.cc +++ b/mindspore/lite/src/ops/mean.cc @@ -50,6 +50,11 @@ int Mean::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } if (this->primitive == nullptr) { return RET_NULL_PTR; } @@ -88,8 +93,6 @@ int Mean::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/nchw2nhwc.cc b/mindspore/lite/src/ops/nchw2nhwc.cc index c4bd16c581cbf882177c3508d6d581e5001f12c8..a18558d23f329915a773c89f65e7234986acc5b8 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.cc +++ b/mindspore/lite/src/ops/nchw2nhwc.cc @@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector inputs_, std::vect MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); + output->SetFormat(schema::Format_NHWC); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } std::vector nchw_shape = input->shape(); if (nchw_shape.size() != 4) { output->set_shape(nchw_shape); @@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector inputs_, std::vect nhwc_shape[NHWC_C] = nchw_shape[NCHW_C]; output->set_shape(nhwc_shape); } - output->SetFormat(schema::Format_NHWC); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/nhwc2nchw.cc b/mindspore/lite/src/ops/nhwc2nchw.cc index a13e858853d79a284fa02a8a098da00c63b725a1..39b9b3854fd96b88292b49b3713e33ff2f4fa086 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.cc +++ b/mindspore/lite/src/ops/nhwc2nchw.cc @@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector inputs_, std::vect MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); + output->SetFormat(schema::Format_NCHW); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } std::vector nhwc_shape = input->shape(); if (nhwc_shape.size() != 4) { output->set_shape(nhwc_shape); @@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector inputs_, std::vect nchw_shape[NCHW_W] = nhwc_shape[NHWC_W]; output->set_shape(nchw_shape); } - output->SetFormat(schema::Format_NCHW); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index 7406216debd1a45e6e38d4d423d5f3e3ae103f80..848f763583ee9f27ddddb3c5087e478fa307b8de 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector inputs, std::vectorset_data_type(on_value->data_type()); + output->SetFormat(on_value->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } const auto input_shape = input->shape(); int input_rank = static_cast(input_shape.size()); if (axis < 0) { @@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector inputs, std::vector output_shape(input_shape); output_shape.insert(output_shape.cbegin() + axis, *depth); - auto output = outputs.front(); - if (output == nullptr) { - return RET_NULL_PTR; - } output->set_shape(output_shape); - auto on_value = inputs.at(2); - if (on_value == nullptr) { - return RET_NULL_PTR; - } - output->set_data_type(on_value->data_type()); - output->SetFormat(on_value->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index e35a050b6464a339bc82128909179ade9c56a4cc..fef616abf623b2a785c675712d82fd7ea6d17c56 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -61,6 +61,15 @@ int Pad::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } auto input_shape = input->shape(); std::vector output_shape; MS_ASSERT(input->shape().size() <= kInputRank); @@ -69,13 +78,8 @@ int Pad::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); + output->set_shape(output_shape); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index 0f7807b3861e824487c6324124f72ff88ee1393e..c9535ff99df898b6c9825724e04115cb475bac1e 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(schema::Format_NHWC); + if (!GetInferFlag()) { + return RET_OK; + } int input_h = input->shape().at(1); int input_w = input->shape().at(2); auto pooling_prim = this->primitive->value_as_Pooling(); @@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector inputs_, std::vectorset_shape(input_shape); - output->set_data_type(input->data_type()); - // todo: temp fix - output->SetFormat(schema::Format_NHWC); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index 7f9aeb6796f5a7ed9e33c4fa2edbc0616fa2e7ab..82398aa6836120b125f7d7b62d3efbba7d4956c0 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -49,15 +49,19 @@ int Power::InferShape(std::vector inputs, std::vectorset_data_type(x_tensor->data_type()); + output_tensor->SetFormat(x_tensor->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } if (exp_tensor != nullptr) { if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) { MS_LOG(ERROR) << "Power inputs shape or type is not equal!"; return RET_INPUT_TENSOR_ERROR; } } - output_tensor->SetFormat(x_tensor->GetFormat()); + output_tensor->set_shape(x_tensor->shape()); - output_tensor->set_data_type(x_tensor->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/prior_box.cc b/mindspore/lite/src/ops/prior_box.cc index bd7b92189766e4063d4d8b2c948420feee218d73..eedb6ec2b7441b5d4f1568e2bcead6c24a21d1c9 100644 --- a/mindspore/lite/src/ops/prior_box.cc +++ b/mindspore/lite/src/ops/prior_box.cc @@ -99,6 +99,15 @@ constexpr int kPriorBoxC = 2; int PriorBox::InferShape(std::vector inputs_, std::vector outputs_) { auto param = this->primitive->value_as_PriorBox(); MS_ASSERT(param != nullptr); + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto output = outputs_.at(0); + MS_ASSERT(output != nullptr); + output->set_data_type(kNumberTypeFloat32); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } std::vector different_aspect_ratios{1.0f}; auto aspect_ratios = param->aspect_ratios(); MS_ASSERT(aspect_ratios != nullptr); @@ -114,15 +123,9 @@ int PriorBox::InferShape(std::vector inputs_, std::vectormin_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size(); - auto input = inputs_.at(0); - MS_ASSERT(input != nullptr); int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints; std::vector output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC}; - auto output = outputs_.at(0); - MS_ASSERT(output != nullptr); output->set_shape(output_shape); - output->set_data_type(kNumberTypeFloat32); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc index 50cd868c6d23c9f365a74735978f251d1d6b7052..ddbb1eadd507823aa5e116bb20c20f9afbc723ad 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.cc +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -40,11 +40,14 @@ int QuantDTypeCast::InferShape(std::vector inputs_, std::vecto MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); auto param = primitive->value_as_QuantDTypeCast(); MS_ASSERT(input->data_type() == param->srcT); output->set_data_type(static_cast(param->dstT())); output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + output->set_shape(input->shape()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index 29173b35f5c021e8dd475b6ad2b2f05e1c1a82dd..08e7f89728af7fd5b2ef3900718d3a2de97859d8 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -50,12 +50,18 @@ int Range::InferShape(std::vector inputs_, std::vectorprimitive->value_as_Range(); MS_ASSERT(range_prim != nullptr); + + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + int shape_size = std::ceil(static_cast(range_prim->limit() - range_prim->start()) / range_prim->delta()); std::vector in_shape(1); in_shape.push_back(shape_size); output->set_shape(in_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc index 5ee331bc76dd66b34358a24f541c69e8fc08d343..8057f52b9e96f470e6efd10ae354ad368d6f8248 100644 --- a/mindspore/lite/src/ops/rank.cc +++ b/mindspore/lite/src/ops/rank.cc @@ -25,10 +25,13 @@ int Rank::InferShape(std::vector inputs_, std::vector in_shape(1, 1); - output->set_shape(in_shape); output->set_data_type(input->data_type()); output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + std::vector in_shape(1, 1); + output->set_shape(in_shape); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index be21df3a8cb2bd534c87f2cc6e4379ac6c895ac6..1ef6c3c2deabc6bf97f898167948a7c95fce00d9 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -66,6 +66,11 @@ int Resize::InferShape(std::vector inputs_, std::vector< if (output == nullptr) { return 1; } + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto new_height = GetNewHeight(); auto new_width = GetNewWidth(); @@ -75,10 +80,8 @@ int Resize::InferShape(std::vector inputs_, std::vector< output_shape.push_back(new_width); output_shape.push_back(input->Channel()); output->set_shape(output_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); - return 0; + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/reverse_sequence.cc b/mindspore/lite/src/ops/reverse_sequence.cc index f0512cd522088168b9ae12e8102d504202cb976f..fc7179d09ab834f58ebc952ffb8cde692b09ec22 100644 --- a/mindspore/lite/src/ops/reverse_sequence.cc +++ b/mindspore/lite/src/ops/reverse_sequence.cc @@ -52,9 +52,13 @@ int ReverseSequence::InferShape(std::vector inputs, std::vecto auto output = outputs.front(); MS_ASSERT(input != nullptr); MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); + output->set_data_type(input->data_type()); output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + output->set_shape(input->shape()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/roi_pooling.cc b/mindspore/lite/src/ops/roi_pooling.cc index 114c10c2b25fd17b25e603b8800a20c34a18d7f8..f4853425988464e5f4c8d5388c931d8d2bd72154 100644 --- a/mindspore/lite/src/ops/roi_pooling.cc +++ b/mindspore/lite/src/ops/roi_pooling.cc @@ -56,6 +56,11 @@ int ROIPooling::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto ROIPooling = this->primitive->value_as_ROIPooling(); auto new_h = ROIPooling->pooledH(); auto new_w = ROIPooling->pooledW(); @@ -66,8 +71,6 @@ int ROIPooling::InferShape(std::vector inputs_, std::vectorChannel()); output->set_shape(output_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc index 0fa21c636acf274fa37f5ece9da2bdb9a5890b6e..a0331927085a1a5ef933b7267abd9434eaf16e50 100644 --- a/mindspore/lite/src/ops/scatter_nd.cc +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -51,11 +51,14 @@ int ScatterND::InferShape(std::vector inputs_, std::vectorset_data_type(update->data_type()); + output->SetFormat(update->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto shape_data = reinterpret_cast(shape->Data()); std::vector out_shape(shape_data, shape_data + shape->DataSize()); output->set_shape(out_shape); - output->set_data_type(update->data_type()); - output->SetFormat(update->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc index 08d411f4c6954bf784374f91acdcf1d10473e1ee..db7bd7dd77d2dd875697eec7319de25f08a7499a 100644 --- a/mindspore/lite/src/ops/space_to_batch.cc +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -63,6 +63,11 @@ int SpaceToBatch::InferShape(std::vector inputs, std::ve MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; return 1; } + outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto input_shape = input->shape(); if (input_shape.size() != kDimension_4d) { MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; @@ -106,8 +111,7 @@ int SpaceToBatch::InferShape(std::vector inputs, std::ve output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H]; output_shape[NHWC_C] = input_shape[NHWC_C]; outputs[0]->set_shape(output_shape); - outputs[0]->set_data_type(input->data_type()); - return 0; + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc index 79f3f175b3d3f2c4e055e24edb6176ab9cc177e8..261a2aa19b7b0c8dcfca740f75fd5a5a6a90093a 100644 --- a/mindspore/lite/src/ops/space_to_depth.cc +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -51,6 +51,11 @@ int SpaceToDepth::InferShape(std::vector inputs, std::ve MS_LOG(ERROR) << "space_to_depth only support NHWC now!"; return 1; } + outputs[0]->SetFormat(input->GetFormat()); + outputs[0]->set_data_type(input->data_type()); + if (!GetInferFlag()) { + return RET_OK; + } auto input_shape = input->shape(); if (input_shape.size() != kDimension_4d) { MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; @@ -69,8 +74,7 @@ int SpaceToDepth::InferShape(std::vector inputs, std::ve output_shape[NHWC_W] = input_shape[NHWC_W] / block_size; output_shape[NHWC_C] = input_shape[NHWC_C] * (block_size * block_size); outputs[0]->set_shape(output_shape); - outputs[0]->set_data_type(input->data_type()); - return 0; + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index 8d3afd8aa70ef3b49dd1c881fa2b850b1f91bb81..4b48debd838e856ee3ac46b797d9478d374c618f 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -66,6 +66,13 @@ int Split::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + outputs_[i]->SetFormat(input->GetFormat()); + } + if (!GetInferFlag()) { + return RET_OK; + } int split_dim = spilt_prim->splitDim(); std::vector input_shape = input->shape(); std::vector size_split; diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index efa625aea03ac937f81f8888b01c89d8bb807717..dcc3e4c07e6bd14699714cfa80e04b3e288f57cd 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -48,6 +48,11 @@ int Squeeze::InferShape(std::vector inputs_, std::vectorset_data_type(in_tensor->data_type()); + outputs_.front()->SetFormat(in_tensor->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto in_shape = in_tensor->shape(); std::vector out_shape; // todo: getAxis @@ -77,8 +82,6 @@ int Squeeze::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); - outputs_.front()->set_data_type(in_tensor->data_type()); - outputs_.front()->SetFormat(in_tensor->GetFormat()); return 0; } } // namespace lite diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index 985aaa79bee0da92b09bda8d13d5d07c498cf976..3ac6c465f14cc50241e1a91defa4cce5a0cc5019 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -56,6 +56,11 @@ int Stack::InferShape(std::vector inputs, std::vectorset_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto input_shape = input->shape(); auto stack_prim = this->primitive->value_as_Stack(); std::vector output_shape = input_shape; @@ -84,8 +89,6 @@ int Stack::InferShape(std::vector inputs, std::vectorset_shape(output_shape); - outputs[0]->set_data_type(input->data_type()); - outputs[0]->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 8722d88930e1c84455a0306cd1fe6f4030aa3713..7ec9fd1af38f51b0890f7e14c22f9f68e61a171a 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -164,6 +164,11 @@ int StridedSlice::InferShape(std::vector inputs, std::ve return RET_PARAM_INVALID; } auto input = inputs.at(0); + outputs.front()->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } MS_ASSERT(input != nullptr); auto input_shape = input->shape(); std::vector output_shape; @@ -214,8 +219,6 @@ int StridedSlice::InferShape(std::vector inputs, std::ve output_shape = ApplyShrinkMask(output_shape); outputs.front()->set_shape(output_shape); - outputs.front()->set_data_type(input->data_type()); - outputs[0]->SetFormat(input->GetFormat()); return RET_OK; } diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 5332882a9e68c966b6710d5d41c44f5777ef6f34..38cbe30991b381a1f4b99cd1763a8083b2abe63e 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -40,6 +40,11 @@ int Tile::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto tile_prim = this->primitive->value_as_Tile(); MS_ASSERT(tile_prim != nullptr); std::vector out_shape; @@ -49,9 +54,8 @@ int Tile::InferShape(std::vector inputs_, std::vectorshape()[i] * multiples[i]; out_shape.push_back(tmp); } - output->SetFormat(input->GetFormat()); + output->set_shape(out_shape); - output->set_data_type(input->data_type()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index a5e1abdc048d8157a8857668a8f7874ec8b052db..03463686d0dce892aa0236a64c1437e98ebfe25b 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -46,16 +46,19 @@ int TopK::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output0->SetFormat(input->GetFormat()); + output1->set_data_type(kNumberTypeInt32); + output1->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } auto topk_prim = this->primitive->value_as_TopK(); MS_ASSERT(topk_prim != nullptr); auto out_shape = input->shape(); out_shape[out_shape.size() - 1] = topk_prim->k(); output0->set_shape(out_shape); - output0->set_data_type(input->data_type()); - output0->SetFormat(input->GetFormat()); output1->set_shape(out_shape); - output1->set_data_type(kNumberTypeInt32); - output1->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/unique.cc b/mindspore/lite/src/ops/unique.cc index e2561690fe14a75029880f4eee5fc2dc9382d23b..56b8a3b2836b053c860456a4aa4ad0c29a60cdb1 100644 --- a/mindspore/lite/src/ops/unique.cc +++ b/mindspore/lite/src/ops/unique.cc @@ -42,12 +42,15 @@ int Unique::InferShape(std::vector inputs_, std::vectorset_shape(input->shape()); output0->set_data_type(input->data_type()); - output1->set_shape(input->shape()); output1->set_data_type(kNumberTypeInt32); output1->SetFormat(input->GetFormat()); output0->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + output0->set_shape(input->shape()); + output1->set_shape(input->shape()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc index 63444e50d05d431615471a8ab51aeb8180f1722a..490e9285b9a2de4f1846bb3653ea73fb9c21ca7e 100644 --- a/mindspore/lite/src/ops/unstack.cc +++ b/mindspore/lite/src/ops/unstack.cc @@ -44,6 +44,14 @@ int Unstack::InferShape(std::vector inputs, std::vectoraxis(); return RET_PARAM_INVALID; } + for (auto &out : outputs) { + MS_ASSERT(out != nullptr); + out->set_data_type(input->data_type()); + out->SetFormat(input->GetFormat()); + } + if (!GetInferFlag()) { + return RET_OK; + } std::vector output_shape; for (size_t i = 0; i < input_shape.size(); ++i) { if (i != axis) { @@ -53,8 +61,6 @@ int Unstack::InferShape(std::vector inputs, std::vectorset_shape(output_shape); - out->set_data_type(input->data_type()); - out->SetFormat(input->GetFormat()); } return RET_OK; } diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc index 36b8c8caa4048390ba7f82ce7d9a4bf248a7b1ac..03dbf1beb38e1f686f74c155c0f3d98d4b19bdc3 100644 --- a/mindspore/lite/src/ops/where.cc +++ b/mindspore/lite/src/ops/where.cc @@ -53,6 +53,11 @@ int Where::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } int num = input0->ElementsNum(); int num1 = input1->ElementsNum(); int num2 = input2->ElementsNum(); @@ -85,8 +90,6 @@ int Where::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/zeros_like.cc b/mindspore/lite/src/ops/zeros_like.cc index 83fe130c37aa853ad6c53c9d08ea3dd3acf9ad82..eaa87c1cf373df3e53c8f6fd28f7058a30557eb5 100644 --- a/mindspore/lite/src/ops/zeros_like.cc +++ b/mindspore/lite/src/ops/zeros_like.cc @@ -29,10 +29,12 @@ int ZerosLike::InferShape(std::vector inputs_, std::vect << ", output size: " << outputs_.size(); return RET_INPUT_TENSOR_ERROR; } - output->set_shape(input->shape()); output->set_data_type(input->data_type()); output->SetFormat(input->GetFormat()); - + if (!GetInferFlag()) { + return RET_OK; + } + output->set_shape(input->shape()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c index 5c01441cd18aab4aa2cc0702e5b344339fad840e..83b256055e50fb4d08742301e7a17e26085bf468 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.c @@ -18,15 +18,29 @@ #include int ArgCompareAscFp32(const void *a, const void *b) { - return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_; + float a_value = ((ArgElement *)a)->data_.f_data_; + float b_value = ((ArgElement *)b)->data_.f_data_; + if (b_value > a_value) { + return -1; + } + if (b_value < a_value) { + return 1; + } + + return 0; } int ArgCompareDescFp32(const void *a, const void *b) { - // cmp funtion of qsort must return int type - auto b_value = ((ArgElement *)b)->data_.f_data_; - auto a_value = ((ArgElement *)a)->data_.f_data_; - int res = b_value > a_value ? 1 : -1; - return res; + float b_value = ((ArgElement *)b)->data_.f_data_; + float a_value = ((ArgElement *)a)->data_.f_data_; + if (b_value > a_value) { + return 1; + } + if (b_value < a_value) { + return -1; + } + + return 0; } void ArgMaxDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {