提交 0fcf3774 编写于 作者: alinag's avatar alinag

Add reduce sum test of some cases (dimension 3)

test=develop
上级 16b8f9de
...@@ -119,6 +119,55 @@ void reduce_sum_w(const float* src, ...@@ -119,6 +119,55 @@ void reduce_sum_w(const float* src,
} }
} }
void reduce_sum_c(
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index, src_index;
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = h * width_in + w;
dst[data_index] = 0.0;
for (int n = 0; n < channel_in; ++n) {
src_index = n * hw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]);
}
}
}
}
void reduce_sum_h(
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index, src_index0, src_index;
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = c * width_in + w;
src_index0 = c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]);
}
}
}
}
void reduce_sum_w(
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index = 0;
int src_index0 = 0;
int src_index = 0;
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = c * height_in + h;
src_index0 = c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]);
}
}
}
}
void reduce_sum_all(const float* src, void reduce_sum_all(const float* src,
float* dst, float* dst,
int num_in, int num_in,
...@@ -255,39 +304,60 @@ class ReduceSumComputeTester : public arena::TestCase { ...@@ -255,39 +304,60 @@ class ReduceSumComputeTester : public arena::TestCase {
out->Resize(DDim(out_dims)); out->Resize(DDim(out_dims));
auto* out_data = out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
int in_n = x_dims_[0]; if (x_dims_.size() == 4) {
int in_c = x_dims_[1]; int in_n = x_dims_[0];
int in_h = x_dims_[2]; int in_c = x_dims_[1];
int in_w = x_dims_[3]; int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (reduce_all_) { if (reduce_all_) {
reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w); reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_.size() == 1) { } else if (dim_.size() == 1) {
switch (dim_[0]) { switch (dim_[0]) {
case 0: case 0:
reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w); reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w);
break; break;
case 1: case 1:
reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w); reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w);
break; break;
case 2: case 2:
reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w); reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w);
break; break;
case 3: case 3:
reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w); reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w);
break; break;
default: default:
LOG(FATAL) << "error!!!"; LOG(FATAL) << "error!!!";
}
} else if (dim_.size() == 2) {
if (dim_[0] == 0 && dim_[1] == 1) {
reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 1 && dim_[1] == 2) {
reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!";
}
} }
} else if (dim_.size() == 2) { } else {
if (dim_[0] == 0 && dim_[1] == 1) { int in_c = x_dims_[0];
reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w); int in_h = x_dims_[1];
} else if (dim_[0] == 1 && dim_[1] == 2) { int in_w = x_dims_[2];
reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w); if (dim_.size() == 1 && !reduce_all_) {
} else if (dim_[0] == 2 && dim_[1] == 3) { switch (dim_[0]) {
reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w); case 0:
} else { reduce_sum_c(x_data, out_data, in_c, in_h, in_w);
LOG(FATAL) << "invalid dims_!!"; break;
case 1:
reduce_sum_h(x_data, out_data, in_c, in_h, in_w);
break;
case 2:
reduce_sum_w(x_data, out_data, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
}
} }
} }
} }
...@@ -333,6 +403,20 @@ void test_reduce_sum(Place place) { ...@@ -333,6 +403,20 @@ void test_reduce_sum(Place place) {
} }
} }
} }
std::vector<std::vector<int>> reduce_dimm{{0}, {1}, {2}};
for (auto dim : reduce_dimm) {
for (auto c : {1, 3}) {
for (auto h : {1, 3}) {
for (auto w : {1, 4}) {
auto x_dims = DDim(std::vector<int64_t>({c, h, w}));
std::unique_ptr<arena::TestCase> tester(new ReduceSumComputeTester(
place, "def", dim, false, false, x_dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
} }
TEST(ReduceSum, precision) { TEST(ReduceSum, precision) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册