未验证 提交 09da1c4c 编写于 作者: Y YuanRisheng 提交者: GitHub

unify kernel (#52594)

上级 5662adcc
...@@ -70,4 +70,5 @@ REGISTER_OPERATOR( ...@@ -70,4 +70,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PushDenseNoNeedBufferVarsInferer); ops::PushDenseNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>) PD_REGISTER_STRUCT_KERNEL(
push_dense, CPU, ALL_LAYOUT, ops::PushDenseCPUKernel, float) {}
...@@ -60,7 +60,7 @@ void PushDenseFunctor(const framework::ExecutionContext& ctx) { ...@@ -60,7 +60,7 @@ void PushDenseFunctor(const framework::ExecutionContext& ctx) {
#endif #endif
} }
template <typename T> template <typename T, typename DeviceContext>
class PushDenseCPUKernel : public framework::OpKernel<T> { class PushDenseCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -232,7 +232,7 @@ class PyramidHashOP : public framework::OperatorWithKernel { ...@@ -232,7 +232,7 @@ class PyramidHashOP : public framework::OperatorWithKernel {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CPUPyramidHashOPKernel : public framework::OpKernel<T> { class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
public: public:
bool should_use_term(math::bloomfilter* _filter, bool should_use_term(math::bloomfilter* _filter,
...@@ -492,7 +492,7 @@ class PyramidHashGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -492,7 +492,7 @@ class PyramidHashGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> { class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
public: public:
void hash_embedding_bp(const T* hash_id, void hash_embedding_bp(const T* hash_id,
...@@ -584,8 +584,11 @@ REGISTER_OPERATOR(pyramid_hash, ...@@ -584,8 +584,11 @@ REGISTER_OPERATOR(pyramid_hash,
ops::PyramidHashGradOpMaker<paddle::imperative::OpBase>); ops::PyramidHashGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad); REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad);
REGISTER_OP_CPU_KERNEL(pyramid_hash, PD_REGISTER_STRUCT_KERNEL(
ops::CPUPyramidHashOPKernel<phi::CPUContext, float>, pyramid_hash, CPU, ALL_LAYOUT, ops::CPUPyramidHashOPKernel, float, int8_t) {
ops::CPUPyramidHashOPKernel<phi::CPUContext, int8_t>); }
REGISTER_OP_CPU_KERNEL(pyramid_hash_grad, PD_REGISTER_STRUCT_KERNEL(pyramid_hash_grad,
ops::CPUPyramidHashOPGradKernel<phi::CPUContext, float>); CPU,
ALL_LAYOUT,
ops::CPUPyramidHashOPGradKernel,
float) {}
...@@ -222,7 +222,6 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$ ...@@ -222,7 +222,6 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = phi::CPUContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
quantize_linear, quantize_linear,
...@@ -231,7 +230,8 @@ REGISTER_OPERATOR( ...@@ -231,7 +230,8 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>); PD_REGISTER_STRUCT_KERNEL(
quantize_linear, CPU, ALL_LAYOUT, ops::QuantizeLinearKernel, float) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
dequantize_linear, dequantize_linear,
...@@ -240,7 +240,10 @@ REGISTER_OPERATOR( ...@@ -240,7 +240,10 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_linear, PD_REGISTER_STRUCT_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CPU, float>, CPU,
ops::DeQuantizeLinearKernel<CPU, int8_t>, ALL_LAYOUT,
ops::DeQuantizeLinearKernel<CPU, double>); ops::DeQuantizeLinearKernel,
float,
int8_t,
double) {}
...@@ -123,12 +123,18 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>; ...@@ -123,12 +123,18 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext; using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(dequantize_linear, PD_REGISTER_STRUCT_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CUDA, float>, GPU,
ops::DeQuantizeLinearKernel<CUDA, float16>, ALL_LAYOUT,
ops::DeQuantizeLinearKernel<CUDA, int8_t>, ops::DeQuantizeLinearKernel,
ops::DeQuantizeLinearKernel<CUDA, double>); float,
float16,
REGISTER_OP_CUDA_KERNEL(quantize_linear, int8_t,
ops::QuantizeLinearKernel<CUDA, float>, double) {}
ops::QuantizeLinearKernel<CUDA, float16>);
PD_REGISTER_STRUCT_KERNEL(quantize_linear,
GPU,
ALL_LAYOUT,
ops::QuantizeLinearKernel,
float,
float16) {}
...@@ -47,7 +47,7 @@ struct ChannelDequantizeFunctorV2 { ...@@ -47,7 +47,7 @@ struct ChannelDequantizeFunctorV2 {
phi::DenseTensor* out); phi::DenseTensor* out);
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class QuantizeLinearKernel : public framework::OpKernel<T> { class QuantizeLinearKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -130,7 +130,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -130,7 +130,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class DeQuantizeLinearKernel : public framework::OpKernel<T> { class DeQuantizeLinearKernel : public framework::OpKernel<T> {
public: public:
template <typename D> template <typename D>
......
...@@ -96,11 +96,12 @@ REGISTER_OPERATOR( ...@@ -96,11 +96,12 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
template <typename T> PD_REGISTER_STRUCT_KERNEL(random_crop,
using Kernel = ops::RandomCropKernel<phi::CPUContext, T>; CPU,
REGISTER_OP_CPU_KERNEL(random_crop, ALL_LAYOUT,
Kernel<float>, ops::RandomCropKernel,
Kernel<int>, float,
Kernel<double>, int,
Kernel<uint8_t>, double,
Kernel<int16_t>); uint8_t,
int16_t) {}
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
#include "paddle/fluid/operators/random_crop_op.h" #include "paddle/fluid/operators/random_crop_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
template <typename T>
using Kernel = ops::RandomCropKernel<phi::GPUContext, T>; PD_REGISTER_STRUCT_KERNEL(random_crop,
REGISTER_OP_CUDA_KERNEL(random_crop, GPU,
Kernel<float>, ALL_LAYOUT,
Kernel<int>, ops::RandomCropKernel,
Kernel<double>, float,
Kernel<uint8_t>, int,
Kernel<int16_t>); double,
uint8_t,
int16_t) {}
...@@ -176,7 +176,7 @@ struct RandomCropFunctor { ...@@ -176,7 +176,7 @@ struct RandomCropFunctor {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RandomCropKernel : public framework::OpKernel<T> { class RandomCropKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -98,6 +98,9 @@ REGISTER_OPERATOR( ...@@ -98,6 +98,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::RandomRoutingInplaceInferer) ops::RandomRoutingInplaceInferer)
REGISTER_OP_CPU_KERNEL(random_routing, PD_REGISTER_STRUCT_KERNEL(random_routing,
ops::RandomRoutingOpCPUKernel<float>, CPU,
ops::RandomRoutingOpCPUKernel<double>); ALL_LAYOUT,
ops::RandomRoutingOpCPUKernel,
float,
double) {}
...@@ -47,7 +47,7 @@ __global__ void random_routing_kernel(int64_t* data, ...@@ -47,7 +47,7 @@ __global__ void random_routing_kernel(int64_t* data,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class RandomRoutingOpCUDAKernel : public framework::OpKernel<T> { class RandomRoutingOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -82,7 +82,10 @@ class RandomRoutingOpCUDAKernel : public framework::OpKernel<T> { ...@@ -82,7 +82,10 @@ class RandomRoutingOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(random_routing, PD_REGISTER_STRUCT_KERNEL(random_routing,
ops::RandomRoutingOpCUDAKernel<float>, GPU,
ops::RandomRoutingOpCUDAKernel<double>, ALL_LAYOUT,
ops::RandomRoutingOpCUDAKernel<plat::float16>); ops::RandomRoutingOpCUDAKernel,
float,
double,
plat::float16) {}
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class RandomRoutingOpCPUKernel : public framework::OpKernel<T> { class RandomRoutingOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -192,9 +192,8 @@ REGISTER_OPERATOR(rank_attention_grad, ...@@ -192,9 +192,8 @@ REGISTER_OPERATOR(rank_attention_grad,
ops::RankAttentionGradOp, ops::RankAttentionGradOp,
ops::RankAttentionGradOpNoNeedBufferVarsInference); ops::RankAttentionGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(rank_attention, PD_REGISTER_STRUCT_KERNEL(
ops::RankAttentionKernel<phi::CPUContext, float>, rank_attention, CPU, ALL_LAYOUT, ops::RankAttentionKernel, float, double) {}
ops::RankAttentionKernel<phi::CPUContext, double>);
REGISTER_OP_VERSION(rank_attention) REGISTER_OP_VERSION(rank_attention)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankAttentionCUDAKernel : public framework::OpKernel<T> { class RankAttentionCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -150,7 +150,7 @@ class RankAttentionCUDAKernel : public framework::OpKernel<T> { ...@@ -150,7 +150,7 @@ class RankAttentionCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> { class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -243,10 +243,16 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -243,10 +243,16 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using GPUCtx = phi::GPUContext; using GPUCtx = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(rank_attention, PD_REGISTER_STRUCT_KERNEL(rank_attention,
ops::RankAttentionCUDAKernel<GPUCtx, float>, GPU,
ops::RankAttentionCUDAKernel<GPUCtx, double>); ALL_LAYOUT,
ops::RankAttentionCUDAKernel,
REGISTER_OP_CUDA_KERNEL(rank_attention_grad, float,
ops::RankAttentionGradOpCUDAKernel<GPUCtx, float>, double) {}
ops::RankAttentionGradOpCUDAKernel<GPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(rank_attention_grad,
GPU,
ALL_LAYOUT,
ops::RankAttentionGradOpCUDAKernel,
float,
double) {}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankAttentionKernel : public framework::OpKernel<T> { class RankAttentionKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CPUReadFileKernel : public framework::OpKernel<T> { class CPUReadFileKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -92,4 +92,5 @@ REGISTER_OPERATOR( ...@@ -92,4 +92,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>) paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(read_file, ops::CPUReadFileKernel<uint8_t>) PD_REGISTER_STRUCT_KERNEL(
read_file, CPU, ALL_LAYOUT, ops::CPUReadFileKernel, uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册