From 20f18930ae463f4eba1f8c0b682fb7db5ddbce33 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Mon, 12 Aug 2019 20:05:11 +0800 Subject: [PATCH] Add hard swish op (new op) (#19001) * add hard_swish activation op (new op) test=develop * remove redundancy files * modify document content of HardSwish OP * add API test in test_layers.py * add dynamic_graph for test_hard_swish --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/activation_op.cc | 26 ++++++++++ paddle/fluid/operators/activation_op.h | 48 ++++++++++++++++++- python/paddle/fluid/layers/nn.py | 36 ++++++++++++++ .../tests/unittests/test_activation_op.py | 25 ++++++++++ .../fluid/tests/unittests/test_layers.py | 14 ++++++ 6 files changed, 149 insertions(+), 1 deletion(-) mode change 100755 => 100644 paddle/fluid/operators/activation_op.cc diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3b73291ca5..bbb541a757 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -270,6 +270,7 @@ paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddi paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35')) paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c')) paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924')) +paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6')) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc old mode 100755 new mode 100644 index 7db6c6e676..75e7e240eb --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -573,6 +573,32 @@ $$out = \\frac{x}{1 + e^{- \beta \ x}}$$ } }; +class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of HardSwish operator"); + AddOutput("Out", "Output of HardSwish operator"); + AddAttr("threshold", "The threshold parameter of HardSwish operator") + .SetDefault(6.0f); + AddAttr("scale", "The scale parameter of HardSwish operator") + .SetDefault(6.0f); + AddAttr("offset", "The offset parameter of HardSwish operator") + .SetDefault(3.0f); + AddComment(R"DOC( +HardSwish Activation Operator. + +The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf). + +$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$ + +The threshold and scale should be positive. The offset can be either positive or negative. +The default parameters are set according to the above reference. +It is recommended to use the defaults for this activation. + +)DOC"); + } +}; + REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b4d01dfc6b..7afa7be253 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -919,6 +919,51 @@ struct Relu6GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +// HardSwish = min(max(0, x+3), 6) * x / 6 +template +struct HardSwishFunctor : public BaseActivationFunctor { + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = (x + static_cast(offset)) + .cwiseMax(static_cast(0)) + .cwiseMin(static_cast(threshold)) * + x / static_cast(scale); + } +}; + +template +struct HardSwishGradFunctor : public BaseActivationFunctor { + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto tmp = ((x + static_cast(offset)) < static_cast(threshold)) + .template cast(); + dx.device(d) = + dout * + (((x + static_cast(offset)) > static_cast(0)).template cast() * + (static_cast(2) * x + static_cast(offset)) / + static_cast(scale) * tmp + + static_cast(1) * (static_cast(1) - tmp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + // softplus(x) = log(1 + exp(x)) // When x is a very large positive number, exp(x) may explode to inf, // Using trick below for numerical stability @@ -1580,4 +1625,5 @@ class SqrtDoubleGradKernel HardSigmoidGradFunctor); \ __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \ - ThresholdedReluGradFunctor); + ThresholdedReluGradFunctor); \ + __macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b269b01ee7..f01435ab92 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -213,6 +213,7 @@ __all__ = [ 'deformable_roi_pooling', 'var_conv_2d', 'shard_index', + 'hard_swish', ] kIgnoreIndex = -100 @@ -13100,3 +13101,38 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): }, stop_gradient=True) return out + + +@templatedoc() +def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None): + """ + ${comment} + Args: + x(Varaible): Input of HardSwish operator. + threshold(float): The threshold parameter of HardSwish operator. Default:threshold=6.0 + scale(float): The scale parameter of HardSwish operator. Default:scale=6.0 + offset(float): The offset parameter of HardSwish operator. Default:offset=3.0 + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The output tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32") + y = fluid.layers.hard_swish(x) + """ + helper = LayerHelper('hard_swish', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='hard_swish', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'threshold': threshold, + 'scale': scale, + 'offset': offset}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 0a4f2bf179..ff210d1f20 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -450,6 +450,30 @@ class TestRelu6(TestActivation): self.check_grad(['X'], 'Out', max_relative_error=0.02) +class TestHardSwish(TestActivation): + def setUp(self): + self.op_type = 'hard_swish' + self.init_dtype() + + x = np.random.uniform(-6, 6, [4, 4]).astype(self.dtype) + threshold = 6.0 + scale = 6.0 + offset = 3.0 + #the same with TestAbs + x[np.abs(x + offset) < 0.005] = 0.02 + x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02 + out = x * np.minimum(np.maximum(x + offset, 0), threshold) / scale + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', max_relative_error=0.02) + + class TestSoftRelu(TestActivation): def setUp(self): self.op_type = "soft_relu" @@ -773,6 +797,7 @@ create_test_act_fp16_class(TestSoftsign) create_test_act_fp16_class(TestThresholdedRelu) create_test_act_fp16_class(TestHardSigmoid) create_test_act_fp16_class(TestSwish) +create_test_act_fp16_class(TestHardSwish) if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index ce1305bfc2..06fe57bbb4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -903,6 +903,20 @@ class TestLayer(LayerTest): with self.assertRaises(TypeError): layers.eye(num_rows=3, batch_shape=[-1]) + def test_hard_swish(self): + with self.static_graph(): + t = layers.data(name='t', shape=[3, 3], dtype='float32') + ret = layers.hard_swish(t) + static_ret = self.get_static_graph_result( + feed={'t': np.ones( + [3, 3], dtype='float32')}, fetch_list=[ret])[0] + + with self.dynamic_graph(): + t = np.ones([3, 3], dtype='float32') + dy_ret = layers.hard_swish(base.to_variable(t)) + + self.assertTrue(np.allclose(static_ret, dy_ret.numpy())) + class TestBook(LayerTest): def test_all_layers(self): -- GitLab