From 40f627370ba0a1ea75f864a9107b5d7a979a911d Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Wed, 18 Aug 2021 16:44:38 +0800 Subject: [PATCH] [NPU] Add leaky Relu (#34894) * test=develop * test=develop --- paddle/fluid/operators/activation_op_npu.cc | 53 +++++++ .../unittests/npu/test_leaky_relu_op_npu.py | 141 ++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_leaky_relu_op_npu.py diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 8f6af4260dc..d815a3eeb4d 100755 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -207,6 +207,47 @@ class SqrtNPUKernel : public framework::OpKernel { } }; +template +class LeakyReluNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto alpha = ctx.Attr("alpha"); + + out->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + const auto& runner = + NpuOpRunner("LeakyRelu", {*x}, {*out}, {{"negative_slope", alpha}}); + runner.Run(stream); + } +}; + +template +class LeakyReluGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto alpha = ctx.Attr("alpha"); + + auto stream = + ctx.template device_context() + .stream(); + + dx->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("LeakyReluGrad", {*dout, *x}, {*dx}, + {{"negative_slope", alpha}}); + + runner.Run(stream); + } +}; + template class SqrtGradNPUKernel : public framework::OpKernel { public: @@ -778,6 +819,18 @@ REGISTER_OP_NPU_KERNEL( ops::Relu6GradNPUKernel); +REGISTER_OP_NPU_KERNEL( + leaky_relu, + ops::LeakyReluNPUKernel, + ops::LeakyReluNPUKernel); + +REGISTER_OP_NPU_KERNEL( + leaky_relu_grad, + ops::LeakyReluGradNPUKernel, + ops::LeakyReluGradNPUKernel); + REGISTER_OP_NPU_KERNEL( sqrt, ops::SqrtNPUKernel, ops::SqrtNPUKernel