未验证 提交 412877e6 编写于 作者: D duanboqiang 提交者: GitHub

fix cast op (#35156)

上级 c71025eb
...@@ -119,6 +119,7 @@ REGISTER_OPERATOR(cast, ops::CastOp, ...@@ -119,6 +119,7 @@ REGISTER_OPERATOR(cast, ops::CastOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
cast, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>, cast, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>, ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int16_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>, ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>, ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>, ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
......
...@@ -102,6 +102,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -102,6 +102,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, ops::CastOpKernel<paddle::platform::CUDADeviceContext,
...@@ -116,6 +117,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -116,6 +117,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, ops::CastOpKernel<paddle::platform::CUDADeviceContext,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册