diff --git a/paddle/fluid/operators/grid_sampler_op.cu b/paddle/fluid/operators/grid_sampler_op.cu index 95bf96073bdd21ddcadd5e447ba38ecb8dd21b83..a227a8e312765b4311314ea884f2c32443924fbc 100644 --- a/paddle/fluid/operators/grid_sampler_op.cu +++ b/paddle/fluid/operators/grid_sampler_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/grid_sampler_op.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { @@ -292,15 +293,12 @@ class GridSampleOpCUDAKernel : public framework::OpKernel { auto* output_data = output->mutable_data(ctx.GetPlace()); VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1] << "; " << output->dims()[2] << "; " << output->dims()[3]; - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); int count = static_cast(n * out_h * out_w); auto cu_stream = dev_ctx.stream(); - int block_size = 512; - int grid_size = (count + block_size - 1) / block_size; - VLOG(3) << "cuda launch - grid dims: " << grid_size << "; block dims" - << block_size; - grid_sample_cuda_kernel<<>>( + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, count); + grid_sample_cuda_kernel< + T><<>>( count, n, c, out_h, out_w, in_h, in_w, input->data(), grid->data(), output_data, mode, padding_mode, align_corners); } @@ -467,19 +465,14 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel { if (ctx.HasOutput(framework::GradVarName("Grid"))) { auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); grid_grad_data = grid_grad->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), - grid_grad, static_cast(0)); } int count = static_cast(n * out_h * out_w); auto cu_stream = dev_ctx.stream(); - int block_size = 512; - int grid_size = (count + block_size - 1) / block_size; - VLOG(3) << "cuda launch grad kernel - grid dims: " << grid_size - << "; block dims" << block_size << "; count: " << count; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, count); grid_sampler_cuda_backward_kernel< - T><<>>( + T><<>>( count, output_grad->data(), input->data(), grid->data(), n, c, out_h, out_w, in_h, in_w, input_grad->data(), grid_grad_data, mode, padding_mode, align_corners);