提交 756b8346 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4263 [MS][LITE][Develop]argmax,argmin support keepdim

Merge pull request !4263 from chenjianping/lite_dev
......@@ -38,9 +38,9 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size;
return RET_PARAM_INVALID;
}
if (argmax_prim->topK() == 1) {
if (argmax_prim->topK() == 1 && !argmax_prim->keepDims()) {
output_shape.erase(output_shape.begin() + axis);
} else if (argmax_prim->axisType() == 1) {
} else {
output_shape[axis] = argmax_prim->topK();
}
......
......@@ -37,9 +37,9 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
return RET_PARAM_INVALID;
}
std::vector<int> output_shape(input->shape());
if (argmin_prim->topK() == 1) {
if (argmin_prim->topK() == 1 && !argmin_prim->keepDims()) {
output_shape.erase(output_shape.begin() + axis);
} else if (argmin_prim->axisType() == 1) {
} else {
output_shape[axis] = argmin_prim->topK();
}
......
......@@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/nnacl/arg_min_max.h"
#include "src/runtime/kernel/arm/fp32/argminmax.h"
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
......@@ -60,7 +61,7 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
return RET_PARAM_INVALID;
}
param->topk_ = MSMIN(param->topk_, in_shape[axis]);
if (param->topk_ > 1) {
if (param->topk_ > 1 || param->keep_dims_) {
if (context_ != nullptr && context_->allocator != nullptr) {
param->arg_elements_ =
reinterpret_cast<ArgElement *>(context_->allocator->Malloc(sizeof(ArgElement) * in_shape[axis]));
......@@ -73,6 +74,9 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
return RET_ERROR;
}
}
ComputeStrides(in_shape.data(), param->in_strides_, in_shape.size());
auto out_shape = outputs_.at(0)->shape();
ComputeStrides(out_shape.data(), param->out_strides_, out_shape.size());
return RET_OK;
}
......
......@@ -89,7 +89,7 @@ void ArgMinMaxTopknFp32(const float *input, float *output, const int *in_shape,
}
void ArgMinMax(const void *input, void *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->topk_ == 1) {
if (param->topk_ == 1 && !param->keep_dims_) {
ArgMinMaxTopk1(input, output, in_shape, param);
return;
}
......
......@@ -40,6 +40,34 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1) {
param.data_type_ = 43;
param.dims_size_ = 2;
param.get_max_ = true;
param.keep_dims_ = false;
ArgMinMax(in.data(), out, shape.data(), &param);
for (size_t i = 0; i < except_out.size(); ++i) {
std::cout << out[i] << " ";
}
std::cout << "\n";
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001);
}
TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1_keep_dim) {
std::vector<float> in = {10, 20, 30, 40, 90,
20, 11, 15, 1, 50,
30, 45, 25, 50, 30};
std::vector<float> except_out = {2, 2, 0, 2, 0};
std::vector<int> shape = {3, 5};
float out[5];
ArgMinMaxParameter param;
param.topk_ = 1;
param.out_value_ = false;
param.axis_ = 0;
param.data_type_ = 43;
param.dims_size_ = 2;
param.get_max_ = true;
param.keep_dims_ = true;
param.arg_elements_ = reinterpret_cast<ArgElement *>(malloc(shape[param.axis_] * sizeof(ArgElement)));
std::vector<int> out_shape = {1, 5};
ComputeStrides(shape.data(), param.in_strides_, shape.size());
ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size());
ArgMinMax(in.data(), out, shape.data(), &param);
for (size_t i = 0; i < except_out.size(); ++i) {
std::cout << out[i] << " ";
......@@ -62,6 +90,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest2) {
param.data_type_ = 43;
param.dims_size_ = 2;
param.get_max_ = true;
param.keep_dims_ = false;
ArgMinMax(in.data(), out, shape.data(), &param);
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001);
}
......@@ -80,6 +109,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMinTest2) {
param.data_type_ = 43;
param.dims_size_ = 2;
param.get_max_ = false;
param.keep_dims_ = false;
ArgMinMax(in.data(), out, shape.data(), &param);
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册