diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 1ac110b3cafd6bfd9da29daaebb65df570a02cb0..0beb2291060169628a7d736a21d728993a2bc0e5 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -40,7 +40,8 @@ __global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using LoadT = AlignedVector; using StoreT = AlignedVector; - for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) { + for (int64_t i = idx * VecSize; i < N; + i += blockDim.x * gridDim.x * VecSize) { InT in_vec[VecSize]; LoadT* in_value = reinterpret_cast(&in_vec); *in_value = *reinterpret_cast(&in[i]); diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h index 4da91b4e764a5285b005ebc459c4dfa4e52df9cd..a82262419066fab0c1a58d3b6781bc765fa1a4c6 100644 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -41,7 +41,7 @@ struct GpuLaunchConfig { }; inline GpuLaunchConfig GetGpuLaunchConfig1D( - const platform::CUDADeviceContext& context, int element_count, + const platform::CUDADeviceContext& context, int64_t element_count, #ifdef PADDLE_WITH_HIP // HIP will throw GPU memory access fault if threads > 256 int max_threads = 256) {