未验证 提交 37ca3b4c 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 6.4] (#52881)

* update

* revert lookup_table_op
上级 4d5a3ad6
......@@ -521,15 +521,15 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
load_combine,
device_type,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, float>,
LoadCombineOpKernel<float, paddle::platform::CustomDeviceContext>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, double>,
LoadCombineOpKernel<double, paddle::platform::CustomDeviceContext>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int>,
LoadCombineOpKernel<int, paddle::platform::CustomDeviceContext>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>,
LoadCombineOpKernel<int8_t, paddle::platform::CustomDeviceContext>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
LoadCombineOpKernel<int64_t, paddle::platform::CustomDeviceContext>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_concat,
device_type,
......
......@@ -77,16 +77,19 @@ that were saved using the SaveCombine operator.
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(load_combine,
ops::LoadCombineOp,
ops::LoadCombineOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load_combine,
ops::LoadCombineOpKernel<phi::CPUContext, float>,
ops::LoadCombineOpKernel<phi::CPUContext, double>,
ops::LoadCombineOpKernel<phi::CPUContext, paddle::platform::bfloat16>,
ops::LoadCombineOpKernel<phi::CPUContext, int>,
ops::LoadCombineOpKernel<phi::CPUContext, int8_t>,
ops::LoadCombineOpKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(load_combine,
CPU,
ALL_LAYOUT,
ops::LoadCombineOpKernel,
float,
double,
plat::bfloat16,
int,
int8_t,
int64_t) {}
......@@ -15,10 +15,12 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(load_combine,
ops::LoadCombineOpKernel<phi::GPUContext, float>,
ops::LoadCombineOpKernel<phi::GPUContext, double>,
ops::LoadCombineOpKernel<phi::GPUContext, int>,
ops::LoadCombineOpKernel<phi::GPUContext, int8_t>,
ops::LoadCombineOpKernel<phi::GPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(load_combine,
GPU,
ALL_LAYOUT,
ops::LoadCombineOpKernel,
float,
double,
int,
int8_t,
int64_t) {}
......@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class LoadCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......
......@@ -15,11 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(load_combine,
ops::LoadCombineOpKernel<float, XPUCtx>,
ops::LoadCombineOpKernel<double, XPUCtx>,
ops::LoadCombineOpKernel<int, XPUCtx>,
ops::LoadCombineOpKernel<int8_t, XPUCtx>,
ops::LoadCombineOpKernel<int64_t, XPUCtx>);
......@@ -235,6 +235,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInferer, "X");
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(lod_reset,
ops::LoDResetOp,
ops::LoDResetOpMaker,
......@@ -247,30 +248,32 @@ REGISTER_OPERATOR(lod_reset_grad,
ops::LoDResetGradNoNeedBufferVarInferer,
ops::LoDResetGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::CPUPlace, paddle::platform::float16>,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
PD_REGISTER_STRUCT_KERNEL(lod_reset,
CPU,
ALL_LAYOUT,
ops::LoDResetKernel,
plat::float16,
float,
double,
int,
int64_t) {}
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int64_t>);
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::float16, XPUCtx>,
ops::LoDResetKernel<float, XPUCtx>,
ops::LoDResetKernel<double, XPUCtx>,
ops::LoDResetKernel<int, XPUCtx>,
ops::LoDResetKernel<int64_t, XPUCtx>);
#endif
REGISTER_OP_CPU_KERNEL(
lod_reset_grad,
ops::LoDResetGradKernel<paddle::platform::CPUPlace,
paddle::platform::float16>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>);
PD_REGISTER_STRUCT_KERNEL(lod_reset_grad,
CPU,
ALL_LAYOUT,
ops::LoDResetGradKernel,
plat::float16,
float,
double,
int,
int64_t) {}
......@@ -16,13 +16,19 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(lod_reset,
ops::LoDResetKernel<phi::GPUContext, float>,
ops::LoDResetKernel<phi::GPUContext, double>,
ops::LoDResetKernel<phi::GPUContext, int>,
ops::LoDResetKernel<phi::GPUContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(lod_reset_grad,
ops::LoDResetGradKernel<phi::GPUContext, float>,
ops::LoDResetGradKernel<phi::GPUContext, double>,
ops::LoDResetGradKernel<phi::GPUContext, int>,
ops::LoDResetGradKernel<phi::GPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(lod_reset,
GPU,
ALL_LAYOUT,
ops::LoDResetKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_STRUCT_KERNEL(lod_reset_grad,
GPU,
ALL_LAYOUT,
ops::LoDResetGradKernel,
float,
double,
int,
int64_t) {}
......@@ -35,7 +35,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class LoDResetKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
......@@ -123,7 +123,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class LoDResetGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
......
......@@ -133,5 +133,9 @@ REGISTER_OPERATOR(
ops::LookupTableDequantOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(lookup_table_dequant,
ops::LookupTableDequantKernel<float>);
PD_REGISTER_STRUCT_KERNEL(lookup_table_dequant,
CPU,
ALL_LAYOUT,
ops::LookupTableDequantKernel,
float) {}
......@@ -46,7 +46,7 @@ void dequant(const unsigned char *in,
constexpr int64_t kNoPadding = -1;
template <typename T>
template <typename T, typename DeviceContext>
class LookupTableDequantKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......
......@@ -54,7 +54,7 @@ class MarkerOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class MarkerOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -75,4 +75,5 @@ class MarkerOpCPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(marker, ops::MarkerOp, ops::MarkerOpMaker);
REGISTER_OP_CPU_KERNEL(marker, ops::MarkerOpCPUKernel<float>);
PD_REGISTER_STRUCT_KERNEL(
marker, CPU, ALL_LAYOUT, ops::MarkerOpCPUKernel, float) {}
......@@ -29,7 +29,7 @@ __global__ void SimpleMarkerKernel(T* in, T* out, int ndim) {
}
}
template <typename T>
template <typename T, typename DeviceContext>
class MarkerOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -61,4 +61,5 @@ class MarkerOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(marker, ops::MarkerOpCUDAKernel<float>);
PD_REGISTER_STRUCT_KERNEL(
marker, GPU, ALL_LAYOUT, ops::MarkerOpCUDAKernel, float) {}
......@@ -23,8 +23,9 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(save_combine);
USE_OP_ITSELF(load_combine);
PD_DECLARE_KERNEL(save_combine_tensor, CPU, ALL_LAYOUT);
USE_CPU_ONLY_OP(load_combine);
PD_DECLARE_KERNEL(load_combine, CPU, ALL_LAYOUT);
template <typename T, typename U>
T* CreateForSaveCombineOp(int x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册