From fd92d949c48137d13d0c4aa1f0dfcf806ebedc4a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 16 Aug 2021 11:32:56 +0800 Subject: [PATCH] Support npu op hard_swish and hard_swish_grad (#34608) * Support NPU OP hard_swish and hard_swish_grad * Support NPU OP hard_swish and hard_swish_grad * add the unittest to compare the result between npu ans cpu * format the prompt of exception * replace Min and Max op by ClipByValue op * fix the precision problem for fp16 * Using HardtanhGrad to improve performace --- paddle/fluid/operators/activation_op_npu.cc | 156 ++++++++++++++++++ .../unittests/npu/test_hard_swish_op_npu.py | 126 ++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_hard_swish_op_npu.py diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 5cf70cc391..8f6af4260d 100755 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -16,6 +16,7 @@ limitations under the Licnse. */ #include #include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -388,6 +389,155 @@ class SigmoidGradNPUKernel : public framework::OpKernel { } }; +// HardSwish = min(max(0, x+offset), threshold) * x / scale +template +class HardSwishNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + float threshold = ctx.Attr("threshold"); + float scale = ctx.Attr("scale"); + float offset = ctx.Attr("offset"); + + auto place = ctx.GetPlace(); + + out->mutable_data(place); + + auto stream = + ctx.template device_context() + .stream(); + + Tensor tensor_offset(x->type()); + tensor_offset.mutable_data({1}, place); + FillNpuTensorWithConstant(&tensor_offset, static_cast(offset)); + + Tensor add_offset_val(x->type()); + add_offset_val.mutable_data(x->dims(), place); + const auto& runner_add = + NpuOpRunner("AddV2", {*x, tensor_offset}, {add_offset_val}); + runner_add.Run(stream); + + Tensor tensor_threshold(x->type()); + tensor_threshold.mutable_data({1}, place); + FillNpuTensorWithConstant(&tensor_threshold, static_cast(threshold)); + + Tensor tensor_zero(x->type()); + tensor_zero.mutable_data({1}, place); + FillNpuTensorWithConstant(&tensor_zero, static_cast(0.0)); + + Tensor clip_val(x->type()); + clip_val.mutable_data(x->dims(), place); + const auto& runner_clip = NpuOpRunner( + "ClipByValue", {add_offset_val, tensor_zero, tensor_threshold}, + {clip_val}); + runner_clip.Run(stream); + + Tensor tensor_scale_tmp(x->type()); + tensor_scale_tmp.mutable_data({1}, place); + FillNpuTensorWithConstant(&tensor_scale_tmp, static_cast(scale)); + Tensor tensor_scale(x->type()); + tensor_scale.mutable_data(x->dims(), place); + const auto& runner_fill = + NpuOpRunner("FillD", {tensor_scale_tmp}, {tensor_scale}, + {{"dims", framework::vectorize(x->dims())}}); + runner_fill.Run(stream); + + Tensor div_val(x->type()); + div_val.mutable_data(x->dims(), place); + const auto& runner_div = + NpuOpRunner("Div", {clip_val, tensor_scale}, {div_val}); + runner_div.Run(stream); + + const auto& runner_mul = NpuOpRunner("Mul", {*x, div_val}, {*out}); + runner_mul.Run(stream); + } +}; + +template +class HardSwishGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + float threshold = ctx.Attr("threshold"); + float scale = ctx.Attr("scale"); + float offset = ctx.Attr("offset"); + + auto place = ctx.GetPlace(); + + dx->mutable_data(place); + + auto stream = + ctx.template device_context() + .stream(); + + Tensor tensor_offset(x->type()); + tensor_offset.mutable_data({1}, place); + FillNpuTensorWithConstant(&tensor_offset, static_cast(offset)); + + Tensor add_offset_val(x->type()); + add_offset_val.mutable_data(x->dims(), place); + const auto& runner_add = + NpuOpRunner("AddV2", {*x, tensor_offset}, {add_offset_val}); + runner_add.Run(stream); + + Tensor tmp1(x->type()); + tmp1.mutable_data(x->dims(), place); + const auto& runner_pow1 = NpuOpRunner("Power", {*x}, {tmp1}, + {{"scale", 2.0f}, {"shift", offset}}); + runner_pow1.Run(stream); + + Tensor tmp2(x->type()); + tmp2.mutable_data(x->dims(), place); + const auto& runner_ht_grad = + NpuOpRunner("HardtanhGrad", {add_offset_val, tmp1}, {tmp2}, + {{"min_val", 0.0f}, {"max_val", threshold}}); + runner_ht_grad.Run(stream); + + Tensor tmp3(x->type()); + tmp3.mutable_data(x->dims(), place); + const auto& runner_pow2 = NpuOpRunner( + "Power", {tmp2}, {tmp3}, {{"scale", 1.0f / scale}, {"shift", 1.0f}}); + runner_pow2.Run(stream); + + Tensor tensor_threshold_tmp(x->type()); + tensor_threshold_tmp.mutable_data({1}, place); + FillNpuTensorWithConstant(&tensor_threshold_tmp, + static_cast(threshold)); + Tensor tensor_threshold(x->type()); + tensor_threshold.mutable_data(x->dims(), place); + const auto& runner_fill = + NpuOpRunner("FillD", {tensor_threshold_tmp}, {tensor_threshold}, + {{"dims", framework::vectorize(x->dims())}}); + runner_fill.Run(stream); + + Tensor tmp_bool(framework::proto::VarType::BOOL); + tmp_bool.mutable_data(x->dims(), place); + const auto& runner_less = + NpuOpRunner("Less", {add_offset_val, tensor_threshold}, {tmp_bool}); + runner_less.Run(stream); + Tensor tmp4(x->type()); + tmp4.mutable_data(x->dims(), place); + auto dst_dtype = ConvertToNpuDtype(x->type()); + const auto& runner_cast = + NpuOpRunner("Cast", {tmp_bool}, {tmp4}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast.Run(stream); + + Tensor tmp5(x->type()); + tmp5.mutable_data(x->dims(), place); + const auto& runner_sub = NpuOpRunner("Sub", {tmp3, tmp4}, {tmp5}); + runner_sub.Run(stream); + + const auto& runner_final = NpuOpRunner("Mul", {tmp5, *dout}, {*dx}); + runner_final.Run(stream); + } +}; + template class HardSigmoidNPUKernel : public framework::OpKernel { public: @@ -677,6 +827,12 @@ REGISTER_OP_NPU_KERNEL( ops::SigmoidGradNPUKernel); +REGISTER_OP_NPU_KERNEL(hard_swish, ops::HardSwishNPUKernel, + ops::HardSwishNPUKernel); + +REGISTER_OP_NPU_KERNEL(hard_swish_grad, ops::HardSwishGradNPUKernel, + ops::HardSwishGradNPUKernel); + REGISTER_OP_NPU_KERNEL( hard_sigmoid, ops::HardSigmoidNPUKernel, diff --git a/python/paddle/fluid/tests/unittests/npu/test_hard_swish_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_hard_swish_op_npu.py new file mode 100644 index 0000000000..32042ba83a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_hard_swish_op_npu.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.nn.functional as F + + +def ref_hard_swish_grad(x, threshold=6.0, scale=6.0, offset=3.0): + dout = np.full_like(x, fill_value=1. / x.size) + tmp = ((x + offset) < threshold).astype(x.dtype) + dx = dout * (((x + offset) > 0).astype(x.dtype) * + (2 * x + offset) * tmp / scale + 1.0 - tmp) + return dx + + +class TestHardSwishNPU(OpTest): + def setUp(self): + paddle.enable_static() + + self.set_npu() + self.op_type = "hard_swish" + self.place = paddle.NPUPlace(0) + self.init_dtype() + + x = np.random.uniform(-6, 6, [10, 12]).astype(self.dtype) + threshold = 6.0 + scale = 6.0 + offset = 3.0 + #the same with TestAbs + x[np.abs(x + offset) < 0.005] = 0.02 + x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02 + out = (x * (np.minimum(np.maximum(x + offset, 0.), threshold) / + scale)).astype(self.dtype) + self.x_grad = ref_hard_swish_grad(x, threshold, scale, offset) + + self.inputs = {'X': x} + self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset} + self.outputs = {'Out': out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + return + # There is a problem that precision of grad result using float32 + # can't satisfy the default precision requirement + # when compared with numeric_grads, but the results on + # NPU and CPU are same (verified in TestHardSwishNPUWithCPU) + self.check_grad_with_place( + self.place, ['X'], 'Out', user_defined_grads=[self.x_grad]) + + +class TestHardSwishNPUFp16(TestHardSwishNPU): + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_dtype(self): + self.dtype = np.float16 + + +# test the result of hard_swish and hard_swish_grad on CPU and NPU +class TestHardSwishNPUWithCPU(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + self.place = paddle.NPUPlace(0) + self.dtype = np.float32 + + self.x = np.random.uniform(-6, 10, [8, 15]).astype(self.dtype) + + paddle.set_device('cpu') + + data = paddle.to_tensor(self.x, stop_gradient=False) + y = F.hardswish(data) + y.sum().backward() + + self.out_g = data.grad + self.out_y = y + + def test_check_output_and_grad_npu(self): + paddle.set_device('npu') + + data = paddle.to_tensor(self.x, stop_gradient=False) + y = F.hardswish(data) + y.sum().backward() + + self.assertTrue( + np.allclose(self.out_y.numpy(), y.numpy()), + "Output of NPU HardSwish forward has diff at " + str(self.place) + + "\nExpect " + str(self.out_y) + "\n" + "But Got" + str(y) + + " in class " + self.__class__.__name__ + ".") + self.assertTrue( + np.allclose(self.out_g.numpy(), data.grad.numpy()), + "Output of NPU HardSwish backward has diff at " + str(self.place) + + "\nExpect " + str(self.out_g) + "\n" + "But Got" + str(data.grad) + + " in class " + self.__class__.__name__ + ".") + + +if __name__ == '__main__': + unittest.main() -- GitLab