未验证 提交 28492bf6 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Unify fluid kernel (Part5) (#53003)

* unify_kernel

* fix compile bugs

* fix py3 bugs

* fix xpu bugs
上级 4812d8e4
...@@ -86,9 +86,12 @@ namespace plat = paddle::platform; ...@@ -86,9 +86,12 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(send_v2, ops::SendOpV2, ops::SendOpV2Maker); REGISTER_OP_WITHOUT_GRADIENT(send_v2, ops::SendOpV2, ops::SendOpV2Maker);
REGISTER_OP_CPU_KERNEL(send_v2, PD_REGISTER_STRUCT_KERNEL(send_v2,
ops::SendOpV2CPUKernel<float>, CPU,
ops::SendOpV2CPUKernel<double>, ALL_LAYOUT,
ops::SendOpV2CPUKernel<int>, ops::SendOpV2CPUKernel,
ops::SendOpV2CPUKernel<int64_t>, float,
ops::SendOpV2CPUKernel<plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -104,7 +104,7 @@ void send_shape_info(const phi::DenseTensor& x, ...@@ -104,7 +104,7 @@ void send_shape_info(const phi::DenseTensor& x,
} }
#endif #endif
template <typename T> template <typename T, typename DeviceContext>
class SendOpV2CUDAKernel : public framework::OpKernel<T> { class SendOpV2CUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -217,13 +217,17 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -217,13 +217,17 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2, PD_REGISTER_STRUCT_KERNEL(send_v2,
ops::SendOpV2CUDAKernel<float>, GPU,
ops::SendOpV2CUDAKernel<double>, ALL_LAYOUT,
ops::SendOpV2CUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>, plat::bfloat16,
#endif #endif
ops::SendOpV2CUDAKernel<int>, int,
ops::SendOpV2CUDAKernel<int64_t>, int64_t,
ops::SendOpV2CUDAKernel<int8_t>, int8_t,
ops::SendOpV2CUDAKernel<plat::float16>); plat::float16) {
}
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class SendOpV2CPUKernel : public framework::OpKernel<T> { class SendOpV2CPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class SendAndRecvKernel : public framework::OpKernel<T> { class SendAndRecvKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -98,17 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,17 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker); REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker);
REGISTER_OP_CUDA_KERNEL(send_and_recv,
ops::SendAndRecvKernel<phi::GPUContext, float>,
ops::SendAndRecvKernel<phi::GPUContext, double>,
ops::SendAndRecvKernel<phi::GPUContext, int>,
ops::SendAndRecvKernel<phi::GPUContext, int64_t>);
REGISTER_OP_CPU_KERNEL(send_and_recv,
ops::SendAndRecvKernel<phi::CPUContext, float>,
ops::SendAndRecvKernel<phi::CPUContext, double>,
ops::SendAndRecvKernel<phi::CPUContext, int>,
ops::SendAndRecvKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(send_and_recv,
CPU,
ALL_LAYOUT,
ops::SendAndRecvKernel,
float,
double,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_STRUCT_KERNEL(send_and_recv,
GPU,
ALL_LAYOUT,
ops::SendAndRecvKernel,
float,
double,
int,
int64_t) {}
#endif
REGISTER_OP_VERSION(send_and_recv) REGISTER_OP_VERSION(send_and_recv)
.AddCheckpoint( .AddCheckpoint(
R"ROC(add new attributes [next_endpoints] [previous_endpoints] and [mode])ROC", R"ROC(add new attributes [next_endpoints] [previous_endpoints] and [mode])ROC",
......
...@@ -35,7 +35,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; ...@@ -35,7 +35,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP(send_and_recv); PD_DECLARE_KERNEL(send_and_recv, CPU, ALL_LAYOUT);
std::string get_ip_port() { std::string get_ip_port() {
std::mt19937 rng; std::mt19937 rng;
......
...@@ -39,7 +39,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; ...@@ -39,7 +39,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP(send_and_recv); PD_DECLARE_KERNEL(send_and_recv, GPU, ALL_LAYOUT);
std::shared_ptr<distributed::HeterServer> b_rpc_service2; std::shared_ptr<distributed::HeterServer> b_rpc_service2;
......
...@@ -272,9 +272,11 @@ REGISTER_OPERATOR(sample_logits, ...@@ -272,9 +272,11 @@ REGISTER_OPERATOR(sample_logits,
ops::SampleLogitsGradMaker<paddle::framework::OpDesc>, ops::SampleLogitsGradMaker<paddle::framework::OpDesc>,
ops::SampleLogitsGradMaker<paddle::imperative::OpBase>); ops::SampleLogitsGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad); REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad);
REGISTER_OP_CPU_KERNEL(sample_logits, PD_REGISTER_STRUCT_KERNEL(
ops::SampleLogitsKernel<float>, sample_logits, CPU, ALL_LAYOUT, ops::SampleLogitsKernel, float, double) {}
ops::SampleLogitsKernel<double>); PD_REGISTER_STRUCT_KERNEL(sample_logits_grad,
REGISTER_OP_CPU_KERNEL(sample_logits_grad, CPU,
ops::SampleLogitsGradKernel<float>, ALL_LAYOUT,
ops::SampleLogitsGradKernel<double>); ops::SampleLogitsGradKernel,
float,
double) {}
...@@ -109,7 +109,7 @@ __global__ void gpu_compute_remove_accidental_hits(const int size, ...@@ -109,7 +109,7 @@ __global__ void gpu_compute_remove_accidental_hits(const int size,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class SampleLogitsCUDAKernel : public framework::OpKernel<T> { class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -235,7 +235,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> { ...@@ -235,7 +235,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> { class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -287,9 +287,15 @@ class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> { ...@@ -287,9 +287,15 @@ class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sample_logits, PD_REGISTER_STRUCT_KERNEL(sample_logits,
ops::SampleLogitsCUDAKernel<float>, GPU,
ops::SampleLogitsCUDAKernel<double>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL(sample_logits_grad, ops::SampleLogitsCUDAKernel,
ops::SampleLogitsGradCUDAKernel<float>, float,
ops::SampleLogitsGradCUDAKernel<double>); double) {}
PD_REGISTER_STRUCT_KERNEL(sample_logits_grad,
GPU,
ALL_LAYOUT,
ops::SampleLogitsGradCUDAKernel,
float,
double) {}
...@@ -208,7 +208,7 @@ static void compute_remove_accidental_hits(const platform::DeviceContext& ctx, ...@@ -208,7 +208,7 @@ static void compute_remove_accidental_hits(const platform::DeviceContext& ctx,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class SampleLogitsKernel : public framework::OpKernel<T> { class SampleLogitsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -304,7 +304,7 @@ class SampleLogitsKernel : public framework::OpKernel<T> { ...@@ -304,7 +304,7 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class SampleLogitsGradKernel : public framework::OpKernel<T> { class SampleLogitsGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -82,6 +82,9 @@ REGISTER_OPERATOR( ...@@ -82,6 +82,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(sampling_id, PD_REGISTER_STRUCT_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>, CPU,
paddle::operators::SamplingIdKernel<double>); ALL_LAYOUT,
paddle::operators::SamplingIdKernel,
float,
double) {}
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#include "paddle/fluid/operators/sampling_id_op.h" #include "paddle/fluid/operators/sampling_id_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sampling_id, PD_REGISTER_STRUCT_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>, GPU,
paddle::operators::SamplingIdKernel<double>); ALL_LAYOUT,
paddle::operators::SamplingIdKernel,
float,
double) {}
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class SamplingIdKernel : public framework::OpKernel<T> { class SamplingIdKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sampling_id_op.h" #include "paddle/fluid/operators/sampling_id_op.h"
#include "paddle/fluid/platform/device_context.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(sampling_id, REGISTER_OP_XPU_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>, paddle::operators::SamplingIdKernel<float, XPUCtx>,
paddle::operators::SamplingIdKernel<double>); paddle::operators::SamplingIdKernel<double, XPUCtx>);
...@@ -53,7 +53,7 @@ REGISTER_OPERATOR( ...@@ -53,7 +53,7 @@ REGISTER_OPERATOR(
ops::SeedOpMaker, ops::SeedOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(seed, ops::CPUSeedKernel<phi::CPUContext, int>); PD_REGISTER_STRUCT_KERNEL(seed, CPU, ALL_LAYOUT, ops::CPUSeedKernel, int) {}
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(seed).AddCheckpoint( REGISTER_OP_VERSION(seed).AddCheckpoint(
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename T, typename DeviceContext>
class GPUSeedKernel : public framework::OpKernel<T> { class GPUSeedKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -53,5 +53,5 @@ class GPUSeedKernel : public framework::OpKernel<T> { ...@@ -53,5 +53,5 @@ class GPUSeedKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(seed, PD_REGISTER_STRUCT_KERNEL(
paddle::operators::GPUSeedKernel<phi::GPUContext, int>); seed, GPU, ALL_LAYOUT, paddle::operators::GPUSeedKernel, int) {}
...@@ -44,7 +44,7 @@ static int get_seed(const framework::ExecutionContext& context) { ...@@ -44,7 +44,7 @@ static int get_seed(const framework::ExecutionContext& context) {
return seed; return seed;
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CPUSeedKernel : public framework::OpKernel<T> { class CPUSeedKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -146,17 +146,23 @@ REGISTER_OPERATOR(sequence_concat, ...@@ -146,17 +146,23 @@ REGISTER_OPERATOR(sequence_concat,
op::SeqConcatOpMaker, op::SeqConcatOpMaker,
op::SeqConcatGradOpMaker<paddle::framework::OpDesc>, op::SeqConcatGradOpMaker<paddle::framework::OpDesc>,
op::SeqConcatGradOpMaker<paddle::imperative::OpBase>); op::SeqConcatGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(sequence_concat, PD_REGISTER_STRUCT_KERNEL(sequence_concat,
op::SeqConcatKernel<phi::CPUContext, float>, CPU,
op::SeqConcatKernel<phi::CPUContext, double>, ALL_LAYOUT,
op::SeqConcatKernel<phi::CPUContext, int>, op::SeqConcatKernel,
op::SeqConcatKernel<phi::CPUContext, int64_t>); float,
double,
int,
int64_t) {}
REGISTER_OPERATOR(sequence_concat_grad, REGISTER_OPERATOR(sequence_concat_grad,
op::SeqConcatGradOp, op::SeqConcatGradOp,
op::SeqConcatGradNoNeedBufferVarsInferer); op::SeqConcatGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, PD_REGISTER_STRUCT_KERNEL(sequence_concat_grad,
op::SeqConcatGradKernel<phi::CPUContext, float>, CPU,
op::SeqConcatGradKernel<phi::CPUContext, double>, ALL_LAYOUT,
op::SeqConcatGradKernel<phi::CPUContext, int>, op::SeqConcatGradKernel,
op::SeqConcatGradKernel<phi::CPUContext, int64_t>); float,
double,
int,
int64_t) {}
...@@ -16,15 +16,19 @@ ...@@ -16,15 +16,19 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(sequence_concat,
sequence_concat, GPU,
paddle::operators::SeqConcatKernel<phi::GPUContext, float>, ALL_LAYOUT,
paddle::operators::SeqConcatKernel<phi::GPUContext, double>, paddle::operators::SeqConcatKernel,
paddle::operators::SeqConcatKernel<phi::GPUContext, int>, float,
paddle::operators::SeqConcatKernel<phi::GPUContext, int64_t>); double,
REGISTER_OP_CUDA_KERNEL( int,
sequence_concat_grad, int64_t) {}
paddle::operators::SeqConcatGradKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(sequence_concat_grad,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, double>, GPU,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, int>, ALL_LAYOUT,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, int64_t>); paddle::operators::SeqConcatGradKernel,
float,
double,
int,
int64_t) {}
...@@ -62,7 +62,7 @@ inline std::vector<std::reference_wrapper<T>> GetDataVectorSafely( ...@@ -62,7 +62,7 @@ inline std::vector<std::reference_wrapper<T>> GetDataVectorSafely(
} }
} // namespace detail } // namespace detail
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class SeqConcatKernel : public framework::OpKernel<T> { class SeqConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -107,7 +107,7 @@ class SeqConcatKernel : public framework::OpKernel<T> { ...@@ -107,7 +107,7 @@ class SeqConcatKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class SeqConcatGradKernel : public framework::OpKernel<T> { class SeqConcatGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
...@@ -22,7 +22,6 @@ register_unity_group( ...@@ -22,7 +22,6 @@ register_unity_group(
sequence_softmax_op.cc sequence_softmax_op.cc
sequence_topk_avg_pooling_op.cc sequence_topk_avg_pooling_op.cc
sequence_unpad_op.cc sequence_unpad_op.cc
sequence_concat_op.cu.cc
sequence_conv_op.cu.cc) sequence_conv_op.cu.cc)
register_unity_group( register_unity_group(
cu cu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册