未验证 提交 37489df5 编写于 作者: H huangjiyi 提交者: GitHub

Register fluid xpu kerenls to phi [part 3] (#53189)

* update

* update
上级 af986bd5
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AffineChannelXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -85,7 +85,7 @@ class AffineChannelXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -189,10 +189,12 @@ class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
using XPU = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(affine_channel, ops::AffineChannelXPUKernel<XPU, float>);
REGISTER_OP_XPU_KERNEL(affine_channel_grad,
ops::AffineChannelGradXPUKernel<XPU, float>);
PD_REGISTER_STRUCT_KERNEL(
affine_channel, XPU, ALL_LAYOUT, ops::AffineChannelXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(affine_channel_grad,
XPU,
ALL_LAYOUT,
ops::AffineChannelGradXPUKernel,
float) {}
#endif
......@@ -293,7 +293,7 @@ static inline void xpu_conv2d_grad(xpu::Context* ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
}
template <typename T>
template <typename T, typename DeviceContext>
class ResNetBasicBlockXPUKernel : public framework::OpKernel<T> {
public:
using XPUT = typename XPUTypeTrait<T>::Type;
......@@ -696,7 +696,7 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> {
public:
using XPUT = typename XPUTypeTrait<T>::Type;
......@@ -992,8 +992,14 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_basic_block,
ops::ResNetBasicBlockXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(resnet_basic_block_grad,
ops::ResNetBasicBlockGradXPUKernel<float>);
PD_REGISTER_STRUCT_KERNEL(resnet_basic_block,
XPU,
ALL_LAYOUT,
ops::ResNetBasicBlockXPUKernel,
float) {}
PD_REGISTER_STRUCT_KERNEL(resnet_basic_block_grad,
XPU,
ALL_LAYOUT,
ops::ResNetBasicBlockGradXPUKernel,
float) {}
#endif
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class ResNetUnitXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......@@ -181,7 +181,7 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......@@ -361,9 +361,15 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_unit,
ops::ResNetUnitXPUKernel<plat::float16>,
ops::ResNetUnitXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(resnet_unit_grad,
ops::ResNetUnitGradXPUKernel<plat::float16>,
ops::ResNetUnitGradXPUKernel<float>);
PD_REGISTER_STRUCT_KERNEL(resnet_unit,
XPU,
ALL_LAYOUT,
ops::ResNetUnitXPUKernel,
plat::float16,
float) {}
PD_REGISTER_STRUCT_KERNEL(resnet_unit_grad,
XPU,
ALL_LAYOUT,
ops::ResNetUnitGradXPUKernel,
plat::float16,
float) {}
......@@ -16,8 +16,6 @@
#include "paddle/fluid/platform/device_context.h"
namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float, XPUCtx>,
paddle::operators::SamplingIdKernel<double, XPUCtx>);
PD_REGISTER_STRUCT_KERNEL(
sampling_id, XPU, ALL_LAYOUT, ops::SamplingIdKernel, float, double) {}
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class SequenceConvXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -154,7 +154,7 @@ class SequenceConvXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -363,12 +363,12 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
sequence_conv,
ops::SequenceConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
PD_REGISTER_STRUCT_KERNEL(
sequence_conv, XPU, ALL_LAYOUT, ops::SequenceConvXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(sequence_conv_grad,
XPU,
ALL_LAYOUT,
ops::SequenceConvGradXPUKernel,
float) {}
#endif
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(sequence_unpad,
ops::SequenceUnpadOpKernel<float, phi::XPUContext>);
PD_REGISTER_STRUCT_KERNEL(
sequence_unpad, XPU, ALL_LAYOUT, ops::SequenceUnpadOpKernel, float) {}
#endif
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -71,7 +71,7 @@ class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
......@@ -95,10 +95,15 @@ class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_XPU_KERNEL(uniform_random_inplace,
paddle::operators::XPUUniformRandomInplaceKernel<float>);
REGISTER_OP_XPU_KERNEL(
uniform_random_inplace_grad,
paddle::operators::XPUUniformRandomInplaceGradKernel<float>);
PD_REGISTER_STRUCT_KERNEL(uniform_random_inplace,
XPU,
ALL_LAYOUT,
ops::XPUUniformRandomInplaceKernel,
float) {}
PD_REGISTER_STRUCT_KERNEL(uniform_random_inplace_grad,
XPU,
ALL_LAYOUT,
ops::XPUUniformRandomInplaceGradKernel,
float) {}
#endif // PADDLE_WITH_XPU
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册