diff --git a/lite/backends/arm/math/reduce_mean.cc b/lite/backends/arm/math/reduce_mean.cc index 56104550d8d68e53ad9a2ac3148887d67480d6f6..a84eef2970b2837159609c1ded1ca0d9991ccfc6 100644 --- a/lite/backends/arm/math/reduce_mean.cc +++ b/lite/backends/arm/math/reduce_mean.cc @@ -198,6 +198,23 @@ void reduce_mean_hw(const float* src, reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in); } +template <> +void mean_grad(const float* out_grad, float* in_grad, int size) { + float grad = out_grad[0] / size; + float32x4_t grad_v = vdupq_n_f32(grad); + int loop = size >> 2; + int remain = size & 3; + +#pragma omp parallel for + for (int i = 0; i < loop; ++i) { + vst1q_f32(in_grad, grad_v); + in_grad += 4; + } + for (int i = 0; i < remain; ++i) { + in_grad[i] = grad; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/reduce_mean.h b/lite/backends/arm/math/reduce_mean.h index 277ed209c058b5b4be76ce18a00683610e6afb7a..aaa9ff42c18d0cfa6a7cf11408dfba06a9444adc 100644 --- a/lite/backends/arm/math/reduce_mean.h +++ b/lite/backends/arm/math/reduce_mean.h @@ -83,6 +83,9 @@ void reduce_mean_all(const T* src, int height_in, int width_in); +template +void mean_grad(const T* out_grad, T* in_grad, int size); + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/kernels/arm/mean_grad_compute.cc b/lite/kernels/arm/mean_grad_compute.cc index f7a5be8be1ebd4e02a188ab40026de04b319c76e..f72ccf47dba0c0e9d0a4e793f4b582c106cfeecd 100644 --- a/lite/kernels/arm/mean_grad_compute.cc +++ b/lite/kernels/arm/mean_grad_compute.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "lite/kernels/arm/mean_grad_compute.h" - +#include "lite/backends/arm/math/reduce_mean.h" namespace paddle { namespace lite { namespace kernels { @@ -31,10 +31,7 @@ void MeanGradCompute::Run() { int input_grad_size = input_grad->dims().production(); - // TODO(mapingshuo): use parallel methods to accelerate this for loop - for (int i = 0; i < input_grad_size; i++) { - input_grad_data[i] = out_grad_data[0] / input_grad_size; - } + lite::arm::math::mean_grad(out_grad_data, input_grad_data, input_grad_size); } } // namespace arm