未验证 提交 e7d43b02 编写于 作者: Q Qinghe JING 提交者: GitHub

add support for 3-dimensional input in reduce max (#4054)

* add support for three-dimentional input in reduce max test=develop
上级 ab7b2855
...@@ -46,6 +46,68 @@ void reduce_n<float>(const float* src, ...@@ -46,6 +46,68 @@ void reduce_n<float>(const float* src,
} }
} }
template <>
void reduce_first_of_three<float>(
const float* src, float* dst, int first_in, int second_in, int third_in) {
for (int i = 0; i < second_in; i++) {
for (int j = 0; j < third_in; j++) {
dst[i * third_in + j] = src[i * third_in + j];
for (int k = 1; k < first_in; k++) {
dst[i * third_in + j] =
src[k * second_in * third_in + i * third_in + j] >
dst[i * third_in + j]
? src[k * second_in * third_in + i * third_in + j]
: dst[i * third_in + j];
}
}
}
}
template <>
void reduce_second_of_three<float>(
const float* src, float* dst, int first_in, int second_in, int third_in) {
for (int i = 0; i < first_in; i++) {
for (int j = 0; j < third_in; j++) {
dst[i * third_in + j] = src[i * second_in * third_in + j];
for (int k = 1; k < second_in; k++) {
dst[i * third_in + j] =
src[i * second_in * third_in + third_in * k + j] >
dst[i * third_in + j]
? src[i * second_in * third_in + third_in * k + j]
: dst[i * third_in + j];
}
}
}
}
template <>
void reduce_third_of_three<float>(
const float* src, float* dst, int first_in, int second_in, int third_in) {
for (int i = 0; i < first_in; i++) {
for (int j = 0; j < second_in; j++) {
dst[i * second_in + j] = src[i * second_in * third_in + j * second_in];
for (int k = 0; k < third_in; k++) {
dst[i * second_in + j] =
src[i * second_in * third_in + j * second_in + k] >
dst[i * second_in + j]
? src[i * second_in * third_in + j * second_in + k]
: dst[i * second_in + j];
}
}
}
}
template <>
void reduce_all_of_three<float>(
const float* src, float* dst, int first_in, int second_in, int third_in) {
float max = src[0];
int total_element = first_in * second_in * third_in;
for (int i = 0; i < total_element; i++) {
max = src[i] > max ? src[i] : max;
}
dst[0] = max;
}
template <> template <>
void reduce_c<float>(const float* src, void reduce_c<float>(const float* src,
float* dst, float* dst,
......
...@@ -35,6 +35,22 @@ void reduce_c(const T* src, ...@@ -35,6 +35,22 @@ void reduce_c(const T* src,
int height_in, int height_in,
int width_in); int width_in);
template <typename T>
void reduce_all_of_three(
const T* src, T* dst, int first_in, int second_in, int third_in);
template <typename T>
void reduce_first_of_three(
const T* src, T* dst, int first_in, int second_in, int third_in);
template <typename T>
void reduce_second_of_three(
const T* src, T* dst, int first_in, int second_in, int third_in);
template <typename T>
void reduce_third_of_three(
const T* src, T* dst, int first_in, int second_in, int third_in);
template <typename T> template <typename T>
void reduce_h(const T* src, void reduce_h(const T* src,
T* dst, T* dst,
......
...@@ -25,6 +25,7 @@ void ReduceMaxCompute::Run() { ...@@ -25,6 +25,7 @@ void ReduceMaxCompute::Run() {
auto& param = Param<operators::ReduceMaxParam>(); auto& param = Param<operators::ReduceMaxParam>();
const float* input = param.X->data<float>(); const float* input = param.X->data<float>();
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
int x_rank = x_dims.size(); int x_rank = x_dims.size();
float* output = param.Out->mutable_data<float>(); float* output = param.Out->mutable_data<float>();
bool keep_dim = param.keep_dim; bool keep_dim = param.keep_dim;
...@@ -37,41 +38,74 @@ void ReduceMaxCompute::Run() { ...@@ -37,41 +38,74 @@ void ReduceMaxCompute::Run() {
} }
} }
} }
int n_in = x_dims[0];
int c_in = x_dims[1]; if (x_dims.size() == 3) {
int h_in = x_dims[2]; if (dim.size() == 0 || dim.size() == 3) {
int w_in = x_dims[3]; lite::arm::math::reduce_all_of_three(
if (dim.size() == 0) { input, output, x_dims[0], x_dims[1], x_dims[2]);
lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in); } else if (dim.size() == 1) {
} else if (dim.size() == 1) { switch (dim[0]) {
switch (dim[0]) { case 0:
case 0: lite::arm::math::reduce_first_of_three(
lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); input, output, x_dims[0], x_dims[1], x_dims[2]);
break; break;
case 1: case 1:
lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); lite::arm::math::reduce_second_of_three(
break; input, output, x_dims[0], x_dims[1], x_dims[2]);
case 2: break;
lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in);
break; case 2:
case 3: lite::arm::math::reduce_third_of_three(
lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); input, output, x_dims[0], x_dims[1], x_dims[2]);
break; break;
default: default:
LOG(FATAL) << "error!!!"; LOG(FATAL) << "error!!!";
}
} else if (dim.size() == 2) {
LOG(FATAL) << "Will support later!!";
} else {
LOG(FATAL) << "dim size should not larger than 3!!!";
} }
} else if (dim.size() == 2) { } else if (x_dims.size() == 4) {
if (dim[0] == 0 && dim[1] == 1) { int n_in = x_dims[0];
lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); int c_in = x_dims[1];
} else if (dim[0] == 1 && dim[1] == 2) { int h_in = x_dims[2];
lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); int w_in = x_dims[3];
} else if (dim[0] == 2 && dim[1] == 3) {
lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); if (dim.size() == 0) {
lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in);
} else if (dim.size() == 1) {
switch (dim[0]) {
case 0:
lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in);
break;
case 1:
lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in);
break;
case 2:
lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in);
break;
case 3:
lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in);
break;
default:
LOG(FATAL) << "error!!!";
}
} else if (dim.size() == 2) {
if (dim[0] == 0 && dim[1] == 1) {
lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 1 && dim[1] == 2) {
lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 2 && dim[1] == 3) {
lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in);
} else {
LOG(FATAL) << "invalid dim!!";
}
} else { } else {
LOG(FATAL) << "invalid dim!!"; LOG(FATAL) << "dim's size over than 2, which is not supported now!!";
} }
} else { } else {
LOG(FATAL) << "dim's size over than 2, which is not supported now!!"; LOG(FATAL) << "only support input with 3&4 dimensions now!!";
} }
} }
......
...@@ -190,6 +190,64 @@ void reduce_hw(const float* src, ...@@ -190,6 +190,64 @@ void reduce_hw(const float* src,
reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in); reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in);
} }
void reduce_first_of_three(
const float* src, float* dst, int first_in, int second_in, int third_in) {
for (int i = 0; i < second_in; i++) {
for (int j = 0; j < third_in; j++) {
dst[i * third_in + j] = src[i * third_in + j];
for (int k = 1; k < first_in; k++) {
dst[i * third_in + j] =
src[k * second_in * third_in + i * third_in + j] >
dst[i * third_in + j]
? src[k * second_in * third_in + i * third_in + j]
: dst[i * third_in + j];
}
}
}
}
void reduce_second_of_three(
const float* src, float* dst, int first_in, int second_in, int third_in) {
for (int i = 0; i < first_in; i++) {
for (int j = 0; j < third_in; j++) {
dst[i * third_in + j] = src[i * second_in * third_in + j];
for (int k = 1; k < second_in; k++) {
dst[i * third_in + j] =
src[i * second_in * third_in + third_in * k + j] >
dst[i * third_in + j]
? src[i * second_in * third_in + third_in * k + j]
: dst[i * third_in + j];
}
}
}
}
void reduce_third_of_three(
const float* src, float* dst, int first_in, int second_in, int third_in) {
for (int i = 0; i < first_in; i++) {
for (int j = 0; j < second_in; j++) {
dst[i * second_in + j] = src[i * second_in * third_in + j * second_in];
for (int k = 0; k < third_in; k++) {
dst[i * second_in + j] =
src[i * second_in * third_in + j * second_in + k] >
dst[i * second_in + j]
? src[i * second_in * third_in + j * second_in + k]
: dst[i * second_in + j];
}
}
}
}
void reduce_all_of_three(
const float* src, float* dst, int first_in, int second_in, int third_in) {
float max = src[0];
int total_element = first_in * second_in * third_in;
for (int i = 0; i < total_element; i++) {
max = src[i] > max ? src[i] : max;
}
dst[0] = max;
}
class ReduceMaxComputeTester : public arena::TestCase { class ReduceMaxComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
...@@ -256,39 +314,69 @@ class ReduceMaxComputeTester : public arena::TestCase { ...@@ -256,39 +314,69 @@ class ReduceMaxComputeTester : public arena::TestCase {
} }
auto* out_data = out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (dim_.size() == 0) { if (x_dims_.size() == 3) {
reduce_all(x_data, out_data, in_n, in_c, in_h, in_w); if (dim_.size() == 0 || dim_.size() == 3) {
} else if (dim_.size() == 1) { reduce_all_of_three(
switch (dim_[0]) { x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]);
case 0: } else if (dim_.size() == 1) {
reduce_n(x_data, out_data, in_n, in_c, in_h, in_w); switch (dim_[0]) {
break; case 0:
case 1: reduce_first_of_three(
reduce_c(x_data, out_data, in_n, in_c, in_h, in_w); x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]);
break; break;
case 2: case 1:
reduce_h(x_data, out_data, in_n, in_c, in_h, in_w); reduce_second_of_three(
break; x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]);
case 3: break;
reduce_w(x_data, out_data, in_n, in_c, in_h, in_w);
break; case 2:
default: reduce_third_of_three(
LOG(FATAL) << "error!!!"; x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]);
} break;
} else if (dim_.size() == 2) { default:
if (dim_[0] == 0 && dim_[1] == 1) { LOG(FATAL) << "error!!!";
reduce_nc(x_data, out_data, in_n, in_c, in_h, in_w); }
} else if (dim_[0] == 1 && dim_[1] == 2) { } else if (dim_.size() == 2) {
reduce_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!"; LOG(FATAL) << "invalid dims_!!";
} else {
LOG(FATAL) << "dim size should not larger than 3!!!";
}
} else if (x_dims_.size() == 4) {
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (dim_.size() == 0) {
reduce_all(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_.size() == 1) {
switch (dim_[0]) {
case 0:
reduce_n(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 1:
reduce_c(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 2:
reduce_h(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 3:
reduce_w(x_data, out_data, in_n, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
}
} else if (dim_.size() == 2) {
if (dim_[0] == 0 && dim_[1] == 1) {
reduce_nc(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 1 && dim_[1] == 2) {
reduce_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!";
}
} }
} }
} }
...@@ -333,6 +421,19 @@ void test_reduce_max(Place place) { ...@@ -333,6 +421,19 @@ void test_reduce_max(Place place) {
} }
} }
void test_reduce_max_for_three(Place place) {
std::vector<std::vector<int>> reduce_dim{{0}, {1}, {2}};
for (bool keep_dim : {false, true}) {
for (auto dim : reduce_dim) {
auto x_dims = DDim(std::vector<int64_t>({3, 4, 5}));
std::unique_ptr<arena::TestCase> tester(
new ReduceMaxComputeTester(place, "def", dim, keep_dim, x_dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
TEST(ReduceMax, precision) { TEST(ReduceMax, precision) {
// #ifdef LITE_WITH_X86 // #ifdef LITE_WITH_X86
// Place place(TARGET(kX86)); // Place place(TARGET(kX86));
...@@ -340,6 +441,7 @@ TEST(ReduceMax, precision) { ...@@ -340,6 +441,7 @@ TEST(ReduceMax, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
test_reduce_max(place); test_reduce_max(place);
test_reduce_max_for_three(place);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册