From 7e214b498515b50820f8535927d30879f048f6a2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 21 Dec 2017 19:45:37 +0800 Subject: [PATCH] Speed up ColwiseSum in CPU (#6834) * Remove unnecessary reshape in ColwiseSum Speed up 12s -> 10s. * Hand write ColwiseAdd in CPU --- paddle/operators/math/math_function_impl.h | 39 ++++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 3e6d83386..aced2690b 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -67,18 +67,45 @@ void RowwiseAdd::operator()(const DeviceContext& context, template void ColwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, - framework::Tensor* vector) { + framework::Tensor* out) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(vector->numel(), size); + PADDLE_ENFORCE_EQ(out->numel(), size); - auto vec = framework::EigenMatrix::From(*vector); auto in = framework::EigenMatrix::From(input); - Eigen::array shape({{1, static_cast(size)}}); - vec.reshape(shape).device(*context.eigen_device()) = - in.sum(Eigen::array({{0}})).reshape(shape); + auto vec = framework::EigenVector::Flatten(*out); + + vec.device(*context.eigen_device()) = in.sum(Eigen::array({{0}})); } +// Specialize for CPU, since Eigen implement a general reduce. However, +// colwise-sum can be easily implemented. General reduce has a huge overhead in +// CPU +template +class ColwiseSum { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + auto& in_dims = input.dims(); + auto height = in_dims[0]; + auto size = in_dims[1]; + PADDLE_ENFORCE_EQ(out->numel(), size); + + T* out_buf = out->mutable_data(out->place()); + const T* in_buf = input.data(); + + for (size_t i = 0; i < height; ++i) { + for (size_t j = 0; j < size; ++j) { + if (i == 0) { + out_buf[j] = in_buf[i * size + j]; + } else { + out_buf[j] += in_buf[i * size + j]; + } + } + } + } +}; + } // namespace math } // namespace operators } // namespace paddle -- GitLab