提交 4be0678b 编写于 作者: alinag's avatar alinag

refine methods of reduce sum

test=develop
上级 0fcf3774
......@@ -119,55 +119,32 @@ void reduce_sum_w(const float* src,
}
}
void reduce_sum_c(
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index, src_index;
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = h * width_in + w;
dst[data_index] = 0.0;
for (int n = 0; n < channel_in; ++n) {
src_index = n * hw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]);
void reduce_sum(const float* src, float* dst, const DDim& x_dim, int dims) {
int reduce_b[3] = {1, 1, 1};
reduce_b[dims] = 0;
DDim reduce_dim{x_dim};
reduce_dim[dims] = 1;
int dim_size = 1;
for (int i = 0; i < 3; i++) {
if (i != dims) {
dim_size *= x_dim[i];
}
}
for (int i = 0; i < dim_size; i++) {
dst[i] = 0.0;
}
}
void reduce_sum_h(
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index, src_index0, src_index;
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = c * width_in + w;
src_index0 = c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]);
}
}
}
}
void reduce_sum_w(
const float* src, float* dst, int channel_in, int height_in, int width_in) {
int hw_size = height_in * width_in;
int data_index = 0;
int src_index0 = 0;
int src_index = 0;
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = c * height_in + h;
src_index0 = c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]);
for (int i = 0; i < x_dim[0]; i++) {
for (int j = 0; j < x_dim[1]; j++) {
for (int k = 0; k < x_dim[2]; k++) {
int src_index = i * x_dim[1] * x_dim[2] + j * x_dim[2] + k;
int dst_index = i * reduce_dim[1] * reduce_dim[2] * reduce_b[0] +
j * reduce_dim[2] * reduce_b[1] + k * reduce_b[2];
dst[dst_index] += static_cast<float>(src[src_index]);
}
}
}
}
void reduce_sum_all(const float* src,
float* dst,
int num_in,
......@@ -341,21 +318,10 @@ class ReduceSumComputeTester : public arena::TestCase {
}
}
} else {
int in_c = x_dims_[0];
int in_h = x_dims_[1];
int in_w = x_dims_[2];
if (dim_.size() == 1 && !reduce_all_) {
switch (dim_[0]) {
case 0:
reduce_sum_c(x_data, out_data, in_c, in_h, in_w);
break;
case 1:
reduce_sum_h(x_data, out_data, in_c, in_h, in_w);
break;
case 2:
reduce_sum_w(x_data, out_data, in_c, in_h, in_w);
break;
default:
if (dim_[0] < x_dims_.size()) {
reduce_sum(x_data, out_data, x_dims_, dim_[0]);
} else {
LOG(FATAL) << "error!!!";
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册