diff --git a/paddle/fluid/lite/arm/math/elementwise.cc b/paddle/fluid/lite/arm/math/elementwise.cc index 68140a5d7dbccc9fa0028e9cde3e9d074275a7ee..2a74e7ee4ec4be51b420b1fa2d2a1be7c3f148fb 100644 --- a/paddle/fluid/lite/arm/math/elementwise.cc +++ b/paddle/fluid/lite/arm/math/elementwise.cc @@ -41,15 +41,15 @@ void elementwise_add(const float* dinx, const float* diny, float* dout, float32x4_t diny2 = vld1q_f32(diny_ptr + 8); float32x4_t diny3 = vld1q_f32(diny_ptr + 12); - float32x4_t vsum0 = vaddq_f32(dinx0, diny0); - float32x4_t vsum1 = vaddq_f32(dinx1, diny1); - float32x4_t vsum2 = vaddq_f32(dinx2, diny2); - float32x4_t vsum3 = vaddq_f32(dinx3, diny3); + dinx0 = vaddq_f32(dinx0, diny0); + dinx1 = vaddq_f32(dinx1, diny1); + dinx2 = vaddq_f32(dinx2, diny2); + dinx3 = vaddq_f32(dinx3, diny3); - vst1q_f32(dout_ptr, vsum0); - vst1q_f32(dout_ptr + 4, vsum1); - vst1q_f32(dout_ptr + 8, vsum2); - vst1q_f32(dout_ptr + 12, vsum3); + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); } if (remain > 0) { const float* dinx_ptr = dinx + (cnt << 4); @@ -64,6 +64,69 @@ void elementwise_add(const float* dinx, const float* diny, float* dout, } } +template <> +void elementwise_add_axis(const float* dinx, const float* diny, + float* dout, int batch, int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + din2 = vaddq_f32(din2, rb); + din3 = vaddq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vaddq_f32(din0, rb); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + *dout_ptr = *din_ptr + diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/arm/math/elementwise.h b/paddle/fluid/lite/arm/math/elementwise.h index cf4c8e46b0703a888bc9ac9a4a395d4e57ba886d..ca8f87895fcea80f9a1a178a0bf43b34c44182bb 100644 --- a/paddle/fluid/lite/arm/math/elementwise.h +++ b/paddle/fluid/lite/arm/math/elementwise.h @@ -22,6 +22,10 @@ namespace math { template void elementwise_add(const T* dinx, const T* diny, T* dout, int num); +template +void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch, + int channels, int num); + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc index 5e9ddb6271684120c8cab68e6e10bade3a3ab015..a8a2ac790a3c045642277ef75367bbdd878f0d6d 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -100,7 +100,7 @@ void ConvCompute::Run() { REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc index 3c11451ce576f203794fdf4857832886b88477cb..e9d9f4927b7ee18b3e18efa69a00dcb1c813bf3b 100644 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc @@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() { const float* x_data = param.X->data(); const float* y_data = param.Y->data(); float* out_data = param.Out->mutable_data(); - int n = param.X->dims().production(); - // lite::arm::math::elementwise_add(x_data, y_data, out_data, n); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + if (x_dims.size() == y_dims.size()) { + lite::arm::math::elementwise_add(x_data, y_data, out_data, + x_dims.production()); + } else { + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch, + channels, num); + } } } // namespace arm diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc index 7156d08ce77df9c93ec46c1c55fb3a11df44a308..20b998dc6cfa8a9606fcf0f716470366fdd60338 100644 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc @@ -41,40 +41,97 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { const dtype* x_data = param.X->data(); const dtype* y_data = param.Y->data(); dtype* out_data = param.Out->mutable_data(); - DDim dim = param.X->dims(); - ASSERT_EQ(dim.data(), param.Out->dims().data()); - for (int i = 0; i < dim.production(); i++) { - out_data[i] = x_data[i] + y_data[i]; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int axis = param.axis; + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr + diny_data; + dout_ptr++; + din_ptr++; + } + } } } TEST(elementwise_add, compute) { ElementwiseAddCompute elementwise_add; operators::ElementwiseParam param; + lite::Tensor x, y, output, output_ref; - lite::Tensor x, y, out, out_ref; - x.Resize(DDim(std::vector({2, 3, 4, 5}))); - y.Resize(DDim(std::vector({2, 3, 4, 5}))); - out.Resize(DDim(std::vector({2, 3, 4, 5}))); - out_ref.Resize(DDim(std::vector({2, 3, 4, 5}))); - auto* x_data = x.mutable_data(); - auto* y_data = y.mutable_data(); - auto* out_data = out.mutable_data(); - auto* out_ref_data = out_ref.mutable_data(); - for (int i = 0; i < x.dims().production(); i++) { - x_data[i] = y_data[i] = i; - } + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 4, 11}) { + for (auto h : {1, 3, 4, 11}) { + for (auto w : {1, 3, 4, 11}) { + for (auto axis : {-1, 0, 1, 2, 3}) { + for (auto yd : + {std::vector({n}), std::vector({c}), + std::vector({h}), std::vector({w}), + std::vector({n, c}), std::vector({c, h}), + std::vector({h, w}), std::vector({n, c, h}), + std::vector({c, h, w}), + std::vector({n, c, h, w})}) { + auto x_dim = DDim(std::vector({n, c, h, w})); + auto y_dim = DDim(yd); + int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis; - param.X = &x; - param.Y = &y; - param.Out = &out; - elementwise_add.SetParam(param); - elementwise_add.Run(); + if (axis_t + y_dim.size() > 4) continue; + bool flag = false; + for (int i = 0; i < y_dim.size(); i++) { + if (x_dim[i + axis_t] != y_dim[i]) flag = true; + } + if (flag) continue; - param.Out = &out_ref; - elementwise_add_compute_ref(param); - for (int i = 0; i < out.dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + x.Resize(x_dim); + y.Resize(y_dim); + output.Resize(x_dim); + output_ref.Resize(x_dim); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x_dim.production(); i++) { + x_data[i] = i; + } + for (int i = 0; i < y_dim.production(); i++) { + y_data[i] = i; + } + param.X = &x; + param.Y = &y; + param.axis = axis; + param.Out = &output; + elementwise_add.SetParam(param); + elementwise_add.Run(); + param.Out = &output_ref; + elementwise_add_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } } }