提交 2cbb280b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4678 change ops getter

Merge pull request !4678 from yeyunpeng2020/master_cops_3
......@@ -61,18 +61,17 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (!GetInferFlag()) {
return RET_OK;
}
auto argmax_prim = this->primitive->value_as_ArgMax();
std::vector<int> output_shape(input->shape());
auto input_shape_size = input->shape().size();
int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis();
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
if (axis >= input_shape_size || axis < 0) {
MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size;
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
return RET_PARAM_INVALID;
}
if (argmax_prim->topK() == 1 && !argmax_prim->keepDims()) {
if (GetTopK() == 1 && !GetKeepDims()) {
output_shape.erase(output_shape.begin() + axis);
} else {
output_shape[axis] = argmax_prim->topK();
output_shape[axis] = GetTopK();
}
output->set_shape(output_shape);
......
......@@ -46,7 +46,7 @@ void ArgMin::SetKeepDims(bool keep_dims) {}
void ArgMin::SetAxisType(int axis_type) {}
#endif
int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
......@@ -60,18 +60,17 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (!GetInferFlag()) {
return RET_OK;
}
auto argmin_prim = this->primitive->value_as_ArgMin();
auto input_shape_size = input->shape().size();
int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis();
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
if (axis >= input_shape_size || axis < 0) {
MS_LOG(ERROR) << "Invalid axis " << argmin_prim->axis() << ", input shape size: " << input_shape_size;
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
return RET_PARAM_INVALID;
}
std::vector<int> output_shape(input->shape());
if (argmin_prim->topK() == 1 && !argmin_prim->keepDims()) {
if (GetTopK() == 1 && !GetKeepDims()) {
output_shape.erase(output_shape.begin() + axis);
} else {
output_shape[axis] = argmin_prim->topK();
output_shape[axis] = GetTopK();
}
output->set_shape(output_shape);
......
......@@ -39,11 +39,10 @@ constexpr int kBroadcastToInputNum = 1;
constexpr int kBroadcastToOutputNum = 1;
} // namespace
int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
return 1;
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
outputs[0]->SetFormat(input->GetFormat());
......@@ -51,27 +50,26 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end());
auto input_shape = input->shape();
std::vector<int> shape(dst_shape.size());
int input_shape_index = input_shape.size() - 1;
if (input_shape.size() > dst_shape.size()) {
MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size "
<< dst_shape.size() << "!";
return 1;
return RET_PARAM_INVALID;
}
for (int i = dst_shape.size() - 1; i >= 0; --i) {
if (dst_shape[i] < 0) {
MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!";
return 1;
return RET_PARAM_INVALID;
}
if (input_shape_index >= 0) {
auto dim = input_shape[input_shape_index];
if (dim != dst_shape[i] && dim != 1) {
MS_LOG(ERROR) << "Invalid broadcast shape!";
return 1;
return RET_PARAM_INVALID;
}
}
shape[i] = dst_shape[i];
......
......@@ -45,14 +45,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
auto cast_prim = this->primitive->value_as_Cast();
MS_ASSERT(cast_prim != nullptr);
output->set_data_type(static_cast<TypeId>(cast_prim->dstT()));
output->set_data_type(static_cast<TypeId>(GetDstT()));
if (!GetInferFlag()) {
return RET_OK;
}
if (input->data_type() != cast_prim->srcT()) {
if (input->data_type() != GetSrcT()) {
MS_LOG(ERROR) << "input dataType is error";
return RET_INPUT_TENSOR_ERROR;
}
......
......@@ -55,10 +55,10 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (!GetInferFlag()) {
return RET_OK;
}
auto concat_prim = this->primitive->value_as_Concat();
MS_ASSERT(concat_prim != nullptr);
auto input0_shape = inputs_.at(0)->shape();
int axis = concat_prim->axis() < 0 ? concat_prim->axis() + input0_shape.size() : concat_prim->axis();
int axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis();
if (axis < 0 || axis >= input0_shape.size()) {
MS_LOG(ERROR) << "Invalid axis: " << axis;
return RET_PARAM_INVALID;
......
......@@ -41,7 +41,6 @@ constexpr int kCropOutputNum = 1;
constexpr int kCropInputNum = 2;
} // namespace
int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;
......
......@@ -139,7 +139,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
} else {
MS_LOG(ERROR) << "unsupported pad mode for deconv";
}
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
output->set_shape(out_shape);
return 0;
......
......@@ -154,7 +154,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.at(2) = output_w;
if (GetChannelMultiplier() * input_channel != weight->shape()[0]) {
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
return 1;
return RET_ERROR;
}
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
......
......@@ -42,13 +42,13 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
if (input->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return 1;
return RET_FORMAT_ERR;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
......@@ -58,14 +58,14 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
return 1;
return RET_PARAM_INVALID;
}
int32_t block_size = GetBlockSize();
if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) {
MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size("
<< block_size << ") * block_size)!";
return 1;
return RET_PARAM_INVALID;
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N];
......
......@@ -47,8 +47,7 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if (!GetInferFlag()) {
return RET_OK;
}
auto expand_dims_prim = this->primitive->value_as_ExpandDims();
int dim = expand_dims_prim->dim();
int dim = GetDim();
if (dim < 0) {
dim += input->shape().size() + 1;
}
......
......@@ -58,10 +58,10 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (!GetInferFlag()) {
return RET_OK;
}
auto gather_prim = this->primitive->value_as_Gather();
MS_ASSERT(gather_prim != nullptr);
int axis = gather_prim->axis();
int batch_dims = gather_prim->batchDims();
int axis = GetAxis();
int batch_dims = GetBatchDims();
if (axis < 0) {
axis += input->shape().size();
}
......
......@@ -58,18 +58,18 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "OpLstm input dims should be 3.";
return RET_ERROR;
}
auto lstm_prim = this->primitive->value_as_Lstm();
int hidden_size = w_shape[1] / 4;
// set output
std::vector<int> out_shape(in_shape);
out_shape[2] = hidden_size;
if (lstm_prim->bidirection()) {
if (GetBidirection()) {
out_shape.insert(out_shape.begin() + 1, 2);
}
output->set_shape(out_shape);
// set hidden state, cell state
std::vector<int> state_shape(in_shape);
state_shape[0] = lstm_prim->bidirection() ? 2 : 1;
state_shape[0] = GetBidirection() ? 2 : 1;
state_shape[2] = hidden_size;
outputs_[1]->set_shape(state_shape);
outputs_[2]->set_shape(state_shape);
......
......@@ -62,11 +62,11 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
return RET_INPUT_TENSOR_ERROR;
}
}
auto matmul_prim = this->primitive->value_as_MatMul();
if (matmul_prim->transposeA()) {
if (GetTransposeA()) {
std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]);
}
if (matmul_prim->transposeB()) {
if (GetTransposeB()) {
std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]);
}
std::vector<int> c_shape(a_shape);
......
......@@ -58,12 +58,12 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
auto mean_prim = this->primitive->value_as_Mean();
bool keep_dims = static_cast<bool>(mean_prim->keepDims());
bool keep_dims = static_cast<bool>(GetKeepDims());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
const auto &axes = mean_prim->axis();
auto num_axes = axes->size();
const auto &axes = GetAxis();
auto num_axes = axes.size();
// reduce on all axes
if (num_axes == 0) {
if (keep_dims) {
......@@ -79,7 +79,7 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((*axes)[idx]) == i) {
if (static_cast<size_t>(axes[idx]) == i) {
reduce_axis = true;
break;
}
......
......@@ -37,11 +37,8 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
auto one_hot_prim = this->primitive->value_as_OneHot();
if (one_hot_prim == nullptr) {
return RET_NULL_PTR;
}
int axis = one_hot_prim->axis();
int axis = GetAxis();
// indices, depth, on_value, off_value
if (inputs.size() != kOneHotInputNum) {
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
......
......@@ -49,14 +49,9 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
auto pad_prim = this->primitive->value_as_Pad();
if (pad_prim == nullptr) {
return RET_NULL_PTR;
}
auto paddings = pad_prim->paddings();
if (paddings == nullptr) {
return RET_NULL_PTR;
}
auto paddings = GetPaddings();
auto input = inputs.front();
if (input == nullptr) {
return RET_NULL_PTR;
......@@ -75,7 +70,7 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
MS_ASSERT(input->shape().size() <= kInputRank);
for (size_t i = 0; i < input_shape.size(); i++) {
auto paddings_index = i + kInputRank - input_shape.size();
auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1];
auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1];
output_shape.push_back(shape);
}
......
......@@ -97,7 +97,6 @@ constexpr int kPriorBoxW = 1;
constexpr int kPriorBoxC = 2;
} // namespace
int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
auto param = this->primitive->value_as_PriorBox();
MS_ASSERT(param != nullptr);
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
......@@ -109,20 +108,20 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
return RET_OK;
}
std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = param->aspect_ratios();
auto aspect_ratios = GetAspectRatios();
MS_ASSERT(aspect_ratios != nullptr);
for (auto i = 0; i < aspect_ratios->size(); i++) {
float ratio = (*aspect_ratios)[i];
for (auto i = 0; i < aspect_ratios.size(); i++) {
float ratio = aspect_ratios[i];
bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(),
[&](float v) { return abs(ratio - v) < 1e-6; });
if (!exist) {
different_aspect_ratios.emplace_back(ratio);
if (param->flip()) {
if (GetFlip()) {
different_aspect_ratios.emplace_back(1.0f / ratio);
}
}
}
int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size();
int32_t num_priors_box = GetMinSizes().size() * different_aspect_ratios.size() + GetMaxSizes().size();
int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints;
std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC};
output->set_shape(output_shape);
......
......@@ -40,9 +40,8 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
auto param = primitive->value_as_QuantDTypeCast();
MS_ASSERT(input->data_type() == param->srcT);
output->set_data_type(static_cast<TypeId>(param->dstT()));
output->set_data_type(static_cast<TypeId>(GetDstT()));
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
......
......@@ -48,7 +48,7 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
auto range_prim = this->primitive->value_as_Range();
MS_ASSERT(range_prim != nullptr);
output->set_data_type(input->data_type());
......@@ -57,7 +57,7 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
return RET_OK;
}
int shape_size = std::ceil(static_cast<float>(range_prim->limit() - range_prim->start()) / range_prim->delta());
int shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta());
std::vector<int> in_shape(1);
in_shape.push_back(shape_size);
output->set_shape(in_shape);
......
......@@ -62,12 +62,12 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
auto reduce_prim = this->primitive->value_as_Reduce();
bool keep_dims = static_cast<bool>(reduce_prim->keepDims());
bool keep_dims = static_cast<bool>(GetKeepDims());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
const auto &axes = reduce_prim->axes();
auto num_axes = axes->size();
const auto &axes = GetAxes();
auto num_axes = axes.size();
// reduce on all axes
if (num_axes == 0) {
if (keep_dims) {
......@@ -83,7 +83,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) {
if (static_cast<size_t>(axes[idx]) == i || static_cast<size_t>(axes[idx] + in_shape.size()) == i) {
reduce_axis = true;
break;
}
......
......@@ -61,9 +61,9 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if (!GetInferFlag()) {
return RET_OK;
}
auto ROIPooling = this->primitive->value_as_ROIPooling();
auto new_h = ROIPooling->pooledH();
auto new_w = ROIPooling->pooledW();
auto new_h = GetPooledH();
auto new_w = GetPooledW();
auto shape_data = roi->shape();
std::vector<int> output_shape;
output_shape.push_back(shape_data[0]);
......
......@@ -55,12 +55,10 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
auto in_shape = in_tensor->shape();
std::vector<int> out_shape;
// todo: getAxis
auto squeeze_prim = this->primitive->value_as_Squeeze();
MS_EXCEPTION_IF_NULL(squeeze_prim);
auto axis = squeeze_prim->axis();
auto axis = GetAxis();
std::vector<int> axes_;
for (auto iter = axis->begin(); iter != axis->end(); iter++) {
for (auto iter = axis.begin(); iter != axis.end(); iter++) {
axes_.push_back(*iter);
}
if (axes_.size() == 0) {
......
......@@ -62,11 +62,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_OK;
}
auto input_shape = input->shape();
auto stack_prim = this->primitive->value_as_Stack();
std::vector<int32_t> output_shape = input_shape;
int axis = stack_prim->axis() < 0 ? stack_prim->axis() + input_shape.size() : stack_prim->axis();
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
if (axis < 0 || axis > input_shape.size()) {
MS_LOG(ERROR) << "Invalid axis " << stack_prim->axis();
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
return RET_PARAM_INVALID;
}
schema::Format input0_format = input->GetFormat();
......
......@@ -174,10 +174,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
std::vector<int> output_shape;
ndim_ = static_cast<int>(GetBegin().size());
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
for (int i = 0; i < ndim_; i++) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back((GetBegin())[i]);
......
......@@ -53,10 +53,9 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if (!GetInferFlag()) {
return RET_OK;
}
auto topk_prim = this->primitive->value_as_TopK();
MS_ASSERT(topk_prim != nullptr);
auto out_shape = input->shape();
out_shape[out_shape.size() - 1] = topk_prim->k();
out_shape[out_shape.size() - 1] = GetK();
output0->set_shape(out_shape);
output1->set_shape(out_shape);
return RET_OK;
......
......@@ -53,11 +53,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
if (!GetInferFlag()) {
return RET_OK;
}
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
auto dims = unsqueeze_prim->axis()->data();
auto dims = GetAxis().data();
auto in_shape = input->shape();
auto in_rank = in_shape.size();
auto dim_rank = unsqueeze_prim->axis()->size();
auto dim_rank = GetAxis().size();
std::vector<int> out_shape;
if (dim_rank == 0) {
for (auto d : in_shape) {
......
......@@ -38,10 +38,10 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
auto input = inputs.at(0);
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
auto prim = this->primitive->value_as_Unstack();
int axis = prim->axis() < 0 ? prim->axis() + input_shape.size() : prim->axis();
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
if (axis < 0 || axis >= input_shape.size()) {
MS_LOG(ERROR) << "Invalid axis " << prim->axis();
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
return RET_PARAM_INVALID;
}
for (auto &out : outputs) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册