From 68399947faa96e73324b403383bad7a09c2dde4a Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Thu, 5 Aug 2021 18:00:12 +0800 Subject: [PATCH] [NPU] Add relu6 and relu6_grad npu op (#34596) * Add relu6 and relu6_grad npu op * fixed pre-commit-config.yaml * fixed for CI --- paddle/fluid/operators/activation_op_npu.cc | 52 ++++++ .../tests/unittests/npu/test_relu6_op_npu.py | 166 ++++++++++++++++++ 2 files changed, 218 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_relu6_op_npu.py diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 1ccd99c71f3..ce629fa3a41 100644 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -144,6 +144,47 @@ class ReluGradNPUKernel : public framework::OpKernel { } }; +template +class Relu6NPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner("Relu6", + { + *x, + }, + {*out}, {}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class Relu6GradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto stream = + ctx.template device_context() + .stream(); + + dx->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("Relu6Grad", {*dout, *out}, {*dx}, {}); + + runner.Run(stream); + } +}; + template class SqrtNPUKernel : public framework::OpKernel { public: @@ -457,6 +498,17 @@ REGISTER_OP_NPU_KERNEL( ops::ReluGradNPUKernel); +REGISTER_OP_NPU_KERNEL( + relu6, ops::Relu6NPUKernel, + ops::Relu6NPUKernel); + +REGISTER_OP_NPU_KERNEL( + relu6_grad, + ops::Relu6GradNPUKernel, + ops::Relu6GradNPUKernel); + REGISTER_OP_NPU_KERNEL( sqrt, ops::SqrtNPUKernel, ops::SqrtNPUKernel