提交 030af09f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4665 [MS][LITE][Develop]support infer data type and format when infer shape fail

Merge pull request !4665 from chenjianping/lite_dev3
......@@ -43,6 +43,11 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
for (int i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
......@@ -53,9 +58,8 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return RET_INPUT_TENSOR_ERROR;
}
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto argmax_prim = this->primitive->value_as_ArgMax();
std::vector<int> output_shape(input->shape());
auto input_shape_size = input->shape().size();
......@@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} else {
output_shape[axis] = argmax_prim->topK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto argmin_prim = this->primitive->value_as_ArgMin();
auto input_shape_size = input->shape().size();
int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis();
......@@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} else {
output_shape[axis] = argmin_prim->topK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
return 1;
}
auto input = inputs.at(0);
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
auto input_shape = input->shape();
......@@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
shape[i] = dst_shape[i];
--input_shape_index;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(shape);
outputs[0]->set_data_type(input->data_type());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "tensor number is error.";
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
auto cast_prim = this->primitive->value_as_Cast();
MS_ASSERT(cast_prim != nullptr);
output->set_data_type(static_cast<TypeId>(cast_prim->dstT()));
if (!GetInferFlag()) {
return RET_OK;
}
if (input->data_type() != cast_prim->srcT()) {
MS_LOG(ERROR) << "input dataType is error";
return RET_INPUT_TENSOR_ERROR;
......@@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
return RET_INPUT_TENSOR_ERROR;
}
if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT();
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeFloat32);
return RET_OK;
}
} // namespace lite
......
......@@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
return RET_ERROR;
}
auto in_tensor = inputs_.front();
auto in_data = reinterpret_cast<int *>(in_tensor->Data());
auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeFloat32);
out_tensor->SetFormat(in_tensor->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_data = reinterpret_cast<int *>(in_tensor->Data());
int size = in_tensor->ElementsNum();
std::vector<int> out_shape(size);
for (int i = 0; i < size; ++i) {
out_shape[i] = in_data[i];
}
out_tensor->set_shape(out_shape);
out_tensor->set_data_type(kNumberTypeFloat32);
out_tensor->SetFormat(in_tensor->GetFormat());
return RET_OK;
}
......
......@@ -46,9 +46,12 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;
}
outputs[0]->set_shape(inputs[1]->shape());
outputs[0]->SetFormat(inputs[0]->GetFormat());
outputs[0]->set_data_type(inputs[0]->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
outputs[0]->set_shape(inputs[1]->shape());
return RET_OK;
}
} // namespace lite
......
......@@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
int32_t input_h = input->Height();
int32_t input_w = input->Width();
......@@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
output->set_shape(out_shape);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return 0;
}
} // namespace lite
......
......@@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
......@@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
output->set_shape(out_shape);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return 0;
}
} // namespace lite
......
......@@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return 1;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
......@@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape[NHWC_W] = input_shape[NHWC_W] * block_size;
output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
......@@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
output->set_shape(out_shape);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return 0;
}
} // namespace lite
......
......@@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
MS_ASSERT(ids != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(params_->GetFormat());
output->set_data_type(params_->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto embedding_shape = params_->shape();
embedding_shape.erase(embedding_shape.begin());
std::vector<int> output_shape(ids->shape());
......@@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
}
}
output->set_shape(output_shape);
output->set_data_type(params_->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid";
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto expand_dims_prim = this->primitive->value_as_ExpandDims();
int dim = expand_dims_prim->dim();
if (dim < 0) {
......@@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto out_shape = input->shape();
out_shape.insert(out_shape.begin() + dim, 1, 1);
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -45,6 +45,11 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto fill_prim = this->primitive->value_as_Fill();
if (fill_prim == nullptr) {
MS_LOG(ERROR) << "Fill primitive is null!";
......@@ -53,8 +58,6 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std::vector<int> output_shape;
(void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
std::vector<int> output_shape(2);
output_shape[0] = input_shape[0];
......@@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output_shape[1] *= input_shape[i];
}
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
MS_LOG(ERROR) << "Input tensors num error";
return 1;
......@@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.resize(GetAxis() + 1);
out_shape[GetAxis()] = input1->shape()[0];
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return 0;
}
......
......@@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
MS_ASSERT(indices != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input->shape();
int in_rank = in_shape.size();
auto indices_shape = indices->shape();
......@@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
out_shape.emplace_back(in_shape[i]);
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT(input0 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
for (int i = 0; i < kLstmOutputNum; i++) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> in_shape = input->shape();
std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size
if (in_shape.size() != 3 || w_shape.size() != 3) {
......@@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
state_shape[2] = hidden_size;
outputs_[1]->set_shape(state_shape);
outputs_[2]->set_shape(state_shape);
for (int i = 0; i < kLstmOutputNum; i++) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
return RET_OK;
}
} // namespace lite
......
......@@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape();
if (a_shape.size() < 2 || b_shape.size() < 2) {
......@@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std::vector<int> c_shape(a_shape);
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
output->set_shape(c_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -50,6 +50,11 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if (input == nullptr || output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
......@@ -88,8 +93,6 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(schema::Format_NHWC);
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> nchw_shape = input->shape();
if (nchw_shape.size() != 4) {
output->set_shape(nchw_shape);
......@@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nhwc_shape[NHWC_C] = nchw_shape[NCHW_C];
output->set_shape(nhwc_shape);
}
output->SetFormat(schema::Format_NHWC);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(schema::Format_NCHW);
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> nhwc_shape = input->shape();
if (nhwc_shape.size() != 4) {
output->set_shape(nhwc_shape);
......@@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nchw_shape[NCHW_W] = nhwc_shape[NHWC_W];
output->set_shape(nchw_shape);
}
output->SetFormat(schema::Format_NCHW);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
if (input == nullptr) {
return RET_NULL_PTR;
}
auto on_value = inputs.at(2);
if (on_value == nullptr) {
return RET_NULL_PTR;
}
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
const auto input_shape = input->shape();
int input_rank = static_cast<int>(input_shape.size());
if (axis < 0) {
......@@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
}
std::vector<int> output_shape(input_shape);
output_shape.insert(output_shape.cbegin() + axis, *depth);
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->set_shape(output_shape);
auto on_value = inputs.at(2);
if (on_value == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -61,6 +61,15 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if (input == nullptr) {
return RET_NULL_PTR;
}
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
std::vector<int> output_shape;
MS_ASSERT(input->shape().size() <= kInputRank);
......@@ -69,13 +78,8 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1];
output_shape.push_back(shape);
}
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(schema::Format_NHWC);
if (!GetInferFlag()) {
return RET_OK;
}
int input_h = input->shape().at(1);
int input_w = input->shape().at(2);
auto pooling_prim = this->primitive->value_as_Pooling();
......@@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape.at(1) = output_h;
input_shape.at(2) = output_w;
output->set_shape(input_shape);
output->set_data_type(input->data_type());
// todo: temp fix
output->SetFormat(schema::Format_NHWC);
return RET_OK;
}
} // namespace lite
......
......@@ -49,15 +49,19 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
auto output_tensor = outputs[0];
MS_ASSERT(output_tensor != nullptr);
output_tensor->set_data_type(x_tensor->data_type());
output_tensor->SetFormat(x_tensor->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if (exp_tensor != nullptr) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
}
output_tensor->SetFormat(x_tensor->GetFormat());
output_tensor->set_shape(x_tensor->shape());
output_tensor->set_data_type(x_tensor->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -99,6 +99,15 @@ constexpr int kPriorBoxC = 2;
int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
auto param = this->primitive->value_as_PriorBox();
MS_ASSERT(param != nullptr);
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);
output->set_data_type(kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = param->aspect_ratios();
MS_ASSERT(aspect_ratios != nullptr);
......@@ -114,15 +123,9 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
}
}
int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size();
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints;
std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC};
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);
output->set_shape(output_shape);
output->set_data_type(kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -40,11 +40,14 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
auto param = primitive->value_as_QuantDTypeCast();
MS_ASSERT(input->data_type() == param->srcT);
output->set_data_type(static_cast<TypeId>(param->dstT()));
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}
} // namespace lite
......
......@@ -50,12 +50,18 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_ASSERT(output != nullptr);
auto range_prim = this->primitive->value_as_Range();
MS_ASSERT(range_prim != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
int shape_size = std::ceil(static_cast<float>(range_prim->limit() - range_prim->start()) / range_prim->delta());
std::vector<int> in_shape(1);
in_shape.push_back(shape_size);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -25,10 +25,13 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> in_shape(1, 1);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> in_shape(1, 1);
output->set_shape(in_shape);
return RET_OK;
}
} // namespace lite
......
......@@ -66,6 +66,11 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
if (output == nullptr) {
return 1;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto new_height = GetNewHeight();
auto new_width = GetNewWidth();
......@@ -75,10 +80,8 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
output_shape.push_back(new_width);
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -52,9 +52,13 @@ int ReverseSequence::InferShape(std::vector<tensor::Tensor *> inputs, std::vecto
auto output = outputs.front();
MS_ASSERT(input != nullptr);
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}
} // namespace lite
......
......@@ -56,6 +56,11 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if (output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto ROIPooling = this->primitive->value_as_ROIPooling();
auto new_h = ROIPooling->pooledH();
auto new_w = ROIPooling->pooledW();
......@@ -66,8 +71,6 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
output_shape.push_back(new_w);
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -51,11 +51,14 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
return RET_ERROR;
}
auto output = outputs_.front();
output->set_data_type(update->data_type());
output->SetFormat(update->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto shape_data = reinterpret_cast<int *>(shape->Data());
std::vector<int> out_shape(shape_data, shape_data + shape->DataSize());
output->set_shape(out_shape);
output->set_data_type(update->data_type());
output->SetFormat(update->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -63,6 +63,11 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG(ERROR) << "space_to_batch only support NHWC now!";
return 1;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
......@@ -106,8 +111,7 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H];
output_shape[NHWC_C] = input_shape[NHWC_C];
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -51,6 +51,11 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG(ERROR) << "space_to_depth only support NHWC now!";
return 1;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
......@@ -69,8 +74,7 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape[NHWC_W] = input_shape[NHWC_W] / block_size;
output_shape[NHWC_C] = input_shape[NHWC_C] * (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -66,6 +66,13 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
return RET_ERROR;
}
for (int i = 0; i < number_split; ++i) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
if (!GetInferFlag()) {
return RET_OK;
}
int split_dim = spilt_prim->splitDim();
std::vector<int> input_shape = input->shape();
std::vector<int> size_split;
......
......@@ -48,6 +48,11 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
return -1;
}
auto *in_tensor = inputs_.front();
outputs_.front()->set_data_type(in_tensor->data_type());
outputs_.front()->SetFormat(in_tensor->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = in_tensor->shape();
std::vector<int> out_shape;
// todo: getAxis
......@@ -77,8 +82,6 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
}
outputs_.front()->set_shape(out_shape);
outputs_.front()->set_data_type(in_tensor->data_type());
outputs_.front()->SetFormat(in_tensor->GetFormat());
return 0;
}
} // namespace lite
......
......@@ -56,6 +56,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
auto stack_prim = this->primitive->value_as_Stack();
std::vector<int32_t> output_shape = input_shape;
......@@ -84,8 +89,6 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
output_shape.insert(output_shape.begin() + axis, inputs.size());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -164,6 +164,11 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
outputs.front()->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
std::vector<int> output_shape;
......@@ -214,8 +219,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape = ApplyShrinkMask(output_shape);
outputs.front()->set_shape(output_shape);
outputs.front()->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
......
......@@ -40,6 +40,11 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto tile_prim = this->primitive->value_as_Tile();
MS_ASSERT(tile_prim != nullptr);
std::vector<int> out_shape;
......@@ -49,9 +54,8 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
int tmp = input->shape()[i] * multiples[i];
out_shape.push_back(tmp);
}
output->SetFormat(input->GetFormat());
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite
......
......@@ -46,16 +46,19 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT(output0 != nullptr);
auto output1 = outputs_.at(1);
MS_ASSERT(output1 != nullptr);
output0->set_data_type(input->data_type());
output0->SetFormat(input->GetFormat());
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto topk_prim = this->primitive->value_as_TopK();
MS_ASSERT(topk_prim != nullptr);
auto out_shape = input->shape();
out_shape[out_shape.size() - 1] = topk_prim->k();
output0->set_shape(out_shape);
output0->set_data_type(input->data_type());
output0->SetFormat(input->GetFormat());
output1->set_shape(out_shape);
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -42,12 +42,15 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(output0 != nullptr);
auto &output1 = outputs_.at(1);
MS_ASSERT(output1 != nullptr);
output0->set_shape(input->shape());
output0->set_data_type(input->data_type());
output1->set_shape(input->shape());
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
output0->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output0->set_shape(input->shape());
output1->set_shape(input->shape());
return RET_OK;
}
} // namespace lite
......
......@@ -44,6 +44,14 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
MS_LOG(ERROR) << "Invalid axis " << prim->axis();
return RET_PARAM_INVALID;
}
for (auto &out : outputs) {
MS_ASSERT(out != nullptr);
out->set_data_type(input->data_type());
out->SetFormat(input->GetFormat());
}
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> output_shape;
for (size_t i = 0; i < input_shape.size(); ++i) {
if (i != axis) {
......@@ -53,8 +61,6 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
for (auto &out : outputs) {
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(input->data_type());
out->SetFormat(input->GetFormat());
}
return RET_OK;
}
......
......@@ -53,6 +53,11 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto input0 = inputs_.at(0);
auto input1 = inputs_.at(1);
auto input2 = inputs_.at(2);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
int num = input0->ElementsNum();
int num1 = input1->ElementsNum();
int num2 = input2->ElementsNum();
......@@ -85,8 +90,6 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto output_shape = shape_tmp;
output_shape[axisout] = nummax;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
......
......@@ -29,10 +29,12 @@ int ZerosLike::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
<< ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}
} // namespace lite
......
......@@ -18,15 +18,29 @@
#include <float.h>
int ArgCompareAscFp32(const void *a, const void *b) {
return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_;
float a_value = ((ArgElement *)a)->data_.f_data_;
float b_value = ((ArgElement *)b)->data_.f_data_;
if (b_value > a_value) {
return -1;
}
if (b_value < a_value) {
return 1;
}
return 0;
}
int ArgCompareDescFp32(const void *a, const void *b) {
// cmp funtion of qsort must return int type
auto b_value = ((ArgElement *)b)->data_.f_data_;
auto a_value = ((ArgElement *)a)->data_.f_data_;
int res = b_value > a_value ? 1 : -1;
return res;
float b_value = ((ArgElement *)b)->data_.f_data_;
float a_value = ((ArgElement *)a)->data_.f_data_;
if (b_value > a_value) {
return 1;
}
if (b_value < a_value) {
return -1;
}
return 0;
}
void ArgMaxDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册