未验证 提交 267275d9 编写于 作者: S sneaxiy 提交者: GitHub

Add int16 support for several ops (#39636)

* add more op int16 support

* fix xpu ci
上级 2fe04264
...@@ -48,6 +48,7 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor> ...@@ -48,6 +48,7 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
op_type, \ op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<bool>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<bool>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int16_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
......
...@@ -95,6 +95,9 @@ class CompareOpKernel ...@@ -95,6 +95,9 @@ class CompareOpKernel
::paddle::operators::CompareOpKernel< \ ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \ ::paddle::platform::dev##DeviceContext, \
functor<int>, inverse_functor<int>>, \ functor<int>, inverse_functor<int>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int16_t>, inverse_functor<int16_t>>, \
::paddle::operators::CompareOpKernel< \ ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \ ::paddle::platform::dev##DeviceContext, \
functor<int64_t>, inverse_functor<int64_t>>, \ functor<int64_t>, inverse_functor<int64_t>>, \
......
...@@ -93,6 +93,7 @@ REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ...@@ -93,6 +93,7 @@ REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::imperative::OpBase>); ops::CumsumGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>, REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>,
ops::CumKernel<CPU, ops::CumsumFunctor<double>>, ops::CumKernel<CPU, ops::CumsumFunctor<double>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int16_t>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int>>, ops::CumKernel<CPU, ops::CumsumFunctor<int>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int64_t>>); ops::CumKernel<CPU, ops::CumsumFunctor<int64_t>>);
......
...@@ -320,5 +320,6 @@ namespace ops = paddle::operators; ...@@ -320,5 +320,6 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
cumsum, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, float>, cumsum, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, double>, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int>, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -96,6 +96,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -96,6 +96,7 @@ REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
...@@ -106,6 +107,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -106,6 +107,7 @@ REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
...@@ -118,6 +120,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -118,6 +120,8 @@ REGISTER_OP_CPU_KERNEL(
float>, float>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>, double>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int16_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>, int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
......
...@@ -94,6 +94,7 @@ REGISTER_OPERATOR( ...@@ -94,6 +94,7 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_any_like, fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>, ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>, ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fill_any_like, fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>, ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>, ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>,
......
...@@ -183,7 +183,9 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp, ...@@ -183,7 +183,9 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>, REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
ops::GatherNdOpKernel<double>, ops::GatherNdOpKernel<double>,
ops::GatherNdOpKernel<int64_t>, ops::GatherNdOpKernel<int64_t>,
ops::GatherNdOpKernel<int>, ops::GatherNdOpKernel<bool>, ops::GatherNdOpKernel<int>,
ops::GatherNdOpKernel<int16_t>,
ops::GatherNdOpKernel<bool>,
ops::GatherNdOpKernel<uint8_t>); ops::GatherNdOpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>, REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>,
......
...@@ -103,6 +103,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>, ...@@ -103,6 +103,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
ops::GatherNdOpCUDAKernel<CUDA, double>, ops::GatherNdOpCUDAKernel<CUDA, double>,
ops::GatherNdOpCUDAKernel<CUDA, int64_t>, ops::GatherNdOpCUDAKernel<CUDA, int64_t>,
ops::GatherNdOpCUDAKernel<CUDA, int>, ops::GatherNdOpCUDAKernel<CUDA, int>,
ops::GatherNdOpCUDAKernel<CUDA, int16_t>,
ops::GatherNdOpCUDAKernel<CUDA, bool>, ops::GatherNdOpCUDAKernel<CUDA, bool>,
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>); ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);
......
...@@ -116,6 +116,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -116,6 +116,8 @@ REGISTER_OP_CPU_KERNEL(
ops::SumFunctor>, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16, ops::SumFunctor>, paddle::platform::float16, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int16_t,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>, ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t, ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>, ops::SumFunctor>,
......
...@@ -20,6 +20,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,6 +20,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::IdentityFunctor>, ops::ReduceCudaKernel<double, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor, ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::IdentityFunctor>, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int16_t, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::AddFunctor, kps::IdentityFunctor>, ops::ReduceCudaKernel<int, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::AddFunctor, kps::IdentityFunctor>, ops::ReduceCudaKernel<int64_t, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::complex<float>, kps::AddFunctor, ops::ReduceCudaKernel<paddle::platform::complex<float>, kps::AddFunctor,
......
...@@ -639,10 +639,12 @@ REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp, ...@@ -639,10 +639,12 @@ REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInferer); ops::ReshapeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int16_t, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int16_t,
ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
...@@ -659,15 +661,15 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, ...@@ -659,15 +661,15 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int16_t, ops::ReshapeKernel,
uint8_t, ops::ReshapeKernel, int64_t, int, ops::ReshapeKernel, uint8_t,
ops::ReshapeKernel, plat::float16, ops::ReshapeKernel, int64_t, ops::ReshapeKernel,
ops::ReshapeKernel, plat::bfloat16, plat::float16, ops::ReshapeKernel,
ops::ReshapeKernel); plat::bfloat16, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int16_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeKernel, int, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, uint8_t, int64_t, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, plat::bfloat16, ops::ReshapeGradKernel, plat::bfloat16,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
......
...@@ -362,6 +362,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -362,6 +362,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
...@@ -377,6 +378,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -377,6 +378,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
...@@ -391,6 +393,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -391,6 +393,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
...@@ -406,6 +409,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -406,6 +409,7 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
......
...@@ -24,6 +24,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -24,6 +24,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -41,6 +42,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -41,6 +42,7 @@ REGISTER_OP_CUDA_KERNEL(
plat::bfloat16>, plat::bfloat16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -56,6 +58,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -56,6 +58,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -73,6 +76,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -73,6 +76,7 @@ REGISTER_OP_CUDA_KERNEL(
plat::bfloat16>, plat::bfloat16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
......
...@@ -57,6 +57,7 @@ REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp, ...@@ -57,6 +57,7 @@ REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
ops::WhereIndexOpMaker); ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>, REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>,
ops::CPUWhereIndexKernel<int>, ops::CPUWhereIndexKernel<int>,
ops::CPUWhereIndexKernel<int16_t>,
ops::CPUWhereIndexKernel<bool>, ops::CPUWhereIndexKernel<bool>,
ops::CPUWhereIndexKernel<float>, ops::CPUWhereIndexKernel<float>,
ops::CPUWhereIndexKernel<double>); ops::CPUWhereIndexKernel<double>);
...@@ -158,6 +158,7 @@ class CUDAWhereIndexKernel : public framework::OpKernel<T> { ...@@ -158,6 +158,7 @@ class CUDAWhereIndexKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>, REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>,
ops::CUDAWhereIndexKernel<int>, ops::CUDAWhereIndexKernel<int>,
ops::CUDAWhereIndexKernel<int16_t>,
ops::CUDAWhereIndexKernel<bool>, ops::CUDAWhereIndexKernel<bool>,
ops::CUDAWhereIndexKernel<float>, ops::CUDAWhereIndexKernel<float>,
ops::CUDAWhereIndexKernel<double>); ops::CUDAWhereIndexKernel<double>);
...@@ -132,6 +132,7 @@ PT_REGISTER_KERNEL(add_grad, ...@@ -132,6 +132,7 @@ PT_REGISTER_KERNEL(add_grad,
pten::AddGradKernel, pten::AddGradKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::complex<float>, pten::dtype::complex<float>,
...@@ -143,6 +144,7 @@ PT_REGISTER_KERNEL(add_double_grad, ...@@ -143,6 +144,7 @@ PT_REGISTER_KERNEL(add_double_grad,
pten::AddDoubleGradKernel, pten::AddDoubleGradKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::complex<float>, pten::dtype::complex<float>,
...@@ -154,6 +156,7 @@ PT_REGISTER_KERNEL(add_triple_grad, ...@@ -154,6 +156,7 @@ PT_REGISTER_KERNEL(add_triple_grad,
pten::AddTripleGradKernel, pten::AddTripleGradKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::complex<float>, pten::dtype::complex<float>,
...@@ -165,6 +168,7 @@ PT_REGISTER_KERNEL(subtract_grad, ...@@ -165,6 +168,7 @@ PT_REGISTER_KERNEL(subtract_grad,
pten::SubtractGradKernel, pten::SubtractGradKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::complex<float>, pten::dtype::complex<float>,
...@@ -176,6 +180,7 @@ PT_REGISTER_KERNEL(subtract_double_grad, ...@@ -176,6 +180,7 @@ PT_REGISTER_KERNEL(subtract_double_grad,
pten::SubtractDoubleGradKernel, pten::SubtractDoubleGradKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::complex<float>, pten::dtype::complex<float>,
......
...@@ -95,6 +95,7 @@ PT_REGISTER_KERNEL(full_like, ...@@ -95,6 +95,7 @@ PT_REGISTER_KERNEL(full_like,
pten::FullLikeKernel, pten::FullLikeKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
bool, bool,
......
...@@ -124,6 +124,7 @@ PT_REGISTER_KERNEL(add_raw, ...@@ -124,6 +124,7 @@ PT_REGISTER_KERNEL(add_raw,
pten::AddRawKernel, pten::AddRawKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -134,6 +135,7 @@ PT_REGISTER_KERNEL(subtract_raw, ...@@ -134,6 +135,7 @@ PT_REGISTER_KERNEL(subtract_raw,
pten::SubtractRawKernel, pten::SubtractRawKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -167,6 +169,7 @@ PT_REGISTER_KERNEL(sum_raw, ...@@ -167,6 +169,7 @@ PT_REGISTER_KERNEL(sum_raw,
float, float,
double, double,
pten::dtype::float16, pten::dtype::float16,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
......
...@@ -56,6 +56,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -56,6 +56,7 @@ PT_REGISTER_KERNEL(flatten,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
int, int,
int64_t) {} int64_t) {}
...@@ -67,6 +68,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -67,6 +68,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
int, int,
int64_t) {} int64_t) {}
...@@ -80,6 +82,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -80,6 +82,7 @@ PT_REGISTER_KERNEL(flatten,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
int, int,
int64_t) {} int64_t) {}
...@@ -92,6 +95,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -92,6 +95,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
int, int,
int64_t) {} int64_t) {}
#endif #endif
...@@ -104,6 +108,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -104,6 +108,7 @@ PT_REGISTER_KERNEL(flatten,
float, float,
pten::dtype::float16, pten::dtype::float16,
int8_t, int8_t,
int16_t,
int, int,
int64_t) {} int64_t) {}
...@@ -114,6 +119,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -114,6 +119,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
float, float,
pten::dtype::float16, pten::dtype::float16,
int8_t, int8_t,
int16_t,
int, int,
int64_t) {} int64_t) {}
#endif #endif
...@@ -119,6 +119,7 @@ PT_REGISTER_KERNEL(full_like, ...@@ -119,6 +119,7 @@ PT_REGISTER_KERNEL(full_like,
pten::FullLikeKernel, pten::FullLikeKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
bool, bool,
......
...@@ -101,6 +101,7 @@ PT_REGISTER_KERNEL(add_raw, ...@@ -101,6 +101,7 @@ PT_REGISTER_KERNEL(add_raw,
pten::AddRawKernel, pten::AddRawKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
float16, float16,
...@@ -112,6 +113,7 @@ PT_REGISTER_KERNEL(subtract_raw, ...@@ -112,6 +113,7 @@ PT_REGISTER_KERNEL(subtract_raw,
pten::SubtractRawKernel, pten::SubtractRawKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
float16, float16,
...@@ -148,6 +150,7 @@ PT_REGISTER_KERNEL(sum_raw, ...@@ -148,6 +150,7 @@ PT_REGISTER_KERNEL(sum_raw,
float, float,
double, double,
float16, float16,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
......
...@@ -92,6 +92,7 @@ PT_REGISTER_KERNEL(sum, ...@@ -92,6 +92,7 @@ PT_REGISTER_KERNEL(sum,
float, float,
double, double,
pten::dtype::float16, pten::dtype::float16,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -105,6 +106,7 @@ PT_REGISTER_KERNEL(add, ...@@ -105,6 +106,7 @@ PT_REGISTER_KERNEL(add,
pten::AddKernel, pten::AddKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -115,6 +117,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -115,6 +117,7 @@ PT_REGISTER_KERNEL(subtract,
pten::SubtractKernel, pten::SubtractKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -158,6 +161,7 @@ PT_REGISTER_KERNEL(sum, ...@@ -158,6 +161,7 @@ PT_REGISTER_KERNEL(sum,
float, float,
double, double,
pten::dtype::float16, pten::dtype::float16,
int16_t,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -170,6 +174,7 @@ PT_REGISTER_KERNEL(add, ...@@ -170,6 +174,7 @@ PT_REGISTER_KERNEL(add,
pten::AddKernel, pten::AddKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::float16, pten::dtype::float16,
...@@ -181,6 +186,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -181,6 +186,7 @@ PT_REGISTER_KERNEL(subtract,
pten::SubtractKernel, pten::SubtractKernel,
float, float,
double, double,
int16_t,
int, int,
int64_t, int64_t,
pten::dtype::float16, pten::dtype::float16,
......
...@@ -6276,7 +6276,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -6276,7 +6276,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
return dygraph_utils._append_activation_in_dygraph(out, act) return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(x, 'x', [ check_variable_and_dtype(x, 'x', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16' 'float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'bool',
'uint16'
], 'reshape') ], 'reshape')
check_type(shape, 'shape', (list, tuple, Variable), 'reshape') check_type(shape, 'shape', (list, tuple, Variable), 'reshape')
check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape') check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape')
...@@ -6456,10 +6457,10 @@ def unsqueeze(input, axes, name=None): ...@@ -6456,10 +6457,10 @@ def unsqueeze(input, axes, name=None):
return out return out
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze') check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype( check_variable_and_dtype(input, 'input', [
input, 'input', 'float16', 'float32', 'float64', 'bool', 'int8', 'int16', 'int32',
['float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64'], 'int64'
'unsqueeze') ], 'unsqueeze')
helper = LayerHelper("unsqueeze2", **locals()) helper = LayerHelper("unsqueeze2", **locals())
inputs = {"X": input} inputs = {"X": input}
attrs = {} attrs = {}
...@@ -8539,9 +8540,9 @@ def gather_nd(input, index, name=None): ...@@ -8539,9 +8540,9 @@ def gather_nd(input, index, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.gather_nd(input, index) return _C_ops.gather_nd(input, index)
check_variable_and_dtype(input, 'input', check_variable_and_dtype(
['bool', 'float32', 'float64', 'int32', 'int64'], input, 'input',
'gather_np') ['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], 'gather_np')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np')
helper = LayerHelper('gather_nd', **locals()) helper = LayerHelper('gather_nd', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
......
...@@ -250,12 +250,12 @@ def cast(x, dtype): ...@@ -250,12 +250,12 @@ def cast(x, dtype):
return out return out
check_variable_and_dtype(x, 'x', [ check_variable_and_dtype(x, 'x', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8', 'bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64',
'uint16' 'uint8', 'uint16'
], 'cast') ], 'cast')
check_dtype(dtype, 'dtype', [ check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32',
'uint8', 'uint16' 'int64', 'uint8', 'uint16'
], 'cast') ], 'cast')
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
......
...@@ -109,15 +109,6 @@ class TestCastOpError(unittest.TestCase): ...@@ -109,15 +109,6 @@ class TestCastOpError(unittest.TestCase):
x1 = fluid.create_lod_tensor( x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()) np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32') self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be bool, float16, float32, float64, int32, int64, uint8.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
def test_dtype_type():
x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32')
output = fluid.layers.cast(x=x4, dtype='int16')
self.assertRaises(TypeError, test_dtype_type)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -81,12 +81,6 @@ class TestFullOpError(unittest.TestCase): ...@@ -81,12 +81,6 @@ class TestFullOpError(unittest.TestCase):
x=input_data, x=input_data,
fill_value=2, fill_value=2,
dtype='uint4') dtype='uint4')
self.assertRaises(
TypeError,
paddle.full_like,
x=input_data,
fill_value=2,
dtype='int16')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -67,15 +67,6 @@ class TestCastOpError(unittest.TestCase): ...@@ -67,15 +67,6 @@ class TestCastOpError(unittest.TestCase):
x1 = fluid.create_lod_tensor( x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0)) np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32') self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be float32, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
def test_dtype_type():
x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32')
output = fluid.layers.cast(x=x4, dtype='int16')
self.assertRaises(TypeError, test_dtype_type)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -219,11 +219,13 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -219,11 +219,13 @@ def full_like(x, fill_value, dtype=None, name=None):
helper = LayerHelper("full_like", **locals()) helper = LayerHelper("full_like", **locals())
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], x, 'x',
['bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64'],
'full_like') 'full_like')
check_dtype(dtype, 'dtype', check_dtype(
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], dtype, 'dtype',
'full_like/zeros_like/ones_like') ['bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64'],
'full_like/zeros_like/ones_like')
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
......
...@@ -672,7 +672,8 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -672,7 +672,8 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if not in_dygraph_mode(): if not in_dygraph_mode():
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64', 'uint8'], x, 'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'],
'flatten') 'flatten')
x_dim = len(x.shape) x_dim = len(x.shape)
......
...@@ -885,7 +885,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -885,7 +885,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', x, 'x', ['bool', 'float16', 'float32', 'float64',
'int32', 'int64', 'complex64', 'complex128', 'int16', 'int32', 'int64', 'complex64', 'complex128',
u'bool', u'float16', u'float32', u'float64', u'bool', u'float16', u'float32', u'float64',
u'int32', u'int64', u'complex64', u'complex128'], 'sum') u'int32', u'int64', u'complex64', u'complex128'], 'sum')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册