未验证 提交 63efdaee 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 4] (#52116)

* update

* fix bug

* fix bug

* revert diag_op

* revert expand_op and expand_as_op

* fix bug

* fix bug
上级 34069c46
...@@ -267,7 +267,7 @@ The required data format for this layer is one of the following: ...@@ -267,7 +267,7 @@ The required data format for this layer is one of the following:
}; };
template <typename T> template <typename T>
class DataNormKernel<phi::CPUContext, T> : public framework::OpKernel<T> { class DataNormKernel<T, phi::CPUContext> : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
// const bool is_test = ctx.Attr<bool>("is_test"); // const bool is_test = ctx.Attr<bool>("is_test");
...@@ -509,7 +509,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -509,7 +509,7 @@ class DataNormGradOp : public framework::OperatorWithKernel {
}; };
template <typename T> template <typename T>
class DataNormGradKernel<phi::CPUContext, T> : public framework::OpKernel<T> { class DataNormGradKernel<T, phi::CPUContext> : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<phi::DenseTensor>("X"); const auto *x = ctx.Input<phi::DenseTensor>("X");
...@@ -764,12 +764,11 @@ REGISTER_OPERATOR(data_norm, ...@@ -764,12 +764,11 @@ REGISTER_OPERATOR(data_norm,
ops::DataNormGradMaker<paddle::imperative::OpBase>); ops::DataNormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(data_norm_grad, ops::DataNormGradOp); REGISTER_OPERATOR(data_norm_grad, ops::DataNormGradOp);
REGISTER_OP_CPU_KERNEL(data_norm, PD_REGISTER_STRUCT_KERNEL(
ops::DataNormKernel<phi::CPUContext, float>, data_norm, CPU, ALL_LAYOUT, ops::DataNormKernel, float, double) {}
ops::DataNormKernel<phi::CPUContext, double>); PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CPU_KERNEL(data_norm_grad, data_norm_grad, CPU, ALL_LAYOUT, ops::DataNormGradKernel, float, double) {}
ops::DataNormGradKernel<phi::CPUContext, float>,
ops::DataNormGradKernel<phi::CPUContext, double>);
REGISTER_OP_VERSION(data_norm).AddCheckpoint( REGISTER_OP_VERSION(data_norm).AddCheckpoint(
R"ROC( R"ROC(
upgrad data_norm op by adding scale_w to support scale and shift.)ROC", upgrad data_norm op by adding scale_w to support scale and shift.)ROC",
......
...@@ -102,7 +102,7 @@ __global__ void KernelUpdateParam(int C, ...@@ -102,7 +102,7 @@ __global__ void KernelUpdateParam(int C,
} }
template <typename T> template <typename T>
class DataNormKernel<phi::GPUContext, T> : public framework::OpKernel<T> { class DataNormKernel<T, phi::GPUContext> : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<phi::DenseTensor>("X"); const auto *x = ctx.Input<phi::DenseTensor>("X");
...@@ -154,7 +154,7 @@ class DataNormKernel<phi::GPUContext, T> : public framework::OpKernel<T> { ...@@ -154,7 +154,7 @@ class DataNormKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class DataNormGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> { class DataNormGradKernel<T, phi::GPUContext> : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<phi::DenseTensor>("X"); const auto *x = ctx.Input<phi::DenseTensor>("X");
...@@ -267,9 +267,8 @@ class DataNormGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> { ...@@ -267,9 +267,8 @@ class DataNormGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(data_norm,
ops::DataNormKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(
ops::DataNormKernel<phi::GPUContext, double>); data_norm, GPU, ALL_LAYOUT, ops::DataNormKernel, float, double) {}
REGISTER_OP_CUDA_KERNEL(data_norm_grad, PD_REGISTER_STRUCT_KERNEL(
ops::DataNormGradKernel<phi::GPUContext, float>, data_norm_grad, GPU, ALL_LAYOUT, ops::DataNormGradKernel, float, double) {}
ops::DataNormGradKernel<phi::GPUContext, double>);
...@@ -19,13 +19,13 @@ limitations under the License. */ ...@@ -19,13 +19,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DataNormKernel : public framework::OpKernel<T> { class DataNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DataNormGradKernel : public framework::OpKernel<T> { class DataNormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
......
...@@ -348,7 +348,6 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -348,7 +348,6 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = phi::CPUContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
deformable_psroi_pooling, deformable_psroi_pooling,
ops::DeformablePSROIPoolOp, ops::DeformablePSROIPoolOp,
...@@ -357,9 +356,16 @@ REGISTER_OPERATOR( ...@@ -357,9 +356,16 @@ REGISTER_OPERATOR(
ops::DeformablePSROIPoolGradOpMaker<paddle::imperative::OpBase>); ops::DeformablePSROIPoolGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(deformable_psroi_pooling_grad, REGISTER_OPERATOR(deformable_psroi_pooling_grad,
ops::DeformablePSROIPoolGradOp); ops::DeformablePSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(deformable_psroi_pooling,
ops::DeformablePSROIPoolCPUKernel<CPU, float>, PD_REGISTER_STRUCT_KERNEL(deformable_psroi_pooling,
ops::DeformablePSROIPoolCPUKernel<CPU, double>); CPU,
REGISTER_OP_CPU_KERNEL(deformable_psroi_pooling_grad, ALL_LAYOUT,
ops::DeformablePSROIPoolGradCPUKernel<CPU, float>, ops::DeformablePSROIPoolCPUKernel,
ops::DeformablePSROIPoolGradCPUKernel<CPU, double>); float,
double) {}
PD_REGISTER_STRUCT_KERNEL(deformable_psroi_pooling_grad,
CPU,
ALL_LAYOUT,
ops::DeformablePSROIPoolGradCPUKernel,
float,
double) {}
...@@ -178,7 +178,7 @@ __global__ void DeformablePSROIPoolForwardKernel(const int count, ...@@ -178,7 +178,7 @@ __global__ void DeformablePSROIPoolForwardKernel(const int count,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> { class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -479,7 +479,7 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( ...@@ -479,7 +479,7 @@ __global__ void DeformablePSROIPoolBackwardAccKernel(
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> { class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -626,10 +626,16 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> { ...@@ -626,10 +626,16 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(deformable_psroi_pooling, PD_REGISTER_STRUCT_KERNEL(deformable_psroi_pooling,
ops::DeformablePSROIPoolCUDAKernel<CUDA, float>, GPU,
ops::DeformablePSROIPoolCUDAKernel<CUDA, double>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL(deformable_psroi_pooling_grad, ops::DeformablePSROIPoolCUDAKernel,
ops::DeformablePSROIPoolGradCUDAKernel<CUDA, float>, float,
ops::DeformablePSROIPoolGradCUDAKernel<CUDA, double>); double) {}
PD_REGISTER_STRUCT_KERNEL(deformable_psroi_pooling_grad,
GPU,
ALL_LAYOUT,
ops::DeformablePSROIPoolGradCUDAKernel,
float,
double) {}
...@@ -166,7 +166,7 @@ void DeformablePSROIPoolForwardCPUKernel(const int count, ...@@ -166,7 +166,7 @@ void DeformablePSROIPoolForwardCPUKernel(const int count,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DeformablePSROIPoolCPUKernel : public framework::OpKernel<T> { class DeformablePSROIPoolCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -468,7 +468,7 @@ void DeformablePSROIPoolBackwardAccCPUKernel(const int count, ...@@ -468,7 +468,7 @@ void DeformablePSROIPoolBackwardAccCPUKernel(const int count,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DeformablePSROIPoolGradCPUKernel : public framework::OpKernel<T> { class DeformablePSROIPoolGradCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -101,7 +101,6 @@ $$Out = \frac{scale*X}{ max\_range }$$ ...@@ -101,7 +101,6 @@ $$Out = \frac{scale*X}{ max\_range }$$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = phi::CPUContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
dequantize_abs_max, dequantize_abs_max,
...@@ -109,6 +108,10 @@ REGISTER_OPERATOR( ...@@ -109,6 +108,10 @@ REGISTER_OPERATOR(
ops::DequantizeMaxAbsOpMaker, ops::DequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CPU, int8_t>, PD_REGISTER_STRUCT_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CPU, int16_t>); CPU,
ALL_LAYOUT,
ops::DequantizeMaxAbsKernel,
int8_t,
int16_t) {}
...@@ -53,7 +53,10 @@ template struct DequantizeFunctor<phi::GPUContext, int16_t>; ...@@ -53,7 +53,10 @@ template struct DequantizeFunctor<phi::GPUContext, int16_t>;
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(dequantize_abs_max, PD_REGISTER_STRUCT_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CUDA, int8_t>, GPU,
ops::DequantizeMaxAbsKernel<CUDA, int16_t>); ALL_LAYOUT,
ops::DequantizeMaxAbsKernel,
int8_t,
int16_t) {}
...@@ -36,7 +36,7 @@ struct DequantizeFunctor { ...@@ -36,7 +36,7 @@ struct DequantizeFunctor {
phi::DenseTensor* out); phi::DenseTensor* out);
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DequantizeMaxAbsKernel : public framework::OpKernel<T> { class DequantizeMaxAbsKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -107,7 +107,6 @@ This calculation is an opposite operation of QuantizeLogOp: ...@@ -107,7 +107,6 @@ This calculation is an opposite operation of QuantizeLogOp:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = phi::CPUContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
dequantize_log, dequantize_log,
...@@ -115,4 +114,6 @@ REGISTER_OPERATOR( ...@@ -115,4 +114,6 @@ REGISTER_OPERATOR(
ops::DequantizeLogOpMaker, ops::DequantizeLogOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_log, ops::DequantizeLogKernel<CPU, int8_t>);
PD_REGISTER_STRUCT_KERNEL(
dequantize_log, CPU, ALL_LAYOUT, ops::DequantizeLogKernel, int8_t) {}
...@@ -60,5 +60,6 @@ template struct DequantizeFunctor<phi::GPUContext, int8_t>; ...@@ -60,5 +60,6 @@ template struct DequantizeFunctor<phi::GPUContext, int8_t>;
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(dequantize_log, ops::DequantizeLogKernel<CUDA, int8_t>); PD_REGISTER_STRUCT_KERNEL(
dequantize_log, GPU, ALL_LAYOUT, ops::DequantizeLogKernel, int8_t) {}
...@@ -34,7 +34,7 @@ struct DequantizeFunctor { ...@@ -34,7 +34,7 @@ struct DequantizeFunctor {
phi::DenseTensor* out); phi::DenseTensor* out);
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DequantizeLogKernel : public framework::OpKernel<T> { class DequantizeLogKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -262,6 +262,9 @@ REGISTER_OPERATOR( ...@@ -262,6 +262,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(density_prior_box, PD_REGISTER_STRUCT_KERNEL(density_prior_box,
ops::DensityPriorBoxOpKernel<float>, CPU,
ops::DensityPriorBoxOpKernel<double>); ALL_LAYOUT,
ops::DensityPriorBoxOpKernel,
float,
double) {}
...@@ -83,7 +83,7 @@ static __global__ void GenDensityPriorBox(const int height, ...@@ -83,7 +83,7 @@ static __global__ void GenDensityPriorBox(const int height,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> { class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -188,6 +188,10 @@ class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -188,6 +188,10 @@ class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(density_prior_box,
ops::DensityPriorBoxOpCUDAKernel<float>, PD_REGISTER_STRUCT_KERNEL(density_prior_box,
ops::DensityPriorBoxOpCUDAKernel<double>); GPU,
ALL_LAYOUT,
ops::DensityPriorBoxOpCUDAKernel,
float,
double) {}
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class DensityPriorBoxOpKernel : public framework::OpKernel<T> { class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -224,7 +224,6 @@ REGISTER_OPERATOR( ...@@ -224,7 +224,6 @@ REGISTER_OPERATOR(
ops::DetectionMAPOpMaker, ops::DetectionMAPOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
detection_map, PD_REGISTER_STRUCT_KERNEL(
ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, float>, detection_map, CPU, ALL_LAYOUT, ops::DetectionMAPOpKernel, float, double) {}
ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, double>);
...@@ -56,7 +56,7 @@ inline void GetAccumulation(std::vector<std::pair<T, int>> in_pairs, ...@@ -56,7 +56,7 @@ inline void GetAccumulation(std::vector<std::pair<T, int>> in_pairs,
} }
} }
template <typename Place, typename T> template <typename T, typename DeviceContext>
class DetectionMAPOpKernel : public framework::OpKernel<T> { class DetectionMAPOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -68,5 +68,5 @@ REGISTER_OP_WITHOUT_GRADIENT(dgc_clip_by_norm, ...@@ -68,5 +68,5 @@ REGISTER_OP_WITHOUT_GRADIENT(dgc_clip_by_norm,
ops::DGCClipByNormOp, ops::DGCClipByNormOp,
ops::DGCClipByNormOpMaker); ops::DGCClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(dgc_clip_by_norm, PD_REGISTER_STRUCT_KERNEL(
ops::DGCClipByNormKernel<phi::CPUContext, float>); dgc_clip_by_norm, CPU, ALL_LAYOUT, ops::DGCClipByNormKernel, float) {}
...@@ -15,5 +15,5 @@ limitations under the License. */ ...@@ -15,5 +15,5 @@ limitations under the License. */
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h" #include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(dgc_clip_by_norm, PD_REGISTER_STRUCT_KERNEL(
ops::DGCClipByNormKernel<phi::GPUContext, float>); dgc_clip_by_norm, GPU, ALL_LAYOUT, ops::DGCClipByNormKernel, float) {}
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DGCClipByNormKernel : public framework::OpKernel<T> { class DGCClipByNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -16,4 +16,4 @@ limitations under the License. */ ...@@ -16,4 +16,4 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(dgc, ops::DGCOpKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(dgc, GPU, ALL_LAYOUT, ops::DGCOpKernel, float) {}
...@@ -49,7 +49,7 @@ inline float get_period_sparcity(const std::vector<float>& sparsity, ...@@ -49,7 +49,7 @@ inline float get_period_sparcity(const std::vector<float>& sparsity,
return sparsity[idx]; return sparsity[idx];
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DGCOpKernel : public framework::OpKernel<T> { class DGCOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -277,9 +277,12 @@ REGISTER_OPERATOR( ...@@ -277,9 +277,12 @@ REGISTER_OPERATOR(
ops::FakeDequantizeMaxAbsOpMaker, ops::FakeDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs, PD_REGISTER_STRUCT_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CPU, float>, CPU,
ops::FakeDequantizeMaxAbsKernel<CPU, double>); ALL_LAYOUT,
ops::FakeDequantizeMaxAbsKernel,
float,
double) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_channel_wise_dequantize_max_abs, fake_channel_wise_dequantize_max_abs,
...@@ -287,9 +290,12 @@ REGISTER_OPERATOR( ...@@ -287,9 +290,12 @@ REGISTER_OPERATOR(
ops::FakeChannelWiseDequantizeMaxAbsOpMaker, ops::FakeChannelWiseDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_dequantize_max_abs, PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, float>, CPU,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, double>); ALL_LAYOUT,
ops::FakeChannelWiseDequantizeMaxAbsKernel,
float,
double) {}
REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs) REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -16,14 +16,19 @@ limitations under the License. */ ...@@ -16,14 +16,19 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.cu.h" #include "paddle/fluid/operators/fake_dequantize_op.cu.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>, PD_REGISTER_STRUCT_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>, GPU,
ops::FakeDequantizeMaxAbsKernel<CUDA, float16>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( ops::FakeDequantizeMaxAbsKernel,
fake_channel_wise_dequantize_max_abs, float,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>, double,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>, float16) {}
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float16>); PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_dequantize_max_abs,
GPU,
ALL_LAYOUT,
ops::FakeChannelWiseDequantizeMaxAbsKernel,
float,
double,
float16) {}
...@@ -44,7 +44,7 @@ struct ChannelDequantizeFunctor { ...@@ -44,7 +44,7 @@ struct ChannelDequantizeFunctor {
phi::DenseTensor* out); phi::DenseTensor* out);
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> { class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
...@@ -62,7 +62,7 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -62,7 +62,7 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -878,8 +878,11 @@ REGISTER_OPERATOR( ...@@ -878,8 +878,11 @@ REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, PD_REGISTER_STRUCT_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::FakeQuantizeAbsMaxKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_dequantize_abs_max, fake_quantize_dequantize_abs_max,
...@@ -887,8 +890,11 @@ REGISTER_OPERATOR( ...@@ -887,8 +890,11 @@ REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max, PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::FakeQuantizeDequantizeAbsMaxKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_range_abs_max, fake_quantize_range_abs_max,
...@@ -896,8 +902,11 @@ REGISTER_OPERATOR( ...@@ -896,8 +902,11 @@ REGISTER_OPERATOR(
ops::FakeQuantizeRangeAbsMaxOpMaker, ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, PD_REGISTER_STRUCT_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::FakeQuantizeRangeAbsMaxKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_moving_average_abs_max, fake_quantize_moving_average_abs_max,
...@@ -905,8 +914,11 @@ REGISTER_OPERATOR( ...@@ -905,8 +914,11 @@ REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, PD_REGISTER_STRUCT_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::FakeQuantizeMovingAverageAbsMaxKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_moving_average_abs_max,
...@@ -914,9 +926,11 @@ REGISTER_OPERATOR( ...@@ -914,9 +926,11 @@ REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_moving_average_abs_max,
fake_quantize_dequantize_moving_average_abs_max, CPU,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>); ALL_LAYOUT,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_channel_wise_quantize_abs_max, fake_channel_wise_quantize_abs_max,
...@@ -924,8 +938,11 @@ REGISTER_OPERATOR( ...@@ -924,8 +938,11 @@ REGISTER_OPERATOR(
ops::FakeChannelWiseQuantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::FakeChannelWiseQuantizeAbsMaxKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
moving_average_abs_max_scale, moving_average_abs_max_scale,
...@@ -933,13 +950,19 @@ REGISTER_OPERATOR( ...@@ -933,13 +950,19 @@ REGISTER_OPERATOR(
ops::MovingAverageAbsMaxScaleOpMaker, ops::MovingAverageAbsMaxScaleOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, PD_REGISTER_STRUCT_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::MovingAverageAbsMaxScaleKernel,
float) {}
REGISTER_OPERATOR(stright_throuth_estimator_grad, REGISTER_OPERATOR(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradOp); ops::StrightThroughEstimatorGradOp);
REGISTER_OP_CPU_KERNEL(stright_throuth_estimator_grad, PD_REGISTER_STRUCT_KERNEL(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradKernel<CPU, float>); CPU,
ALL_LAYOUT,
ops::StrightThroughEstimatorGradKernel,
float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_channel_wise_quantize_dequantize_abs_max, fake_channel_wise_quantize_dequantize_abs_max,
...@@ -947,9 +970,11 @@ REGISTER_OPERATOR( ...@@ -947,9 +970,11 @@ REGISTER_OPERATOR(
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_dequantize_abs_max,
fake_channel_wise_quantize_dequantize_abs_max, CPU,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>); ALL_LAYOUT,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel,
float) {}
REGISTER_OP_VERSION(fake_channel_wise_quantize_abs_max) REGISTER_OP_VERSION(fake_channel_wise_quantize_abs_max)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -16,35 +16,58 @@ limitations under the License. */ ...@@ -16,35 +16,58 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.cu.h" #include "paddle/fluid/operators/fake_quantize_op.cu.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>, PD_REGISTER_STRUCT_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float16>); GPU,
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max, ALL_LAYOUT,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>, ops::FakeQuantizeAbsMaxKernel,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>); float,
REGISTER_OP_CUDA_KERNEL( float16) {}
fake_channel_wise_quantize_abs_max, PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>, GPU,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float16>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeDequantizeAbsMaxKernel,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>, float,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float16>); float16) {}
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_abs_max,
fake_quantize_moving_average_abs_max, GPU,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>, ALL_LAYOUT,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float16>); ops::FakeChannelWiseQuantizeAbsMaxKernel,
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, float,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>, float16) {}
ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>); PD_REGISTER_STRUCT_KERNEL(fake_quantize_range_abs_max,
REGISTER_OP_CUDA_KERNEL( GPU,
fake_quantize_dequantize_moving_average_abs_max, ALL_LAYOUT,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>, ops::FakeQuantizeRangeAbsMaxKernel,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float16>); float,
REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad, float16) {}
ops::StrightThroughEstimatorGradKernel<CUDA, float>, PD_REGISTER_STRUCT_KERNEL(fake_quantize_moving_average_abs_max,
ops::StrightThroughEstimatorGradKernel<CUDA, float16>); GPU,
REGISTER_OP_CUDA_KERNEL( ALL_LAYOUT,
fake_channel_wise_quantize_dequantize_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>); float,
float16) {}
PD_REGISTER_STRUCT_KERNEL(moving_average_abs_max_scale,
GPU,
ALL_LAYOUT,
ops::MovingAverageAbsMaxScaleKernel,
float,
float16) {}
PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_moving_average_abs_max,
GPU,
ALL_LAYOUT,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel,
float,
float16) {}
PD_REGISTER_STRUCT_KERNEL(stright_throuth_estimator_grad,
GPU,
ALL_LAYOUT,
ops::StrightThroughEstimatorGradKernel,
float,
float16) {}
PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_dequantize_abs_max,
GPU,
ALL_LAYOUT,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel,
float) {}
...@@ -176,7 +176,7 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> { ...@@ -176,7 +176,7 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
phi::DenseTensor *out) const = 0; phi::DenseTensor *out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> { class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext &dev_ctx, void RunClipFunctor(const DeviceContext &dev_ctx,
...@@ -190,7 +190,7 @@ class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> { ...@@ -190,7 +190,7 @@ class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeQuantizeDequantizeAbsMaxKernel class FakeQuantizeDequantizeAbsMaxKernel
: public FakeAbsMaxKernelBase<DeviceContext, T> { : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
...@@ -205,7 +205,7 @@ class FakeQuantizeDequantizeAbsMaxKernel ...@@ -205,7 +205,7 @@ class FakeQuantizeDequantizeAbsMaxKernel
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -232,7 +232,7 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -232,7 +232,7 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
...@@ -257,7 +257,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel ...@@ -257,7 +257,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -304,7 +304,7 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -304,7 +304,7 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -367,9 +367,9 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { ...@@ -367,9 +367,9 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
phi::DenseTensor *out) const = 0; phi::DenseTensor *out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeQuantizeMovingAverageAbsMaxKernel class FakeQuantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<T, DeviceContext> {
protected: protected:
void RunClipFunctor(const DeviceContext &dev_ctx, void RunClipFunctor(const DeviceContext &dev_ctx,
const phi::DenseTensor &in, const phi::DenseTensor &in,
...@@ -382,9 +382,9 @@ class FakeQuantizeMovingAverageAbsMaxKernel ...@@ -382,9 +382,9 @@ class FakeQuantizeMovingAverageAbsMaxKernel
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<T, DeviceContext> {
protected: protected:
void RunClipFunctor(const DeviceContext &dev_ctx, void RunClipFunctor(const DeviceContext &dev_ctx,
const phi::DenseTensor &in, const phi::DenseTensor &in,
...@@ -397,7 +397,7 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel ...@@ -397,7 +397,7 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -445,7 +445,7 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -445,7 +445,7 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> { class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
...@@ -130,5 +130,6 @@ namespace ops = paddle::operators; ...@@ -130,5 +130,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad,
ops::DecayedAdagradOp, ops::DecayedAdagradOp,
ops::DecayedAdagradOpMaker); ops::DecayedAdagradOpMaker);
REGISTER_OP_CPU_KERNEL(decayed_adagrad,
ops::DecayedAdagradOpKernel<phi::CPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
decayed_adagrad, CPU, ALL_LAYOUT, ops::DecayedAdagradOpKernel, float) {}
...@@ -14,5 +14,6 @@ limitations under the License. */ ...@@ -14,5 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h" #include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(decayed_adagrad,
ops::DecayedAdagradOpKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
decayed_adagrad, GPU, ALL_LAYOUT, ops::DecayedAdagradOpKernel, float) {}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DecayedAdagradOpKernel : public framework::OpKernel<T> { class DecayedAdagradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -76,5 +76,5 @@ REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum, ...@@ -76,5 +76,5 @@ REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum,
ops::DGCMomentumOp, ops::DGCMomentumOp,
ops::DGCMomentumOpMaker); ops::DGCMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(dgc_momentum, PD_REGISTER_STRUCT_KERNEL(
ops::DGCMomentumKernel<phi::CPUContext, float>); dgc_momentum, CPU, ALL_LAYOUT, ops::DGCMomentumKernel, float) {}
...@@ -15,5 +15,6 @@ ...@@ -15,5 +15,6 @@
#include "paddle/fluid/operators/optimizers/dgc_momentum_op.h" #include "paddle/fluid/operators/optimizers/dgc_momentum_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(dgc_momentum,
ops::DGCMomentumKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
dgc_momentum, GPU, ALL_LAYOUT, ops::DGCMomentumKernel, float) {}
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DGCMomentumKernel : public framework::OpKernel<T> { class DGCMomentumKernel : public framework::OpKernel<T> {
public: public:
DGCMomentumKernel() {} DGCMomentumKernel() {}
......
...@@ -118,6 +118,8 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init, ...@@ -118,6 +118,8 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init,
ops::DistributedFusedLambInitOp, ops::DistributedFusedLambInitOp,
ops::DistributedFusedLambInitOpMaker); ops::DistributedFusedLambInitOpMaker);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb_init,
distributed_fused_lamb_init, CPU,
ops::DistributedFusedLambInitOpKernel<phi::CPUContext, float>); ALL_LAYOUT,
ops::DistributedFusedLambInitOpKernel,
float) {}
...@@ -340,7 +340,7 @@ static T ClipByBound(T x, T low_value, T high_value) { ...@@ -340,7 +340,7 @@ static T ClipByBound(T x, T low_value, T high_value) {
} }
template <typename T> template <typename T>
class DistributedFusedLambInitOpKernel<phi::GPUContext, T> class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -790,6 +790,8 @@ class DistributedFusedLambInitOpKernel<phi::GPUContext, T> ...@@ -790,6 +790,8 @@ class DistributedFusedLambInitOpKernel<phi::GPUContext, T>
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb_init,
distributed_fused_lamb_init, GPU,
ops::DistributedFusedLambInitOpKernel<phi::GPUContext, float>); ALL_LAYOUT,
ops::DistributedFusedLambInitOpKernel,
float) {}
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DevCtx, typename T> template <typename T, typename DevCtx>
class DistributedFusedLambInitOpKernel : public framework::OpKernel<T> { class DistributedFusedLambInitOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
......
...@@ -170,6 +170,8 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb, ...@@ -170,6 +170,8 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp, ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker); ops::DistributedFusedLambOpMaker);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb,
distributed_fused_lamb, CPU,
ops::DistributedFusedLambOpKernel<phi::CPUContext, float>); ALL_LAYOUT,
ops::DistributedFusedLambOpKernel,
float) {}
...@@ -1330,7 +1330,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx, ...@@ -1330,7 +1330,7 @@ static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
} }
template <typename T> template <typename T>
class DistributedFusedLambOpKernel<phi::GPUContext, T> class DistributedFusedLambOpKernel<T, phi::GPUContext>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -2300,6 +2300,8 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -2300,6 +2300,8 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
namespace plat = paddle::platform; namespace plat = paddle::platform;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb,
distributed_fused_lamb, GPU,
ops::DistributedFusedLambOpKernel<phi::GPUContext, float>); ALL_LAYOUT,
ops::DistributedFusedLambOpKernel,
float) {}
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DevCtx, typename T> template <typename T, typename DevCtx>
class DistributedFusedLambOpKernel : public framework::OpKernel<T> { class DistributedFusedLambOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
......
...@@ -131,6 +131,6 @@ CCS16 - Deep Learning with Differential Privacy. ...@@ -131,6 +131,6 @@ CCS16 - Deep Learning with Differential Privacy.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dpsgd, ops::DpsgdOp, ops::DpsgdOpMaker); REGISTER_OP_WITHOUT_GRADIENT(dpsgd, ops::DpsgdOp, ops::DpsgdOpMaker);
REGISTER_OP_CPU_KERNEL(dpsgd,
ops::DpsgdOpKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(
ops::DpsgdOpKernel<phi::CPUContext, double>); dpsgd, CPU, ALL_LAYOUT, ops::DpsgdOpKernel, float, double) {}
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DpsgdOpKernel : public framework::OpKernel<T> { class DpsgdOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
......
...@@ -151,6 +151,8 @@ REGISTER_OPERATOR(distributed_lookup_table, ...@@ -151,6 +151,8 @@ REGISTER_OPERATOR(distributed_lookup_table,
ops::DistributedLookupTableOp, ops::DistributedLookupTableOp,
ops::DistributedLookupTableOpMaker); ops::DistributedLookupTableOpMaker);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_lookup_table,
distributed_lookup_table, CPU,
ops::DistributedLookupTableKernel<phi::CPUContext, float>); ALL_LAYOUT,
ops::DistributedLookupTableKernel,
float) {}
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_lookup_table,
distributed_lookup_table, GPU,
ops::DistributedLookupTableKernel<phi::GPUContext, float>); ALL_LAYOUT,
ops::DistributedLookupTableKernel,
float) {}
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DistributedLookupTableKernel : public framework::OpKernel<T> { class DistributedLookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
...@@ -134,7 +134,9 @@ REGISTER_OPERATOR(distributed_push_sparse, ...@@ -134,7 +134,9 @@ REGISTER_OPERATOR(distributed_push_sparse,
ops::DistributedPushSparseOp, ops::DistributedPushSparseOp,
ops::DistributedPushSparseOpMaker); ops::DistributedPushSparseOpMaker);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_push_sparse,
distributed_push_sparse, CPU,
ops::DistributedPushSparseKernel<phi::CPUContext, float>, ALL_LAYOUT,
ops::DistributedPushSparseKernel<phi::CPUContext, double>); ops::DistributedPushSparseKernel,
float,
double) {}
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(distributed_push_sparse,
distributed_push_sparse, GPU,
ops::DistributedPushSparseKernel<phi::GPUContext, float>, ALL_LAYOUT,
ops::DistributedPushSparseKernel<phi::GPUContext, double>); ops::DistributedPushSparseKernel,
float,
double) {}
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DistributedPushSparseKernel : public framework::OpKernel<T> { class DistributedPushSparseKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册