提交 4be0678b 编写于 作者: alinag's avatar alinag

refine methods of reduce sum

test=develop
上级 0fcf3774
...@@ -119,55 +119,32 @@ void reduce_sum_w(const float* src, ...@@ -119,55 +119,32 @@ void reduce_sum_w(const float* src,
} }
} }
void reduce_sum_c( void reduce_sum(const float* src, float* dst, const DDim& x_dim, int dims) {
const float* src, float* dst, int channel_in, int height_in, int width_in) { int reduce_b[3] = {1, 1, 1};
int hw_size = height_in * width_in; reduce_b[dims] = 0;
int data_index, src_index; DDim reduce_dim{x_dim};
for (int h = 0; h < height_in; ++h) { reduce_dim[dims] = 1;
for (int w = 0; w < width_in; ++w) { int dim_size = 1;
data_index = h * width_in + w; for (int i = 0; i < 3; i++) {
dst[data_index] = 0.0; if (i != dims) {
for (int n = 0; n < channel_in; ++n) { dim_size *= x_dim[i];
src_index = n * hw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]);
}
} }
} }
} for (int i = 0; i < dim_size; i++) {
void reduce_sum_h( dst[i] = 0.0;
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index, src_index0, src_index;
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = c * width_in + w;
src_index0 = c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]);
}
}
} }
} for (int i = 0; i < x_dim[0]; i++) {
void reduce_sum_w( for (int j = 0; j < x_dim[1]; j++) {
const float* src, float* dst, int channel_in, int height_in, int width_in) { for (int k = 0; k < x_dim[2]; k++) {
int hw_size = height_in * width_in; int src_index = i * x_dim[1] * x_dim[2] + j * x_dim[2] + k;
int data_index = 0; int dst_index = i * reduce_dim[1] * reduce_dim[2] * reduce_b[0] +
int src_index0 = 0; j * reduce_dim[2] * reduce_b[1] + k * reduce_b[2];
int src_index = 0; dst[dst_index] += static_cast<float>(src[src_index]);
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = c * height_in + h;
src_index0 = c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]);
} }
} }
} }
} }
void reduce_sum_all(const float* src, void reduce_sum_all(const float* src,
float* dst, float* dst,
int num_in, int num_in,
...@@ -341,22 +318,11 @@ class ReduceSumComputeTester : public arena::TestCase { ...@@ -341,22 +318,11 @@ class ReduceSumComputeTester : public arena::TestCase {
} }
} }
} else { } else {
int in_c = x_dims_[0];
int in_h = x_dims_[1];
int in_w = x_dims_[2];
if (dim_.size() == 1 && !reduce_all_) { if (dim_.size() == 1 && !reduce_all_) {
switch (dim_[0]) { if (dim_[0] < x_dims_.size()) {
case 0: reduce_sum(x_data, out_data, x_dims_, dim_[0]);
reduce_sum_c(x_data, out_data, in_c, in_h, in_w); } else {
break; LOG(FATAL) << "error!!!";
case 1:
reduce_sum_h(x_data, out_data, in_c, in_h, in_w);
break;
case 2:
reduce_sum_w(x_data, out_data, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册