未验证 提交 eee9c788 编写于 作者: H huangjiyi 提交者: GitHub

Register fluid xpu kerenls to phi [part 2] (#53188)

* update

* fix bug
上级 166964b1
......@@ -29,7 +29,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -468,7 +468,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -535,9 +535,13 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOp<phi::XPUContext, float>);
REGISTER_OP_XPU_KERNEL(
c_softmax_with_cross_entropy_grad,
ops::CSoftmaxWithCrossEntropyGrad<phi::XPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyOp,
float) {}
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad,
XPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyGrad,
float) {}
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CSplitOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -87,7 +87,10 @@ class CSplitOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_split,
ops::CSplitOpXPUKernel<float>,
ops::CSplitOpXPUKernel<int>,
ops::CSplitOpXPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_split,
XPU,
ALL_LAYOUT,
ops::CSplitOpXPUKernel,
float,
int,
plat::float16) {}
......@@ -17,5 +17,5 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float, plat::XPUPlace>)
PD_REGISTER_STRUCT_KERNEL(
c_sync_calc_stream, XPU, ALL_LAYOUT, ops::CSyncCalcStreamKernel, float) {}
......@@ -17,5 +17,5 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamKernel<float, plat::XPUPlace>);
PD_REGISTER_STRUCT_KERNEL(
c_sync_comm_stream, XPU, ALL_LAYOUT, ops::CSyncCommStreamKernel, float) {}
......@@ -156,7 +156,12 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
#define DEFINE_BINARY_LOGICAL_OP_XPU_KERNEL(op_name, xpu_type) \
template <typename T, typename DeviceContext> \
class BinaryLogical##op_name##CPUKernel \
: public CReduceOpCPUKernel<xpu_type, T> {};
template <typename T, typename DeviceContext>
class UnaryLogicalOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -14,14 +14,23 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace paddle {
namespace operators {
DEFINE_BINARY_LOGICAL_OP_XPU_KERNEL(AND, XpuLogicalType::XPU_AND);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logical_and,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, bool>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int8_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int16_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int64_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, float>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, double>);
PD_REGISTER_STRUCT_KERNEL(logical_and,
XPU,
ALL_LAYOUT,
ops::BinaryLogicalANDCPUKernel,
bool,
int8_t,
int16_t,
int,
int64_t,
float,
double) {}
#endif
......@@ -15,12 +15,15 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(logicalnot,
ops::UnaryLogicalOpXPUKernel<bool>,
ops::UnaryLogicalOpXPUKernel<int8_t>,
ops::UnaryLogicalOpXPUKernel<int16_t>,
ops::UnaryLogicalOpXPUKernel<int>,
ops::UnaryLogicalOpXPUKernel<int64_t>,
ops::UnaryLogicalOpXPUKernel<float>,
ops::UnaryLogicalOpXPUKernel<double>);
PD_REGISTER_STRUCT_KERNEL(logicalnot,
XPU,
ALL_LAYOUT,
ops::UnaryLogicalOpXPUKernel,
bool,
int8_t,
int16_t,
int,
int64_t,
float,
double) {}
#endif
......@@ -15,14 +15,22 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace paddle {
namespace operators {
DEFINE_BINARY_LOGICAL_OP_XPU_KERNEL(OR, XpuLogicalType::XPU_OR);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logical_or,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, bool>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int8_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int16_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int64_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, float>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, double>);
PD_REGISTER_STRUCT_KERNEL(logical_or,
XPU,
ALL_LAYOUT,
ops::BinaryLogicalORCPUKernel,
bool,
int8_t,
int16_t,
int,
int64_t,
float,
double) {}
#endif
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class XPUIOUSimilarityKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -57,6 +57,7 @@ class XPUIOUSimilarityKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using XPU = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(iou_similarity, ops::XPUIOUSimilarityKernel<XPU, float>);
PD_REGISTER_STRUCT_KERNEL(
iou_similarity, XPU, ALL_LAYOUT, ops::XPUIOUSimilarityKernel, float) {}
#endif
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......@@ -102,7 +102,7 @@ class FusedGemmEpilogueXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......@@ -227,15 +227,17 @@ class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
fused_gemm_epilogue,
ops::FusedGemmEpilogueXPUKernel<phi::XPUContext, float>,
ops::FusedGemmEpilogueXPUKernel<phi::XPUContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
fused_gemm_epilogue_grad,
ops::FusedGemmEpilogueXPUGradKernel<phi::XPUContext, float>,
ops::FusedGemmEpilogueXPUGradKernel<phi::XPUContext,
paddle::platform::float16>);
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue,
XPU,
ALL_LAYOUT,
ops::FusedGemmEpilogueXPUKernel,
float,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue_grad,
XPU,
ALL_LAYOUT,
ops::FusedGemmEpilogueXPUGradKernel,
float,
plat::float16) {}
......@@ -15,11 +15,12 @@ limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(load_combine,
ops::LoadCombineOpKernel<float, XPUCtx>,
ops::LoadCombineOpKernel<double, XPUCtx>,
ops::LoadCombineOpKernel<int, XPUCtx>,
ops::LoadCombineOpKernel<int8_t, XPUCtx>,
ops::LoadCombineOpKernel<int64_t, XPUCtx>);
PD_REGISTER_STRUCT_KERNEL(load_combine,
XPU,
ALL_LAYOUT,
ops::LoadCombineOpKernel,
float,
double,
int,
int8_t,
int64_t) {}
......@@ -259,13 +259,15 @@ PD_REGISTER_STRUCT_KERNEL(lod_reset,
int64_t) {}
#ifdef PADDLE_WITH_XPU
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::float16, XPUCtx>,
ops::LoDResetKernel<float, XPUCtx>,
ops::LoDResetKernel<double, XPUCtx>,
ops::LoDResetKernel<int, XPUCtx>,
ops::LoDResetKernel<int64_t, XPUCtx>);
PD_REGISTER_STRUCT_KERNEL(lod_reset,
XPU,
ALL_LAYOUT,
ops::LoDResetKernel,
plat::float16,
float,
double,
int,
int64_t) {}
#endif
PD_REGISTER_STRUCT_KERNEL(lod_reset_grad,
......
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T, typename AttrType = T>
template <typename T, typename DeviceContext, typename AttrType = T>
class LogLossXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -37,7 +37,7 @@ class LogLossXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss");
}
};
template <typename DeviceContext, typename T, typename AttrType = T>
template <typename T, typename DeviceContext, typename AttrType = T>
class LogLossGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -67,10 +67,9 @@ class LogLossGradXPUKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
log_loss, ops::LogLossXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
log_loss_grad,
ops::LogLossGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
PD_REGISTER_STRUCT_KERNEL(
log_loss, XPU, ALL_LAYOUT, ops::LogLossXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
log_loss_grad, XPU, ALL_LAYOUT, ops::LogLossGradXPUKernel, float) {}
#endif
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AccuracyXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -73,8 +73,6 @@ class AccuracyXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
accuracy,
ops::AccuracyXPUKernel<paddle::platform::XPUDeviceContext, float>);
PD_REGISTER_STRUCT_KERNEL(
accuracy, XPU, ALL_LAYOUT, ops::AccuracyXPUKernel, float) {}
#endif
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......@@ -115,7 +115,11 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(lars_momentum,
ops::LarsMomentumOpXPUKernel<paddle::platform::float16>,
ops::LarsMomentumOpXPUKernel<float>);
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(lars_momentum,
XPU,
ALL_LAYOUT,
ops::LarsMomentumOpXPUKernel,
float,
plat::float16) {}
#endif
......@@ -70,6 +70,10 @@ class XPULogsumexpKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
// This kernel can not be registered in phi, because op logsumexp should run
// phi::LogsumexpKernel rather than XPULogsumexpKernel here. And if register
// xpu logsumexp kernel in phi, op logsumexp will run XPULogsumexpKernel here
// and raise error.
REGISTER_OP_XPU_KERNEL(
logsumexp,
ops::XPULogsumexpKernel<paddle::platform::XPUDeviceContext, float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册