From cb81befa9474a7f3d1b30f7afbfba74ea6806093 Mon Sep 17 00:00:00 2001 From: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Date: Tue, 18 Apr 2023 15:05:53 +0800 Subject: [PATCH] register fluid kerenls to phi [part 6.5] (#52882) * update * fix bug * update * fix bug --- paddle/fluid/framework/operator.cc | 41 ++++++++++++------- .../collective/mp_allreduce_sum_op.cc | 17 +++++--- .../collective/mp_allreduce_sum_op.cu.cc | 26 ++++++++---- .../operators/detection/iou_similarity_op.cc | 5 +-- .../operators/detection/iou_similarity_op.cu | 5 +-- .../operators/detection/iou_similarity_op.h | 2 +- .../detection/locality_aware_nms_op.cc | 11 +++-- .../detection/mine_hard_examples_op.cc | 11 +++-- .../operators/detection/multiclass_nms_op.cc | 20 +++++---- .../operators/fused/multihead_matmul_op.cu | 17 ++++---- .../fluid/operators/match_matrix_tensor_op.cc | 21 ++++++---- paddle/fluid/operators/mean_iou_op.cc | 7 ++-- paddle/fluid/operators/mean_iou_op.cu | 8 ++-- paddle/fluid/operators/mean_iou_op.h | 2 +- paddle/fluid/operators/minus_op.cc | 6 ++- paddle/fluid/operators/minus_op.h | 2 +- .../operators/optimizers/lars_momentum_op.cc | 6 +-- .../operators/optimizers/lars_momentum_op.cu | 15 ++++--- .../operators/optimizers/lars_momentum_op.h | 2 +- .../optimizers/unity_build_rule.cmake | 1 - 20 files changed, 133 insertions(+), 92 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c37a5260bcd..fc605bebc54 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2483,21 +2483,32 @@ Scope* OperatorWithKernel::PrepareData( } std::unique_ptr new_expected_kernel_key = nullptr; - if (run_phi_kernel_ && in_def != nullptr && - in_def->backend != phi::Backend::ALL_BACKEND) { - auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); - if ((in_def->backend != tensor_backend && - !(in_def->backend == phi::Backend::GPUDNN && - tensor_backend == phi::Backend::GPU) && - !(in_def->backend == phi::Backend::KPS && - tensor_backend == phi::Backend::XPU) && - !(in_def->backend == phi::Backend::ONEDNN && - tensor_backend == phi::Backend::CPU)) || - tensor_in->place().GetType() == AllocationType::GPUPINNED) { - new_expected_kernel_key = - std::make_unique(in_def->backend, - expected_kernel_key.layout(), - expected_kernel_key.dtype()); + if (run_phi_kernel_) { + if (phi_kernel_->GetKernelRegisteredType() == + phi::KernelRegisteredType::STRUCTURE) { + if (!backends_are_same_class(kernel_type_for_var.backend(), + expected_kernel_key.backend())) { + new_expected_kernel_key = + std::make_unique(expected_kernel_key.backend(), + expected_kernel_key.layout(), + expected_kernel_key.dtype()); + } + } else if (in_def != nullptr && + in_def->backend != phi::Backend::ALL_BACKEND) { + auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); + if ((in_def->backend != tensor_backend && + !(in_def->backend == phi::Backend::GPUDNN && + tensor_backend == phi::Backend::GPU) && + !(in_def->backend == phi::Backend::KPS && + tensor_backend == phi::Backend::XPU) && + !(in_def->backend == phi::Backend::ONEDNN && + tensor_backend == phi::Backend::CPU)) || + tensor_in->place().GetType() == AllocationType::GPUPINNED) { + new_expected_kernel_key = + std::make_unique(in_def->backend, + expected_kernel_key.layout(), + expected_kernel_key.dtype()); + } } } diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc index 121595714e0..dcc59f703ff 100644 --- a/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc @@ -73,6 +73,8 @@ class MpAllReduceSumOpGradMaker : public framework::SingleGradOpMaker { DECLARE_INPLACE_OP_INFERER(MpAllReduceSumInplaceInferer, {"X", "Out"}); +DEFINE_C_ALLREDUCE_CPU_KERNEL(MpAllReduceSum, kRedSum); + } // namespace operators } // namespace paddle @@ -86,9 +88,12 @@ REGISTER_OPERATOR(mp_allreduce_sum, ops::MpAllReduceSumOpMaker, ops::MpAllReduceSumInplaceInferer); -REGISTER_OP_CPU_KERNEL(mp_allreduce_sum, - ops::CAllReduceOpCPUKernel, - ops::CAllReduceOpCPUKernel, - ops::CAllReduceOpCPUKernel, - ops::CAllReduceOpCPUKernel, - ops::CAllReduceOpCPUKernel) +PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum, + CPU, + ALL_LAYOUT, + ops::MpAllReduceSumCPUKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc index 26092a55e0f..b6af2dbd1c8 100644 --- a/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc @@ -15,16 +15,24 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h" +namespace paddle { +namespace operators { +DEFINE_C_ALLREDUCE_CUDA_KERNEL(MpAllReduceSum, kRedSum) +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - mp_allreduce_sum, - ops::CAllReduceOpCUDAKernel, +PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum, + GPU, + ALL_LAYOUT, + ops::MpAllReduceSumCUDAKernel, + float, + double, + int, + int64_t, #if NCCL_VERSION_CODE >= 21000 - ops::CAllReduceOpCUDAKernel, + plat::bfloat16, #endif - ops::CAllReduceOpCUDAKernel, - ops::CAllReduceOpCUDAKernel, - ops::CAllReduceOpCUDAKernel, - ops::CAllReduceOpCUDAKernel) + plat::float16) { +} diff --git a/paddle/fluid/operators/detection/iou_similarity_op.cc b/paddle/fluid/operators/detection/iou_similarity_op.cc index 406114c588a..ca107077232 100644 --- a/paddle/fluid/operators/detection/iou_similarity_op.cc +++ b/paddle/fluid/operators/detection/iou_similarity_op.cc @@ -114,6 +114,5 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(iou_similarity, - ops::IOUSimilarityKernel, - ops::IOUSimilarityKernel); +PD_REGISTER_STRUCT_KERNEL( + iou_similarity, CPU, ALL_LAYOUT, ops::IOUSimilarityKernel, float, double) {} diff --git a/paddle/fluid/operators/detection/iou_similarity_op.cu b/paddle/fluid/operators/detection/iou_similarity_op.cu index dc27f326538..e4e001e0965 100644 --- a/paddle/fluid/operators/detection/iou_similarity_op.cu +++ b/paddle/fluid/operators/detection/iou_similarity_op.cu @@ -15,6 +15,5 @@ limitations under the License. */ #include "paddle/fluid/operators/detection/iou_similarity_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(iou_similarity, - ops::IOUSimilarityKernel, - ops::IOUSimilarityKernel); +PD_REGISTER_STRUCT_KERNEL( + iou_similarity, GPU, ALL_LAYOUT, ops::IOUSimilarityKernel, float, double) {} diff --git a/paddle/fluid/operators/detection/iou_similarity_op.h b/paddle/fluid/operators/detection/iou_similarity_op.h index 3d14bb2ae62..75e7b909624 100644 --- a/paddle/fluid/operators/detection/iou_similarity_op.h +++ b/paddle/fluid/operators/detection/iou_similarity_op.h @@ -105,7 +105,7 @@ struct IOUSimilarityFunctor { namespace paddle { namespace operators { -template +template class IOUSimilarityKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/detection/locality_aware_nms_op.cc b/paddle/fluid/operators/detection/locality_aware_nms_op.cc index 9a230dc3224..dc512c9f839 100644 --- a/paddle/fluid/operators/detection/locality_aware_nms_op.cc +++ b/paddle/fluid/operators/detection/locality_aware_nms_op.cc @@ -160,7 +160,7 @@ void GetMaxScoreIndexWithLocalityAware( } } -template +template class LocalityAwareNMSKernel : public framework::OpKernel { public: void LocalityAwareNMSFast(phi::DenseTensor* bbox, @@ -520,6 +520,9 @@ REGISTER_OPERATOR( ops::LocalityAwareNMSOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(locality_aware_nms, - ops::LocalityAwareNMSKernel, - ops::LocalityAwareNMSKernel); +PD_REGISTER_STRUCT_KERNEL(locality_aware_nms, + CPU, + ALL_LAYOUT, + ops::LocalityAwareNMSKernel, + float, + double) {} diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index a673d64c52d..3e2ad2c8564 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -49,7 +49,7 @@ inline MiningType GetMiningType(std::string str) { } } -template +template class MineHardExamplesKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -403,6 +403,9 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(mine_hard_examples, - ops::MineHardExamplesKernel, - ops::MineHardExamplesKernel); +PD_REGISTER_STRUCT_KERNEL(mine_hard_examples, + CPU, + ALL_LAYOUT, + ops::MineHardExamplesKernel, + float, + double) {} diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index ed5fca119cf..159027dacad 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -143,7 +143,7 @@ void SliceOneClass(const platform::DeviceContext& ctx, } } -template +template class MultiClassNMSKernel : public framework::OpKernel { public: void NMSFast(const phi::DenseTensor& bbox, @@ -629,6 +629,9 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker { } }; +template +class MultiClassNMS2Kernel : public MultiClassNMSKernel {}; + } // namespace operators } // namespace paddle @@ -643,18 +646,21 @@ REGISTER_OPERATOR( ops::MultiClassNMSOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(multiclass_nms, - ops::MultiClassNMSKernel, - ops::MultiClassNMSKernel); +PD_REGISTER_STRUCT_KERNEL( + multiclass_nms, CPU, ALL_LAYOUT, ops::MultiClassNMSKernel, float, double) {} + REGISTER_OPERATOR( multiclass_nms2, ops::MultiClassNMS2Op, ops::MultiClassNMS2OpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(multiclass_nms2, - ops::MultiClassNMSKernel, - ops::MultiClassNMSKernel); +PD_REGISTER_STRUCT_KERNEL(multiclass_nms2, + CPU, + ALL_LAYOUT, + ops::MultiClassNMS2Kernel, + float, + double) {} REGISTER_OPERATOR( multiclass_nms3, diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 0b9f23657a2..4cd7254eaf1 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -270,7 +270,7 @@ __global__ void broadcast_batch_head_number(const T *src, } } -template +template class MultiHeadMatMulV2Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { @@ -423,12 +423,15 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 -REGISTER_OP_CUDA_KERNEL( - multihead_matmul, - ops::MultiHeadMatMulV2Kernel, - ops::MultiHeadMatMulV2Kernel); +PD_REGISTER_STRUCT_KERNEL(multihead_matmul, + GPU, + ALL_LAYOUT, + ops::MultiHeadMatMulV2Kernel, + float, + plat::float16) {} #else -REGISTER_OP_CUDA_KERNEL(multihead_matmul, - ops::MultiHeadMatMulV2Kernel); +PD_REGISTER_STRUCT_KERNEL( + multihead_matmul, GPU, ALL_LAYOUT, ops::MultiHeadMatMulV2Kernel, float) {} #endif diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index 773d9f223f8..271a027c456 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -240,7 +240,7 @@ void MatchMatrixTensorOpMaker::Make() { )DOC"); } -template +template class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -321,7 +321,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { } }; -template +template class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -458,10 +458,13 @@ REGISTER_OPERATOR( ops::MatchMatrixTensorGradOpMaker); REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad); -REGISTER_OP_CPU_KERNEL( - match_matrix_tensor, - ops::CPUMatchMatrixTensorOPKernel); - -REGISTER_OP_CPU_KERNEL( - match_matrix_tensor_grad, - ops::CPUMatchMatrixTensorOPGradKernel); +PD_REGISTER_STRUCT_KERNEL(match_matrix_tensor, + CPU, + ALL_LAYOUT, + ops::CPUMatchMatrixTensorOPKernel, + float) {} +PD_REGISTER_STRUCT_KERNEL(match_matrix_tensor_grad, + CPU, + ALL_LAYOUT, + ops::CPUMatchMatrixTensorOPGradKernel, + float) {} diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index 3728fbee534..27fd86fab08 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -107,7 +107,6 @@ REGISTER_OPERATOR( ops::MeanIoUOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(mean_iou, - ops::MeanIoUKernel, - ops::MeanIoUKernel, - ops::MeanIoUKernel); + +PD_REGISTER_STRUCT_KERNEL( + mean_iou, CPU, ALL_LAYOUT, ops::MeanIoUKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/mean_iou_op.cu b/paddle/fluid/operators/mean_iou_op.cu index e73496a46a0..1dbc9f6fdc8 100644 --- a/paddle/fluid/operators/mean_iou_op.cu +++ b/paddle/fluid/operators/mean_iou_op.cu @@ -88,7 +88,7 @@ __global__ void ComputeIoUCUDAKernel( } } -template +template class MeanIoUCUDAOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -166,7 +166,5 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(mean_iou, - ops::MeanIoUCUDAOpKernel, - ops::MeanIoUCUDAOpKernel, - ops::MeanIoUCUDAOpKernel); +PD_REGISTER_STRUCT_KERNEL( + mean_iou, GPU, ALL_LAYOUT, ops::MeanIoUCUDAOpKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/mean_iou_op.h b/paddle/fluid/operators/mean_iou_op.h index 9be97f5ba95..436cf84a548 100644 --- a/paddle/fluid/operators/mean_iou_op.h +++ b/paddle/fluid/operators/mean_iou_op.h @@ -27,7 +27,7 @@ template using EigenTensor = framework::EigenTensor; -template +template class MeanIoUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/minus_op.cc b/paddle/fluid/operators/minus_op.cc index 398a254f45c..8c33a5da1ba 100644 --- a/paddle/fluid/operators/minus_op.cc +++ b/paddle/fluid/operators/minus_op.cc @@ -155,6 +155,8 @@ REGISTER_OPERATOR(minus, ops::MinusOpMaker, ops::MinusGradDescMaker, ops::MinusGradMaker); -REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel); +PD_REGISTER_STRUCT_KERNEL(minus, CPU, ALL_LAYOUT, ops::MinusKernel, float) {} -REGISTER_OP_CUDA_KERNEL(minus, ops::MinusKernel); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_STRUCT_KERNEL(minus, GPU, ALL_LAYOUT, ops::MinusKernel, float) {} +#endif diff --git a/paddle/fluid/operators/minus_op.h b/paddle/fluid/operators/minus_op.h index 0a576e875a4..8cc18fe0c97 100644 --- a/paddle/fluid/operators/minus_op.h +++ b/paddle/fluid/operators/minus_op.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class MinusKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index b5b15fa09ea..bf8461d49a0 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -233,6 +233,6 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::LarsMomentumOpVarTypeInference); -REGISTER_OP_CPU_KERNEL(lars_momentum, - ops::LarsMomentumOpKernel, - ops::LarsMomentumOpKernel); + +PD_REGISTER_STRUCT_KERNEL( + lars_momentum, CPU, ALL_LAYOUT, ops::LarsMomentumOpKernel, float, double) {} diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index c91752fef0b..83293c991e9 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -474,7 +474,7 @@ inline void SeparatedLarsMomentumOpCUDAKernel(const phi::GPUContext& cuda_ctx, is_amp); } -template +template class LarsMomentumOpCUDAKernel : public framework::OpKernel { using MT = MultiPrecisionType; @@ -679,8 +679,11 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - lars_momentum, - ops::LarsMomentumOpCUDAKernel, - ops::LarsMomentumOpCUDAKernel, - ops::LarsMomentumOpCUDAKernel); +namespace plat = paddle::platform; +PD_REGISTER_STRUCT_KERNEL(lars_momentum, + GPU, + ALL_LAYOUT, + ops::LarsMomentumOpCUDAKernel, + float, + double, + plat::float16) {} diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h index 530afa20a4d..70bf0c9186b 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class LarsMomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/optimizers/unity_build_rule.cmake b/paddle/fluid/operators/optimizers/unity_build_rule.cmake index 8f89abf1a09..272eb118892 100644 --- a/paddle/fluid/operators/optimizers/unity_build_rule.cmake +++ b/paddle/fluid/operators/optimizers/unity_build_rule.cmake @@ -10,7 +10,6 @@ register_unity_group( lars_momentum_op.cc proximal_adagrad_op.cc adam_op.cc - dgc_momentum_op.cc proximal_gd_op.cc decayed_adagrad_op.cc adadelta_op.cc -- GitLab