From 28492bf6b3699eabe991449616c0853637ec0212 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 19 Apr 2023 17:06:02 +0800 Subject: [PATCH] [PHI]Unify fluid kernel (Part5) (#53003) * unify_kernel * fix compile bugs * fix py3 bugs * fix xpu bugs --- .../fluid/operators/collective/send_v2_op.cc | 15 ++++++---- .../operators/collective/send_v2_op.cu.cc | 22 ++++++++------ .../fluid/operators/collective/send_v2_op.h | 2 +- .../operators/pscore/send_and_recv_op.cc | 30 ++++++++++++------- .../pscore/send_and_recv_op_cpu_test.cc | 2 +- .../pscore/send_and_recv_op_gpu_test.cc | 2 +- paddle/fluid/operators/sample_logits_op.cc | 14 +++++---- paddle/fluid/operators/sample_logits_op.cu | 22 +++++++++----- paddle/fluid/operators/sample_logits_op.h | 4 +-- paddle/fluid/operators/sampling_id_op.cc | 9 ++++-- paddle/fluid/operators/sampling_id_op.cu | 9 ++++-- paddle/fluid/operators/sampling_id_op.h | 2 +- paddle/fluid/operators/sampling_id_op_xpu.cc | 7 +++-- paddle/fluid/operators/seed_op.cc | 2 +- paddle/fluid/operators/seed_op.cu | 6 ++-- paddle/fluid/operators/seed_op.h | 2 +- .../sequence_ops/sequence_concat_op.cc | 26 +++++++++------- .../sequence_ops/sequence_concat_op.cu.cc | 28 +++++++++-------- .../sequence_ops/sequence_concat_op.h | 4 +-- .../sequence_ops/unity_build_rule.cmake | 1 - 20 files changed, 125 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/operators/collective/send_v2_op.cc b/paddle/fluid/operators/collective/send_v2_op.cc index 8fca90c5ccc..5f1f766316c 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cc @@ -86,9 +86,12 @@ namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(send_v2, ops::SendOpV2, ops::SendOpV2Maker); -REGISTER_OP_CPU_KERNEL(send_v2, - ops::SendOpV2CPUKernel, - ops::SendOpV2CPUKernel, - ops::SendOpV2CPUKernel, - ops::SendOpV2CPUKernel, - ops::SendOpV2CPUKernel); +PD_REGISTER_STRUCT_KERNEL(send_v2, + CPU, + ALL_LAYOUT, + ops::SendOpV2CPUKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index e6418e0f032..7c1ab8ace34 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -104,7 +104,7 @@ void send_shape_info(const phi::DenseTensor& x, } #endif -template +template class SendOpV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -217,13 +217,17 @@ class SendOpV2CUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(send_v2, - ops::SendOpV2CUDAKernel, - ops::SendOpV2CUDAKernel, +PD_REGISTER_STRUCT_KERNEL(send_v2, + GPU, + ALL_LAYOUT, + ops::SendOpV2CUDAKernel, + float, + double, #if NCCL_VERSION_CODE >= 21000 - ops::SendOpV2CUDAKernel, + plat::bfloat16, #endif - ops::SendOpV2CUDAKernel, - ops::SendOpV2CUDAKernel, - ops::SendOpV2CUDAKernel, - ops::SendOpV2CUDAKernel); + int, + int64_t, + int8_t, + plat::float16) { +} diff --git a/paddle/fluid/operators/collective/send_v2_op.h b/paddle/fluid/operators/collective/send_v2_op.h index 6215fb1f3b6..047796dfe24 100644 --- a/paddle/fluid/operators/collective/send_v2_op.h +++ b/paddle/fluid/operators/collective/send_v2_op.h @@ -25,7 +25,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class SendOpV2CPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/pscore/send_and_recv_op.cc b/paddle/fluid/operators/pscore/send_and_recv_op.cc index d2526211163..99e8d04a9e3 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op.cc @@ -26,7 +26,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class SendAndRecvKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -98,17 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker); -REGISTER_OP_CUDA_KERNEL(send_and_recv, - ops::SendAndRecvKernel, - ops::SendAndRecvKernel, - ops::SendAndRecvKernel, - ops::SendAndRecvKernel); -REGISTER_OP_CPU_KERNEL(send_and_recv, - ops::SendAndRecvKernel, - ops::SendAndRecvKernel, - ops::SendAndRecvKernel, - ops::SendAndRecvKernel); +PD_REGISTER_STRUCT_KERNEL(send_and_recv, + CPU, + ALL_LAYOUT, + ops::SendAndRecvKernel, + float, + double, + int, + int64_t) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_STRUCT_KERNEL(send_and_recv, + GPU, + ALL_LAYOUT, + ops::SendAndRecvKernel, + float, + double, + int, + int64_t) {} +#endif REGISTER_OP_VERSION(send_and_recv) .AddCheckpoint( R"ROC(add new attributes [next_endpoints] [previous_endpoints] and [mode])ROC", diff --git a/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc b/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc index b575e7b5fa4..5087857c2f0 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc @@ -35,7 +35,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; using VarMsg = ::paddle::distributed::VariableMessage; USE_OP_ITSELF(scale); -USE_OP(send_and_recv); +PD_DECLARE_KERNEL(send_and_recv, CPU, ALL_LAYOUT); std::string get_ip_port() { std::mt19937 rng; diff --git a/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc b/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc index faac4865975..e00d7f1dac1 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc @@ -39,7 +39,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; using VarMsg = ::paddle::distributed::VariableMessage; USE_OP_ITSELF(scale); -USE_OP(send_and_recv); +PD_DECLARE_KERNEL(send_and_recv, GPU, ALL_LAYOUT); std::shared_ptr b_rpc_service2; diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index db9944ffb11..469f91b16d1 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -272,9 +272,11 @@ REGISTER_OPERATOR(sample_logits, ops::SampleLogitsGradMaker, ops::SampleLogitsGradMaker); REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad); -REGISTER_OP_CPU_KERNEL(sample_logits, - ops::SampleLogitsKernel, - ops::SampleLogitsKernel); -REGISTER_OP_CPU_KERNEL(sample_logits_grad, - ops::SampleLogitsGradKernel, - ops::SampleLogitsGradKernel); +PD_REGISTER_STRUCT_KERNEL( + sample_logits, CPU, ALL_LAYOUT, ops::SampleLogitsKernel, float, double) {} +PD_REGISTER_STRUCT_KERNEL(sample_logits_grad, + CPU, + ALL_LAYOUT, + ops::SampleLogitsGradKernel, + float, + double) {} diff --git a/paddle/fluid/operators/sample_logits_op.cu b/paddle/fluid/operators/sample_logits_op.cu index a24cb99b6ea..6a853f71e6f 100644 --- a/paddle/fluid/operators/sample_logits_op.cu +++ b/paddle/fluid/operators/sample_logits_op.cu @@ -109,7 +109,7 @@ __global__ void gpu_compute_remove_accidental_hits(const int size, } } -template +template class SampleLogitsCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -235,7 +235,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel { } }; -template +template class SampleLogitsGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -287,9 +287,15 @@ class SampleLogitsGradCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(sample_logits, - ops::SampleLogitsCUDAKernel, - ops::SampleLogitsCUDAKernel); -REGISTER_OP_CUDA_KERNEL(sample_logits_grad, - ops::SampleLogitsGradCUDAKernel, - ops::SampleLogitsGradCUDAKernel); +PD_REGISTER_STRUCT_KERNEL(sample_logits, + GPU, + ALL_LAYOUT, + ops::SampleLogitsCUDAKernel, + float, + double) {} +PD_REGISTER_STRUCT_KERNEL(sample_logits_grad, + GPU, + ALL_LAYOUT, + ops::SampleLogitsGradCUDAKernel, + float, + double) {} diff --git a/paddle/fluid/operators/sample_logits_op.h b/paddle/fluid/operators/sample_logits_op.h index a8413a4988d..bf58a054dad 100644 --- a/paddle/fluid/operators/sample_logits_op.h +++ b/paddle/fluid/operators/sample_logits_op.h @@ -208,7 +208,7 @@ static void compute_remove_accidental_hits(const platform::DeviceContext& ctx, } } -template +template class SampleLogitsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -304,7 +304,7 @@ class SampleLogitsKernel : public framework::OpKernel { } }; -template +template class SampleLogitsGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 7e84077fd60..785d148f79d 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -82,6 +82,9 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(sampling_id, - paddle::operators::SamplingIdKernel, - paddle::operators::SamplingIdKernel); +PD_REGISTER_STRUCT_KERNEL(sampling_id, + CPU, + ALL_LAYOUT, + paddle::operators::SamplingIdKernel, + float, + double) {} diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index 64b50d7fa77..2ec00d125bc 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -15,6 +15,9 @@ #include "paddle/fluid/operators/sampling_id_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(sampling_id, - paddle::operators::SamplingIdKernel, - paddle::operators::SamplingIdKernel); +PD_REGISTER_STRUCT_KERNEL(sampling_id, + GPU, + ALL_LAYOUT, + paddle::operators::SamplingIdKernel, + float, + double) {} diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 6894746ee46..38c0ea3834a 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -27,7 +27,7 @@ namespace paddle { namespace operators { -template +template class SamplingIdKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/sampling_id_op_xpu.cc b/paddle/fluid/operators/sampling_id_op_xpu.cc index 027db5508de..0b720c21381 100644 --- a/paddle/fluid/operators/sampling_id_op_xpu.cc +++ b/paddle/fluid/operators/sampling_id_op_xpu.cc @@ -13,8 +13,11 @@ limitations under the License. */ #include "paddle/fluid/operators/sampling_id_op.h" +#include "paddle/fluid/platform/device_context.h" namespace ops = paddle::operators; +using XPUCtx = paddle::platform::XPUDeviceContext; + REGISTER_OP_XPU_KERNEL(sampling_id, - paddle::operators::SamplingIdKernel, - paddle::operators::SamplingIdKernel); + paddle::operators::SamplingIdKernel, + paddle::operators::SamplingIdKernel); diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc index f6d19749689..08b400af416 100644 --- a/paddle/fluid/operators/seed_op.cc +++ b/paddle/fluid/operators/seed_op.cc @@ -53,7 +53,7 @@ REGISTER_OPERATOR( ops::SeedOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(seed, ops::CPUSeedKernel); +PD_REGISTER_STRUCT_KERNEL(seed, CPU, ALL_LAYOUT, ops::CPUSeedKernel, int) {} /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(seed).AddCheckpoint( diff --git a/paddle/fluid/operators/seed_op.cu b/paddle/fluid/operators/seed_op.cu index 87ba439d792..042d289673c 100644 --- a/paddle/fluid/operators/seed_op.cu +++ b/paddle/fluid/operators/seed_op.cu @@ -18,7 +18,7 @@ namespace paddle { namespace operators { -template +template class GPUSeedKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { @@ -53,5 +53,5 @@ class GPUSeedKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(seed, - paddle::operators::GPUSeedKernel); +PD_REGISTER_STRUCT_KERNEL( + seed, GPU, ALL_LAYOUT, paddle::operators::GPUSeedKernel, int) {} diff --git a/paddle/fluid/operators/seed_op.h b/paddle/fluid/operators/seed_op.h index 0d8f06c01d2..b9cbb81dd2d 100644 --- a/paddle/fluid/operators/seed_op.h +++ b/paddle/fluid/operators/seed_op.h @@ -44,7 +44,7 @@ static int get_seed(const framework::ExecutionContext& context) { return seed; } -template +template class CPUSeedKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index 762ca5e42d7..3ef695b1119 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -146,17 +146,23 @@ REGISTER_OPERATOR(sequence_concat, op::SeqConcatOpMaker, op::SeqConcatGradOpMaker, op::SeqConcatGradOpMaker); -REGISTER_OP_CPU_KERNEL(sequence_concat, - op::SeqConcatKernel, - op::SeqConcatKernel, - op::SeqConcatKernel, - op::SeqConcatKernel); +PD_REGISTER_STRUCT_KERNEL(sequence_concat, + CPU, + ALL_LAYOUT, + op::SeqConcatKernel, + float, + double, + int, + int64_t) {} REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp, op::SeqConcatGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(sequence_concat_grad, - op::SeqConcatGradKernel, - op::SeqConcatGradKernel, - op::SeqConcatGradKernel, - op::SeqConcatGradKernel); +PD_REGISTER_STRUCT_KERNEL(sequence_concat_grad, + CPU, + ALL_LAYOUT, + op::SeqConcatGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc index 2374ec02e8f..b668a9d2558 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc @@ -16,15 +16,19 @@ #include "paddle/fluid/framework/op_registry.h" -REGISTER_OP_CUDA_KERNEL( - sequence_concat, - paddle::operators::SeqConcatKernel, - paddle::operators::SeqConcatKernel, - paddle::operators::SeqConcatKernel, - paddle::operators::SeqConcatKernel); -REGISTER_OP_CUDA_KERNEL( - sequence_concat_grad, - paddle::operators::SeqConcatGradKernel, - paddle::operators::SeqConcatGradKernel, - paddle::operators::SeqConcatGradKernel, - paddle::operators::SeqConcatGradKernel); +PD_REGISTER_STRUCT_KERNEL(sequence_concat, + GPU, + ALL_LAYOUT, + paddle::operators::SeqConcatKernel, + float, + double, + int, + int64_t) {} +PD_REGISTER_STRUCT_KERNEL(sequence_concat_grad, + GPU, + ALL_LAYOUT, + paddle::operators::SeqConcatGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h index 65d4d579b03..463cadc3ce7 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h @@ -62,7 +62,7 @@ inline std::vector> GetDataVectorSafely( } } // namespace detail -template +template class SeqConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { @@ -107,7 +107,7 @@ class SeqConcatKernel : public framework::OpKernel { } }; -template +template class SeqConcatGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { diff --git a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake index 9a87e27b241..b23035082e4 100644 --- a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake @@ -22,7 +22,6 @@ register_unity_group( sequence_softmax_op.cc sequence_topk_avg_pooling_op.cc sequence_unpad_op.cc - sequence_concat_op.cu.cc sequence_conv_op.cu.cc) register_unity_group( cu -- GitLab