未验证 提交 c5f7cec9 编写于 作者: X xiaogang 提交者: GitHub

fix: fix mean_grad compute (#3402)

上级 3190c354
...@@ -198,6 +198,23 @@ void reduce_mean_hw<float>(const float* src, ...@@ -198,6 +198,23 @@ void reduce_mean_hw<float>(const float* src,
reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in); reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in);
} }
template <>
void mean_grad<float>(const float* out_grad, float* in_grad, int size) {
float grad = out_grad[0] / size;
float32x4_t grad_v = vdupq_n_f32(grad);
int loop = size >> 2;
int remain = size & 3;
#pragma omp parallel for
for (int i = 0; i < loop; ++i) {
vst1q_f32(in_grad, grad_v);
in_grad += 4;
}
for (int i = 0; i < remain; ++i) {
in_grad[i] = grad;
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -83,6 +83,9 @@ void reduce_mean_all(const T* src, ...@@ -83,6 +83,9 @@ void reduce_mean_all(const T* src,
int height_in, int height_in,
int width_in); int width_in);
template <typename T>
void mean_grad(const T* out_grad, T* in_grad, int size);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/mean_grad_compute.h" #include "lite/kernels/arm/mean_grad_compute.h"
#include "lite/backends/arm/math/reduce_mean.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -31,10 +31,7 @@ void MeanGradCompute::Run() { ...@@ -31,10 +31,7 @@ void MeanGradCompute::Run() {
int input_grad_size = input_grad->dims().production(); int input_grad_size = input_grad->dims().production();
// TODO(mapingshuo): use parallel methods to accelerate this for loop lite::arm::math::mean_grad(out_grad_data, input_grad_data, input_grad_size);
for (int i = 0; i < input_grad_size; i++) {
input_grad_data[i] = out_grad_data[0] / input_grad_size;
}
} }
} // namespace arm } // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册