未验证 提交 9814f895 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix cast cuda implementation (#36266)

上级 730dcaf4
...@@ -47,12 +47,12 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { ...@@ -47,12 +47,12 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
} }
template <typename InT> template <typename InT>
struct CastOpFunctor<platform::CUDADeviceContext, InT> { struct CastCUDAOpFunctor {
const framework::Tensor* in_; const framework::Tensor* in_;
framework::Tensor* out_; framework::Tensor* out_;
const platform::CUDADeviceContext& ctx_; const platform::CUDADeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, CastCUDAOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::CUDADeviceContext& ctx) const platform::CUDADeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {} : in_(in), out_(out), ctx_(ctx) {}
template <typename OutT> template <typename OutT>
...@@ -75,6 +75,21 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> { ...@@ -75,6 +75,21 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
} }
}; };
template <typename InT>
class CastCUDAOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastCUDAOpFunctor<InT>(
in, out,
context.template device_context<platform::CUDADeviceContext>()));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -82,34 +97,21 @@ namespace ops = paddle::operators; ...@@ -82,34 +97,21 @@ namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>, cast, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>, ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>, ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::CastCUDAOpKernel<uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int16_t>, ops::CastCUDAOpKernel<paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>, ops::CastCUDAOpKernel<paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::CastCUDAOpKernel<paddle::platform::complex<double>>);
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
#else #else
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>, cast, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>, ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>, ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::CastCUDAOpKernel<uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int16_t>, ops::CastCUDAOpKernel<paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>, ops::CastCUDAOpKernel<paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::CastCUDAOpKernel<paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, ops::CastCUDAOpKernel<paddle::platform::complex<double>>);
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册