From 30d9589afe26f9a08979b3d4506b2f2a802f0236 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 11 Dec 2020 18:08:53 +0800 Subject: [PATCH] add cast cuda kernel (#29352) --- paddle/fluid/operators/cast_op.cu | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index f71af20576..55cc5a675b 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -14,6 +14,39 @@ limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +template +__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } +} + +template +struct CastOpFunctor { + const framework::Tensor* in_; + framework::Tensor* out_; + const platform::CUDADeviceContext& ctx_; + CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, + const platform::CUDADeviceContext& ctx) + : in_(in), out_(out), ctx_(ctx) {} + + template + void apply() const { + auto* in = in_->data(); + auto size = in_->numel(); + auto* out = out_->mutable_data(ctx_.GetPlace()); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx_, size); + CastCUDAKernel<<>>(in, size, out); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -- GitLab