未验证 提交 93d01787 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 1] (#52014)

* update assign_pos

* update attention_lstm

* update barrier

* update batch_fc

* update beam_search

* update beam_search_decode

* update bilateral_slice

* fix bug

* Handle Structure kernel for InterpreterCore::RunOperator

* fix bug

* fix rocm compile

* fix rocm compile

* Revert "fix rocm compile"

* test

* revert test and update cmake

---------
Co-authored-by: Nchenruibiao <chenruibiao@baidu.com>
上级 70ebef81
......@@ -475,6 +475,9 @@ function(op_library TARGET)
foreach(hip_src ${hip_srcs})
set(op_name "")
find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name)
find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
set(pybind_flag 1)
......
......@@ -947,23 +947,28 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) {
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp);
if (op_with_kernel == nullptr) {
if (op_with_kernel == nullptr) { // operator base
instr_node.OpBase()->Run(*local_scope, place_);
} else {
// fit for phi
if (instr_node.PhiKernel() && instr_node.PhiKernel()->IsValid()) {
VLOG(4) << "Run phi kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&phi_kernel_context);
(*instr_node.PhiKernel())(&phi_kernel_context);
} else {
phi::Kernel* kernel = instr_node.PhiKernel();
if (kernel && kernel->IsValid()) { // phi kernel
if (kernel->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
VLOG(4) << "Run function kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&phi_kernel_context);
(*kernel)(&phi_kernel_context);
} else {
VLOG(4) << "Run structure kernel: " << op->Type();
(*kernel)(instr_node.InnerExecutionContext().get());
}
} else { // fluid kernel
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
}
......
......@@ -79,6 +79,5 @@ REGISTER_OP_WITHOUT_GRADIENT(assign_pos,
ops::AssignPosOp,
ops::AssignPosOpMaker);
REGISTER_OP_CPU_KERNEL(assign_pos,
ops::AssignPosOpCPUKernel<int>,
ops::AssignPosOpCPUKernel<int64_t>);
PD_REGISTER_STRUCT_KERNEL(
assign_pos, CPU, ALL_LAYOUT, ops::AssignPosOpCPUKernel, int, int64_t) {}
......@@ -53,7 +53,7 @@ __global__ void AssignPos(T* cum_count,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class AssignPosCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -102,4 +102,6 @@ class AssignPosCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel<int64_t>);
PD_REGISTER_STRUCT_KERNEL(
assign_pos, GPU, ALL_LAYOUT, ops::AssignPosCUDAKernel, int64_t) {}
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -340,12 +340,10 @@ inline void vec_softmax(const int n, const T* x, T* y) {
phi::funcs::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale
}
template <typename T>
template <typename T, typename DeviceContext>
class AttentionLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = phi::CPUContext;
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* h0 = ctx.Input<phi::DenseTensor>("H0");
auto* c0 = ctx.Input<phi::DenseTensor>("C0");
......@@ -525,6 +523,5 @@ REGISTER_OPERATOR(attention_lstm,
ops::AttentionLSTMOp,
ops::AttentionLSTMOpMaker);
REGISTER_OP_CPU_KERNEL(attention_lstm,
ops::AttentionLSTMKernel<float>,
ops::AttentionLSTMKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
attention_lstm, CPU, ALL_LAYOUT, ops::AttentionLSTMKernel, float, double) {}
......@@ -165,6 +165,5 @@ REGISTER_OPERATOR(batch_fc_grad,
ops::BatchFCGradOp,
ops::BatchFCGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(batch_fc,
ops::BatchFCKernel<phi::CPUContext, float>,
ops::BatchFCKernel<phi::CPUContext, double>);
PD_REGISTER_STRUCT_KERNEL(
batch_fc, CPU, ALL_LAYOUT, ops::BatchFCKernel, float, double) {}
......@@ -85,7 +85,7 @@ void add_bias_grad(gpuStream_t stream,
dout_data, slot_pairs_num, ins_num, out_dim, db_data);
}
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class BatchFCCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -149,7 +149,7 @@ class BatchFCCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -238,10 +238,13 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using GPUCtx = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(batch_fc,
ops::BatchFCCUDAKernel<GPUCtx, float>,
ops::BatchFCCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(batch_fc_grad,
ops::BatchFCGradOpCUDAKernel<GPUCtx, float>,
ops::BatchFCGradOpCUDAKernel<GPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(
batch_fc, GPU, ALL_LAYOUT, ops::BatchFCCUDAKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(batch_fc_grad,
GPU,
ALL_LAYOUT,
ops::BatchFCGradOpCUDAKernel,
float,
double) {}
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class BatchFCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -108,10 +108,12 @@ REGISTER_OPERATOR(beam_search_decode,
paddle::operators::BeamSearchDecodeOpProtoMaker,
paddle::operators::BeamSearchDecodeInferVarType);
REGISTER_OP_CPU_KERNEL(
beam_search_decode,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, float>,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, double>,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, paddle::platform::float16>,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, int>,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(beam_search_decode,
CPU,
ALL_LAYOUT,
ops::BeamSearchDecodeOpKernel,
float,
double,
paddle::platform::float16,
int,
int64_t) {}
......@@ -16,10 +16,13 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
beam_search_decode,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, float>,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, double>,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, paddle::platform::float16>,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, int>,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(beam_search_decode,
GPU,
ALL_LAYOUT,
ops::BeamSearchDecodeOpKernel,
float,
double,
paddle::platform::float16,
int,
int64_t) {}
......@@ -123,7 +123,7 @@ struct BeamSearchDecodeFunctor {
phi::DenseTensor* score_tensor_;
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class BeamSearchDecodeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -148,8 +148,12 @@ REGISTER_OPERATOR(beam_search,
ops::BeamSearchOp,
ops::BeamSearchOpMaker,
ops::BeamSearchInferVarType);
REGISTER_OP_CPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<phi::CPUContext, float>,
ops::BeamSearchOpKernel<phi::CPUContext, double>,
ops::BeamSearchOpKernel<phi::CPUContext, int>,
ops::BeamSearchOpKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(beam_search,
CPU,
ALL_LAYOUT,
ops::BeamSearchOpKernel,
float,
double,
int,
int64_t) {}
......@@ -17,8 +17,12 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(beam_search,
ops::BeamSearchOpKernel<phi::GPUContext, float>,
ops::BeamSearchOpKernel<phi::GPUContext, double>,
ops::BeamSearchOpKernel<phi::GPUContext, int>,
ops::BeamSearchOpKernel<phi::GPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(beam_search,
GPU,
ALL_LAYOUT,
ops::BeamSearchOpKernel,
float,
double,
int,
int64_t) {}
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class BeamSearchOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -16,9 +16,10 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, double>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, int>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
using NPUCtx = paddle::platform::NPUDeviceContext;
REGISTER_OP_NPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<float, NPUCtx>,
ops::BeamSearchOpKernel<double, NPUCtx>,
ops::BeamSearchOpKernel<int, NPUCtx>,
ops::BeamSearchOpKernel<int64_t, NPUCtx>);
......@@ -18,10 +18,11 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, double>,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, int>,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, int64_t>);
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<float, XPUCtx>,
ops::BeamSearchOpKernel<double, XPUCtx>,
ops::BeamSearchOpKernel<int, XPUCtx>,
ops::BeamSearchOpKernel<int64_t, XPUCtx>);
#endif
......@@ -175,7 +175,7 @@ class BilateralSliceGradMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class BilateralSliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -196,6 +196,10 @@ REGISTER_OPERATOR(bilateral_slice,
ops::BilateralSliceGradMaker<paddle::framework::OpDesc>,
ops::BilateralSliceGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilateral_slice_grad, ops::BilateralSliceOpGrad);
REGISTER_OP_CPU_KERNEL(bilateral_slice,
ops::BilateralSliceKernel<float>,
ops::BilateralSliceKernel<double>);
PD_REGISTER_STRUCT_KERNEL(bilateral_slice,
CPU,
ALL_LAYOUT,
ops::BilateralSliceKernel,
float,
double) {}
......@@ -126,7 +126,7 @@ __global__ void BilateralSliceCudaForwardKernel(T* output,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class BilateralSliceOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -442,7 +442,7 @@ __global__ void BilateralSliceCudaInputGradKernel(T* out_input_grad,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -557,9 +557,16 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(bilateral_slice,
ops::BilateralSliceOpCUDAKernel<float>,
ops::BilateralSliceOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(bilateral_slice_grad,
ops::BilateralSliceGradOpCUDAKernel<float>,
ops::BilateralSliceGradOpCUDAKernel<double>);
PD_REGISTER_STRUCT_KERNEL(bilateral_slice,
GPU,
ALL_LAYOUT,
ops::BilateralSliceOpCUDAKernel,
float,
double) {}
PD_REGISTER_STRUCT_KERNEL(bilateral_slice_grad,
GPU,
ALL_LAYOUT,
ops::BilateralSliceGradOpCUDAKernel,
float,
double) {}
......@@ -44,4 +44,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(barrier, ops::BarrierOp, ops::BarrierOpMaker);
REGISTER_OP_CPU_KERNEL(barrier, ops::BarrierOpCPUKernel<int>);
PD_REGISTER_STRUCT_KERNEL(
barrier, CPU, ALL_LAYOUT, ops::BarrierOpCPUKernel, int) {}
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class BarrierOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -58,4 +58,5 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(barrier, ops::BarrierOpCUDAKernel<int>);
PD_REGISTER_STRUCT_KERNEL(
barrier, GPU, ALL_LAYOUT, ops::BarrierOpCUDAKernel, int) {}
......@@ -32,7 +32,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class BarrierOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册