未验证 提交 f6f48780 编写于 作者: H huangjiyi 提交者: GitHub

Register fluid xpu kerenls to phi [part 1] (#53187)

* update

* fix bug

* Revert "affine_channel_op"
上级 c1a61fc0
...@@ -470,6 +470,7 @@ function(op_library TARGET) ...@@ -470,6 +470,7 @@ function(op_library TARGET)
foreach(xpu_src ${xpu_cc_srcs}) foreach(xpu_src ${xpu_cc_srcs})
set(op_name "") set(op_name "")
find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name) find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name)
find_phi_register(${xpu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
set(pybind_flag 1) set(pybind_flag 1)
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> { class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -111,13 +111,15 @@ class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> { ...@@ -111,13 +111,15 @@ class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( namespace plat = paddle::platform;
beam_search_decode,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, float>, PD_REGISTER_STRUCT_KERNEL(beam_search_decode,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, double>, XPU,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, ALL_LAYOUT,
paddle::platform::float16>, ops::BeamSearchDecodeXPUKernel,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, int>, float,
ops::BeamSearchDecodeXPUKernel<paddle::platform::XPUDeviceContext, double,
int64_t>); plat::float16,
int,
int64_t) {}
#endif #endif
...@@ -18,11 +18,12 @@ limitations under the License. */ ...@@ -18,11 +18,12 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h" #include "paddle/fluid/operators/beam_search_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext; PD_REGISTER_STRUCT_KERNEL(beam_search,
XPU,
REGISTER_OP_XPU_KERNEL(beam_search, ALL_LAYOUT,
ops::BeamSearchOpKernel<float, XPUCtx>, ops::BeamSearchOpKernel,
ops::BeamSearchOpKernel<double, XPUCtx>, float,
ops::BeamSearchOpKernel<int, XPUCtx>, double,
ops::BeamSearchOpKernel<int64_t, XPUCtx>); int,
int64_t) {}
#endif #endif
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CAllGatherOpXPUKernel : public framework::OpKernel<T> { class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -80,9 +80,12 @@ class CAllGatherOpXPUKernel : public framework::OpKernel<T> { ...@@ -80,9 +80,12 @@ class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allgather, PD_REGISTER_STRUCT_KERNEL(c_allgather,
ops::CAllGatherOpXPUKernel<float>, XPU,
ops::CAllGatherOpXPUKernel<double>, ALL_LAYOUT,
ops::CAllGatherOpXPUKernel<int>, ops::CAllGatherOpXPUKernel,
ops::CAllGatherOpXPUKernel<int64_t>, float,
ops::CAllGatherOpXPUKernel<plat::float16>); double,
plat::float16,
int,
int64_t) {}
...@@ -14,10 +14,18 @@ limitations under the License. */ ...@@ -14,10 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceMax, kRedMax)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
REGISTER_OP_XPU_KERNEL(c_allreduce_max, XPU,
ops::CAllReduceOpXPUKernel<ops::kRedMax, plat::float16>, ALL_LAYOUT,
ops::CAllReduceOpXPUKernel<ops::kRedMax, int>, ops::CAllReduceMaxXPUKernel,
ops::CAllReduceOpXPUKernel<ops::kRedMax, float>) float,
int,
plat::float16) {}
...@@ -14,8 +14,18 @@ limitations under the License. */ ...@@ -14,8 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceMin, kRedMin)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_allreduce_min,
REGISTER_OP_XPU_KERNEL(c_allreduce_min, XPU,
ops::CAllReduceOpXPUKernel<ops::kRedMin, float>) ALL_LAYOUT,
ops::CAllReduceMinXPUKernel,
float,
int,
plat::float16) {}
...@@ -223,6 +223,10 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> { ...@@ -223,6 +223,10 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
} }
}; };
#define DEFINE_C_ALLREDUCE_XPU_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##XPUKernel : public CAllReduceOpXPUKernel<red_type, T> {};
template <ReduceType red_type, typename T> template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -14,8 +14,18 @@ limitations under the License. */ ...@@ -14,8 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceProd, kRedProd)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_allreduce_prod,
REGISTER_OP_XPU_KERNEL(c_allreduce_prod, XPU,
ops::CAllReduceOpXPUKernel<ops::kRedProd, float>) ALL_LAYOUT,
ops::CAllReduceProdXPUKernel,
float,
int,
plat::float16) {}
...@@ -14,10 +14,18 @@ limitations under the License. */ ...@@ -14,10 +14,18 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
REGISTER_OP_XPU_KERNEL(c_allreduce_sum, XPU,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>, ALL_LAYOUT,
ops::CAllReduceOpXPUKernel<ops::kRedSum, plat::float16>, ops::CAllReduceSumXPUKernel,
ops::CAllReduceOpXPUKernel<ops::kRedSum, int>) float,
int,
plat::float16) {}
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CBroadcastOpXPUKernel : public framework::OpKernel<T> { class CBroadcastOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -113,6 +113,9 @@ class CBroadcastOpXPUKernel : public framework::OpKernel<T> { ...@@ -113,6 +113,9 @@ class CBroadcastOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_broadcast, PD_REGISTER_STRUCT_KERNEL(c_broadcast,
ops::CBroadcastOpXPUKernel<float>, XPU,
ops::CBroadcastOpXPUKernel<plat::float16>); ALL_LAYOUT,
ops::CBroadcastOpXPUKernel,
float,
plat::float16) {}
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CConcatOpXPUKernel : public framework::OpKernel<T> { class CConcatOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -118,6 +118,5 @@ class CConcatOpXPUKernel : public framework::OpKernel<T> { ...@@ -118,6 +118,5 @@ class CConcatOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_concat, PD_REGISTER_STRUCT_KERNEL(
ops::CConcatOpXPUKernel<float>, c_concat, XPU, ALL_LAYOUT, ops::CConcatOpXPUKernel, float, plat::float16) {}
ops::CConcatOpXPUKernel<plat::float16>);
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CEmbeddingOpXPUKernel : public framework::OpKernel<T> { class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -68,7 +68,7 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,7 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> { class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -140,9 +140,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> { ...@@ -140,9 +140,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(
c_embedding, c_embedding, XPU, ALL_LAYOUT, ops::CEmbeddingOpXPUKernel, float) {}
ops::CEmbeddingOpXPUKernel<paddle::platform::XPUDeviceContext, float>); PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_XPU_KERNEL( c_embedding_grad, XPU, ALL_LAYOUT, ops::CEmbeddingGradOpXPUKernel, float) {}
c_embedding_grad,
ops::CEmbeddingGradOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
...@@ -14,9 +14,12 @@ limitations under the License. */ ...@@ -14,9 +14,12 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_identity, PD_REGISTER_STRUCT_KERNEL(c_identity,
ops::CIdentityOpKernel<float, plat::XPUPlace>, XPU,
ops::CIdentityOpKernel<double, plat::XPUPlace>, ALL_LAYOUT,
ops::CIdentityOpKernel<int, plat::XPUPlace>, ops::CIdentityOpKernel,
ops::CIdentityOpKernel<int64_t, plat::XPUPlace>, float,
ops::CIdentityOpKernel<plat::float16, plat::XPUPlace>); double,
int,
int64_t,
plat::float16) {}
...@@ -14,8 +14,14 @@ limitations under the License. */ ...@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h" #include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceMax, kRedMax);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_max, PD_REGISTER_STRUCT_KERNEL(
ops::CReduceOpXPUKernel<ops::kRedMax, float>) c_reduce_max, XPU, ALL_LAYOUT, ops::CReduceMaxXPUKernel, float) {}
...@@ -14,8 +14,14 @@ limitations under the License. */ ...@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h" #include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceMin, kRedMin);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_min, PD_REGISTER_STRUCT_KERNEL(
ops::CReduceOpXPUKernel<ops::kRedMin, float>) c_reduce_min, XPU, ALL_LAYOUT, ops::CReduceMinXPUKernel, float) {}
...@@ -198,6 +198,10 @@ class CReduceOpXPUKernel : public framework::OpKernel<T> { ...@@ -198,6 +198,10 @@ class CReduceOpXPUKernel : public framework::OpKernel<T> {
} }
}; };
#define DEFINE_C_REDUCE_XPU_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##XPUKernel : public CReduceOpXPUKernel<red_type, T> {};
template <ReduceType red_type, typename T> template <ReduceType red_type, typename T>
class CReduceOpCUDAKernel : public framework::OpKernel<T> { class CReduceOpCUDAKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -14,8 +14,14 @@ limitations under the License. */ ...@@ -14,8 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h" #include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceProd, kRedProd);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_prod, PD_REGISTER_STRUCT_KERNEL(
ops::CReduceOpXPUKernel<ops::kRedProd, float>) c_reduce_prod, XPU, ALL_LAYOUT, ops::CReduceProdXPUKernel, float) {}
...@@ -14,9 +14,14 @@ limitations under the License. */ ...@@ -14,9 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h" #include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_XPU_KERNEL(CReduceSum, kRedSum);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_reduce_sum, PD_REGISTER_STRUCT_KERNEL(
ops::CReduceOpXPUKernel<ops::kRedSum, plat::float16>, c_reduce_sum, XPU, ALL_LAYOUT, ops::CReduceSumXPUKernel, float) {}
ops::CReduceOpXPUKernel<ops::kRedSum, float>)
...@@ -14,10 +14,18 @@ ...@@ -14,10 +14,18 @@
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_XPU_KERNEL(CAllReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
REGISTER_OP_XPU_KERNEL(mp_allreduce_sum, XPU,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>, ALL_LAYOUT,
ops::CAllReduceOpXPUKernel<ops::kRedSum, plat::float16>, ops::CAllReduceSumXPUKernel,
ops::CAllReduceOpXPUKernel<ops::kRedSum, int>) float,
int,
plat::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册