diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index de121b3786f3c9cc94a0c3dab789f372d9e17e72..71019872802eaca964373fd58a7ccc6445d9c489 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -40,8 +40,9 @@ struct SameDimsElemwiseAdd { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 gird_size = dim3( - (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); const half* x2 = reinterpret_cast(x->data()); @@ -49,7 +50,7 @@ struct SameDimsElemwiseAdd { reinterpret_cast(y->data()); half* z2 = reinterpret_cast(z->data()); SameDimsElemwiseAddCUDAKernel<<< - gird_size, block_size, 0, + grid_size, block_size, 0, ctx.template device_context().stream()>>>( x2, y2, z2, size); } @@ -78,10 +79,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, framework::Tensor* dy) { dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); auto size = x->numel(); - dim3 gird_size = + dim3 grid_size = dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); SimpleElemwiseAddGradCUDAKernel< - T><<<<().stream()>>>( dout->data(), size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index b1698491180d773dc35d09d8d7ec642d3ed914fe..e31722a2881f2c424052701b84dfbe058340b266 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -41,8 +41,9 @@ struct SameDimsElemwiseDiv { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 gird_size = dim3( - (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); const half* x2 = reinterpret_cast(x->data()); @@ -50,7 +51,7 @@ struct SameDimsElemwiseDiv { reinterpret_cast(y->data()); half* z2 = reinterpret_cast(z->data()); SameDimsElemwiseDivCUDAKernel<<< - gird_size, block_size, 0, + grid_size, block_size, 0, ctx.template device_context().stream()>>>( x2, y2, z2, size); } @@ -82,10 +83,10 @@ elementwise_div_grad(const framework::ExecutionContext& ctx, framework::Tensor* dy) { dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); auto size = x->numel(); - dim3 gird_size = + dim3 grid_size = dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); SimpleElemwiseDivGradCUDAKernel< - T><<<<().stream()>>>( x->data(), y->data(), out->data(), dout->data(), size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 4814cb144f057d4cb76b416e896e24ea227e92a2..8533189f81abbc4ebe0e47cb7272063af5a4aca7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -41,8 +41,9 @@ struct SameDimsElemwiseMul { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 gird_size = dim3( - (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); const half* x2 = reinterpret_cast(x->data()); @@ -50,7 +51,7 @@ struct SameDimsElemwiseMul { reinterpret_cast(y->data()); half* z2 = reinterpret_cast(z->data()); SameDimsElemwiseMulCUDAKernel<<< - gird_size, block_size, 0, + grid_size, block_size, 0, ctx.template device_context().stream()>>>( x2, y2, z2, size); } @@ -82,10 +83,10 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, framework::Tensor* dy) { dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); auto size = x->numel(); - dim3 gird_size = + dim3 grid_size = dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); SimpleElemwiseMulGradCUDAKernel< - T><<<<().stream()>>>( x->data(), y->data(), out->data(), dout->data(), size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 7ff72028091ed78e7ca5c27d2b8bb362c12fd152..9913927ee3c5901cc4a74cf1ace1e1085d2b8ff3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -41,8 +41,9 @@ struct SameDimsElemwiseSub { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 gird_size = dim3( - (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); const half* x2 = reinterpret_cast(x->data()); @@ -50,7 +51,7 @@ struct SameDimsElemwiseSub { reinterpret_cast(y->data()); half* z2 = reinterpret_cast(z->data()); SameDimsElemwiseSubCUDAKernel<<< - gird_size, block_size, 0, + grid_size, block_size, 0, ctx.template device_context().stream()>>>( x2, y2, z2, size); } @@ -79,10 +80,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx, framework::Tensor* dy) { dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); auto size = x->numel(); - dim3 gird_size = + dim3 grid_size = dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); SimpleElemwiseSubGradCUDAKernel< - T><<<<().stream()>>>( dout->data(), size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace()));