提交 19367389 编写于 作者: Q qingqing01

Update the CUDA kernel.

上级 41372ded
......@@ -274,15 +274,14 @@ void set_constant_with_place<platform::CUDAPlace>(
}
template <typename T>
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int64_t height,
int64_t width) {
int64_t num = height * width;
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width,
int num) {
T tmp = 1.0 / width;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
int h = i / width;
int w = i % width;
int idx = h * width + w;
c[idx] = a[idx] + b[w];
int h = i * tmp;
int w = i - h * width;
c[i] = a[i] + b[w];
}
}
......@@ -292,11 +291,14 @@ struct RowwiseAdd<platform::CUDADeviceContext, T> {
const framework::Tensor& input,
const framework::Tensor& vector, framework::Tensor* output) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector.numel(), size);
PADDLE_ENFORCE_EQ(output->dims(), in_dims);
int blocks = 512;
int grids = (input.numel() + blocks - 1) / blocks;
RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
input.data<T>(), vector.data<T>(), output->data<T>(), in_dims[0],
in_dims[1]);
input.data<T>(), vector.data<T>(), output->data<T>(),
static_cast<int>(in_dims[1]), static_cast<int>(input.numel()));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册