From afbc636705d583d4e5ea9a7570b6a754f7f0b8c4 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Mon, 11 Jan 2021 17:04:11 +0800 Subject: [PATCH] [cherry-pick]add cast cuda kernel (#29352) #30263 add cast cuda kernel cherry-pick #29352 --- paddle/fluid/operators/cast_op.cu | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index f71af20576..55cc5a675b 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -14,6 +14,39 @@ limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +template +__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } +} + +template +struct CastOpFunctor { + const framework::Tensor* in_; + framework::Tensor* out_; + const platform::CUDADeviceContext& ctx_; + CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, + const platform::CUDADeviceContext& ctx) + : in_(in), out_(out), ctx_(ctx) {} + + template + void apply() const { + auto* in = in_->data(); + auto size = in_->numel(); + auto* out = out_->mutable_data(ctx_.GetPlace()); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx_, size); + CastCUDAKernel<<>>(in, size, out); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -- GitLab