未验证 提交 4d3c7f33 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix cast cuda implementation (#36679)

上级 bdcc2ad4
......@@ -47,12 +47,12 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
}
template <typename InT>
struct CastOpFunctor<platform::CUDADeviceContext, InT> {
struct CastCUDAOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::CUDADeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::CUDADeviceContext& ctx)
CastCUDAOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::CUDADeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
......@@ -75,6 +75,21 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
}
};
template <typename InT>
class CastCUDAOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastCUDAOpFunctor<InT>(
in, out,
context.template device_context<platform::CUDADeviceContext>()));
}
};
} // namespace operators
} // namespace paddle
......@@ -82,34 +97,21 @@ namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
cast, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>,
ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,
ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,
ops::CastCUDAOpKernel<uint8_t>,
ops::CastCUDAOpKernel<paddle::platform::float16>,
ops::CastCUDAOpKernel<paddle::platform::complex<float>>,
ops::CastCUDAOpKernel<paddle::platform::complex<double>>);
#else
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
cast, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>,
ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,
ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,
ops::CastCUDAOpKernel<uint8_t>,
ops::CastCUDAOpKernel<paddle::platform::float16>,
ops::CastCUDAOpKernel<paddle::platform::bfloat16>,
ops::CastCUDAOpKernel<paddle::platform::complex<float>>,
ops::CastCUDAOpKernel<paddle::platform::complex<double>>);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册