未验证 提交 eee9c788 编写于 作者: H huangjiyi 提交者: GitHub

Register fluid xpu kerenls to phi [part 2] (#53188)

* update

* fix bug
上级 166964b1
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> { class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -468,7 +468,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> { ...@@ -468,7 +468,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> { class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -535,9 +535,13 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> { ...@@ -535,9 +535,13 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_softmax_with_cross_entropy, PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOp<phi::XPUContext, float>); XPU,
ALL_LAYOUT,
REGISTER_OP_XPU_KERNEL( ops::CSoftmaxWithCrossEntropyOp,
c_softmax_with_cross_entropy_grad, float) {}
ops::CSoftmaxWithCrossEntropyGrad<phi::XPUContext, float>); PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyGrad,
float) {}
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CSplitOpXPUKernel : public framework::OpKernel<T> { class CSplitOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -87,7 +87,10 @@ class CSplitOpXPUKernel : public framework::OpKernel<T> { ...@@ -87,7 +87,10 @@ class CSplitOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_split, PD_REGISTER_STRUCT_KERNEL(c_split,
ops::CSplitOpXPUKernel<float>, XPU,
ops::CSplitOpXPUKernel<int>, ALL_LAYOUT,
ops::CSplitOpXPUKernel<plat::float16>); ops::CSplitOpXPUKernel,
float,
int,
plat::float16) {}
...@@ -17,5 +17,5 @@ limitations under the License. */ ...@@ -17,5 +17,5 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_calc_stream, PD_REGISTER_STRUCT_KERNEL(
ops::CSyncCalcStreamKernel<float, plat::XPUPlace>) c_sync_calc_stream, XPU, ALL_LAYOUT, ops::CSyncCalcStreamKernel, float) {}
...@@ -17,5 +17,5 @@ limitations under the License. */ ...@@ -17,5 +17,5 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_comm_stream, PD_REGISTER_STRUCT_KERNEL(
ops::CSyncCommStreamKernel<float, plat::XPUPlace>); c_sync_comm_stream, XPU, ALL_LAYOUT, ops::CSyncCommStreamKernel, float) {}
...@@ -156,7 +156,12 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> { ...@@ -156,7 +156,12 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> #define DEFINE_BINARY_LOGICAL_OP_XPU_KERNEL(op_name, xpu_type) \
template <typename T, typename DeviceContext> \
class BinaryLogical##op_name##CPUKernel \
: public CReduceOpCPUKernel<xpu_type, T> {};
template <typename T, typename DeviceContext>
class UnaryLogicalOpXPUKernel : public framework::OpKernel<T> { class UnaryLogicalOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -14,14 +14,23 @@ limitations under the License. */ ...@@ -14,14 +14,23 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h" #include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace paddle {
namespace operators {
DEFINE_BINARY_LOGICAL_OP_XPU_KERNEL(AND, XpuLogicalType::XPU_AND);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(logical_and,
logical_and, XPU,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, bool>, ALL_LAYOUT,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int8_t>, ops::BinaryLogicalANDCPUKernel,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int16_t>, bool,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int>, int8_t,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int64_t>, int16_t,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, float>, int,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, double>); int64_t,
float,
double) {}
#endif #endif
...@@ -15,12 +15,15 @@ limitations under the License. */ ...@@ -15,12 +15,15 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h" #include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(logicalnot, PD_REGISTER_STRUCT_KERNEL(logicalnot,
ops::UnaryLogicalOpXPUKernel<bool>, XPU,
ops::UnaryLogicalOpXPUKernel<int8_t>, ALL_LAYOUT,
ops::UnaryLogicalOpXPUKernel<int16_t>, ops::UnaryLogicalOpXPUKernel,
ops::UnaryLogicalOpXPUKernel<int>, bool,
ops::UnaryLogicalOpXPUKernel<int64_t>, int8_t,
ops::UnaryLogicalOpXPUKernel<float>, int16_t,
ops::UnaryLogicalOpXPUKernel<double>); int,
int64_t,
float,
double) {}
#endif #endif
...@@ -15,14 +15,22 @@ limitations under the License. */ ...@@ -15,14 +15,22 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h" #include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace paddle {
namespace operators {
DEFINE_BINARY_LOGICAL_OP_XPU_KERNEL(OR, XpuLogicalType::XPU_OR);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(logical_or,
logical_or, XPU,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, bool>, ALL_LAYOUT,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int8_t>, ops::BinaryLogicalORCPUKernel,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int16_t>, bool,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int>, int8_t,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int64_t>, int16_t,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, float>, int,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, double>); int64_t,
float,
double) {}
#endif #endif
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class XPUIOUSimilarityKernel : public framework::OpKernel<T> { class XPUIOUSimilarityKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -57,6 +57,7 @@ class XPUIOUSimilarityKernel : public framework::OpKernel<T> { ...@@ -57,6 +57,7 @@ class XPUIOUSimilarityKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPU = paddle::platform::XPUDeviceContext; using XPU = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(iou_similarity, ops::XPUIOUSimilarityKernel<XPU, float>); PD_REGISTER_STRUCT_KERNEL(
iou_similarity, XPU, ALL_LAYOUT, ops::XPUIOUSimilarityKernel, float) {}
#endif #endif
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> { class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -102,7 +102,7 @@ class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> { ...@@ -102,7 +102,7 @@ class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> { class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -227,15 +227,17 @@ class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> { ...@@ -227,15 +227,17 @@ class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
fused_gemm_epilogue, PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue,
ops::FusedGemmEpilogueXPUKernel<phi::XPUContext, float>, XPU,
ops::FusedGemmEpilogueXPUKernel<phi::XPUContext, ALL_LAYOUT,
paddle::platform::float16>); ops::FusedGemmEpilogueXPUKernel,
float,
REGISTER_OP_XPU_KERNEL( plat::float16) {}
fused_gemm_epilogue_grad, PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue_grad,
ops::FusedGemmEpilogueXPUGradKernel<phi::XPUContext, float>, XPU,
ops::FusedGemmEpilogueXPUGradKernel<phi::XPUContext, ALL_LAYOUT,
paddle::platform::float16>); ops::FusedGemmEpilogueXPUGradKernel,
float,
plat::float16) {}
...@@ -15,11 +15,12 @@ limitations under the License. */ ...@@ -15,11 +15,12 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext; PD_REGISTER_STRUCT_KERNEL(load_combine,
XPU,
REGISTER_OP_XPU_KERNEL(load_combine, ALL_LAYOUT,
ops::LoadCombineOpKernel<float, XPUCtx>, ops::LoadCombineOpKernel,
ops::LoadCombineOpKernel<double, XPUCtx>, float,
ops::LoadCombineOpKernel<int, XPUCtx>, double,
ops::LoadCombineOpKernel<int8_t, XPUCtx>, int,
ops::LoadCombineOpKernel<int64_t, XPUCtx>); int8_t,
int64_t) {}
...@@ -259,13 +259,15 @@ PD_REGISTER_STRUCT_KERNEL(lod_reset, ...@@ -259,13 +259,15 @@ PD_REGISTER_STRUCT_KERNEL(lod_reset,
int64_t) {} int64_t) {}
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
using XPUCtx = paddle::platform::XPUDeviceContext; PD_REGISTER_STRUCT_KERNEL(lod_reset,
REGISTER_OP_XPU_KERNEL(lod_reset, XPU,
ops::LoDResetKernel<paddle::platform::float16, XPUCtx>, ALL_LAYOUT,
ops::LoDResetKernel<float, XPUCtx>, ops::LoDResetKernel,
ops::LoDResetKernel<double, XPUCtx>, plat::float16,
ops::LoDResetKernel<int, XPUCtx>, float,
ops::LoDResetKernel<int64_t, XPUCtx>); double,
int,
int64_t) {}
#endif #endif
PD_REGISTER_STRUCT_KERNEL(lod_reset_grad, PD_REGISTER_STRUCT_KERNEL(lod_reset_grad,
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T, typename AttrType = T> template <typename T, typename DeviceContext, typename AttrType = T>
class LogLossXPUKernel : public framework::OpKernel<T> { class LogLossXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -37,7 +37,7 @@ class LogLossXPUKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class LogLossXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss");
} }
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename T, typename DeviceContext, typename AttrType = T>
class LogLossGradXPUKernel : public framework::OpKernel<T> { class LogLossGradXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -67,10 +67,9 @@ class LogLossGradXPUKernel : public framework::OpKernel<T> { ...@@ -67,10 +67,9 @@ class LogLossGradXPUKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
log_loss, ops::LogLossXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
log_loss_grad,
ops::LogLossGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
PD_REGISTER_STRUCT_KERNEL(
log_loss, XPU, ALL_LAYOUT, ops::LogLossXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
log_loss_grad, XPU, ALL_LAYOUT, ops::LogLossGradXPUKernel, float) {}
#endif #endif
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class AccuracyXPUKernel : public framework::OpKernel<T> { class AccuracyXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -73,8 +73,6 @@ class AccuracyXPUKernel : public framework::OpKernel<T> { ...@@ -73,8 +73,6 @@ class AccuracyXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(
accuracy, accuracy, XPU, ALL_LAYOUT, ops::AccuracyXPUKernel, float) {}
ops::AccuracyXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -115,7 +115,11 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -115,7 +115,11 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(lars_momentum, namespace plat = paddle::platform;
ops::LarsMomentumOpXPUKernel<paddle::platform::float16>, PD_REGISTER_STRUCT_KERNEL(lars_momentum,
ops::LarsMomentumOpXPUKernel<float>); XPU,
ALL_LAYOUT,
ops::LarsMomentumOpXPUKernel,
float,
plat::float16) {}
#endif #endif
...@@ -70,6 +70,10 @@ class XPULogsumexpKernel : public framework::OpKernel<T> { ...@@ -70,6 +70,10 @@ class XPULogsumexpKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
// This kernel can not be registered in phi, because op logsumexp should run
// phi::LogsumexpKernel rather than XPULogsumexpKernel here. And if register
// xpu logsumexp kernel in phi, op logsumexp will run XPULogsumexpKernel here
// and raise error.
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
logsumexp, logsumexp,
ops::XPULogsumexpKernel<paddle::platform::XPUDeviceContext, float>); ops::XPULogsumexpKernel<paddle::platform::XPUDeviceContext, float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册