未验证 提交 f6f48780 编写于 作者: H huangjiyi 提交者: GitHub

Register fluid xpu kerenls to phi [part 1] (#53187)

* update

* fix bug

* Revert "affine_channel_op"
上级 c1a61fc0
......@@ -470,6 +470,7 @@ function(op_library TARGET)
foreach(xpu_src ${xpu_cc_srcs})
set(op_name "")
find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name)
find_phi_register(${xpu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
set(pybind_flag 1)
......
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -111,13 +111,15 @@ class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
beam_search_decode,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, double>,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext,
int64_t>);
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(beam_search_decode,
XPU,
ALL_LAYOUT,
ops::BeamSearchDecodeXPUKernel,
float,
double,
plat::float16,
int,
int64_t) {}
#endif
......@@ -18,11 +18,12 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<float, XPUCtx>,
ops::BeamSearchOpKernel<double, XPUCtx>,
ops::BeamSearchOpKernel<int, XPUCtx>,
ops::BeamSearchOpKernel<int64_t, XPUCtx>);
PD_REGISTER_STRUCT_KERNEL(beam_search,
XPU,
ALL_LAYOUT,
ops::BeamSearchOpKernel,
float,
double,
int,
int64_t) {}
#endif
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -80,9 +80,12 @@ class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allgather,
ops::CAllGatherOpXPUKernel<float>,
ops::CAllGatherOpXPUKernel<double>,
ops::CAllGatherOpXPUKernel<int>,
ops::CAllGatherOpXPUKernel<int64_t>,
ops::CAllGatherOpXPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_allgather,
XPU,
ALL_LAYOUT,
ops::CAllGatherOpXPUKernel,
float,
double,
plat::float16,
int,
int64_t) {}
......@@ -14,10 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceMax, kRedMax)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpXPUKernel<ops::kRedMax, plat::float16>,
ops::CAllReduceOpXPUKernel<ops::kRedMax, int>,
ops::CAllReduceOpXPUKernel<ops::kRedMax, float>)
PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
XPU,
ALL_LAYOUT,
ops::CAllReduceMaxXPUKernel,
float,
int,
plat::float16) {}
......@@ -14,8 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceMin, kRedMin)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allreduce_min,
ops::CAllReduceOpXPUKernel<ops::kRedMin, float>)
PD_REGISTER_STRUCT_KERNEL(c_allreduce_min,
XPU,
ALL_LAYOUT,
ops::CAllReduceMinXPUKernel,
float,
int,
plat::float16) {}
......@@ -223,6 +223,10 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
}
};
#define DEFINE_C_ALLREDUCE_XPU_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##XPUKernel : public CAllReduceOpXPUKernel<red_type, T> {};
template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
public:
......
......@@ -14,8 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceProd, kRedProd)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allreduce_prod,
ops::CAllReduceOpXPUKernel<ops::kRedProd, float>)
PD_REGISTER_STRUCT_KERNEL(c_allreduce_prod,
XPU,
ALL_LAYOUT,
ops::CAllReduceProdXPUKernel,
float,
int,
plat::float16) {}
......@@ -14,10 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allreduce_sum,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, plat::float16>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, int>)
PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
XPU,
ALL_LAYOUT,
ops::CAllReduceSumXPUKernel,
float,
int,
plat::float16) {}
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CBroadcastOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -113,6 +113,9 @@ class CBroadcastOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_broadcast,
ops::CBroadcastOpXPUKernel<float>,
ops::CBroadcastOpXPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_broadcast,
XPU,
ALL_LAYOUT,
ops::CBroadcastOpXPUKernel,
float,
plat::float16) {}
......@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CConcatOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -118,6 +118,5 @@ class CConcatOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_concat,
ops::CConcatOpXPUKernel<float>,
ops::CConcatOpXPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(
c_concat, XPU, ALL_LAYOUT, ops::CConcatOpXPUKernel, float, plat::float16) {}
......@@ -15,7 +15,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -68,7 +68,7 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -140,9 +140,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
c_embedding,
ops::CEmbeddingOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
c_embedding_grad,
ops::CEmbeddingGradOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
PD_REGISTER_STRUCT_KERNEL(
c_embedding, XPU, ALL_LAYOUT, ops::CEmbeddingOpXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
c_embedding_grad, XPU, ALL_LAYOUT, ops::CEmbeddingGradOpXPUKernel, float) {}
......@@ -14,9 +14,12 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_identity,
ops::CIdentityOpKernel<float, plat::XPUPlace>,
ops::CIdentityOpKernel<double, plat::XPUPlace>,
ops::CIdentityOpKernel<int, plat::XPUPlace>,
ops::CIdentityOpKernel<int64_t, plat::XPUPlace>,
ops::CIdentityOpKernel<plat::float16, plat::XPUPlace>);
PD_REGISTER_STRUCT_KERNEL(c_identity,
XPU,
ALL_LAYOUT,
ops::CIdentityOpKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceMax, kRedMax);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_max,
ops::CReduceOpXPUKernel<ops::kRedMax, float>)
PD_REGISTER_STRUCT_KERNEL(
c_reduce_max, XPU, ALL_LAYOUT, ops::CReduceMaxXPUKernel, float) {}
......@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceMin, kRedMin);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_min,
ops::CReduceOpXPUKernel<ops::kRedMin, float>)
PD_REGISTER_STRUCT_KERNEL(
c_reduce_min, XPU, ALL_LAYOUT, ops::CReduceMinXPUKernel, float) {}
......@@ -198,6 +198,10 @@ class CReduceOpXPUKernel : public framework::OpKernel<T> {
}
};
#define DEFINE_C_REDUCE_XPU_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##XPUKernel : public CReduceOpXPUKernel<red_type, T> {};
template <ReduceType red_type, typename T>
class CReduceOpCUDAKernel : public framework::OpKernel<T> {
public:
......
......@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceProd, kRedProd);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_prod,
ops::CReduceOpXPUKernel<ops::kRedProd, float>)
PD_REGISTER_STRUCT_KERNEL(
c_reduce_prod, XPU, ALL_LAYOUT, ops::CReduceProdXPUKernel, float) {}
......@@ -14,9 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceSum, kRedSum);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_sum,
ops::CReduceOpXPUKernel<ops::kRedSum, plat::float16>,
ops::CReduceOpXPUKernel<ops::kRedSum, float>)
PD_REGISTER_STRUCT_KERNEL(
c_reduce_sum, XPU, ALL_LAYOUT, ops::CReduceSumXPUKernel, float) {}
......@@ -14,10 +14,18 @@
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(mp_allreduce_sum,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, plat::float16>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, int>)
PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
XPU,
ALL_LAYOUT,
ops::CAllReduceSumXPUKernel,
float,
int,
plat::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册