diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 8f6af4260dcc965dddf3af3fe82e26e9c7b6cc6d..d815a3eeb4d81c70f7eb6ab729afee6b04ffe12f 100755 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -207,6 +207,47 @@ class SqrtNPUKernel : public framework::OpKernel { } }; +template +class LeakyReluNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto alpha = ctx.Attr("alpha"); + + out->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + const auto& runner = + NpuOpRunner("LeakyRelu", {*x}, {*out}, {{"negative_slope", alpha}}); + runner.Run(stream); + } +}; + +template +class LeakyReluGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto alpha = ctx.Attr("alpha"); + + auto stream = + ctx.template device_context() + .stream(); + + dx->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("LeakyReluGrad", {*dout, *x}, {*dx}, + {{"negative_slope", alpha}}); + + runner.Run(stream); + } +}; + template class SqrtGradNPUKernel : public framework::OpKernel { public: @@ -778,6 +819,18 @@ REGISTER_OP_NPU_KERNEL( ops::Relu6GradNPUKernel); +REGISTER_OP_NPU_KERNEL( + leaky_relu, + ops::LeakyReluNPUKernel, + ops::LeakyReluNPUKernel); + +REGISTER_OP_NPU_KERNEL( + leaky_relu_grad, + ops::LeakyReluGradNPUKernel, + ops::LeakyReluGradNPUKernel); + REGISTER_OP_NPU_KERNEL( sqrt, ops::SqrtNPUKernel, ops::SqrtNPUKernel