diff --git a/paddle/fluid/operators/temporal_shift_op.cu b/paddle/fluid/operators/temporal_shift_op.cu index a292f16fe20d10aa683dfc3e94eae4f7b6125a3c..b61d9aeff7d4c2b92c4861444e8c4d1bb5d9d1cc 100644 --- a/paddle/fluid/operators/temporal_shift_op.cu +++ b/paddle/fluid/operators/temporal_shift_op.cu @@ -11,6 +11,7 @@ #include "paddle/fluid/operators/temporal_shift_op.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { @@ -112,11 +113,11 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel { T* output_data = output->mutable_data({nt, c, h, w}, ctx.GetPlace()); int pixelNum = nt * chw; - int grid_dim = (pixelNum + 512 - 1) / 512; - grid_dim = grid_dim > 8 ? 8 : grid_dim; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); - KeTemporalShiftFw< - T><<>>( + KeTemporalShiftFw<<>>( input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); } }; @@ -148,11 +149,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { static_cast(0)); int pixelNum = nt * chw; - int grid_dim = (pixelNum + 512 - 1) / 512; - grid_dim = grid_dim > 8 ? 8 : grid_dim; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); - KeTemporalShiftBw< - T><<>>( + KeTemporalShiftBw<<>>( output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); }