diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 1897ded550d00118517e504c5490d0da932e54fa..42c7cc5862a9f2c5a2a37feb1547234073f0b155 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -475,6 +475,9 @@ function(op_library TARGET) foreach(hip_src ${hip_srcs}) set(op_name "") find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name) + find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_KERNEL") + find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL") + find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") set(pybind_flag 1) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 11e67dcc9399544ff03ffdcb2140f450c6aec1d2..8d38da543ad03ce3fd305ea78902327fdb1c95a0 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -947,23 +947,28 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) { platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - if (op_with_kernel == nullptr) { + if (op_with_kernel == nullptr) { // operator base instr_node.OpBase()->Run(*local_scope, place_); } else { - // fit for phi - if (instr_node.PhiKernel() && instr_node.PhiKernel()->IsValid()) { - VLOG(4) << "Run phi kernel: " << op->Type(); - VLOG(4) << instr_node.InnerRuntimeContext().get() << " " - << &instr_node.DeviceContext(); - phi::KernelContext phi_kernel_context; - op_with_kernel->BuildPhiKernelContext( - *instr_node.InnerRuntimeContext().get(), - const_cast(&instr_node.DeviceContext()), - &phi_kernel_context); - - (*instr_node.PhiKernel())(&phi_kernel_context); - - } else { + phi::Kernel* kernel = instr_node.PhiKernel(); + if (kernel && kernel->IsValid()) { // phi kernel + if (kernel->GetKernelRegisteredType() == + phi::KernelRegisteredType::FUNCTION) { + VLOG(4) << "Run function kernel: " << op->Type(); + VLOG(4) << instr_node.InnerRuntimeContext().get() << " " + << &instr_node.DeviceContext(); + phi::KernelContext phi_kernel_context; + op_with_kernel->BuildPhiKernelContext( + *instr_node.InnerRuntimeContext().get(), + const_cast(&instr_node.DeviceContext()), + &phi_kernel_context); + + (*kernel)(&phi_kernel_context); + } else { + VLOG(4) << "Run structure kernel: " << op->Type(); + (*kernel)(instr_node.InnerExecutionContext().get()); + } + } else { // fluid kernel instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); } } diff --git a/paddle/fluid/operators/assign_pos_op.cc b/paddle/fluid/operators/assign_pos_op.cc index 24fc4adc60d94f6c0ce11e080dfb2501c5e4daaf..c14a87a27a2ce70eb5cd7cce97dce7b92f836796 100644 --- a/paddle/fluid/operators/assign_pos_op.cc +++ b/paddle/fluid/operators/assign_pos_op.cc @@ -79,6 +79,5 @@ REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp, ops::AssignPosOpMaker); -REGISTER_OP_CPU_KERNEL(assign_pos, - ops::AssignPosOpCPUKernel, - ops::AssignPosOpCPUKernel); +PD_REGISTER_STRUCT_KERNEL( + assign_pos, CPU, ALL_LAYOUT, ops::AssignPosOpCPUKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/assign_pos_op.cu b/paddle/fluid/operators/assign_pos_op.cu index e5f783ec2d6ac9e3211cd69966fb47695581fcce..3b0492406d82872852943159dad22f237bf7eb55 100644 --- a/paddle/fluid/operators/assign_pos_op.cu +++ b/paddle/fluid/operators/assign_pos_op.cu @@ -53,7 +53,7 @@ __global__ void AssignPos(T* cum_count, } } -template +template class AssignPosCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -102,4 +102,6 @@ class AssignPosCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel); + +PD_REGISTER_STRUCT_KERNEL( + assign_pos, GPU, ALL_LAYOUT, ops::AssignPosCUDAKernel, int64_t) {} diff --git a/paddle/fluid/operators/assign_pos_op.h b/paddle/fluid/operators/assign_pos_op.h index 6c75fb55f58468f0bf9f9ddbadbab45ee7cd058f..038d40896a07a4d90f89d34aff8d0ba02fdc78fa 100644 --- a/paddle/fluid/operators/assign_pos_op.h +++ b/paddle/fluid/operators/assign_pos_op.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class AssignPosOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 5ee8b2c7efbb26db49a24c1d150af8ec42a78a7a..6b9e3b7c296f0f5730986a4b39e7c22fab046057 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -340,12 +340,10 @@ inline void vec_softmax(const int n, const T* x, T* y) { phi::funcs::vec_scal(n, static_cast(1) / scalar, y); // scale } -template +template class AttentionLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using DeviceContext = phi::CPUContext; - auto* x = ctx.Input("X"); auto* h0 = ctx.Input("H0"); auto* c0 = ctx.Input("C0"); @@ -525,6 +523,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, ops::AttentionLSTMOpMaker); -REGISTER_OP_CPU_KERNEL(attention_lstm, - ops::AttentionLSTMKernel, - ops::AttentionLSTMKernel); +PD_REGISTER_STRUCT_KERNEL( + attention_lstm, CPU, ALL_LAYOUT, ops::AttentionLSTMKernel, float, double) {} diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc index 9010cadd1533238d0e401be425330c3170595a4e..706cb17e40f3411b4acd55b7bc5b2586d936d5c9 100644 --- a/paddle/fluid/operators/batch_fc_op.cc +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -165,6 +165,5 @@ REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp, ops::BatchFCGradOpNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(batch_fc, - ops::BatchFCKernel, - ops::BatchFCKernel); +PD_REGISTER_STRUCT_KERNEL( + batch_fc, CPU, ALL_LAYOUT, ops::BatchFCKernel, float, double) {} diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index 178e57d7a261a166c70495a717ca602474083f70..00a09563c00ad4d03203348fcdc2d2b6b73d4d17 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -85,7 +85,7 @@ void add_bias_grad(gpuStream_t stream, dout_data, slot_pairs_num, ins_num, out_dim, db_data); } -template +template class BatchFCCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -149,7 +149,7 @@ class BatchFCCUDAKernel : public framework::OpKernel { } }; -template +template class BatchFCGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -238,10 +238,13 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; using GPUCtx = phi::GPUContext; -REGISTER_OP_CUDA_KERNEL(batch_fc, - ops::BatchFCCUDAKernel, - ops::BatchFCCUDAKernel); -REGISTER_OP_CUDA_KERNEL(batch_fc_grad, - ops::BatchFCGradOpCUDAKernel, - ops::BatchFCGradOpCUDAKernel); +PD_REGISTER_STRUCT_KERNEL( + batch_fc, GPU, ALL_LAYOUT, ops::BatchFCCUDAKernel, float, double) {} + +PD_REGISTER_STRUCT_KERNEL(batch_fc_grad, + GPU, + ALL_LAYOUT, + ops::BatchFCGradOpCUDAKernel, + float, + double) {} diff --git a/paddle/fluid/operators/batch_fc_op.h b/paddle/fluid/operators/batch_fc_op.h index d2a07e1015a7e2be8d5b29d2dc17da1f6c6953c9..ca8c22243dbe4ca308695b9988e2a1635993dea5 100644 --- a/paddle/fluid/operators/batch_fc_op.h +++ b/paddle/fluid/operators/batch_fc_op.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class BatchFCKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 0add535509026dba4ca52684d3e63b0229d2eb29..ec3ced614bd92abd7324b3c3812b8583679a1673 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -108,10 +108,12 @@ REGISTER_OPERATOR(beam_search_decode, paddle::operators::BeamSearchDecodeOpProtoMaker, paddle::operators::BeamSearchDecodeInferVarType); -REGISTER_OP_CPU_KERNEL( - beam_search_decode, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel); +PD_REGISTER_STRUCT_KERNEL(beam_search_decode, + CPU, + ALL_LAYOUT, + ops::BeamSearchDecodeOpKernel, + float, + double, + paddle::platform::float16, + int, + int64_t) {} diff --git a/paddle/fluid/operators/beam_search_decode_op.cu.cc b/paddle/fluid/operators/beam_search_decode_op.cu.cc index fef36ea6d9a36287febdf5bd0477cb0b725b620a..bab5423c99b05890ac77726b229013f9a1378f44 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cu.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cu.cc @@ -16,10 +16,13 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - beam_search_decode, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel, - ops::BeamSearchDecodeOpKernel); + +PD_REGISTER_STRUCT_KERNEL(beam_search_decode, + GPU, + ALL_LAYOUT, + ops::BeamSearchDecodeOpKernel, + float, + double, + paddle::platform::float16, + int, + int64_t) {} diff --git a/paddle/fluid/operators/beam_search_decode_op.h b/paddle/fluid/operators/beam_search_decode_op.h index e635405f3884eaea398b1621c6f284e86a5cc4c6..c4f7b3b5785f48e31967fa50503e49fd36f857dc 100644 --- a/paddle/fluid/operators/beam_search_decode_op.h +++ b/paddle/fluid/operators/beam_search_decode_op.h @@ -123,7 +123,7 @@ struct BeamSearchDecodeFunctor { phi::DenseTensor* score_tensor_; }; -template +template class BeamSearchDecodeOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 1e569c4bb27324a91bdcebdfca3f5af55a284802..1edac2ebf810ad603ce3079c7109470a9eea8d25 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -148,8 +148,12 @@ REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker, ops::BeamSearchInferVarType); -REGISTER_OP_CPU_KERNEL(beam_search, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel); + +PD_REGISTER_STRUCT_KERNEL(beam_search, + CPU, + ALL_LAYOUT, + ops::BeamSearchOpKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/fluid/operators/beam_search_op.cu.cc b/paddle/fluid/operators/beam_search_op.cu.cc index 93f538e67890674c6bb502ab3c3aa4d8b4fa9555..53d3743d0bcfb934f5667e5254dda577ad216a95 100644 --- a/paddle/fluid/operators/beam_search_op.cu.cc +++ b/paddle/fluid/operators/beam_search_op.cu.cc @@ -17,8 +17,12 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(beam_search, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel); + +PD_REGISTER_STRUCT_KERNEL(beam_search, + GPU, + ALL_LAYOUT, + ops::BeamSearchOpKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 1f72452a13b6ad7c32b0575275bf048b0ffab78a..fea706bb54a9351762636c9c32ee5a672a71d361 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class BeamSearchOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/beam_search_op_npu.cc b/paddle/fluid/operators/beam_search_op_npu.cc index f5fa0ac026d57645e444e0e40b08e91313db72f9..147d1be2262556359d8d3e3581bd1bbabb1c156a 100644 --- a/paddle/fluid/operators/beam_search_op_npu.cc +++ b/paddle/fluid/operators/beam_search_op_npu.cc @@ -16,9 +16,10 @@ limitations under the License. */ #include "paddle/fluid/operators/beam_search_op.h" namespace ops = paddle::operators; -REGISTER_OP_NPU_KERNEL( - beam_search, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel); +using NPUCtx = paddle::platform::NPUDeviceContext; + +REGISTER_OP_NPU_KERNEL(beam_search, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); diff --git a/paddle/fluid/operators/beam_search_op_xpu.cc b/paddle/fluid/operators/beam_search_op_xpu.cc index ab52f09c2b668969d59b24b11d65c6cc22b78e69..9f1d1488d9a64946c8af2e47f731a67193286c81 100644 --- a/paddle/fluid/operators/beam_search_op_xpu.cc +++ b/paddle/fluid/operators/beam_search_op_xpu.cc @@ -18,10 +18,11 @@ limitations under the License. */ #include "paddle/fluid/operators/beam_search_op.h" namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - beam_search, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel); +using XPUCtx = paddle::platform::XPUDeviceContext; + +REGISTER_OP_XPU_KERNEL(beam_search, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); #endif diff --git a/paddle/fluid/operators/bilateral_slice_op.cc b/paddle/fluid/operators/bilateral_slice_op.cc index c824fd9e6316046fd4b0009fb6a27087c930a44c..53386c1551d0f595fd847bd7eec0cbb9583a3031 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cc +++ b/paddle/fluid/operators/bilateral_slice_op.cc @@ -175,7 +175,7 @@ class BilateralSliceGradMaker : public framework::SingleGradOpMaker { } }; -template +template class BilateralSliceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -196,6 +196,10 @@ REGISTER_OPERATOR(bilateral_slice, ops::BilateralSliceGradMaker, ops::BilateralSliceGradMaker); REGISTER_OPERATOR(bilateral_slice_grad, ops::BilateralSliceOpGrad); -REGISTER_OP_CPU_KERNEL(bilateral_slice, - ops::BilateralSliceKernel, - ops::BilateralSliceKernel); + +PD_REGISTER_STRUCT_KERNEL(bilateral_slice, + CPU, + ALL_LAYOUT, + ops::BilateralSliceKernel, + float, + double) {} diff --git a/paddle/fluid/operators/bilateral_slice_op.cu b/paddle/fluid/operators/bilateral_slice_op.cu index c995c3ed091dd31862e05ebd8a4f3d7b66c67dc4..08f7e454bd47ba95cd7e7598805cd365ccb6dce9 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cu +++ b/paddle/fluid/operators/bilateral_slice_op.cu @@ -126,7 +126,7 @@ __global__ void BilateralSliceCudaForwardKernel(T* output, } } -template +template class BilateralSliceOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -442,7 +442,7 @@ __global__ void BilateralSliceCudaInputGradKernel(T* out_input_grad, } } -template +template class BilateralSliceGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -557,9 +557,16 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(bilateral_slice, - ops::BilateralSliceOpCUDAKernel, - ops::BilateralSliceOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(bilateral_slice_grad, - ops::BilateralSliceGradOpCUDAKernel, - ops::BilateralSliceGradOpCUDAKernel); + +PD_REGISTER_STRUCT_KERNEL(bilateral_slice, + GPU, + ALL_LAYOUT, + ops::BilateralSliceOpCUDAKernel, + float, + double) {} +PD_REGISTER_STRUCT_KERNEL(bilateral_slice_grad, + GPU, + ALL_LAYOUT, + ops::BilateralSliceGradOpCUDAKernel, + float, + double) {} diff --git a/paddle/fluid/operators/collective/barrier_op.cc b/paddle/fluid/operators/collective/barrier_op.cc index 3f154a42e2be8f825e6b4a386c0c262f31b0edda..c90669804732a66b47444b7c6e377d311c837e92 100644 --- a/paddle/fluid/operators/collective/barrier_op.cc +++ b/paddle/fluid/operators/collective/barrier_op.cc @@ -44,4 +44,6 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(barrier, ops::BarrierOp, ops::BarrierOpMaker); -REGISTER_OP_CPU_KERNEL(barrier, ops::BarrierOpCPUKernel); + +PD_REGISTER_STRUCT_KERNEL( + barrier, CPU, ALL_LAYOUT, ops::BarrierOpCPUKernel, int) {} diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index 648b8fdc83b878be13a2b0b885e721bafe5ea2f6..79d036cdae89b72029e1c9d30184b18a35009589 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class BarrierOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -58,4 +58,5 @@ class BarrierOpCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(barrier, ops::BarrierOpCUDAKernel); +PD_REGISTER_STRUCT_KERNEL( + barrier, GPU, ALL_LAYOUT, ops::BarrierOpCUDAKernel, int) {} diff --git a/paddle/fluid/operators/collective/barrier_op.h b/paddle/fluid/operators/collective/barrier_op.h index 36b7973e2c12818ff0a23bf55c094981efcd228b..099d6cccb9a0393c47405984d38269a29336c872 100644 --- a/paddle/fluid/operators/collective/barrier_op.h +++ b/paddle/fluid/operators/collective/barrier_op.h @@ -32,7 +32,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class BarrierOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override {