From 32d881beabdfb7130072bb624bc29fa6c6b30904 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 26 Dec 2017 05:46:23 -0800 Subject: [PATCH] Optimize the rowwise add function. --- paddle/operators/math/math_function.cc | 32 ++++++++++++++++++++++ paddle/operators/math/math_function.cu | 27 ++++++++++++++++++ paddle/operators/math/math_function_impl.h | 19 ------------- 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 2b35e4532a..1a4829c49f 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -302,8 +302,40 @@ void set_constant(const platform::DeviceContext& context, #endif } +template +struct RowwiseAdd { + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + // auto in = framework::EigenMatrix::From(input); + // auto vec = framework::EigenVector::Flatten(vector); + // auto out = framework::EigenMatrix::From(*output); + // for (int64_t i = 0; i < in_dims[0]; ++i) { + // out.chip(i, 0) = in.chip(i, 0) + vec; + // } + + auto* in = input.data(); + auto* vec = vector.data(); + auto* out = output->data(); + + int64_t h = in_dims[0]; + int64_t w = in_dims[1]; + for (int64_t i = 0; i < h; ++i) { + for (int64_t j = 0; j < w; ++j) { + out[i * w + j] = in[i * w + j] + vec[j]; + } + } + } +}; + template struct RowwiseAdd; template struct RowwiseAdd; + template struct ColwiseSum; template struct ColwiseSum; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 927838a094..36e6cc8914 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -273,6 +273,33 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template +__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int64_t height, + int64_t width) { + int64_t num = height * width; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + int h = i / width; + int w = i % width; + int idx = h * width + w; + c[idx] = a[idx] + b[w]; + } +} + +template +struct RowwiseAdd { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, framework::Tensor* output) { + auto in_dims = input.dims(); + int blocks = 512; + int grids = (input.numel() + blocks - 1) / blocks; + RowwiseAddKernel<<>>( + input.data(), vector.data(), output->data(), in_dims[0], + in_dims[1]); + } +}; + template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index ddd798dace..de591626df 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -45,25 +45,6 @@ void Transpose::operator()( eigen_out.device(*dev) = eigen_in.shuffle(permute); } -template -void RowwiseAdd::operator()(const DeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& vector, - framework::Tensor* output) { - auto in_dims = input.dims(); - auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(vector.numel(), size); - PADDLE_ENFORCE_EQ(output->dims(), in_dims); - - auto in = framework::EigenMatrix::From(input); - auto vec = framework::EigenMatrix::From(vector); - auto out = framework::EigenMatrix::From(*output); - Eigen::array shape({{1, static_cast(size)}}); - Eigen::array bcast({{static_cast(in_dims[0]), 1}}); - out.device(*context.eigen_device()) = - in + vec.reshape(shape).broadcast(bcast); -} - template void ColwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, -- GitLab