diff --git a/paddle/fluid/lite/arm/math/elementwise.cc b/paddle/fluid/lite/arm/math/elementwise.cc index 68140a5d7dbccc9fa0028e9cde3e9d074275a7ee..cf300616245c452fddbf89abfcb346188539edcd 100644 --- a/paddle/fluid/lite/arm/math/elementwise.cc +++ b/paddle/fluid/lite/arm/math/elementwise.cc @@ -41,15 +41,15 @@ void elementwise_add(const float* dinx, const float* diny, float* dout, float32x4_t diny2 = vld1q_f32(diny_ptr + 8); float32x4_t diny3 = vld1q_f32(diny_ptr + 12); - float32x4_t vsum0 = vaddq_f32(dinx0, diny0); - float32x4_t vsum1 = vaddq_f32(dinx1, diny1); - float32x4_t vsum2 = vaddq_f32(dinx2, diny2); - float32x4_t vsum3 = vaddq_f32(dinx3, diny3); + float32x4_t dinx0 = vaddq_f32(dinx0, diny0); + float32x4_t dinx1 = vaddq_f32(dinx1, diny1); + float32x4_t dinx2 = vaddq_f32(dinx2, diny2); + float32x4_t dinx3 = vaddq_f32(dinx3, diny3); - vst1q_f32(dout_ptr, vsum0); - vst1q_f32(dout_ptr + 4, vsum1); - vst1q_f32(dout_ptr + 8, vsum2); - vst1q_f32(dout_ptr + 12, vsum3); + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); } if (remain > 0) { const float* dinx_ptr = dinx + (cnt << 4); @@ -64,6 +64,69 @@ void elementwise_add(const float* dinx, const float* diny, float* dout, } } +template <> +void elementwise_add_axis(const float* dinx, const float* diny, + float* dout, int batch, int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + din2 = vaddq_f32(din2, rb); + din3 = vaddq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vaddq_f32(din0, diny_data); + din1 = vaddq_f32(din1, diny_data); + vst1q_f32(dout_ptr, r0); + vst1q_f32(dout_ptr + 4, r1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vaddq_f32(din0, rb); + vst1q_f32(dout_ptr, diny_data); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (p = 0; p < remain; p++) { + *dout_ptr = *dinx_ptr + diny_data; + dout_ptr++; + dinx_ptr++; + } + } + } + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/arm/math/elementwise.h b/paddle/fluid/lite/arm/math/elementwise.h index cf4c8e46b0703a888bc9ac9a4a395d4e57ba886d..7e907cd5e04eeaa5f61f426897502278e49ac9ad 100644 --- a/paddle/fluid/lite/arm/math/elementwise.h +++ b/paddle/fluid/lite/arm/math/elementwise.h @@ -22,6 +22,10 @@ namespace math { template void elementwise_add(const T* dinx, const T* diny, T* dout, int num); +template +void elementwise_add_axis(const T* dinx, const T* diny, T* dout, + int batch, int channels, int num); + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc index 310cde17bbd2f235789250fa02f8e8f82f672ff0..1f06e6285bb73e9116dcb9c0f0cf85751c16fdb7 100644 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc @@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() { const float* x_data = param.X->data(); const float* y_data = param.Y->data(); float* out_data = param.Out->mutable_data(); - int n = param.X->dims().production(); - lite::arm::math::elementwise_add(x_data, y_data, out_data, n); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + if (axis == 0) { + lite::arm::math::elementwise_add(x_data, y_data, out_data, + x_dims.production()); + } else { + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch, + channels, num); + } } } // namespace arm diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc index 7156d08ce77df9c93ec46c1c55fb3a11df44a308..b2bbe2d3ae0f453161f9c7bb03ce852c43b048d5 100644 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc @@ -41,40 +41,94 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { const dtype* x_data = param.X->data(); const dtype* y_data = param.Y->data(); dtype* out_data = param.Out->mutable_data(); - DDim dim = param.X->dims(); - ASSERT_EQ(dim.data(), param.Out->dims().data()); - for (int i = 0; i < dim.production(); i++) { - out_data[i] = x_data[i] + y_data[i]; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int axis = param.axis; + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = dout + offset; + for (int k = 0; k < num; ++k) { + dout_ptr[k] = din_ptr[k] + diny_data; + } + } } } TEST(elementwise_add, compute) { ElementwiseAddCompute elementwise_add; operators::ElementwiseParam param; + lite::Tensor x, y, output, output_ref; - lite::Tensor x, y, out, out_ref; - x.Resize(DDim(std::vector({2, 3, 4, 5}))); - y.Resize(DDim(std::vector({2, 3, 4, 5}))); - out.Resize(DDim(std::vector({2, 3, 4, 5}))); - out_ref.Resize(DDim(std::vector({2, 3, 4, 5}))); - auto* x_data = x.mutable_data(); - auto* y_data = y.mutable_data(); - auto* out_data = out.mutable_data(); - auto* out_ref_data = out_ref.mutable_data(); - for (int i = 0; i < x.dims().production(); i++) { - x_data[i] = y_data[i] = i; - } + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 4, 11}) { + for (auto h : {1, 3, 4, 11}) { + for (auto w : {1, 3, 4, 11}) { + for (auto axis : {-1, 0, 1, 2, 3}) { + for (auto yd{{n}, + {c}, + {h}, + {w}, + {n, c}, + {c, h}, + {h, w}, + {n, c, h}, + {c, h, w}, + {n, c, h, w}}) { + auto x_dim = DDim(std::vector({n, c, h, w})); + auto y_dim = DDim(std::vector(yd)); + int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis; - param.X = &x; - param.Y = &y; - param.Out = &out; - elementwise_add.SetParam(param); - elementwise_add.Run(); + if (axis_t + y_dim.size() > 4) continue; + bool flag = false; + for (int i = 0; i < y_dim.size(); i++) { + if (x_dim[i + axis_t] != y_dim[i]) flag = true; + } + if (flag) continue; - param.Out = &out_ref; - elementwise_add_compute_ref(param); - for (int i = 0; i < out.dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + x.Resize(x_dim); + y.Resize(y_dim); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + for (int i = 0; i < y.dims().production(); i++) { + y_data[i] = i; + } + param.X = &x; + param.Y = &y; + param.axis = axis; + param.Out = &output; + softmax.SetParam(param); + softmax.Run(); + param.Out = &output_ref; + elementwise_add_compute_ref(param); + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } } }