未验证 提交 28492bf6 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Unify fluid kernel (Part5) (#53003)

* unify_kernel

* fix compile bugs

* fix py3 bugs

* fix xpu bugs
上级 4812d8e4
......@@ -86,9 +86,12 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(send_v2, ops::SendOpV2, ops::SendOpV2Maker);
REGISTER_OP_CPU_KERNEL(send_v2,
ops::SendOpV2CPUKernel<float>,
ops::SendOpV2CPUKernel<double>,
ops::SendOpV2CPUKernel<int>,
ops::SendOpV2CPUKernel<int64_t>,
ops::SendOpV2CPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(send_v2,
CPU,
ALL_LAYOUT,
ops::SendOpV2CPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -104,7 +104,7 @@ void send_shape_info(const phi::DenseTensor& x,
}
#endif
template <typename T>
template <typename T, typename DeviceContext>
class SendOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -217,13 +217,17 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2,
ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>,
PD_REGISTER_STRUCT_KERNEL(send_v2,
GPU,
ALL_LAYOUT,
ops::SendOpV2CUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::SendOpV2CUDAKernel<int>,
ops::SendOpV2CUDAKernel<int64_t>,
ops::SendOpV2CUDAKernel<int8_t>,
ops::SendOpV2CUDAKernel<plat::float16>);
int,
int64_t,
int8_t,
plat::float16) {
}
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class SendOpV2CPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class SendAndRecvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -98,17 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker);
REGISTER_OP_CUDA_KERNEL(send_and_recv,
ops::SendAndRecvKernel<phi::GPUContext, float>,
ops::SendAndRecvKernel<phi::GPUContext, double>,
ops::SendAndRecvKernel<phi::GPUContext, int>,
ops::SendAndRecvKernel<phi::GPUContext, int64_t>);
REGISTER_OP_CPU_KERNEL(send_and_recv,
ops::SendAndRecvKernel<phi::CPUContext, float>,
ops::SendAndRecvKernel<phi::CPUContext, double>,
ops::SendAndRecvKernel<phi::CPUContext, int>,
ops::SendAndRecvKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(send_and_recv,
CPU,
ALL_LAYOUT,
ops::SendAndRecvKernel,
float,
double,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_STRUCT_KERNEL(send_and_recv,
GPU,
ALL_LAYOUT,
ops::SendAndRecvKernel,
float,
double,
int,
int64_t) {}
#endif
REGISTER_OP_VERSION(send_and_recv)
.AddCheckpoint(
R"ROC(add new attributes [next_endpoints] [previous_endpoints] and [mode])ROC",
......
......@@ -35,7 +35,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale);
USE_OP(send_and_recv);
PD_DECLARE_KERNEL(send_and_recv, CPU, ALL_LAYOUT);
std::string get_ip_port() {
std::mt19937 rng;
......
......@@ -39,7 +39,7 @@ using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale);
USE_OP(send_and_recv);
PD_DECLARE_KERNEL(send_and_recv, GPU, ALL_LAYOUT);
std::shared_ptr<distributed::HeterServer> b_rpc_service2;
......
......@@ -272,9 +272,11 @@ REGISTER_OPERATOR(sample_logits,
ops::SampleLogitsGradMaker<paddle::framework::OpDesc>,
ops::SampleLogitsGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad);
REGISTER_OP_CPU_KERNEL(sample_logits,
ops::SampleLogitsKernel<float>,
ops::SampleLogitsKernel<double>);
REGISTER_OP_CPU_KERNEL(sample_logits_grad,
ops::SampleLogitsGradKernel<float>,
ops::SampleLogitsGradKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
sample_logits, CPU, ALL_LAYOUT, ops::SampleLogitsKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(sample_logits_grad,
CPU,
ALL_LAYOUT,
ops::SampleLogitsGradKernel,
float,
double) {}
......@@ -109,7 +109,7 @@ __global__ void gpu_compute_remove_accidental_hits(const int size,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -235,7 +235,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -287,9 +287,15 @@ class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sample_logits,
ops::SampleLogitsCUDAKernel<float>,
ops::SampleLogitsCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(sample_logits_grad,
ops::SampleLogitsGradCUDAKernel<float>,
ops::SampleLogitsGradCUDAKernel<double>);
PD_REGISTER_STRUCT_KERNEL(sample_logits,
GPU,
ALL_LAYOUT,
ops::SampleLogitsCUDAKernel,
float,
double) {}
PD_REGISTER_STRUCT_KERNEL(sample_logits_grad,
GPU,
ALL_LAYOUT,
ops::SampleLogitsGradCUDAKernel,
float,
double) {}
......@@ -208,7 +208,7 @@ static void compute_remove_accidental_hits(const platform::DeviceContext& ctx,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class SampleLogitsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -304,7 +304,7 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class SampleLogitsGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -82,6 +82,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>,
paddle::operators::SamplingIdKernel<double>);
PD_REGISTER_STRUCT_KERNEL(sampling_id,
CPU,
ALL_LAYOUT,
paddle::operators::SamplingIdKernel,
float,
double) {}
......@@ -15,6 +15,9 @@
#include "paddle/fluid/operators/sampling_id_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>,
paddle::operators::SamplingIdKernel<double>);
PD_REGISTER_STRUCT_KERNEL(sampling_id,
GPU,
ALL_LAYOUT,
paddle::operators::SamplingIdKernel,
float,
double) {}
......@@ -27,7 +27,7 @@
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class SamplingIdKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -13,8 +13,11 @@
limitations under the License. */
#include "paddle/fluid/operators/sampling_id_op.h"
#include "paddle/fluid/platform/device_context.h"
namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>,
paddle::operators::SamplingIdKernel<double>);
paddle::operators::SamplingIdKernel<float, XPUCtx>,
paddle::operators::SamplingIdKernel<double, XPUCtx>);
......@@ -53,7 +53,7 @@ REGISTER_OPERATOR(
ops::SeedOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(seed, ops::CPUSeedKernel<phi::CPUContext, int>);
PD_REGISTER_STRUCT_KERNEL(seed, CPU, ALL_LAYOUT, ops::CPUSeedKernel, int) {}
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(seed).AddCheckpoint(
......
......@@ -18,7 +18,7 @@
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename T, typename DeviceContext>
class GPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......@@ -53,5 +53,5 @@ class GPUSeedKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(seed,
paddle::operators::GPUSeedKernel<phi::GPUContext, int>);
PD_REGISTER_STRUCT_KERNEL(
seed, GPU, ALL_LAYOUT, paddle::operators::GPUSeedKernel, int) {}
......@@ -44,7 +44,7 @@ static int get_seed(const framework::ExecutionContext& context) {
return seed;
}
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -146,17 +146,23 @@ REGISTER_OPERATOR(sequence_concat,
op::SeqConcatOpMaker,
op::SeqConcatGradOpMaker<paddle::framework::OpDesc>,
op::SeqConcatGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(sequence_concat,
op::SeqConcatKernel<phi::CPUContext, float>,
op::SeqConcatKernel<phi::CPUContext, double>,
op::SeqConcatKernel<phi::CPUContext, int>,
op::SeqConcatKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(sequence_concat,
CPU,
ALL_LAYOUT,
op::SeqConcatKernel,
float,
double,
int,
int64_t) {}
REGISTER_OPERATOR(sequence_concat_grad,
op::SeqConcatGradOp,
op::SeqConcatGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(sequence_concat_grad,
op::SeqConcatGradKernel<phi::CPUContext, float>,
op::SeqConcatGradKernel<phi::CPUContext, double>,
op::SeqConcatGradKernel<phi::CPUContext, int>,
op::SeqConcatGradKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(sequence_concat_grad,
CPU,
ALL_LAYOUT,
op::SeqConcatGradKernel,
float,
double,
int,
int64_t) {}
......@@ -16,15 +16,19 @@
#include "paddle/fluid/framework/op_registry.h"
REGISTER_OP_CUDA_KERNEL(
sequence_concat,
paddle::operators::SeqConcatKernel<phi::GPUContext, float>,
paddle::operators::SeqConcatKernel<phi::GPUContext, double>,
paddle::operators::SeqConcatKernel<phi::GPUContext, int>,
paddle::operators::SeqConcatKernel<phi::GPUContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_concat_grad,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, float>,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, double>,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, int>,
paddle::operators::SeqConcatGradKernel<phi::GPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(sequence_concat,
GPU,
ALL_LAYOUT,
paddle::operators::SeqConcatKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_STRUCT_KERNEL(sequence_concat_grad,
GPU,
ALL_LAYOUT,
paddle::operators::SeqConcatGradKernel,
float,
double,
int,
int64_t) {}
......@@ -62,7 +62,7 @@ inline std::vector<std::reference_wrapper<T>> GetDataVectorSafely(
}
} // namespace detail
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class SeqConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......@@ -107,7 +107,7 @@ class SeqConcatKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class SeqConcatGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......
......@@ -22,7 +22,6 @@ register_unity_group(
sequence_softmax_op.cc
sequence_topk_avg_pooling_op.cc
sequence_unpad_op.cc
sequence_concat_op.cu.cc
sequence_conv_op.cu.cc)
register_unity_group(
cu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册