diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc index 613eea90a6500dab4db55d6d44fb15cd1ed50e39..0e7136b9f6ce8fd59e031c23c3b4e5b94b0a300d 100644 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -166,6 +166,24 @@ struct XPUReluGradFunctor : public BaseActivationFunctor { } }; +template +struct XPURelu6Functor : public BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward( + ctx, xpu::relu6); + } +}; + +template +struct XPURelu6GradFunctor : public BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::relu6_grad); + } +}; + template struct XPUSigmoidFunctor : public BaseActivationFunctor { using XPUType = typename XPUTypeTrait::Type; @@ -548,6 +566,10 @@ REGISTER_OP_XPU_KERNEL( ops::XPUActivationGradKernel>, ops::XPUActivationGradKernel< ops::XPUReluGradFunctor>); +REGISTER_OP_XPU_KERNEL(relu6, + ops::XPUActivationKernel>); +REGISTER_OP_XPU_KERNEL( + relu6_grad, ops::XPUActivationKernel>); REGISTER_OP_XPU_KERNEL( tanh, ops::XPUActivationKernel>, diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 9f07f05ff7fa6dac5c8b90ab069a820e4b9cdb99..e7570de695f281b77479a94fad4af654035d0543 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -349,6 +349,8 @@ XPUOpMap& get_kl2_ops() { {"reduce_sum_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu6", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu6_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})},