diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 06300817e0a1288536620c668654108cd03afee9..601735c2f148adc14c94e786c7707b920f00fd8b 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -47,12 +47,12 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { } template -struct CastOpFunctor { +struct CastCUDAOpFunctor { const framework::Tensor* in_; framework::Tensor* out_; const platform::CUDADeviceContext& ctx_; - CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, - const platform::CUDADeviceContext& ctx) + CastCUDAOpFunctor(const framework::Tensor* in, framework::Tensor* out, + const platform::CUDADeviceContext& ctx) : in_(in), out_(out), ctx_(ctx) {} template @@ -75,6 +75,21 @@ struct CastOpFunctor { } }; +template +class CastCUDAOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + framework::VisitDataType( + static_cast( + context.Attr("out_dtype")), + CastCUDAOpFunctor( + in, out, + context.template device_context())); + } +}; + } // namespace operators } // namespace paddle @@ -82,34 +97,21 @@ namespace ops = paddle::operators; #ifdef PADDLE_WITH_HIP REGISTER_OP_CUDA_KERNEL( - cast, ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel>, - ops::CastOpKernel>); + cast, ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel>, + ops::CastCUDAOpKernel>); #else REGISTER_OP_CUDA_KERNEL( - cast, ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel>, - ops::CastOpKernel>); + cast, ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel, + ops::CastCUDAOpKernel>, + ops::CastCUDAOpKernel>); #endif