未验证 提交 cc9bbd5b 编写于 作者: H Huang Jiyi 提交者: GitHub

register fluid kerenls to phi (#51976)

* unify add_position_encoding

* unify affine_channel

* unify alloc_float_status

* unify allreduce

* unify alltoall

* unify anchor_generator

* unify ascend_trigger

* fix bug

* fix test
上级 aaa14780
......@@ -121,11 +121,15 @@ REGISTER_OPERATOR(
ops::AddPositionEncodingGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad);
REGISTER_OP_CPU_KERNEL(add_position_encoding,
ops::AddPositionEncodingKernel<phi::CPUContext, float>,
ops::AddPositionEncodingKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(
add_position_encoding_grad,
ops::AddPositionEncodingGradKernel<phi::CPUContext, float>,
ops::AddPositionEncodingGradKernel<phi::CPUContext, double>);
PD_REGISTER_STRUCT_KERNEL(add_position_encoding,
CPU,
ALL_LAYOUT,
ops::AddPositionEncodingKernel,
float,
double) {}
PD_REGISTER_STRUCT_KERNEL(add_position_encoding_grad,
CPU,
ALL_LAYOUT,
ops::AddPositionEncodingGradKernel,
float,
double) {}
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AddPositionEncodingKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -99,7 +99,7 @@ class AddPositionEncodingKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AddPositionEncodingGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -184,7 +184,7 @@ template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AffineChannelKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -228,7 +228,7 @@ class AffineChannelKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AffineChannelGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -353,9 +353,12 @@ REGISTER_OPERATOR(affine_channel_grad,
ops::AffineChannelNoNeedBufferVarsInference,
ops::AffineChannelGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(affine_channel,
ops::AffineChannelKernel<CPU, float>,
ops::AffineChannelKernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(affine_channel_grad,
ops::AffineChannelGradKernel<CPU, float>,
ops::AffineChannelGradKernel<CPU, double>);
PD_REGISTER_STRUCT_KERNEL(
affine_channel, CPU, ALL_LAYOUT, ops::AffineChannelKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(affine_channel_grad,
CPU,
ALL_LAYOUT,
ops::AffineChannelGradKernel,
float,
double) {}
......@@ -48,7 +48,7 @@ __global__ void KeAffineChannelCUDA(const T* x,
}
}
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AffineChannelCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -132,7 +132,7 @@ __global__ void AffineChannelScaleBiasGradientCUDAKernel(const T* dy,
}
}
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -211,9 +211,15 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(affine_channel,
ops::AffineChannelCUDAKernel<CUDA, float>,
ops::AffineChannelCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(affine_channel_grad,
ops::AffineChannelGradCUDAKernel<CUDA, float>,
ops::AffineChannelGradCUDAKernel<CUDA, double>);
PD_REGISTER_STRUCT_KERNEL(affine_channel,
GPU,
ALL_LAYOUT,
ops::AffineChannelCUDAKernel,
float,
double) {}
PD_REGISTER_STRUCT_KERNEL(affine_channel_grad,
GPU,
ALL_LAYOUT,
ops::AffineChannelGradCUDAKernel,
float,
double) {}
......@@ -51,7 +51,7 @@ class AllocFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AllocFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -73,5 +73,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(alloc_float_status,
ops::AllocFloatStatusKernel<CPU, float>);
PD_REGISTER_STRUCT_KERNEL(
alloc_float_status, CPU, ALL_LAYOUT, ops::AllocFloatStatusKernel, float) {}
......@@ -50,4 +50,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(ascend_trigger,
ops::AscendTriggerOp,
ops::AscendTriggerOpMaker);
REGISTER_OP_CPU_KERNEL(ascend_trigger, ops::AscendTriggerCPUKernel<float>)
PD_REGISTER_STRUCT_KERNEL(
ascend_trigger, CPU, ALL_LAYOUT, ops::AscendTriggerCPUKernel, float) {}
......@@ -25,7 +25,7 @@
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class AscendTriggerCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......
......@@ -73,9 +73,12 @@ REGISTER_OP_WITHOUT_GRADIENT(allreduce,
ops::AllReduceOp,
ops::AllReduceOpMaker);
REGISTER_OP_CPU_KERNEL(allreduce,
ops::AllReduceOpKernel<phi::CPUContext, float>,
ops::AllReduceOpKernel<phi::CPUContext, double>,
ops::AllReduceOpKernel<phi::CPUContext, int>,
ops::AllReduceOpKernel<phi::CPUContext, int64_t>,
ops::AllReduceOpKernel<phi::CPUContext, plat::float16>);
PD_REGISTER_STRUCT_KERNEL(allreduce,
CPU,
ALL_LAYOUT,
ops::AllReduceOpKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -17,9 +17,12 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(allreduce,
ops::AllReduceOpKernel<phi::GPUContext, float>,
ops::AllReduceOpKernel<phi::GPUContext, double>,
ops::AllReduceOpKernel<phi::GPUContext, int>,
ops::AllReduceOpKernel<phi::GPUContext, int64_t>,
ops::AllReduceOpKernel<phi::GPUContext, plat::float16>);
PD_REGISTER_STRUCT_KERNEL(allreduce,
GPU,
ALL_LAYOUT,
ops::AllReduceOpKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -69,9 +69,12 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker)
REGISTER_OP_CPU_KERNEL(alltoall,
ops::AllToAllOpCPUKernel<float>,
ops::AllToAllOpCPUKernel<double>,
ops::AllToAllOpCPUKernel<int>,
ops::AllToAllOpCPUKernel<int64_t>,
ops::AllToAllOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(alltoall,
CPU,
ALL_LAYOUT,
ops::AllToAllOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -92,12 +92,16 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(alltoall,
ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>,
PD_REGISTER_STRUCT_KERNEL(alltoall,
GPU,
ALL_LAYOUT,
ops::AllToAllOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
ops::AllToAllOpCUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::AllToAllOpCUDAKernel<int>,
ops::AllToAllOpCUDAKernel<int64_t>,
ops::AllToAllOpCUDAKernel<plat::float16>);
int,
int64_t,
plat::float16) {
}
......@@ -29,7 +29,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class AllToAllOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -175,6 +175,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(anchor_generator,
ops::AnchorGeneratorOpKernel<float>,
ops::AnchorGeneratorOpKernel<double>);
PD_REGISTER_STRUCT_KERNEL(anchor_generator,
CPU,
ALL_LAYOUT,
ops::AnchorGeneratorOpKernel,
float,
double) {}
......@@ -71,7 +71,7 @@ __global__ void SetVariance(T* out,
CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; }
}
template <typename T>
template <typename T, typename DeviceContext>
class AnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -133,6 +133,10 @@ class AnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(anchor_generator,
ops::AnchorGeneratorOpCUDAKernel<float>,
ops::AnchorGeneratorOpCUDAKernel<double>);
PD_REGISTER_STRUCT_KERNEL(anchor_generator,
GPU,
ALL_LAYOUT,
ops::AnchorGeneratorOpCUDAKernel,
float,
double) {}
......@@ -44,7 +44,7 @@ extern __global__ void SetVariance(T* out,
const int num);
#endif
template <typename T>
template <typename T, typename DeviceContext>
class AnchorGeneratorOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -77,7 +77,7 @@ class TestAnchorGeneratorOp(OpTest):
self.outputs = {'Anchors': self.out_anchors, 'Variances': self.out_var}
def test_check_output(self):
self.check_output()
self.check_output(check_dygraph=False)
def setUp(self):
self.op_type = "anchor_generator"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册