From 9814f89551e2133c6733352f6445d4d668da6f63 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 8 Oct 2021 10:47:13 +0800 Subject: [PATCH] fix cast cuda implementation (#36266) --- paddle/fluid/operators/cast_op.cu | 64 ++++++++++++++++--------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 06300817e0..601735c2f1 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -47,12 +47,12 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { } template -struct CastOpFunctor { +struct CastCUDAOpFunctor { const framework::Tensor* in_; framework::Tensor* out_; const platform::CUDADeviceContext& ctx_; - CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, - const platform::CUDADeviceContext& ctx) + CastCUDAOpFunctor(const framework::Tensor* in, framework::Tensor* out, + const platform::CUDADeviceContext& ctx) : in_(in), out_(out), ctx_(ctx) {} template @@ -75,6 +75,21 @@ struct CastOpFunctor { } }; +template +class CastCUDAOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + framework::VisitDataType( + static_cast( + context.Attr("out_dtype")), + CastCUDAOpFunctor( + in, out, + context.template device_context())); + } +}; + } // namespace operators } // namespace paddle @@ -82,34 +97,21 @@ namespace ops = paddle::operators; #ifdef PADDLE_WITH_HIP REGISTER_OP_CUDA_KERNEL( - cast, ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel>, - ops::CastOpKernel>); + cast, ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel>, + ops::CastCUDAOpKernel>); #else REGISTER_OP_CUDA_KERNEL( - cast, ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel>, - ops::CastOpKernel>); + cast, ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel>, + ops::CastCUDAOpKernel>); #endif -- GitLab