未验证 提交 37ca3b4c 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 6.4] (#52881)

* update

* revert lookup_table_op
上级 4d5a3ad6
...@@ -521,15 +521,15 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -521,15 +521,15 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
load_combine, load_combine,
device_type, device_type,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, float>, LoadCombineOpKernel<float, paddle::platform::CustomDeviceContext>,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, double>, LoadCombineOpKernel<double, paddle::platform::CustomDeviceContext>,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int>, LoadCombineOpKernel<int, paddle::platform::CustomDeviceContext>,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>, LoadCombineOpKernel<int8_t, paddle::platform::CustomDeviceContext>,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>); LoadCombineOpKernel<int64_t, paddle::platform::CustomDeviceContext>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL( REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_concat, c_concat,
device_type, device_type,
......
...@@ -77,16 +77,19 @@ that were saved using the SaveCombine operator. ...@@ -77,16 +77,19 @@ that were saved using the SaveCombine operator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(load_combine, REGISTER_OPERATOR(load_combine,
ops::LoadCombineOp, ops::LoadCombineOp,
ops::LoadCombineOpProtoMaker); ops::LoadCombineOpProtoMaker);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(load_combine,
load_combine, CPU,
ops::LoadCombineOpKernel<phi::CPUContext, float>, ALL_LAYOUT,
ops::LoadCombineOpKernel<phi::CPUContext, double>, ops::LoadCombineOpKernel,
ops::LoadCombineOpKernel<phi::CPUContext, paddle::platform::bfloat16>, float,
ops::LoadCombineOpKernel<phi::CPUContext, int>, double,
ops::LoadCombineOpKernel<phi::CPUContext, int8_t>, plat::bfloat16,
ops::LoadCombineOpKernel<phi::CPUContext, int64_t>); int,
int8_t,
int64_t) {}
...@@ -15,10 +15,12 @@ limitations under the License. */ ...@@ -15,10 +15,12 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(load_combine,
REGISTER_OP_CUDA_KERNEL(load_combine, GPU,
ops::LoadCombineOpKernel<phi::GPUContext, float>, ALL_LAYOUT,
ops::LoadCombineOpKernel<phi::GPUContext, double>, ops::LoadCombineOpKernel,
ops::LoadCombineOpKernel<phi::GPUContext, int>, float,
ops::LoadCombineOpKernel<phi::GPUContext, int8_t>, double,
ops::LoadCombineOpKernel<phi::GPUContext, int64_t>); int,
int8_t,
int64_t) {}
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class LoadCombineOpKernel : public framework::OpKernel<T> { class LoadCombineOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
......
...@@ -15,11 +15,11 @@ limitations under the License. */ ...@@ -15,11 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(load_combine,
load_combine, ops::LoadCombineOpKernel<float, XPUCtx>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, float>, ops::LoadCombineOpKernel<double, XPUCtx>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, double>, ops::LoadCombineOpKernel<int, XPUCtx>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, int>, ops::LoadCombineOpKernel<int8_t, XPUCtx>,
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, int8_t>, ops::LoadCombineOpKernel<int64_t, XPUCtx>);
ops::LoadCombineOpKernel<paddle::platform::XPUDeviceContext, int64_t>);
...@@ -235,6 +235,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInferer, "X"); ...@@ -235,6 +235,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(lod_reset, REGISTER_OPERATOR(lod_reset,
ops::LoDResetOp, ops::LoDResetOp,
ops::LoDResetOpMaker, ops::LoDResetOpMaker,
...@@ -247,30 +248,32 @@ REGISTER_OPERATOR(lod_reset_grad, ...@@ -247,30 +248,32 @@ REGISTER_OPERATOR(lod_reset_grad,
ops::LoDResetGradNoNeedBufferVarInferer, ops::LoDResetGradNoNeedBufferVarInferer,
ops::LoDResetGradInplaceInferer); ops::LoDResetGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(lod_reset,
lod_reset, CPU,
ops::LoDResetKernel<paddle::platform::CPUPlace, paddle::platform::float16>, ALL_LAYOUT,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>, ops::LoDResetKernel,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>, plat::float16,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>, float,
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>); double,
int,
int64_t) {}
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL( using XPUCtx = paddle::platform::XPUDeviceContext;
lod_reset, REGISTER_OP_XPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, ops::LoDResetKernel<paddle::platform::float16, XPUCtx>,
paddle::platform::float16>, ops::LoDResetKernel<float, XPUCtx>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>, ops::LoDResetKernel<double, XPUCtx>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>, ops::LoDResetKernel<int, XPUCtx>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>, ops::LoDResetKernel<int64_t, XPUCtx>);
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif #endif
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(lod_reset_grad,
lod_reset_grad, CPU,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, ALL_LAYOUT,
paddle::platform::float16>, ops::LoDResetGradKernel,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>, plat::float16,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>, float,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>, double,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>); int,
int64_t) {}
...@@ -16,13 +16,19 @@ limitations under the License. */ ...@@ -16,13 +16,19 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(lod_reset, PD_REGISTER_STRUCT_KERNEL(lod_reset,
ops::LoDResetKernel<phi::GPUContext, float>, GPU,
ops::LoDResetKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::LoDResetKernel<phi::GPUContext, int>, ops::LoDResetKernel,
ops::LoDResetKernel<phi::GPUContext, int64_t>); float,
REGISTER_OP_CUDA_KERNEL(lod_reset_grad, double,
ops::LoDResetGradKernel<phi::GPUContext, float>, int,
ops::LoDResetGradKernel<phi::GPUContext, double>, int64_t) {}
ops::LoDResetGradKernel<phi::GPUContext, int>, PD_REGISTER_STRUCT_KERNEL(lod_reset_grad,
ops::LoDResetGradKernel<phi::GPUContext, int64_t>); GPU,
ALL_LAYOUT,
ops::LoDResetGradKernel,
float,
double,
int,
int64_t) {}
...@@ -35,7 +35,7 @@ limitations under the License. */ ...@@ -35,7 +35,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class LoDResetKernel : public framework::OpKernel<T> { class LoDResetKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
...@@ -123,7 +123,7 @@ class LoDResetKernel : public framework::OpKernel<T> { ...@@ -123,7 +123,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class LoDResetGradKernel : public framework::OpKernel<T> { class LoDResetGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -133,5 +133,9 @@ REGISTER_OPERATOR( ...@@ -133,5 +133,9 @@ REGISTER_OPERATOR(
ops::LookupTableDequantOpMaker, ops::LookupTableDequantOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(lookup_table_dequant,
ops::LookupTableDequantKernel<float>); PD_REGISTER_STRUCT_KERNEL(lookup_table_dequant,
CPU,
ALL_LAYOUT,
ops::LookupTableDequantKernel,
float) {}
...@@ -46,7 +46,7 @@ void dequant(const unsigned char *in, ...@@ -46,7 +46,7 @@ void dequant(const unsigned char *in,
constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
template <typename T> template <typename T, typename DeviceContext>
class LookupTableDequantKernel : public framework::OpKernel<T> { class LookupTableDequantKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
...@@ -54,7 +54,7 @@ class MarkerOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,7 +54,7 @@ class MarkerOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class MarkerOpCPUKernel : public framework::OpKernel<T> { class MarkerOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -75,4 +75,5 @@ class MarkerOpCPUKernel : public framework::OpKernel<T> { ...@@ -75,4 +75,5 @@ class MarkerOpCPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(marker, ops::MarkerOp, ops::MarkerOpMaker); REGISTER_OP_WITHOUT_GRADIENT(marker, ops::MarkerOp, ops::MarkerOpMaker);
REGISTER_OP_CPU_KERNEL(marker, ops::MarkerOpCPUKernel<float>); PD_REGISTER_STRUCT_KERNEL(
marker, CPU, ALL_LAYOUT, ops::MarkerOpCPUKernel, float) {}
...@@ -29,7 +29,7 @@ __global__ void SimpleMarkerKernel(T* in, T* out, int ndim) { ...@@ -29,7 +29,7 @@ __global__ void SimpleMarkerKernel(T* in, T* out, int ndim) {
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class MarkerOpCUDAKernel : public framework::OpKernel<T> { class MarkerOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -61,4 +61,5 @@ class MarkerOpCUDAKernel : public framework::OpKernel<T> { ...@@ -61,4 +61,5 @@ class MarkerOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(marker, ops::MarkerOpCUDAKernel<float>); PD_REGISTER_STRUCT_KERNEL(
marker, GPU, ALL_LAYOUT, ops::MarkerOpCUDAKernel, float) {}
...@@ -23,8 +23,9 @@ limitations under the License. */ ...@@ -23,8 +23,9 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(save_combine); USE_OP_ITSELF(save_combine);
USE_OP_ITSELF(load_combine);
PD_DECLARE_KERNEL(save_combine_tensor, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(save_combine_tensor, CPU, ALL_LAYOUT);
USE_CPU_ONLY_OP(load_combine); PD_DECLARE_KERNEL(load_combine, CPU, ALL_LAYOUT);
template <typename T, typename U> template <typename T, typename U>
T* CreateForSaveCombineOp(int x, T* CreateForSaveCombineOp(int x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册