diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index e19ce87e7c8ecdb2c212e19055451b876b55c054..6ba86351e6af55d2d4bf8a8bbb5fb28e30da596d 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -256,6 +256,149 @@ class ExpGradMLUKernel : public framework::OpKernel { } }; +template +class HardSwishMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + output->mutable_data(ctx.GetPlace()); + float threshold = ctx.Attr("threshold"); + float scale = ctx.Attr("scale"); + float offset = ctx.Attr("offset"); + PADDLE_ENFORCE_EQ(threshold, + 6.0f, + platform::errors::External( + "Not support threshold [%f] in MLU", threshold)); + PADDLE_ENFORCE_EQ( + scale, + 6.0f, + platform::errors::External("Not support scale [%f] in MLU", scale)); + PADDLE_ENFORCE_EQ( + offset, + 3.0f, + platform::errors::External("Not support offset [%f] in MLU", offset)); + + MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSWISH, + 1.0f /*ceof useless*/); + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc output_desc(*output); + + MLUCnnl::Active(ctx, + act_desc.get(), + input_desc.get(), + GetBasePtr(input), + output_desc.get(), + GetBasePtr(output)); + } +}; + +template +class HardSwishGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + float threshold = ctx.Attr("threshold"); + float scale = ctx.Attr("scale"); + float offset = ctx.Attr("offset"); + PADDLE_ENFORCE_EQ(threshold, + 6.0f, + platform::errors::External( + "Not support threshold [%f] in MLU", threshold)); + PADDLE_ENFORCE_EQ( + scale, + 6.0f, + platform::errors::External("Not support scale [%f] in MLU", scale)); + PADDLE_ENFORCE_EQ( + offset, + 3.0f, + platform::errors::External("Not support offset [%f] in MLU", offset)); + auto* out = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSWISH, + 1.0f /*ceof useless*/); + MLUCnnl::ActiveGrad(ctx, + act_desc.get(), + nullptr, + nullptr, + nullptr, + nullptr, + dout_desc.get(), + GetBasePtr(dout), + out_desc.get(), + GetBasePtr(out), + dx_desc.get(), + GetBasePtr(dx)); + } +}; + +template +class HardSigmoidMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + float slope = ctx.Attr("slope"); + float offset = ctx.Attr("offset"); + output->mutable_data(ctx.GetPlace()); + + MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSIGMOID, + 1.0f /*ceof useless*/, + 1.0f /*sliced_dim useless*/, + slope, + offset); + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc output_desc(*output); + + MLUCnnl::Active(ctx, + act_desc.get(), + input_desc.get(), + GetBasePtr(input), + output_desc.get(), + GetBasePtr(output)); + } +}; + +template +class HardSigmoidGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* out = ctx.Input("Out"); + auto* dx = ctx.Output(framework::GradVarName("X")); + float slope = ctx.Attr("slope"); + float offset = ctx.Attr("offset"); + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSIGMOID, + 1.0f /*ceof useless*/, + 1.0f /*sliced_dim useless*/, + slope, + offset); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnl::ActiveGrad(ctx, + act_desc.get(), + nullptr, + nullptr, + nullptr, + nullptr, + dout_desc.get(), + GetBasePtr(dout), + out_desc.get(), + GetBasePtr(out), + dx_desc.get(), + GetBasePtr(dx)); + } +}; + } // namespace operators } // namespace paddle @@ -359,3 +502,20 @@ REGISTER_OP_MLU_KERNEL(exp, REGISTER_OP_MLU_KERNEL(exp_grad, ops::ExpGradMLUKernel, ops::ExpGradMLUKernel); + +REGISTER_OP_MLU_KERNEL(hard_swish, + ops::HardSwishMLUKernel, + ops::HardSwishMLUKernel); + +REGISTER_OP_MLU_KERNEL(hard_swish_grad, + ops::HardSwishGradMLUKernel, + ops::HardSwishGradMLUKernel); + +REGISTER_OP_MLU_KERNEL(hard_sigmoid, + ops::HardSigmoidMLUKernel, + ops::HardSigmoidMLUKernel); + +REGISTER_OP_MLU_KERNEL( + hard_sigmoid_grad, + ops::HardSigmoidGradMLUKernel, + ops::HardSigmoidGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_hard_sigmoid_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_hard_sigmoid_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..a38c12c9004708ededbfc7f8d246b8db01b3b32e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_hard_sigmoid_op_mlu.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys + +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.nn.functional as F + +paddle.enable_static() +SEED = 2021 + + +def ref_hardsigmoid(x, slope=0.166666666666667, offset=0.5): + return np.maximum(np.minimum(x * slope + offset, 1.), 0.).astype(x.dtype) + + +class TestMLUHardSigmoid(OpTest): + + def setUp(self): + paddle.enable_static() + + self.op_type = "hard_sigmoid" + self.set_mlu() + self.init_dtype() + self.set_attrs() + + x = np.random.uniform(-5, 5, [10, 12]).astype(self.dtype) + lower_threshold = -self.offset / self.slope + upper_threshold = (1. - self.offset) / self.slope + + # Same reason as TestAbs + delta = 0.005 + x[np.abs(x - lower_threshold) < delta] = lower_threshold - 0.02 + x[np.abs(x - upper_threshold) < delta] = upper_threshold - 0.02 + + out = ref_hardsigmoid(x, self.slope, self.offset) + + self.attrs = {'slope': self.slope, 'offset': self.offset} + self.inputs = {'X': x} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def set_attrs(self): + self.slope = 0.166666666666667 + self.offset = 0.5 + + +class TestMLUHardSigmoid2(TestMLUHardSigmoid): + + def set_attrs(self): + self.slope = 0.2 + self.offset = 0.5 + + +class TestMLUHardSigmoid3(TestMLUHardSigmoid): + + def set_attrs(self): + self.slope = 0.2 + self.offset = 0.4 + + +class TestMLUHardSigmoidFp16(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + + self.place = paddle.MLUPlace(0) + self.dtype = np.float32 + + # float32 + self.float32_x = np.random.uniform(-5, 5, [10, 12]).astype(np.float32) + paddle.set_device('cpu') + data = paddle.to_tensor(self.float32_x, stop_gradient=True) + self.float32_y = F.hardsigmoid(data) + + # float16 + self.float16_x = self.float32_x.astype(np.float16) + self.float16_y = ref_hardsigmoid(self.float16_x) + + def test_check_output_and_grad_mlu(self): + # mlu float16 + paddle.set_device('mlu') + data = paddle.to_tensor(self.float16_x, stop_gradient=True) + mlu_float16_y = F.hardsigmoid(data) + + cpu_diff_1 = np.divide( + np.sum(np.abs(self.float32_y.numpy() - self.float16_y)), + np.sum(np.abs(self.float32_y.numpy()))) + mlu_diff_1 = np.divide( + np.sum(np.abs(self.float32_y.numpy() - mlu_float16_y.numpy())), + np.sum(np.abs(self.float32_y.numpy()))) + + cpu_diff_2 = np.divide( + np.sum(np.square(self.float32_y.numpy() - self.float16_y)), + np.sum(np.square(self.float32_y.numpy()))) + mlu_diff_2 = np.divide( + np.sum(np.square(self.float32_y.numpy() - mlu_float16_y.numpy())), + np.sum(np.square(self.float32_y.numpy()))) + assert mlu_diff_1 <= cpu_diff_1 + assert mlu_diff_2 <= cpu_diff_2 + + +class TestHardsigmoidAPI(unittest.TestCase): + # test paddle.nn.Hardsigmoid, paddle.nn.functional.hardsigmoid + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float32) + self.place = paddle.MLUPlace(0) + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.hardsigmoid(x) + m = paddle.nn.Hardsigmoid() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_hardsigmoid(self.x_np) + for r in res: + self.assertTrue(np.allclose(out_ref, r)) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.hardsigmoid(x) + m = paddle.nn.Hardsigmoid() + out2 = m(x) + out_ref = ref_hardsigmoid(self.x_np) + for r in [out1, out2]: + self.assertTrue(np.allclose(out_ref, r.numpy())) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.hard_sigmoid(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_hardsigmoid(self.x_np, 0.2, 0.5) + self.assertTrue(np.allclose(out_ref, res[0])) + + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out = paddle.fluid.layers.hard_sigmoid(x) + self.assertTrue(np.allclose(out_ref, out.numpy())) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.hardsigmoid, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data(name='x_int32', + shape=[12, 10], + dtype='int32') + self.assertRaises(TypeError, F.hardsigmoid, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data(name='x_fp16', + shape=[12, 10], + dtype='float16') + F.hardsigmoid(x_fp16) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ae182b41d19e21b852a317a70022c4275c1b6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle +from op_test import OpTest + +import numpy as np +import unittest +import sys + +sys.path.append("..") + +paddle.enable_static() +SEED = 2020 + + +def scalarToType(val, data_type): + converted_val = np.array([val]).astype(data_type)[0] + print("converted_val type: ", type(converted_val)) + return converted_val + + +def ref_hard_swish_grad(x, threshold, scale, offset, data_type): + threshold = scalarToType(threshold, data_type) + scale = scalarToType(scale, data_type) + offset = scalarToType(offset, data_type) + dout = np.full_like(x, fill_value=1. / x.size) + tmp = ((x + offset) < threshold).astype(x.dtype) + dx = dout * (((x + offset) > 0).astype(x.dtype) * + (2 * x + offset) * tmp / scale + 1.0 - tmp) + return dx + + +class TestHardSwishMLU(OpTest): + + def setUp(self): + paddle.enable_static() + + self.op_type = "hard_swish" + self.place = paddle.MLUPlace(0) + self.init_dtype() + + x = np.random.uniform(-2, 2, [10, 12]).astype(self.dtype) + threshold = 6.0 + scale = 6.0 + offset = 3.0 + + x[np.abs(x + offset) < 0.005] = 0.02 + x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02 + + out = ( + x * + (np.minimum(np.maximum(x + offset, 0.), threshold) / scale)).astype( + self.dtype) + self.x_grad = ref_hard_swish_grad(x, threshold, scale, offset, + self.dtype) + self.set_mlu() + self.inputs = {'X': x} + self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestHardSwishMLUWithCPUFloat16(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + + self.place = paddle.MLUPlace(0) + self.dtype = np.float32 + + # float32 + self.float32_x = np.random.uniform(-6, 10, [8, 15]).astype(np.float32) + paddle.set_device('cpu') + data = paddle.to_tensor(self.float32_x, stop_gradient=False) + self.float32_y = F.hardswish(data) + self.float32_y.sum().backward() + self.float32_grad = data.grad + + # float16 + self.float16_x = self.float32_x.astype(np.float16) + threshold = 6.0 + scale = 6.0 + offset = 3.0 + + threshold = scalarToType(threshold, np.float16) + scale = scalarToType(scale, np.float16) + offset = scalarToType(offset, np.float16) + self.float16_y = (self.float16_x * (np.minimum( + np.maximum(self.float16_x + offset, scalarToType(0., np.float16)), + threshold) / scale)).astype(np.float16) + self.float16_grad = ref_hard_swish_grad(self.float16_x, threshold, + scale, offset, np.float16) + + def test_check_output_and_grad_mlu(self): + # mlu float16 + paddle.set_device('mlu') + data = paddle.to_tensor(self.float16_x, stop_gradient=False) + mlu_float16_y = F.hardswish(data) + mlu_float16_y.sum().backward() + mlu_float16_grad = data.grad + + cpu_diff_1 = np.divide( + np.sum(np.abs(self.float32_y.numpy() - self.float16_y)), + np.sum(np.abs(self.float32_y.numpy()))) + mlu_diff_1 = np.divide( + np.sum(np.abs(self.float32_y.numpy() - mlu_float16_y.numpy())), + np.sum(np.abs(self.float32_y.numpy()))) + + cpu_diff_2 = np.divide( + np.sum(np.square(self.float32_y.numpy() - self.float16_y)), + np.sum(np.square(self.float32_y.numpy()))) + mlu_diff_2 = np.divide( + np.sum(np.square(self.float32_y.numpy() - mlu_float16_y.numpy())), + np.sum(np.square(self.float32_y.numpy()))) + assert mlu_diff_1 <= cpu_diff_1 + assert mlu_diff_2 <= cpu_diff_2 + + cpu_diff_1 = np.divide( + np.sum(np.abs(self.float32_grad.numpy() - self.float16_grad)), + np.sum(np.abs(self.float32_grad.numpy()))) + mlu_diff_1 = np.divide( + np.sum(np.abs(self.float32_grad.numpy() - + mlu_float16_grad.numpy())), + np.sum(np.abs(self.float32_grad.numpy()))) + + cpu_diff_2 = np.divide( + np.sum(np.square(self.float32_grad.numpy() - self.float16_grad)), + np.sum(np.square(self.float32_grad.numpy()))) + mlu_diff_2 = np.divide( + np.sum( + np.square(self.float32_grad.numpy() - + mlu_float16_grad.numpy())), + np.sum(np.square(self.float32_grad.numpy()))) + assert mlu_diff_1 <= cpu_diff_1 + assert mlu_diff_2 <= cpu_diff_2 + + +if __name__ == '__main__': + unittest.main()