提交 9171f737 编写于 作者: D danleifeng 提交者: gongweibao

fix fp16 grid_size for size=1; test=develop (#20812)

上级 cd1c4043
......@@ -40,8 +40,9 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 gird_size = dim3(
(size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
......@@ -49,7 +50,7 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, platform::float16> {
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseAddCUDAKernel<<<
gird_size, block_size, 0,
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
}
......@@ -78,10 +79,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
framework::Tensor* dy) {
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
auto size = x->numel();
dim3 gird_size =
dim3 grid_size =
dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
SimpleElemwiseAddGradCUDAKernel<
T><<<gird_size, block_size, 0,
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
......
......@@ -41,8 +41,9 @@ struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 gird_size = dim3(
(size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
......@@ -50,7 +51,7 @@ struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseDivCUDAKernel<<<
gird_size, block_size, 0,
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
}
......@@ -82,10 +83,10 @@ elementwise_div_grad(const framework::ExecutionContext& ctx,
framework::Tensor* dy) {
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
auto size = x->numel();
dim3 gird_size =
dim3 grid_size =
dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
SimpleElemwiseDivGradCUDAKernel<
T><<<gird_size, block_size, 0,
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
......
......@@ -41,8 +41,9 @@ struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 gird_size = dim3(
(size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
......@@ -50,7 +51,7 @@ struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseMulCUDAKernel<<<
gird_size, block_size, 0,
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
}
......@@ -82,10 +83,10 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx,
framework::Tensor* dy) {
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
auto size = x->numel();
dim3 gird_size =
dim3 grid_size =
dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
SimpleElemwiseMulGradCUDAKernel<
T><<<gird_size, block_size, 0,
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
......
......@@ -41,8 +41,9 @@ struct SameDimsElemwiseSub<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 gird_size = dim3(
(size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
......@@ -50,7 +51,7 @@ struct SameDimsElemwiseSub<platform::CUDADeviceContext, platform::float16> {
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseSubCUDAKernel<<<
gird_size, block_size, 0,
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
}
......@@ -79,10 +80,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx,
framework::Tensor* dy) {
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
auto size = x->numel();
dim3 gird_size =
dim3 grid_size =
dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<gird_size, block_size, 0,
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册