diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 1ccd99c71f339a8711744c75f21fabde8fa3e6ad..ce629fa3a4190408515919b8bdbcaa0f4385b7ef 100644 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -144,6 +144,47 @@ class ReluGradNPUKernel : public framework::OpKernel { } }; +template +class Relu6NPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner("Relu6", + { + *x, + }, + {*out}, {}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class Relu6GradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto stream = + ctx.template device_context() + .stream(); + + dx->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("Relu6Grad", {*dout, *out}, {*dx}, {}); + + runner.Run(stream); + } +}; + template class SqrtNPUKernel : public framework::OpKernel { public: @@ -457,6 +498,17 @@ REGISTER_OP_NPU_KERNEL( ops::ReluGradNPUKernel); +REGISTER_OP_NPU_KERNEL( + relu6, ops::Relu6NPUKernel, + ops::Relu6NPUKernel); + +REGISTER_OP_NPU_KERNEL( + relu6_grad, + ops::Relu6GradNPUKernel, + ops::Relu6GradNPUKernel); + REGISTER_OP_NPU_KERNEL( sqrt, ops::SqrtNPUKernel, ops::SqrtNPUKernel