From 19367389c0f2245669e1d05afaa9e6cdd19022a0 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 26 Dec 2017 18:59:53 -0800 Subject: [PATCH] Update the CUDA kernel. --- paddle/operators/math/math_function.cu | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 36e6cc8914..d47a7f818d 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -274,15 +274,14 @@ void set_constant_with_place( } template -__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int64_t height, - int64_t width) { - int64_t num = height * width; +__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width, + int num) { + T tmp = 1.0 / width; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { - int h = i / width; - int w = i % width; - int idx = h * width + w; - c[idx] = a[idx] + b[w]; + int h = i * tmp; + int w = i - h * width; + c[i] = a[i] + b[w]; } } @@ -292,11 +291,14 @@ struct RowwiseAdd { const framework::Tensor& input, const framework::Tensor& vector, framework::Tensor* output) { auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); int blocks = 512; int grids = (input.numel() + blocks - 1) / blocks; RowwiseAddKernel<<>>( - input.data(), vector.data(), output->data(), in_dims[0], - in_dims[1]); + input.data(), vector.data(), output->data(), + static_cast(in_dims[1]), static_cast(input.numel())); } }; -- GitLab