未验证 提交 37489df5 编写于 作者: H huangjiyi 提交者: GitHub

Register fluid xpu kerenls to phi [part 3] (#53189)

* update

* update
上级 af986bd5
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class AffineChannelXPUKernel : public framework::OpKernel<T> { class AffineChannelXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -85,7 +85,7 @@ class AffineChannelXPUKernel : public framework::OpKernel<T> { ...@@ -85,7 +85,7 @@ class AffineChannelXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class AffineChannelGradXPUKernel : public framework::OpKernel<T> { class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -189,10 +189,12 @@ class AffineChannelGradXPUKernel : public framework::OpKernel<T> { ...@@ -189,10 +189,12 @@ class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPU = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(affine_channel, ops::AffineChannelXPUKernel<XPU, float>);
REGISTER_OP_XPU_KERNEL(affine_channel_grad,
ops::AffineChannelGradXPUKernel<XPU, float>);
PD_REGISTER_STRUCT_KERNEL(
affine_channel, XPU, ALL_LAYOUT, ops::AffineChannelXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(affine_channel_grad,
XPU,
ALL_LAYOUT,
ops::AffineChannelGradXPUKernel,
float) {}
#endif #endif
...@@ -293,7 +293,7 @@ static inline void xpu_conv2d_grad(xpu::Context* ctx, ...@@ -293,7 +293,7 @@ static inline void xpu_conv2d_grad(xpu::Context* ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
} }
template <typename T> template <typename T, typename DeviceContext>
class ResNetBasicBlockXPUKernel : public framework::OpKernel<T> { class ResNetBasicBlockXPUKernel : public framework::OpKernel<T> {
public: public:
using XPUT = typename XPUTypeTrait<T>::Type; using XPUT = typename XPUTypeTrait<T>::Type;
...@@ -696,7 +696,7 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel<T> { ...@@ -696,7 +696,7 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> { class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> {
public: public:
using XPUT = typename XPUTypeTrait<T>::Type; using XPUT = typename XPUTypeTrait<T>::Type;
...@@ -992,8 +992,14 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> { ...@@ -992,8 +992,14 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_basic_block, PD_REGISTER_STRUCT_KERNEL(resnet_basic_block,
ops::ResNetBasicBlockXPUKernel<float>); XPU,
REGISTER_OP_XPU_KERNEL(resnet_basic_block_grad, ALL_LAYOUT,
ops::ResNetBasicBlockGradXPUKernel<float>); ops::ResNetBasicBlockXPUKernel,
float) {}
PD_REGISTER_STRUCT_KERNEL(resnet_basic_block_grad,
XPU,
ALL_LAYOUT,
ops::ResNetBasicBlockGradXPUKernel,
float) {}
#endif #endif
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class ResNetUnitXPUKernel : public framework::OpKernel<T> { class ResNetUnitXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -181,7 +181,7 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -181,7 +181,7 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -361,9 +361,15 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -361,9 +361,15 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_unit, PD_REGISTER_STRUCT_KERNEL(resnet_unit,
ops::ResNetUnitXPUKernel<plat::float16>, XPU,
ops::ResNetUnitXPUKernel<float>); ALL_LAYOUT,
REGISTER_OP_XPU_KERNEL(resnet_unit_grad, ops::ResNetUnitXPUKernel,
ops::ResNetUnitGradXPUKernel<plat::float16>, plat::float16,
ops::ResNetUnitGradXPUKernel<float>); float) {}
PD_REGISTER_STRUCT_KERNEL(resnet_unit_grad,
XPU,
ALL_LAYOUT,
ops::ResNetUnitGradXPUKernel,
plat::float16,
float) {}
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using XPUCtx = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(sampling_id, PD_REGISTER_STRUCT_KERNEL(
paddle::operators::SamplingIdKernel<float, XPUCtx>, sampling_id, XPU, ALL_LAYOUT, ops::SamplingIdKernel, float, double) {}
paddle::operators::SamplingIdKernel<double, XPUCtx>);
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class SequenceConvXPUKernel : public framework::OpKernel<T> { class SequenceConvXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -154,7 +154,7 @@ class SequenceConvXPUKernel : public framework::OpKernel<T> { ...@@ -154,7 +154,7 @@ class SequenceConvXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class SequenceConvGradXPUKernel : public framework::OpKernel<T> { class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -363,12 +363,12 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -363,12 +363,12 @@ class SequenceConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(
sequence_conv, sequence_conv, XPU, ALL_LAYOUT, ops::SequenceConvXPUKernel, float) {}
ops::SequenceConvXPUKernel<paddle::platform::XPUDeviceContext, float>); PD_REGISTER_STRUCT_KERNEL(sequence_conv_grad,
XPU,
REGISTER_OP_XPU_KERNEL( ALL_LAYOUT,
sequence_conv_grad, ops::SequenceConvGradXPUKernel,
ops::SequenceConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); float) {}
#endif #endif
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(sequence_unpad, PD_REGISTER_STRUCT_KERNEL(
ops::SequenceUnpadOpKernel<float, phi::XPUContext>); sequence_unpad, XPU, ALL_LAYOUT, ops::SequenceUnpadOpKernel, float) {}
#endif #endif
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> { class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -71,7 +71,7 @@ class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> { ...@@ -71,7 +71,7 @@ class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> { class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override { void Compute(const paddle::framework::ExecutionContext &ctx) const override {
...@@ -95,10 +95,15 @@ class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> { ...@@ -95,10 +95,15 @@ class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_XPU_KERNEL(uniform_random_inplace, PD_REGISTER_STRUCT_KERNEL(uniform_random_inplace,
paddle::operators::XPUUniformRandomInplaceKernel<float>); XPU,
REGISTER_OP_XPU_KERNEL( ALL_LAYOUT,
uniform_random_inplace_grad, ops::XPUUniformRandomInplaceKernel,
paddle::operators::XPUUniformRandomInplaceGradKernel<float>); float) {}
PD_REGISTER_STRUCT_KERNEL(uniform_random_inplace_grad,
XPU,
ALL_LAYOUT,
ops::XPUUniformRandomInplaceGradKernel,
float) {}
#endif // PADDLE_WITH_XPU #endif // PADDLE_WITH_XPU
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册