未验证 提交 0025e0d8 编写于 作者: Z zhupengyang 提交者: GitHub

refine APIs: brelu, hardsigmoid, hardswish, maxout (#27658)

上级 5098891f
......@@ -83,6 +83,18 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d.",
groups));
PADDLE_ENFORCE_EQ(
axis == 1 || axis == -1 || axis == 3, true,
platform::errors::InvalidArgument(
"axis only supported 1, -1 or 3, but recevied axis is: %d", axis));
PADDLE_ENFORCE_EQ(in_x_dims.size(), 4,
platform::errors::InvalidArgument(
"x's dims should be 4, but received x's dims is: %d",
in_x_dims.size()));
if (axis < 0) {
axis += in_x_dims.size();
}
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0,
platform::errors::InvalidArgument(
......
......@@ -31,6 +31,9 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
if (axis < 0) {
axis += in_x->dims().size();
}
math::MaxOutFunctor<DeviceContext, T> maxout_forward;
maxout_forward(context.template device_context<DeviceContext>(), *in_x, out,
......@@ -49,6 +52,10 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
if (axis < 0) {
axis += in_x->dims().size();
}
auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) {
......
......@@ -9592,10 +9592,6 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
@templatedoc()
def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
"""
:alias_main: paddle.nn.functional.hard_sigmoid
:alias: paddle.nn.functional.hard_sigmoid,paddle.nn.functional.activation.hard_sigmoid
:old_api: paddle.fluid.layers.hard_sigmoid
${comment}
Parameters:
x (${x_type}): ${x_comment}
......@@ -9613,9 +9609,15 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]]
"""
if in_dygraph_mode():
return core.ops.hard_sigmoid(x, 'slope', slope, 'offset', offset)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_sigmoid')
......@@ -9802,10 +9804,6 @@ def prelu(x, mode, param_attr=None, name=None):
@templatedoc()
def brelu(x, t_min=0.0, t_max=24.0, name=None):
"""
:alias_main: paddle.nn.functional.brelu
:alias: paddle.nn.functional.brelu,paddle.nn.functional.activation.brelu
:old_api: paddle.fluid.layers.brelu
${comment}
Args:
x(${x_type}): ${x_comment}
......@@ -9821,7 +9819,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
paddle.enable_static()
input_brelu = np.array([[-1,6],[1,15.6]])
with fluid.dygraph.guard():
......@@ -9831,6 +9831,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
#[[ 1. 6.]
#[ 1. 10.]]
"""
if in_dygraph_mode():
return core.ops.brelu(x, 't_min', t_min, 't_max', t_max)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu')
helper = LayerHelper('brelu', **locals())
......@@ -12564,13 +12567,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.maxout")
@templatedoc()
def maxout(x, groups, name=None, axis=1):
"""
:alias_main: paddle.nn.functional.maxout
:alias: paddle.nn.functional.maxout,paddle.nn.functional.activation.maxout
:old_api: paddle.fluid.layers.maxout
${comment}
Args:
......@@ -12592,31 +12592,16 @@ def maxout(x, groups, name=None, axis=1):
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
input = fluid.data(
name='data',
shape=[None, 256, 32, 32],
dtype='float32')
out = fluid.layers.maxout(input, groups=2)
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="maxout",
inputs={"X": x},
attrs={"groups": groups,
"axis": axis},
outputs={"Out": out})
return out
return paddle.nn.functional.maxout(**locals())
def space_to_depth(x, blocksize, name=None):
......@@ -14877,10 +14862,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
@templatedoc()
def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
"""
:alias_main: paddle.nn.functional.hard_swish
:alias: paddle.nn.functional.hard_swish,paddle.nn.functional.activation.hard_swish
:old_api: paddle.fluid.layers.hard_swish
This operator implements the hard_swish activation function.
Hard_swish is proposed in MobileNetV3, and performs better in computational stability and efficiency compared to swish function.
For more details please refer to: https://arxiv.org/pdf/1905.02244.pdf
......@@ -14911,7 +14892,9 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
paddle.enable_static()
DATATYPE='float32'
......@@ -14926,6 +14909,10 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667,3., 4.]]
"""
if in_dygraph_mode():
return core.ops.hard_swish(x, 'threshold', threshold, 'scale', scale,
'offset', offset)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_swish')
......
......@@ -1657,21 +1657,6 @@ class TestLayer(LayerTest):
with self.assertRaises(TypeError):
layers.eye(num_rows=3, batch_shape=[-1])
def test_hard_swish(self):
with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32')
ret = layers.hard_swish(t)
static_ret = self.get_static_graph_result(
feed={'t': np.ones(
[3, 3], dtype='float32')}, fetch_list=[ret])[0]
with self.dynamic_graph():
t = np.ones([3, 3], dtype='float32')
dy_ret = layers.hard_swish(base.to_variable(t))
dy_ret_rlt = dy_ret.numpy()
self.assertTrue(np.allclose(static_ret, dy_ret_rlt))
def test_while_loop(self):
with self.static_graph():
i = layers.fill_constant(shape=[1], dtype='int64', value=0)
......@@ -2563,13 +2548,6 @@ class TestBook(LayerTest):
output = layers.l2_normalize(x, axis=1)
return output
def make_maxout(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
data = self._get_data(name='x', shape=[8, 6, 6], dtype="float32")
output = layers.maxout(x=data, groups=2)
return (output)
def make_crop(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......@@ -2656,13 +2634,6 @@ class TestBook(LayerTest):
name='prelu')
return (out)
def make_brelu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = self._get_data(name="input", shape=[16], dtype="float32")
out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu')
return (out)
def make_soft_relu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
......@@ -16,32 +16,43 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core
import paddle.nn.functional as F
from op_test import OpTest
paddle.enable_static()
np.random.seed(1)
def maxout_forward_naive(input, groups, channel_axis):
s0, s1, s2, s3 = input.shape
if channel_axis == 3:
return np.ndarray([s0, s1, s2, s3 // groups, groups], \
buffer = input, dtype=input.dtype).max(axis=(4))
def maxout_forward_naive(x, groups, channel_axis):
s0, s1, s2, s3 = x.shape
if channel_axis == 1:
return np.ndarray([s0, s1 // groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
buffer = x, dtype=x.dtype).max(axis=2)
return np.ndarray([s0, s1, s2, s3 // groups, groups], \
buffer = x, dtype=x.dtype).max(axis=4)
class TestMaxOutOp(OpTest):
def setUp(self):
self.op_type = "maxout"
self.init_test_case()
input = np.random.random(self.shape)
output = self.MaxOut_forward_naive(input, self.groups, self.axis)
self.dtype = 'float64'
self.shape = [3, 6, 2, 4]
self.groups = 2
self.axis = 1
self.set_attrs()
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = maxout_forward_naive(x, self.groups, self.axis)
self.inputs = {'X': input}
self.inputs = {'X': x}
self.attrs = {'groups': self.groups, 'axis': self.axis}
self.outputs = {'Out': out}
self.outputs = {'Out': output}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output()
......@@ -49,65 +60,89 @@ class TestMaxOutOp(OpTest):
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups = 2
self.axis = 1
class TestMaxOutOpAxis(TestMaxOutOp):
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 2, 2, 6] # NHWC format
self.groups = 2
self.axis = 3
class TestMaxOutOpAxis0(TestMaxOutOp):
def set_attrs(self):
self.axis = -1
class TestMaxOutOpAxisAPI(unittest.TestCase):
def test_axis(self):
data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32')
data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32')
out1 = fluid.layers.maxout(data1, groups=2, axis=1)
out2 = fluid.layers.maxout(data2, groups=2, axis=-1)
data1_np = np.random.random((3, 6, 2, 2)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
class TestMaxOutOpAxis1(TestMaxOutOp):
def set_attrs(self):
self.axis = 3
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
class TestMaxOutOpFP32(TestMaxOutOp):
def set_attrs(self):
self.dtype = 'float32'
def test_exception(self):
input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32")
def _attr_axis():
out = fluid.layers.maxout(input, groups=2, axis=2)
class TestMaxOutOpGroups(TestMaxOutOp):
def set_attrs(self):
self.groups = 3
self.assertRaises(ValueError, _attr_axis)
class TestMaxoutAPI(unittest.TestCase):
# test paddle.nn.Maxout, paddle.nn.functional.maxout
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [2, 6, 5, 4]).astype(np.float64)
self.groups = 2
self.axis = 1
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.maxout(x, self.groups, self.axis)
m = paddle.nn.Maxout(self.groups, self.axis)
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
for r in res:
self.assertTrue(np.allclose(out_ref, r))
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.maxout(x, self.groups, self.axis)
m = paddle.nn.Maxout(self.groups, self.axis)
out2 = m(x)
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
for r in [out1, out2]:
self.assertTrue(np.allclose(out_ref, r.numpy()))
out3 = F.maxout(x, self.groups, -1)
out3_ref = maxout_forward_naive(self.x_np, self.groups, -1)
self.assertTrue(np.allclose(out3_ref, out3.numpy()))
paddle.enable_static()
def test_fluid_api(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = fluid.layers.maxout(x, groups=self.groups, axis=self.axis)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
self.assertTrue(np.allclose(out_ref, res[0]))
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out = paddle.fluid.layers.maxout(x, groups=self.groups, axis=self.axis)
self.assertTrue(np.allclose(out_ref, out.numpy()))
paddle.enable_static()
class TestMaxOutOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.maxout, 1, 2)
self.assertRaises(TypeError, F.maxout, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.maxout, x_int32, 2)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
fluid.layers.maxout(x_fp32, 2)
x_int32 = paddle.data(
name='x_int32', shape=[2, 4, 6, 8], dtype='int32')
self.assertRaises(TypeError, F.maxout, x_int32)
x_float32 = paddle.data(name='x_float32', shape=[2, 4, 6, 8])
self.assertRaises(ValueError, F.maxout, x_float32, 2, 2)
if __name__ == '__main__':
......
......@@ -55,6 +55,7 @@ from .layer.activation import ELU #DEFINE_ALIAS
from .layer.activation import GELU #DEFINE_ALIAS
from .layer.activation import Tanh #DEFINE_ALIAS
from .layer.activation import Hardshrink #DEFINE_ALIAS
from .layer.activation import Hardswish #DEFINE_ALIAS
from .layer.activation import Hardtanh #DEFINE_ALIAS
from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS
......@@ -62,6 +63,7 @@ from .layer.activation import ReLU6 #DEFINE_ALIAS
from .layer.activation import SELU #DEFINE_ALIAS
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import Hardsigmoid #DEFINE_ALIAS
from .layer.activation import LogSigmoid
from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import Softplus #DEFINE_ALIAS
......@@ -70,6 +72,7 @@ from .layer.activation import Softsign #DEFINE_ALIAS
from .layer.activation import Tanhshrink #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS
from .layer.activation import HSigmoid #DEFINE_ALIAS
from .layer.activation import Maxout #DEFINE_ALIAS
from .layer.common import BilinearTensorProduct #DEFINE_ALIAS
from .layer.common import Pool2D #DEFINE_ALIAS
from .layer.common import Pad2D #DEFINE_ALIAS
......
......@@ -29,14 +29,13 @@ from . import pooling
__all__ += pooling.__all__
from . import loss
__all__ += loss.__all__
from .activation import brelu #DEFINE_ALIAS
from .activation import elu #DEFINE_ALIAS
from .activation import erf #DEFINE_ALIAS
from .activation import gelu #DEFINE_ALIAS
from .activation import hardshrink #DEFINE_ALIAS
from .activation import hardtanh #DEFINE_ALIAS
from .activation import hard_sigmoid #DEFINE_ALIAS
from .activation import hard_swish #DEFINE_ALIAS
from .activation import hardsigmoid #DEFINE_ALIAS
from .activation import hardswish #DEFINE_ALIAS
from .activation import hsigmoid #DEFINE_ALIAS
from .activation import leaky_relu #DEFINE_ALIAS
from .activation import log_sigmoid #DEFINE_ALIAS
......
......@@ -13,11 +13,7 @@
# limitations under the License.
# TODO: define activation functions of neural network
from ...fluid.layers import brelu #DEFINE_ALIAS
from ...fluid.layers import erf #DEFINE_ALIAS
from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS
from ...fluid.layers import hard_swish #DEFINE_ALIAS
from ...fluid.layers import maxout #DEFINE_ALIAS
from ...fluid.layers import soft_relu #DEFINE_ALIAS
from ...fluid.layers import swish #DEFINE_ALIAS
from ...fluid.layers import sigmoid #DEFINE_ALIAS
......@@ -25,14 +21,13 @@ from ...fluid.layers import thresholded_relu #DEFINE_ALIAS
from ...tensor.math import tanh #DEFINE_ALIAS
__all__ = [
'brelu',
'elu',
'erf',
'gelu',
'hardshrink',
'hardtanh',
'hard_sigmoid',
'hard_swish',
'hardsigmoid',
'hardswish',
'hsigmoid',
'leaky_relu',
'log_sigmoid',
......@@ -265,6 +260,109 @@ def hardtanh(x, min=-1.0, max=1.0, name=None):
return out
def hardsigmoid(x, name=None):
"""
hardsigmoid activation.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
.. math::
hardsigmoid(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&1, & & \\text{if } x \\geq 3 \\\\
&x/6 + 1/2, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-4., 5., 1.])
out = F.hardsigmoid(x) # [0., 1., 0.666667]
"""
if in_dygraph_mode():
return core.ops.hard_sigmoid(x, 'slope', 0.1666666666666667, 'offset',
0.5)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hardsigmoid')
helper = LayerHelper('hardsigmoid', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='hard_sigmoid',
inputs={'X': x},
outputs={'Out': out},
attrs={'slope': 0.1666666666666667,
'offset': 0.5})
return out
def hardswish(x, name=None):
"""
hardswish activation
hardswish is proposed in MobileNetV3, and performs better in computational stability
and efficiency compared to swish function. For more details please refer
to: https://arxiv.org/pdf/1905.02244.pdf
.. math::
hardswish(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&x, & & \\text{if } x \\geq 3 \\\\
&\\frac{x(x+3)}{6}, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-4., 5., 1.])
out = F.hardswish(x) # [0., 5., 0.666667]
"""
if in_dygraph_mode():
return core.ops.hard_swish(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hardswish')
helper = LayerHelper('hardswish', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='hard_swish', inputs={'X': x}, outputs={'Out': out})
return out
def hsigmoid(input,
label,
weight,
......@@ -591,6 +689,81 @@ def log_sigmoid(x, name=None):
return out
def maxout(x, groups, axis=1, name=None):
"""
maxout activation.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then Co = Ci/groups and the operator formula is as follows:
.. math::
&out_{si+j} = \\max_{k} x_{gsi + sk + j} \\\\
&g = groups \\\\
&s = \\frac{input.size}{num\\_channels} \\\\
&0 \\le i < \\frac{num\\_channels}{groups} \\\\
&0 \\le j < s \\\\
&0 \\le k < groups
Parameters:
x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type
of input is float32 or float64.
groups (int, optional): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be
a factor of number of features. Default is 1.
axis (int, optional): The axis along which to perform maxout calculations.
It should be 1 when data format is NCHW, be -1 or 3 when data format
is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
where D is the dimensions of ``x`` . ``axis`` only supports 1, 3 or -1.
Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.rand([1, 2, 3, 4])
# [[[[0.5002636 0.22272532 0.17402348 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.02879342 0.88725346 0.61093384 0.38833922]]
# [[0.5231306 0.03807496 0.91661984 0.15602879]
# [0.666127 0.616567 0.30741522 0.24044901]
# [0.7142536 0.7351477 0.31588817 0.23782359]]]]
out = F.maxout(x, groups=2)
# [[[[0.5231306 0.22272532 0.91661984 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.7142536 0.88725346 0.61093384 0.38833922]]]]
"""
if in_dygraph_mode():
return core.ops.maxout(x, 'groups', groups, 'axis', axis)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
helper = LayerHelper('maxout', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='maxout',
inputs={'X': x},
outputs={'Out': out},
attrs={'groups': groups,
'axis': axis})
return out
def relu6(x, name=None):
"""
relu6 activation
......
......@@ -18,6 +18,7 @@ __all__ = [
'ELU',
'GELU',
'Hardshrink',
'Hardswish',
'Tanh',
'Hardtanh',
'PReLU',
......@@ -26,6 +27,7 @@ __all__ = [
'SELU',
'LeakyReLU',
'Sigmoid',
'Hardsigmoid',
'Softmax',
'Softplus',
'Softshrink',
......@@ -33,6 +35,7 @@ __all__ = [
'Tanhshrink',
'LogSigmoid',
'LogSoftmax',
'Maxout',
'HSigmoid',
]
......@@ -184,6 +187,52 @@ class Hardshrink(layers.Layer):
return F.hardshrink(x, self._threshold, self._name)
class Hardswish(layers.Layer):
"""
Hardswish activation
Hardswish is proposed in MobileNetV3, and performs better in computational stability
and efficiency compared to swish function. For more details please refer
to: https://arxiv.org/pdf/1905.02244.pdf
.. math::
Hardswish(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&x, & & \\text{if } x \\geq 3 \\\\
&\\frac{x(x+3)}{6}, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-4., 5., 1.])
m = paddle.nn.Hardswish()
out = m(x) # [0., 5., 0.666667]
"""
def __init__(self, name=None):
super(Hardswish, self).__init__()
self._name = name
def forward(self, x):
return F.hardswish(x, self._name)
class Tanh(layers.Layer):
"""
Tanh Activation.
......@@ -680,6 +729,53 @@ class Sigmoid(layers.Layer):
return F.sigmoid(x, self.name)
class Hardsigmoid(layers.Layer):
"""
This interface is used to construct a callable object of the ``Hardsigmoid`` class.
This layer calcluate the `hardsigmoid` of input x.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
.. math::
Hardsigmoid(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&1, & & \\text{if } x \\geq 3 \\\\
&x/6 + 1/2, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
x: N-D tensor, available dtype is float32, float64.
Returns:
A callable object of Hardsigmoid.
Examples:
.. code-block:: python
import paddle
m = paddle.nn.Sigmoid()
x = paddle.to_tensor([-4., 5., 1.])
out = m(x) # [0., 1, 0.666667]
"""
def __init__(self, name=None):
super(Hardsigmoid, self).__init__()
self.name = name
def forward(self, x):
return F.hardsigmoid(x, self.name)
class Softplus(layers.Layer):
"""
Softplus Activation
......@@ -1060,3 +1156,64 @@ class LogSoftmax(layers.Layer):
def forward(self, x):
return F.log_softmax(x, self._axis)
class Maxout(layers.Layer):
"""
Maxout Activation.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then Co = Ci/groups and the operator formula is as follows:
.. math::
&out_{si+j} = \max_{k} x_{gsi + sk + j} \\\\
&g = groups \\\\
&s = \\frac{input.size}{num\\_channels} \\\\
&0 \\le i < \\frac{num\\_channels}{groups} \\\\
&0 \\le j < s \\\\
&0 \\le k < groups
Parameters:
groups (int, optional): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be
a factor of number of features. Default is 1.
axis (int, optional): The axis along which to perform maxout calculations.
It should be 1 when data format is NCHW, be -1 or 3 when data format
is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
where D is the dimensions of ``x`` . Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: :math:`(N, C_{in}, H_{in}, W_{in})`
- output: :math:`(N, C_{out}, H_{out}, W_{out})`
Examples:
.. code-block:: python
import paddle
x = paddle.rand([1, 2, 3, 4])
# [[[[0.5002636 0.22272532 0.17402348 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.02879342 0.88725346 0.61093384 0.38833922]]
# [[0.5231306 0.03807496 0.91661984 0.15602879]
# [0.666127 0.616567 0.30741522 0.24044901]
# [0.7142536 0.7351477 0.31588817 0.23782359]]]]
m = paddle.nn.Maxout(groups=2)
out = m(x)
# [[[[0.5231306 0.22272532 0.91661984 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.7142536 0.88725346 0.61093384 0.38833922]]]]
"""
def __init__(self, groups, axis=1, name=None):
super(Maxout, self).__init__()
self._groups = groups
self._axis = axis
self._name = name
def forward(self, x):
return F.maxout(x, self._groups, self._axis, self._name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册