From 15cbf81b87ca03fcc77a6b82e0a72cc913c6d163 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 30 Dec 2021 11:39:47 +0800 Subject: [PATCH] try to expose cast with ptr function (#38598) --- paddle/pten/kernels/gpu/cast_kernel.cu | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/pten/kernels/gpu/cast_kernel.cu b/paddle/pten/kernels/gpu/cast_kernel.cu index 8481666416..e413a38d5e 100644 --- a/paddle/pten/kernels/gpu/cast_kernel.cu +++ b/paddle/pten/kernels/gpu/cast_kernel.cu @@ -54,13 +54,10 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { } template -void CastCUDAKernelImpl(const GPUContext& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - auto* in_data = x.data(); - auto size = x.numel(); - auto* out_data = out->mutable_data(); - +void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx, + const InT* in_data, + OutT* out_data, + int64_t size) { paddle::platform::GpuLaunchConfig config = paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size); int vec_size = paddle::platform::GetVectorizedSize(out_data); @@ -78,6 +75,16 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx, } } +template +void CastCUDAKernelImpl(const GPUContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto* in_data = x.data(); + auto size = x.numel(); + auto* out_data = out->mutable_data(); + CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size); +} + template void Cast(const ContextT& dev_ctx, const DenseTensor& x, -- GitLab