未验证 提交 fd92d949 编写于 作者: Z zyfncg 提交者: GitHub

Support npu op hard_swish and hard_swish_grad (#34608)

* Support NPU OP hard_swish and hard_swish_grad

* Support NPU OP hard_swish and hard_swish_grad

* add the unittest to compare the result between npu ans cpu

* format the prompt of exception

* replace Min and Max op by ClipByValue op

* fix the precision problem for fp16

* Using HardtanhGrad to improve performace
上级 ad6c3b92
......@@ -16,6 +16,7 @@ limitations under the Licnse. */
#include <string>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
......@@ -388,6 +389,155 @@ class SigmoidGradNPUKernel : public framework::OpKernel<T> {
}
};
// HardSwish = min(max(0, x+offset), threshold) * x / scale
template <typename T>
class HardSwishNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
float offset = ctx.Attr<float>("offset");
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Tensor tensor_offset(x->type());
tensor_offset.mutable_data<T>({1}, place);
FillNpuTensorWithConstant<T>(&tensor_offset, static_cast<T>(offset));
Tensor add_offset_val(x->type());
add_offset_val.mutable_data<T>(x->dims(), place);
const auto& runner_add =
NpuOpRunner("AddV2", {*x, tensor_offset}, {add_offset_val});
runner_add.Run(stream);
Tensor tensor_threshold(x->type());
tensor_threshold.mutable_data<T>({1}, place);
FillNpuTensorWithConstant<T>(&tensor_threshold, static_cast<T>(threshold));
Tensor tensor_zero(x->type());
tensor_zero.mutable_data<T>({1}, place);
FillNpuTensorWithConstant<T>(&tensor_zero, static_cast<T>(0.0));
Tensor clip_val(x->type());
clip_val.mutable_data<T>(x->dims(), place);
const auto& runner_clip = NpuOpRunner(
"ClipByValue", {add_offset_val, tensor_zero, tensor_threshold},
{clip_val});
runner_clip.Run(stream);
Tensor tensor_scale_tmp(x->type());
tensor_scale_tmp.mutable_data<T>({1}, place);
FillNpuTensorWithConstant<T>(&tensor_scale_tmp, static_cast<T>(scale));
Tensor tensor_scale(x->type());
tensor_scale.mutable_data<T>(x->dims(), place);
const auto& runner_fill =
NpuOpRunner("FillD", {tensor_scale_tmp}, {tensor_scale},
{{"dims", framework::vectorize(x->dims())}});
runner_fill.Run(stream);
Tensor div_val(x->type());
div_val.mutable_data<T>(x->dims(), place);
const auto& runner_div =
NpuOpRunner("Div", {clip_val, tensor_scale}, {div_val});
runner_div.Run(stream);
const auto& runner_mul = NpuOpRunner("Mul", {*x, div_val}, {*out});
runner_mul.Run(stream);
}
};
template <typename T>
class HardSwishGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
float offset = ctx.Attr<float>("offset");
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Tensor tensor_offset(x->type());
tensor_offset.mutable_data<T>({1}, place);
FillNpuTensorWithConstant<T>(&tensor_offset, static_cast<T>(offset));
Tensor add_offset_val(x->type());
add_offset_val.mutable_data<T>(x->dims(), place);
const auto& runner_add =
NpuOpRunner("AddV2", {*x, tensor_offset}, {add_offset_val});
runner_add.Run(stream);
Tensor tmp1(x->type());
tmp1.mutable_data<T>(x->dims(), place);
const auto& runner_pow1 = NpuOpRunner("Power", {*x}, {tmp1},
{{"scale", 2.0f}, {"shift", offset}});
runner_pow1.Run(stream);
Tensor tmp2(x->type());
tmp2.mutable_data<T>(x->dims(), place);
const auto& runner_ht_grad =
NpuOpRunner("HardtanhGrad", {add_offset_val, tmp1}, {tmp2},
{{"min_val", 0.0f}, {"max_val", threshold}});
runner_ht_grad.Run(stream);
Tensor tmp3(x->type());
tmp3.mutable_data<T>(x->dims(), place);
const auto& runner_pow2 = NpuOpRunner(
"Power", {tmp2}, {tmp3}, {{"scale", 1.0f / scale}, {"shift", 1.0f}});
runner_pow2.Run(stream);
Tensor tensor_threshold_tmp(x->type());
tensor_threshold_tmp.mutable_data<T>({1}, place);
FillNpuTensorWithConstant<T>(&tensor_threshold_tmp,
static_cast<T>(threshold));
Tensor tensor_threshold(x->type());
tensor_threshold.mutable_data<T>(x->dims(), place);
const auto& runner_fill =
NpuOpRunner("FillD", {tensor_threshold_tmp}, {tensor_threshold},
{{"dims", framework::vectorize(x->dims())}});
runner_fill.Run(stream);
Tensor tmp_bool(framework::proto::VarType::BOOL);
tmp_bool.mutable_data<bool>(x->dims(), place);
const auto& runner_less =
NpuOpRunner("Less", {add_offset_val, tensor_threshold}, {tmp_bool});
runner_less.Run(stream);
Tensor tmp4(x->type());
tmp4.mutable_data<T>(x->dims(), place);
auto dst_dtype = ConvertToNpuDtype(x->type());
const auto& runner_cast =
NpuOpRunner("Cast", {tmp_bool}, {tmp4},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast.Run(stream);
Tensor tmp5(x->type());
tmp5.mutable_data<T>(x->dims(), place);
const auto& runner_sub = NpuOpRunner("Sub", {tmp3, tmp4}, {tmp5});
runner_sub.Run(stream);
const auto& runner_final = NpuOpRunner("Mul", {tmp5, *dout}, {*dx});
runner_final.Run(stream);
}
};
template <typename DeviceContext, typename T>
class HardSigmoidNPUKernel : public framework::OpKernel<T> {
public:
......@@ -677,6 +827,12 @@ REGISTER_OP_NPU_KERNEL(
ops::SigmoidGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(hard_swish, ops::HardSwishNPUKernel<float>,
ops::HardSwishNPUKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(hard_swish_grad, ops::HardSwishGradNPUKernel<float>,
ops::HardSwishGradNPUKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
hard_sigmoid,
ops::HardSigmoidNPUKernel<paddle::platform::NPUDeviceContext, float>,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
def ref_hard_swish_grad(x, threshold=6.0, scale=6.0, offset=3.0):
dout = np.full_like(x, fill_value=1. / x.size)
tmp = ((x + offset) < threshold).astype(x.dtype)
dx = dout * (((x + offset) > 0).astype(x.dtype) *
(2 * x + offset) * tmp / scale + 1.0 - tmp)
return dx
class TestHardSwishNPU(OpTest):
def setUp(self):
paddle.enable_static()
self.set_npu()
self.op_type = "hard_swish"
self.place = paddle.NPUPlace(0)
self.init_dtype()
x = np.random.uniform(-6, 6, [10, 12]).astype(self.dtype)
threshold = 6.0
scale = 6.0
offset = 3.0
#the same with TestAbs
x[np.abs(x + offset) < 0.005] = 0.02
x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02
out = (x * (np.minimum(np.maximum(x + offset, 0.), threshold) /
scale)).astype(self.dtype)
self.x_grad = ref_hard_swish_grad(x, threshold, scale, offset)
self.inputs = {'X': x}
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.outputs = {'Out': out}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
return
# There is a problem that precision of grad result using float32
# can't satisfy the default precision requirement
# when compared with numeric_grads, but the results on
# NPU and CPU are same (verified in TestHardSwishNPUWithCPU)
self.check_grad_with_place(
self.place, ['X'], 'Out', user_defined_grads=[self.x_grad])
class TestHardSwishNPUFp16(TestHardSwishNPU):
def test_check_output(self):
self.check_output_with_place(self.place)
def init_dtype(self):
self.dtype = np.float16
# test the result of hard_swish and hard_swish_grad on CPU and NPU
class TestHardSwishNPUWithCPU(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.place = paddle.NPUPlace(0)
self.dtype = np.float32
self.x = np.random.uniform(-6, 10, [8, 15]).astype(self.dtype)
paddle.set_device('cpu')
data = paddle.to_tensor(self.x, stop_gradient=False)
y = F.hardswish(data)
y.sum().backward()
self.out_g = data.grad
self.out_y = y
def test_check_output_and_grad_npu(self):
paddle.set_device('npu')
data = paddle.to_tensor(self.x, stop_gradient=False)
y = F.hardswish(data)
y.sum().backward()
self.assertTrue(
np.allclose(self.out_y.numpy(), y.numpy()),
"Output of NPU HardSwish forward has diff at " + str(self.place) +
"\nExpect " + str(self.out_y) + "\n" + "But Got" + str(y) +
" in class " + self.__class__.__name__ + ".")
self.assertTrue(
np.allclose(self.out_g.numpy(), data.grad.numpy()),
"Output of NPU HardSwish backward has diff at " + str(self.place) +
"\nExpect " + str(self.out_g) + "\n" + "But Got" + str(data.grad) +
" in class " + self.__class__.__name__ + ".")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册