diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc index 0f52362c21e911d50339c94255cb0a8a5bd4a99d..f965cc077c48138bab02cf750d91da3537658fe5 100644 --- a/paddle/fluid/operators/add_position_encoding_op.cc +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -121,11 +121,15 @@ REGISTER_OPERATOR( ops::AddPositionEncodingGradOpMaker); REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad); -REGISTER_OP_CPU_KERNEL(add_position_encoding, - ops::AddPositionEncodingKernel, - ops::AddPositionEncodingKernel); - -REGISTER_OP_CPU_KERNEL( - add_position_encoding_grad, - ops::AddPositionEncodingGradKernel, - ops::AddPositionEncodingGradKernel); +PD_REGISTER_STRUCT_KERNEL(add_position_encoding, + CPU, + ALL_LAYOUT, + ops::AddPositionEncodingKernel, + float, + double) {} +PD_REGISTER_STRUCT_KERNEL(add_position_encoding_grad, + CPU, + ALL_LAYOUT, + ops::AddPositionEncodingGradKernel, + float, + double) {} diff --git a/paddle/fluid/operators/add_position_encoding_op.h b/paddle/fluid/operators/add_position_encoding_op.h index 0cf67de6ca29b2bb9aa981209ad543a1a5b13224..4547f6321a01d4912dc28d70539d1b40234b1bfc 100644 --- a/paddle/fluid/operators/add_position_encoding_op.h +++ b/paddle/fluid/operators/add_position_encoding_op.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class AddPositionEncodingKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -99,7 +99,7 @@ class AddPositionEncodingKernel : public framework::OpKernel { } }; -template +template class AddPositionEncodingGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 90d8c8b0ce12fb9f355f46b2ccfd7e9707cd95f7..565054c265151149bf54ed914180f29e6a06d6b5 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -184,7 +184,7 @@ template using ConstEigenVectorArrayMap = Eigen::Map>; -template +template class AffineChannelKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -228,7 +228,7 @@ class AffineChannelKernel : public framework::OpKernel { } }; -template +template class AffineChannelGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -353,9 +353,12 @@ REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelNoNeedBufferVarsInference, ops::AffineChannelGradInplaceInferer); -REGISTER_OP_CPU_KERNEL(affine_channel, - ops::AffineChannelKernel, - ops::AffineChannelKernel); -REGISTER_OP_CPU_KERNEL(affine_channel_grad, - ops::AffineChannelGradKernel, - ops::AffineChannelGradKernel); +PD_REGISTER_STRUCT_KERNEL( + affine_channel, CPU, ALL_LAYOUT, ops::AffineChannelKernel, float, double) {} + +PD_REGISTER_STRUCT_KERNEL(affine_channel_grad, + CPU, + ALL_LAYOUT, + ops::AffineChannelGradKernel, + float, + double) {} diff --git a/paddle/fluid/operators/affine_channel_op.cu b/paddle/fluid/operators/affine_channel_op.cu index 16c297459ce046bf48731febb44a58ab102dd826..6ec8d77da2c85687036028b2c26eea93e9884504 100644 --- a/paddle/fluid/operators/affine_channel_op.cu +++ b/paddle/fluid/operators/affine_channel_op.cu @@ -48,7 +48,7 @@ __global__ void KeAffineChannelCUDA(const T* x, } } -template +template class AffineChannelCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -132,7 +132,7 @@ __global__ void AffineChannelScaleBiasGradientCUDAKernel(const T* dy, } } -template +template class AffineChannelGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -211,9 +211,15 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; using CUDA = phi::GPUContext; -REGISTER_OP_CUDA_KERNEL(affine_channel, - ops::AffineChannelCUDAKernel, - ops::AffineChannelCUDAKernel); -REGISTER_OP_CUDA_KERNEL(affine_channel_grad, - ops::AffineChannelGradCUDAKernel, - ops::AffineChannelGradCUDAKernel); +PD_REGISTER_STRUCT_KERNEL(affine_channel, + GPU, + ALL_LAYOUT, + ops::AffineChannelCUDAKernel, + float, + double) {} +PD_REGISTER_STRUCT_KERNEL(affine_channel_grad, + GPU, + ALL_LAYOUT, + ops::AffineChannelGradCUDAKernel, + float, + double) {} diff --git a/paddle/fluid/operators/amp/alloc_float_status_op.cc b/paddle/fluid/operators/amp/alloc_float_status_op.cc index 24e960867716ef0aac04b26c66b24bdfd0098808..2c1b4b201e5c3940219ca775a9597b4018244a6d 100644 --- a/paddle/fluid/operators/amp/alloc_float_status_op.cc +++ b/paddle/fluid/operators/amp/alloc_float_status_op.cc @@ -51,7 +51,7 @@ class AllocFloatStatusMaker : public framework::OpProtoAndCheckerMaker { } }; -template +template class AllocFloatStatusKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -73,5 +73,5 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(alloc_float_status, - ops::AllocFloatStatusKernel); +PD_REGISTER_STRUCT_KERNEL( + alloc_float_status, CPU, ALL_LAYOUT, ops::AllocFloatStatusKernel, float) {} diff --git a/paddle/fluid/operators/ascend_trigger_op.cc b/paddle/fluid/operators/ascend_trigger_op.cc index b312f97d3f93d1d8147950a64362297bb3b18cae..7d23e16804ce32939d19512459be03f121502657 100644 --- a/paddle/fluid/operators/ascend_trigger_op.cc +++ b/paddle/fluid/operators/ascend_trigger_op.cc @@ -50,4 +50,6 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(ascend_trigger, ops::AscendTriggerOp, ops::AscendTriggerOpMaker); -REGISTER_OP_CPU_KERNEL(ascend_trigger, ops::AscendTriggerCPUKernel) + +PD_REGISTER_STRUCT_KERNEL( + ascend_trigger, CPU, ALL_LAYOUT, ops::AscendTriggerCPUKernel, float) {} diff --git a/paddle/fluid/operators/ascend_trigger_op.h b/paddle/fluid/operators/ascend_trigger_op.h index 943960e1bb1c5a6cc3cdb3c27d8ca110819dd8db..09e160d8f6aac36cc8aa75b0a2dc1cbef378747e 100644 --- a/paddle/fluid/operators/ascend_trigger_op.h +++ b/paddle/fluid/operators/ascend_trigger_op.h @@ -25,7 +25,7 @@ namespace paddle { namespace operators { -template +template class AscendTriggerCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { diff --git a/paddle/fluid/operators/collective/allreduce_op.cc b/paddle/fluid/operators/collective/allreduce_op.cc index e136d8ef6e3889fb92b80daaab289cbe0e569cfa..ca9cd1ca52952b4eb9d243c9a74ae7d17e09c207 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cc @@ -73,9 +73,12 @@ REGISTER_OP_WITHOUT_GRADIENT(allreduce, ops::AllReduceOp, ops::AllReduceOpMaker); -REGISTER_OP_CPU_KERNEL(allreduce, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel); +PD_REGISTER_STRUCT_KERNEL(allreduce, + CPU, + ALL_LAYOUT, + ops::AllReduceOpKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/fluid/operators/collective/allreduce_op.cu.cc b/paddle/fluid/operators/collective/allreduce_op.cu.cc index 174a5afa69dc60c373f65f00796a8398ed42d7e9..0c9b95c76866bb1f67f0735348217a13be31d43d 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cu.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cu.cc @@ -17,9 +17,12 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(allreduce, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel, - ops::AllReduceOpKernel); +PD_REGISTER_STRUCT_KERNEL(allreduce, + GPU, + ALL_LAYOUT, + ops::AllReduceOpKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/fluid/operators/collective/allreduce_op.h b/paddle/fluid/operators/collective/allreduce_op.h index a4f935a9c95860d28c490f62c4757f3d6f1dd1bb..794e37c312a9babae07c6f3330381c55d473db66 100644 --- a/paddle/fluid/operators/collective/allreduce_op.h +++ b/paddle/fluid/operators/collective/allreduce_op.h @@ -28,7 +28,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class AllReduceOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/collective/alltoall_op.cc b/paddle/fluid/operators/collective/alltoall_op.cc index e6fa37e0e42d8f55e33dea841308efb3d087790d..5248a9ac37edb61cc534e64701f10f48d021674b 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cc @@ -69,9 +69,12 @@ namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker) -REGISTER_OP_CPU_KERNEL(alltoall, - ops::AllToAllOpCPUKernel, - ops::AllToAllOpCPUKernel, - ops::AllToAllOpCPUKernel, - ops::AllToAllOpCPUKernel, - ops::AllToAllOpCPUKernel); +PD_REGISTER_STRUCT_KERNEL(alltoall, + CPU, + ALL_LAYOUT, + ops::AllToAllOpCPUKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index fd67342b3affa3e7b97f46c03734058c1a2b74cb..aacd76af4af0586de2cd2c97b439d8c380eaeefc 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class AllToAllOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -92,12 +92,16 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(alltoall, - ops::AllToAllOpCUDAKernel, - ops::AllToAllOpCUDAKernel, +PD_REGISTER_STRUCT_KERNEL(alltoall, + GPU, + ALL_LAYOUT, + ops::AllToAllOpCUDAKernel, + float, + double, #if NCCL_VERSION_CODE >= 21000 - ops::AllToAllOpCUDAKernel, + plat::bfloat16, #endif - ops::AllToAllOpCUDAKernel, - ops::AllToAllOpCUDAKernel, - ops::AllToAllOpCUDAKernel); + int, + int64_t, + plat::float16) { +} diff --git a/paddle/fluid/operators/collective/alltoall_op.h b/paddle/fluid/operators/collective/alltoall_op.h index 61eec44093794ccaf820d257d7c2c6b363e10391..bbded798efd364641b0bd8e0c319b8bf5895d64e 100644 --- a/paddle/fluid/operators/collective/alltoall_op.h +++ b/paddle/fluid/operators/collective/alltoall_op.h @@ -29,7 +29,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class AllToAllOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 7a1397ba08f17d75aa7abec2d3e2e426f9a02b59..71f7bb938a92d8d3046015420c89e571dffe7d65 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -175,6 +175,9 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(anchor_generator, - ops::AnchorGeneratorOpKernel, - ops::AnchorGeneratorOpKernel); +PD_REGISTER_STRUCT_KERNEL(anchor_generator, + CPU, + ALL_LAYOUT, + ops::AnchorGeneratorOpKernel, + float, + double) {} diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cu b/paddle/fluid/operators/detection/anchor_generator_op.cu index eeb4d731b7b3b8e2398a8e8e648939b7e7d66845..342a492794d5585f836f915c81577518335ea9b4 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cu +++ b/paddle/fluid/operators/detection/anchor_generator_op.cu @@ -71,7 +71,7 @@ __global__ void SetVariance(T* out, CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; } } -template +template class AnchorGeneratorOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -133,6 +133,10 @@ class AnchorGeneratorOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(anchor_generator, - ops::AnchorGeneratorOpCUDAKernel, - ops::AnchorGeneratorOpCUDAKernel); + +PD_REGISTER_STRUCT_KERNEL(anchor_generator, + GPU, + ALL_LAYOUT, + ops::AnchorGeneratorOpCUDAKernel, + float, + double) {} diff --git a/paddle/fluid/operators/detection/anchor_generator_op.h b/paddle/fluid/operators/detection/anchor_generator_op.h index 726b65fb1f427ce19cb75f9495bf94f6cfb237a0..9e667d9f99fc1cbbe02c24f607485bcc63bf135b 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.h +++ b/paddle/fluid/operators/detection/anchor_generator_op.h @@ -44,7 +44,7 @@ extern __global__ void SetVariance(T* out, const int num); #endif -template +template class AnchorGeneratorOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py b/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py index cfe29dce1058d751d1303583c4891016f134d384..6c30031e6308bc6a6dea6a17474554c7f04b3ffa 100644 --- a/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py +++ b/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py @@ -77,7 +77,7 @@ class TestAnchorGeneratorOp(OpTest): self.outputs = {'Anchors': self.out_anchors, 'Variances': self.out_var} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) def setUp(self): self.op_type = "anchor_generator"