diff --git a/paddle/pten/kernels/gpu/cast_kernel.cu b/paddle/pten/kernels/gpu/cast_kernel.cu index 84816664164fab910b016a8315c68684360a49ac..e413a38d5e01e39e58674d326ecf661708fafae5 100644 --- a/paddle/pten/kernels/gpu/cast_kernel.cu +++ b/paddle/pten/kernels/gpu/cast_kernel.cu @@ -54,13 +54,10 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { } template -void CastCUDAKernelImpl(const GPUContext& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - auto* in_data = x.data(); - auto size = x.numel(); - auto* out_data = out->mutable_data(); - +void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx, + const InT* in_data, + OutT* out_data, + int64_t size) { paddle::platform::GpuLaunchConfig config = paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size); int vec_size = paddle::platform::GetVectorizedSize(out_data); @@ -78,6 +75,16 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx, } } +template +void CastCUDAKernelImpl(const GPUContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto* in_data = x.data(); + auto size = x.numel(); + auto* out_data = out->mutable_data(); + CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size); +} + template void Cast(const ContextT& dev_ctx, const DenseTensor& x,