未验证 提交 09da1c4c 编写于 作者: Y YuanRisheng 提交者: GitHub

unify kernel (#52594)

上级 5662adcc
......@@ -70,4 +70,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PushDenseNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>)
PD_REGISTER_STRUCT_KERNEL(
push_dense, CPU, ALL_LAYOUT, ops::PushDenseCPUKernel, float) {}
......@@ -60,7 +60,7 @@ void PushDenseFunctor(const framework::ExecutionContext& ctx) {
#endif
}
template <typename T>
template <typename T, typename DeviceContext>
class PushDenseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -232,7 +232,7 @@ class PyramidHashOP : public framework::OperatorWithKernel {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
public:
bool should_use_term(math::bloomfilter* _filter,
......@@ -492,7 +492,7 @@ class PyramidHashGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
public:
void hash_embedding_bp(const T* hash_id,
......@@ -584,8 +584,11 @@ REGISTER_OPERATOR(pyramid_hash,
ops::PyramidHashGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad);
REGISTER_OP_CPU_KERNEL(pyramid_hash,
ops::CPUPyramidHashOPKernel<phi::CPUContext, float>,
ops::CPUPyramidHashOPKernel<phi::CPUContext, int8_t>);
REGISTER_OP_CPU_KERNEL(pyramid_hash_grad,
ops::CPUPyramidHashOPGradKernel<phi::CPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(
pyramid_hash, CPU, ALL_LAYOUT, ops::CPUPyramidHashOPKernel, float, int8_t) {
}
PD_REGISTER_STRUCT_KERNEL(pyramid_hash_grad,
CPU,
ALL_LAYOUT,
ops::CPUPyramidHashOPGradKernel,
float) {}
......@@ -222,7 +222,6 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
} // namespace paddle
namespace ops = paddle::operators;
using CPU = phi::CPUContext;
REGISTER_OPERATOR(
quantize_linear,
......@@ -231,7 +230,8 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>);
PD_REGISTER_STRUCT_KERNEL(
quantize_linear, CPU, ALL_LAYOUT, ops::QuantizeLinearKernel, float) {}
REGISTER_OPERATOR(
dequantize_linear,
......@@ -240,7 +240,10 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CPU, float>,
ops::DeQuantizeLinearKernel<CPU, int8_t>,
ops::DeQuantizeLinearKernel<CPU, double>);
PD_REGISTER_STRUCT_KERNEL(dequantize_linear,
CPU,
ALL_LAYOUT,
ops::DeQuantizeLinearKernel,
float,
int8_t,
double) {}
......@@ -123,12 +123,18 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CUDA, float>,
ops::DeQuantizeLinearKernel<CUDA, float16>,
ops::DeQuantizeLinearKernel<CUDA, int8_t>,
ops::DeQuantizeLinearKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(quantize_linear,
ops::QuantizeLinearKernel<CUDA, float>,
ops::QuantizeLinearKernel<CUDA, float16>);
PD_REGISTER_STRUCT_KERNEL(dequantize_linear,
GPU,
ALL_LAYOUT,
ops::DeQuantizeLinearKernel,
float,
float16,
int8_t,
double) {}
PD_REGISTER_STRUCT_KERNEL(quantize_linear,
GPU,
ALL_LAYOUT,
ops::QuantizeLinearKernel,
float,
float16) {}
......@@ -47,7 +47,7 @@ struct ChannelDequantizeFunctorV2 {
phi::DenseTensor* out);
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class QuantizeLinearKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -130,7 +130,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class DeQuantizeLinearKernel : public framework::OpKernel<T> {
public:
template <typename D>
......
......@@ -96,11 +96,12 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
template <typename T>
using Kernel = ops::RandomCropKernel<phi::CPUContext, T>;
REGISTER_OP_CPU_KERNEL(random_crop,
Kernel<float>,
Kernel<int>,
Kernel<double>,
Kernel<uint8_t>,
Kernel<int16_t>);
PD_REGISTER_STRUCT_KERNEL(random_crop,
CPU,
ALL_LAYOUT,
ops::RandomCropKernel,
float,
int,
double,
uint8_t,
int16_t) {}
......@@ -15,11 +15,13 @@
#include "paddle/fluid/operators/random_crop_op.h"
namespace ops = paddle::operators;
template <typename T>
using Kernel = ops::RandomCropKernel<phi::GPUContext, T>;
REGISTER_OP_CUDA_KERNEL(random_crop,
Kernel<float>,
Kernel<int>,
Kernel<double>,
Kernel<uint8_t>,
Kernel<int16_t>);
PD_REGISTER_STRUCT_KERNEL(random_crop,
GPU,
ALL_LAYOUT,
ops::RandomCropKernel,
float,
int,
double,
uint8_t,
int16_t) {}
......@@ -176,7 +176,7 @@ struct RandomCropFunctor {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class RandomCropKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
......
......@@ -98,6 +98,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::RandomRoutingInplaceInferer)
REGISTER_OP_CPU_KERNEL(random_routing,
ops::RandomRoutingOpCPUKernel<float>,
ops::RandomRoutingOpCPUKernel<double>);
PD_REGISTER_STRUCT_KERNEL(random_routing,
CPU,
ALL_LAYOUT,
ops::RandomRoutingOpCPUKernel,
float,
double) {}
......@@ -47,7 +47,7 @@ __global__ void random_routing_kernel(int64_t* data,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class RandomRoutingOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -82,7 +82,10 @@ class RandomRoutingOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(random_routing,
ops::RandomRoutingOpCUDAKernel<float>,
ops::RandomRoutingOpCUDAKernel<double>,
ops::RandomRoutingOpCUDAKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(random_routing,
GPU,
ALL_LAYOUT,
ops::RandomRoutingOpCUDAKernel,
float,
double,
plat::float16) {}
......@@ -24,7 +24,7 @@
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class RandomRoutingOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -192,9 +192,8 @@ REGISTER_OPERATOR(rank_attention_grad,
ops::RankAttentionGradOp,
ops::RankAttentionGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(rank_attention,
ops::RankAttentionKernel<phi::CPUContext, float>,
ops::RankAttentionKernel<phi::CPUContext, double>);
PD_REGISTER_STRUCT_KERNEL(
rank_attention, CPU, ALL_LAYOUT, ops::RankAttentionKernel, float, double) {}
REGISTER_OP_VERSION(rank_attention)
.AddCheckpoint(
......
......@@ -24,7 +24,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class RankAttentionCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -150,7 +150,7 @@ class RankAttentionCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -243,10 +243,16 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using GPUCtx = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(rank_attention,
ops::RankAttentionCUDAKernel<GPUCtx, float>,
ops::RankAttentionCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(rank_attention_grad,
ops::RankAttentionGradOpCUDAKernel<GPUCtx, float>,
ops::RankAttentionGradOpCUDAKernel<GPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(rank_attention,
GPU,
ALL_LAYOUT,
ops::RankAttentionCUDAKernel,
float,
double) {}
PD_REGISTER_STRUCT_KERNEL(rank_attention_grad,
GPU,
ALL_LAYOUT,
ops::RankAttentionGradOpCUDAKernel,
float,
double) {}
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class RankAttentionKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -24,7 +24,7 @@
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CPUReadFileKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -92,4 +92,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(read_file, ops::CPUReadFileKernel<uint8_t>)
PD_REGISTER_STRUCT_KERNEL(
read_file, CPU, ALL_LAYOUT, ops::CPUReadFileKernel, uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册