未验证 提交 0025e0d8 编写于 作者: Z zhupengyang 提交者: GitHub

refine APIs: brelu, hardsigmoid, hardswish, maxout (#27658)

上级 5098891f
...@@ -83,6 +83,18 @@ class MaxOutOp : public framework::OperatorWithKernel { ...@@ -83,6 +83,18 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Attr(groups) of Op(maxout) should be " "Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d.", "larger than 1. But received %d.",
groups)); groups));
PADDLE_ENFORCE_EQ(
axis == 1 || axis == -1 || axis == 3, true,
platform::errors::InvalidArgument(
"axis only supported 1, -1 or 3, but recevied axis is: %d", axis));
PADDLE_ENFORCE_EQ(in_x_dims.size(), 4,
platform::errors::InvalidArgument(
"x's dims should be 4, but received x's dims is: %d",
in_x_dims.size()));
if (axis < 0) {
axis += in_x_dims.size();
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0, in_x_dims[axis] % groups, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -31,6 +31,9 @@ class MaxOutKernel : public framework::OpKernel<T> { ...@@ -31,6 +31,9 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis"); int axis = context.template Attr<int>("axis");
if (axis < 0) {
axis += in_x->dims().size();
}
math::MaxOutFunctor<DeviceContext, T> maxout_forward; math::MaxOutFunctor<DeviceContext, T> maxout_forward;
maxout_forward(context.template device_context<DeviceContext>(), *in_x, out, maxout_forward(context.template device_context<DeviceContext>(), *in_x, out,
...@@ -49,6 +52,10 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -49,6 +52,10 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis"); int axis = context.template Attr<int>("axis");
if (axis < 0) {
axis += in_x->dims().size();
}
auto& device_ctx = context.template device_context<DeviceContext>(); auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) { if (in_x_grad) {
......
...@@ -9592,10 +9592,6 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None): ...@@ -9592,10 +9592,6 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
@templatedoc() @templatedoc()
def hard_sigmoid(x, slope=0.2, offset=0.5, name=None): def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
""" """
:alias_main: paddle.nn.functional.hard_sigmoid
:alias: paddle.nn.functional.hard_sigmoid,paddle.nn.functional.activation.hard_sigmoid
:old_api: paddle.fluid.layers.hard_sigmoid
${comment} ${comment}
Parameters: Parameters:
x (${x_type}): ${x_comment} x (${x_type}): ${x_comment}
...@@ -9613,9 +9609,15 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None): ...@@ -9613,9 +9609,15 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
paddle.enable_static()
data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]] result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]]
""" """
if in_dygraph_mode():
return core.ops.hard_sigmoid(x, 'slope', slope, 'offset', offset)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_sigmoid') 'hard_sigmoid')
...@@ -9802,10 +9804,6 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9802,10 +9804,6 @@ def prelu(x, mode, param_attr=None, name=None):
@templatedoc() @templatedoc()
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
""" """
:alias_main: paddle.nn.functional.brelu
:alias: paddle.nn.functional.brelu,paddle.nn.functional.activation.brelu
:old_api: paddle.fluid.layers.brelu
${comment} ${comment}
Args: Args:
x(${x_type}): ${x_comment} x(${x_type}): ${x_comment}
...@@ -9821,7 +9819,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): ...@@ -9821,7 +9819,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
import numpy as np import numpy as np
paddle.enable_static()
input_brelu = np.array([[-1,6],[1,15.6]]) input_brelu = np.array([[-1,6],[1,15.6]])
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -9831,6 +9831,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): ...@@ -9831,6 +9831,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
#[[ 1. 6.] #[[ 1. 6.]
#[ 1. 10.]] #[ 1. 10.]]
""" """
if in_dygraph_mode():
return core.ops.brelu(x, 't_min', t_min, 't_max', t_max)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu')
helper = LayerHelper('brelu', **locals()) helper = LayerHelper('brelu', **locals())
...@@ -12564,13 +12567,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -12564,13 +12567,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
return out return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.maxout")
@templatedoc() @templatedoc()
def maxout(x, groups, name=None, axis=1): def maxout(x, groups, name=None, axis=1):
""" """
:alias_main: paddle.nn.functional.maxout
:alias: paddle.nn.functional.maxout,paddle.nn.functional.activation.maxout
:old_api: paddle.fluid.layers.maxout
${comment} ${comment}
Args: Args:
...@@ -12592,31 +12592,16 @@ def maxout(x, groups, name=None, axis=1): ...@@ -12592,31 +12592,16 @@ def maxout(x, groups, name=None, axis=1):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
paddle.enable_static()
input = fluid.data( input = fluid.data(
name='data', name='data',
shape=[None, 256, 32, 32], shape=[None, 256, 32, 32],
dtype='float32') dtype='float32')
out = fluid.layers.maxout(input, groups=2) out = fluid.layers.maxout(input, groups=2)
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout') return paddle.nn.functional.maxout(**locals())
helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="maxout",
inputs={"X": x},
attrs={"groups": groups,
"axis": axis},
outputs={"Out": out})
return out
def space_to_depth(x, blocksize, name=None): def space_to_depth(x, blocksize, name=None):
...@@ -14877,10 +14862,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): ...@@ -14877,10 +14862,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
@templatedoc() @templatedoc()
def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None): def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
""" """
:alias_main: paddle.nn.functional.hard_swish
:alias: paddle.nn.functional.hard_swish,paddle.nn.functional.activation.hard_swish
:old_api: paddle.fluid.layers.hard_swish
This operator implements the hard_swish activation function. This operator implements the hard_swish activation function.
Hard_swish is proposed in MobileNetV3, and performs better in computational stability and efficiency compared to swish function. Hard_swish is proposed in MobileNetV3, and performs better in computational stability and efficiency compared to swish function.
For more details please refer to: https://arxiv.org/pdf/1905.02244.pdf For more details please refer to: https://arxiv.org/pdf/1905.02244.pdf
...@@ -14911,7 +14892,9 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None): ...@@ -14911,7 +14892,9 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
import numpy as np import numpy as np
paddle.enable_static()
DATATYPE='float32' DATATYPE='float32'
...@@ -14926,6 +14909,10 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None): ...@@ -14926,6 +14909,10 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name]) out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667,3., 4.]] print(out) # [[0.66666667, 1.66666667,3., 4.]]
""" """
if in_dygraph_mode():
return core.ops.hard_swish(x, 'threshold', threshold, 'scale', scale,
'offset', offset)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_swish') 'hard_swish')
......
...@@ -1657,21 +1657,6 @@ class TestLayer(LayerTest): ...@@ -1657,21 +1657,6 @@ class TestLayer(LayerTest):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
layers.eye(num_rows=3, batch_shape=[-1]) layers.eye(num_rows=3, batch_shape=[-1])
def test_hard_swish(self):
with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32')
ret = layers.hard_swish(t)
static_ret = self.get_static_graph_result(
feed={'t': np.ones(
[3, 3], dtype='float32')}, fetch_list=[ret])[0]
with self.dynamic_graph():
t = np.ones([3, 3], dtype='float32')
dy_ret = layers.hard_swish(base.to_variable(t))
dy_ret_rlt = dy_ret.numpy()
self.assertTrue(np.allclose(static_ret, dy_ret_rlt))
def test_while_loop(self): def test_while_loop(self):
with self.static_graph(): with self.static_graph():
i = layers.fill_constant(shape=[1], dtype='int64', value=0) i = layers.fill_constant(shape=[1], dtype='int64', value=0)
...@@ -2563,13 +2548,6 @@ class TestBook(LayerTest): ...@@ -2563,13 +2548,6 @@ class TestBook(LayerTest):
output = layers.l2_normalize(x, axis=1) output = layers.l2_normalize(x, axis=1)
return output return output
def make_maxout(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
data = self._get_data(name='x', shape=[8, 6, 6], dtype="float32")
output = layers.maxout(x=data, groups=2)
return (output)
def make_crop(self): def make_crop(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):
...@@ -2656,13 +2634,6 @@ class TestBook(LayerTest): ...@@ -2656,13 +2634,6 @@ class TestBook(LayerTest):
name='prelu') name='prelu')
return (out) return (out)
def make_brelu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = self._get_data(name="input", shape=[16], dtype="float32")
out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu')
return (out)
def make_soft_relu(self): def make_soft_relu(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):
......
...@@ -16,32 +16,43 @@ from __future__ import print_function ...@@ -16,32 +16,43 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.nn.functional as F
from op_test import OpTest from op_test import OpTest
paddle.enable_static()
np.random.seed(1)
def maxout_forward_naive(input, groups, channel_axis):
s0, s1, s2, s3 = input.shape def maxout_forward_naive(x, groups, channel_axis):
if channel_axis == 3: s0, s1, s2, s3 = x.shape
return np.ndarray([s0, s1, s2, s3 // groups, groups], \ if channel_axis == 1:
buffer = input, dtype=input.dtype).max(axis=(4)) return np.ndarray([s0, s1 // groups, groups, s2, s3], \
return np.ndarray([s0, s1 // groups, groups, s2, s3], \ buffer = x, dtype=x.dtype).max(axis=2)
buffer = input, dtype=input.dtype).max(axis=(2)) return np.ndarray([s0, s1, s2, s3 // groups, groups], \
buffer = x, dtype=x.dtype).max(axis=4)
class TestMaxOutOp(OpTest): class TestMaxOutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "maxout" self.op_type = "maxout"
self.init_test_case() self.dtype = 'float64'
input = np.random.random(self.shape) self.shape = [3, 6, 2, 4]
output = self.MaxOut_forward_naive(input, self.groups, self.axis) self.groups = 2
self.axis = 1
self.set_attrs()
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = maxout_forward_naive(x, self.groups, self.axis)
self.inputs = {'X': input} self.inputs = {'X': x}
self.attrs = {'groups': self.groups, 'axis': self.axis} self.attrs = {'groups': self.groups, 'axis': self.axis}
self.outputs = {'Out': out}
self.outputs = {'Out': output} def set_attrs(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -49,65 +60,89 @@ class TestMaxOutOp(OpTest): ...@@ -49,65 +60,89 @@ class TestMaxOutOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups = 2
self.axis = 1
class TestMaxOutOpAxis(TestMaxOutOp): class TestMaxOutOpAxis0(TestMaxOutOp):
def init_test_case(self): def set_attrs(self):
self.MaxOut_forward_naive = maxout_forward_naive self.axis = -1
self.shape = [100, 2, 2, 6] # NHWC format
self.groups = 2
self.axis = 3
class TestMaxOutOpAxisAPI(unittest.TestCase): class TestMaxOutOpAxis1(TestMaxOutOp):
def test_axis(self): def set_attrs(self):
data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32') self.axis = 3
data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32')
out1 = fluid.layers.maxout(data1, groups=2, axis=1)
out2 = fluid.layers.maxout(data2, groups=2, axis=-1)
data1_np = np.random.random((3, 6, 2, 2)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue( class TestMaxOutOpFP32(TestMaxOutOp):
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2)))) def set_attrs(self):
self.dtype = 'float32'
def test_exception(self):
input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32")
def _attr_axis(): class TestMaxOutOpGroups(TestMaxOutOp):
out = fluid.layers.maxout(input, groups=2, axis=2) def set_attrs(self):
self.groups = 3
self.assertRaises(ValueError, _attr_axis)
class TestMaxoutAPI(unittest.TestCase):
# test paddle.nn.Maxout, paddle.nn.functional.maxout
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [2, 6, 5, 4]).astype(np.float64)
self.groups = 2
self.axis = 1
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.maxout(x, self.groups, self.axis)
m = paddle.nn.Maxout(self.groups, self.axis)
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
for r in res:
self.assertTrue(np.allclose(out_ref, r))
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.maxout(x, self.groups, self.axis)
m = paddle.nn.Maxout(self.groups, self.axis)
out2 = m(x)
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
for r in [out1, out2]:
self.assertTrue(np.allclose(out_ref, r.numpy()))
out3 = F.maxout(x, self.groups, -1)
out3_ref = maxout_forward_naive(self.x_np, self.groups, -1)
self.assertTrue(np.allclose(out3_ref, out3.numpy()))
paddle.enable_static()
def test_fluid_api(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = fluid.layers.maxout(x, groups=self.groups, axis=self.axis)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
self.assertTrue(np.allclose(out_ref, res[0]))
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out = paddle.fluid.layers.maxout(x, groups=self.groups, axis=self.axis)
self.assertTrue(np.allclose(out_ref, out.numpy()))
paddle.enable_static()
class TestMaxOutOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.maxout, 1, 2) self.assertRaises(TypeError, F.maxout, 1)
# The input dtype must be float16, float32, float64. # The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') x_int32 = paddle.data(
self.assertRaises(TypeError, fluid.layers.maxout, x_int32, 2) name='x_int32', shape=[2, 4, 6, 8], dtype='int32')
# support the input dtype is float32 self.assertRaises(TypeError, F.maxout, x_int32)
x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
fluid.layers.maxout(x_fp32, 2) x_float32 = paddle.data(name='x_float32', shape=[2, 4, 6, 8])
self.assertRaises(ValueError, F.maxout, x_float32, 2, 2)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -55,6 +55,7 @@ from .layer.activation import ELU #DEFINE_ALIAS ...@@ -55,6 +55,7 @@ from .layer.activation import ELU #DEFINE_ALIAS
from .layer.activation import GELU #DEFINE_ALIAS from .layer.activation import GELU #DEFINE_ALIAS
from .layer.activation import Tanh #DEFINE_ALIAS from .layer.activation import Tanh #DEFINE_ALIAS
from .layer.activation import Hardshrink #DEFINE_ALIAS from .layer.activation import Hardshrink #DEFINE_ALIAS
from .layer.activation import Hardswish #DEFINE_ALIAS
from .layer.activation import Hardtanh #DEFINE_ALIAS from .layer.activation import Hardtanh #DEFINE_ALIAS
from .layer.activation import PReLU #DEFINE_ALIAS from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS from .layer.activation import ReLU #DEFINE_ALIAS
...@@ -62,6 +63,7 @@ from .layer.activation import ReLU6 #DEFINE_ALIAS ...@@ -62,6 +63,7 @@ from .layer.activation import ReLU6 #DEFINE_ALIAS
from .layer.activation import SELU #DEFINE_ALIAS from .layer.activation import SELU #DEFINE_ALIAS
from .layer.activation import LeakyReLU #DEFINE_ALIAS from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import Hardsigmoid #DEFINE_ALIAS
from .layer.activation import LogSigmoid from .layer.activation import LogSigmoid
from .layer.activation import Softmax #DEFINE_ALIAS from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import Softplus #DEFINE_ALIAS from .layer.activation import Softplus #DEFINE_ALIAS
...@@ -70,6 +72,7 @@ from .layer.activation import Softsign #DEFINE_ALIAS ...@@ -70,6 +72,7 @@ from .layer.activation import Softsign #DEFINE_ALIAS
from .layer.activation import Tanhshrink #DEFINE_ALIAS from .layer.activation import Tanhshrink #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS
from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.activation import HSigmoid #DEFINE_ALIAS
from .layer.activation import Maxout #DEFINE_ALIAS
from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import BilinearTensorProduct #DEFINE_ALIAS
from .layer.common import Pool2D #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS
from .layer.common import Pad2D #DEFINE_ALIAS from .layer.common import Pad2D #DEFINE_ALIAS
......
...@@ -29,14 +29,13 @@ from . import pooling ...@@ -29,14 +29,13 @@ from . import pooling
__all__ += pooling.__all__ __all__ += pooling.__all__
from . import loss from . import loss
__all__ += loss.__all__ __all__ += loss.__all__
from .activation import brelu #DEFINE_ALIAS
from .activation import elu #DEFINE_ALIAS from .activation import elu #DEFINE_ALIAS
from .activation import erf #DEFINE_ALIAS from .activation import erf #DEFINE_ALIAS
from .activation import gelu #DEFINE_ALIAS from .activation import gelu #DEFINE_ALIAS
from .activation import hardshrink #DEFINE_ALIAS from .activation import hardshrink #DEFINE_ALIAS
from .activation import hardtanh #DEFINE_ALIAS from .activation import hardtanh #DEFINE_ALIAS
from .activation import hard_sigmoid #DEFINE_ALIAS from .activation import hardsigmoid #DEFINE_ALIAS
from .activation import hard_swish #DEFINE_ALIAS from .activation import hardswish #DEFINE_ALIAS
from .activation import hsigmoid #DEFINE_ALIAS from .activation import hsigmoid #DEFINE_ALIAS
from .activation import leaky_relu #DEFINE_ALIAS from .activation import leaky_relu #DEFINE_ALIAS
from .activation import log_sigmoid #DEFINE_ALIAS from .activation import log_sigmoid #DEFINE_ALIAS
......
...@@ -13,11 +13,7 @@ ...@@ -13,11 +13,7 @@
# limitations under the License. # limitations under the License.
# TODO: define activation functions of neural network # TODO: define activation functions of neural network
from ...fluid.layers import brelu #DEFINE_ALIAS
from ...fluid.layers import erf #DEFINE_ALIAS from ...fluid.layers import erf #DEFINE_ALIAS
from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS
from ...fluid.layers import hard_swish #DEFINE_ALIAS
from ...fluid.layers import maxout #DEFINE_ALIAS
from ...fluid.layers import soft_relu #DEFINE_ALIAS from ...fluid.layers import soft_relu #DEFINE_ALIAS
from ...fluid.layers import swish #DEFINE_ALIAS from ...fluid.layers import swish #DEFINE_ALIAS
from ...fluid.layers import sigmoid #DEFINE_ALIAS from ...fluid.layers import sigmoid #DEFINE_ALIAS
...@@ -25,14 +21,13 @@ from ...fluid.layers import thresholded_relu #DEFINE_ALIAS ...@@ -25,14 +21,13 @@ from ...fluid.layers import thresholded_relu #DEFINE_ALIAS
from ...tensor.math import tanh #DEFINE_ALIAS from ...tensor.math import tanh #DEFINE_ALIAS
__all__ = [ __all__ = [
'brelu',
'elu', 'elu',
'erf', 'erf',
'gelu', 'gelu',
'hardshrink', 'hardshrink',
'hardtanh', 'hardtanh',
'hard_sigmoid', 'hardsigmoid',
'hard_swish', 'hardswish',
'hsigmoid', 'hsigmoid',
'leaky_relu', 'leaky_relu',
'log_sigmoid', 'log_sigmoid',
...@@ -75,10 +70,10 @@ def elu(x, alpha=1.0, name=None): ...@@ -75,10 +70,10 @@ def elu(x, alpha=1.0, name=None):
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0. alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
A Tensor with the same data type and shape as ``x`` . A Tensor with the same data type and shape as ``x`` .
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -89,7 +84,7 @@ def elu(x, alpha=1.0, name=None): ...@@ -89,7 +84,7 @@ def elu(x, alpha=1.0, name=None):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) x = paddle.to_tensor(np.array([[-1,6],[1,15.6]]))
out = F.elu(x, alpha=0.2) out = F.elu(x, alpha=0.2)
# [[-0.12642411 6. ] # [[-0.12642411 6. ]
# [ 1. 15.6 ]] # [ 1. 15.6 ]]
""" """
...@@ -123,16 +118,16 @@ def gelu(x, approximate=False, name=None): ...@@ -123,16 +118,16 @@ def gelu(x, approximate=False, name=None):
.. math:: .. math::
gelu(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}})) gelu(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}}))
Parameters: Parameters:
x (Tensor): The input Tensor with data type float32, float64. x (Tensor): The input Tensor with data type float32, float64.
approximate (bool, optional): Wether to enable approximation. Default is False. approximate (bool, optional): Wether to enable approximation. Default is False.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
A Tensor with the same data type and shape as ``x`` . A Tensor with the same data type and shape as ``x`` .
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -265,6 +260,109 @@ def hardtanh(x, min=-1.0, max=1.0, name=None): ...@@ -265,6 +260,109 @@ def hardtanh(x, min=-1.0, max=1.0, name=None):
return out return out
def hardsigmoid(x, name=None):
"""
hardsigmoid activation.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
.. math::
hardsigmoid(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&1, & & \\text{if } x \\geq 3 \\\\
&x/6 + 1/2, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-4., 5., 1.])
out = F.hardsigmoid(x) # [0., 1., 0.666667]
"""
if in_dygraph_mode():
return core.ops.hard_sigmoid(x, 'slope', 0.1666666666666667, 'offset',
0.5)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hardsigmoid')
helper = LayerHelper('hardsigmoid', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='hard_sigmoid',
inputs={'X': x},
outputs={'Out': out},
attrs={'slope': 0.1666666666666667,
'offset': 0.5})
return out
def hardswish(x, name=None):
"""
hardswish activation
hardswish is proposed in MobileNetV3, and performs better in computational stability
and efficiency compared to swish function. For more details please refer
to: https://arxiv.org/pdf/1905.02244.pdf
.. math::
hardswish(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&x, & & \\text{if } x \\geq 3 \\\\
&\\frac{x(x+3)}{6}, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-4., 5., 1.])
out = F.hardswish(x) # [0., 5., 0.666667]
"""
if in_dygraph_mode():
return core.ops.hard_swish(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hardswish')
helper = LayerHelper('hardswish', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='hard_swish', inputs={'X': x}, outputs={'Out': out})
return out
def hsigmoid(input, def hsigmoid(input,
label, label,
weight, weight,
...@@ -489,7 +587,7 @@ def prelu(x, weight, name=None): ...@@ -489,7 +587,7 @@ def prelu(x, weight, name=None):
assert len(weight.shape assert len(weight.shape
) == 1, "The dim count of weight shape should be 1 in prelu()." ) == 1, "The dim count of weight shape should be 1 in prelu()."
# NOTE(): The input of this API should be ``N,C,...`` format, # NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel. # which means x.shape[0] is batch_size and x.shape[0] is channel.
mode = 'all' mode = 'all'
if weight.shape[0] > 1: if weight.shape[0] > 1:
...@@ -559,15 +657,15 @@ def log_sigmoid(x, name=None): ...@@ -559,15 +657,15 @@ def log_sigmoid(x, name=None):
.. math:: .. math::
log\\_sigmoid(x) = log \\frac{1}{1 + e^{-x}} log\\_sigmoid(x) = log \\frac{1}{1 + e^{-x}}
Parameters: Parameters:
x (Tensor): The input Tensor with data type float32, float64. x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
A Tensor with the same data type and shape as ``x`` . A Tensor with the same data type and shape as ``x`` .
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -591,6 +689,81 @@ def log_sigmoid(x, name=None): ...@@ -591,6 +689,81 @@ def log_sigmoid(x, name=None):
return out return out
def maxout(x, groups, axis=1, name=None):
"""
maxout activation.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then Co = Ci/groups and the operator formula is as follows:
.. math::
&out_{si+j} = \\max_{k} x_{gsi + sk + j} \\\\
&g = groups \\\\
&s = \\frac{input.size}{num\\_channels} \\\\
&0 \\le i < \\frac{num\\_channels}{groups} \\\\
&0 \\le j < s \\\\
&0 \\le k < groups
Parameters:
x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type
of input is float32 or float64.
groups (int, optional): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be
a factor of number of features. Default is 1.
axis (int, optional): The axis along which to perform maxout calculations.
It should be 1 when data format is NCHW, be -1 or 3 when data format
is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
where D is the dimensions of ``x`` . ``axis`` only supports 1, 3 or -1.
Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.rand([1, 2, 3, 4])
# [[[[0.5002636 0.22272532 0.17402348 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.02879342 0.88725346 0.61093384 0.38833922]]
# [[0.5231306 0.03807496 0.91661984 0.15602879]
# [0.666127 0.616567 0.30741522 0.24044901]
# [0.7142536 0.7351477 0.31588817 0.23782359]]]]
out = F.maxout(x, groups=2)
# [[[[0.5231306 0.22272532 0.91661984 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.7142536 0.88725346 0.61093384 0.38833922]]]]
"""
if in_dygraph_mode():
return core.ops.maxout(x, 'groups', groups, 'axis', axis)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
helper = LayerHelper('maxout', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='maxout',
inputs={'X': x},
outputs={'Out': out},
attrs={'groups': groups,
'axis': axis})
return out
def relu6(x, name=None): def relu6(x, name=None):
""" """
relu6 activation relu6 activation
...@@ -778,7 +951,7 @@ def softmax(x, axis=-1, dtype=None, name=None): ...@@ -778,7 +951,7 @@ def softmax(x, axis=-1, dtype=None, name=None):
:math:`axis + D` . Default is -1. :math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64. preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x. If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None. Default is None.
...@@ -1051,13 +1224,13 @@ def log_softmax(x, axis=-1, dtype=None, name=None): ...@@ -1051,13 +1224,13 @@ def log_softmax(x, axis=-1, dtype=None, name=None):
:math:`axis + D` . Default is -1. :math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64. preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x. If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None. Default is None.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
A Tensor with the same shape and data type (use ``dtype`` if it is A Tensor with the same shape and data type (use ``dtype`` if it is
specified) as x. specified) as x.
......
...@@ -18,6 +18,7 @@ __all__ = [ ...@@ -18,6 +18,7 @@ __all__ = [
'ELU', 'ELU',
'GELU', 'GELU',
'Hardshrink', 'Hardshrink',
'Hardswish',
'Tanh', 'Tanh',
'Hardtanh', 'Hardtanh',
'PReLU', 'PReLU',
...@@ -26,6 +27,7 @@ __all__ = [ ...@@ -26,6 +27,7 @@ __all__ = [
'SELU', 'SELU',
'LeakyReLU', 'LeakyReLU',
'Sigmoid', 'Sigmoid',
'Hardsigmoid',
'Softmax', 'Softmax',
'Softplus', 'Softplus',
'Softshrink', 'Softshrink',
...@@ -33,6 +35,7 @@ __all__ = [ ...@@ -33,6 +35,7 @@ __all__ = [
'Tanhshrink', 'Tanhshrink',
'LogSigmoid', 'LogSigmoid',
'LogSoftmax', 'LogSoftmax',
'Maxout',
'HSigmoid', 'HSigmoid',
] ]
...@@ -50,18 +53,18 @@ class ELU(layers.Layer): ...@@ -50,18 +53,18 @@ class ELU(layers.Layer):
ELU Activation. ELU Activation.
.. math:: .. math::
ELU(x) = max(0, x) + min(0, \\alpha * (e^{x}-1)) ELU(x) = max(0, x) + min(0, \\alpha * (e^{x}-1))
Parameters: Parameters:
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0. alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -106,11 +109,11 @@ class GELU(layers.Layer): ...@@ -106,11 +109,11 @@ class GELU(layers.Layer):
approximate (bool, optional): Wether to enable approximation. Default is False. approximate (bool, optional): Wether to enable approximation. Default is False.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -120,7 +123,7 @@ class GELU(layers.Layer): ...@@ -120,7 +123,7 @@ class GELU(layers.Layer):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]])) x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]]))
m = paddle.nn.GELU() m = paddle.nn.GELU()
out = m(x) # [-0.158655 0.345731 0.841345 1.39979] out = m(x) # [-0.158655 0.345731 0.841345 1.39979]
...@@ -184,6 +187,52 @@ class Hardshrink(layers.Layer): ...@@ -184,6 +187,52 @@ class Hardshrink(layers.Layer):
return F.hardshrink(x, self._threshold, self._name) return F.hardshrink(x, self._threshold, self._name)
class Hardswish(layers.Layer):
"""
Hardswish activation
Hardswish is proposed in MobileNetV3, and performs better in computational stability
and efficiency compared to swish function. For more details please refer
to: https://arxiv.org/pdf/1905.02244.pdf
.. math::
Hardswish(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&x, & & \\text{if } x \\geq 3 \\\\
&\\frac{x(x+3)}{6}, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-4., 5., 1.])
m = paddle.nn.Hardswish()
out = m(x) # [0., 5., 0.666667]
"""
def __init__(self, name=None):
super(Hardswish, self).__init__()
self._name = name
def forward(self, x):
return F.hardswish(x, self._name)
class Tanh(layers.Layer): class Tanh(layers.Layer):
""" """
Tanh Activation. Tanh Activation.
...@@ -240,11 +289,11 @@ class Hardtanh(layers.Layer): ...@@ -240,11 +289,11 @@ class Hardtanh(layers.Layer):
max (float, optional): The value of max for Hardtanh. Default is 1. max (float, optional): The value of max for Hardtanh. Default is 1.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -274,7 +323,7 @@ class HSigmoid(layers.Layer): ...@@ -274,7 +323,7 @@ class HSigmoid(layers.Layer):
:alias: paddle.nn.HSigmoid,paddle.nn.layer.HSigmoid,paddle.nn.layer.activation.HSigmoid :alias: paddle.nn.HSigmoid,paddle.nn.layer.HSigmoid,paddle.nn.layer.activation.HSigmoid
Hierarchical Sigmoid Layer. Hierarchical Sigmoid Layer.
The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
and speed up the model training, especially the training of language model. and speed up the model training, especially the training of language model.
Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier. Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
...@@ -309,7 +358,7 @@ class HSigmoid(layers.Layer): ...@@ -309,7 +358,7 @@ class HSigmoid(layers.Layer):
is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr, is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
set, the bias is initialized zero. Default: None. set, the bias is initialized zero. Default: None.
is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and
`path_code` should be passed to its forward method, otherwise `path_table` and `path_code` `path_code` should be passed to its forward method, otherwise `path_table` and `path_code`
should not be passed to its forward method. Default: False. should not be passed to its forward method. Default: False.
is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the
...@@ -414,19 +463,19 @@ class PReLU(layers.Layer): ...@@ -414,19 +463,19 @@ class PReLU(layers.Layer):
Parameters: Parameters:
num_parameters (int, optional): Number of `weight` to learn. The supported values are: num_parameters (int, optional): Number of `weight` to learn. The supported values are:
1 - a single parameter `alpha` is used for all input channels; 1 - a single parameter `alpha` is used for all input channels;
Number of channels - a seperate `alpha` is used for each input channel. Number of channels - a seperate `alpha` is used for each input channel.
Default is 1. Default is 1.
init (float, optional): Init value of learnable `weight`. Default is 0.25. init (float, optional): Init value of learnable `weight`. Default is 0.25.
weight_attr(ParamAttr, optional): The parameter attribute for the learnable `weight`. weight_attr(ParamAttr, optional): The parameter attribute for the learnable `weight`.
Default is None. For more information, please refer to :ref:`api_fluid_ParamAttr`. Default is None. For more information, please refer to :ref:`api_fluid_ParamAttr`.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. Default dtype is float32. - input: Tensor with any shape. Default dtype is float32.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -487,7 +536,7 @@ class ReLU(layers.Layer): ...@@ -487,7 +536,7 @@ class ReLU(layers.Layer):
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -613,11 +662,11 @@ class LeakyReLU(layers.Layer): ...@@ -613,11 +662,11 @@ class LeakyReLU(layers.Layer):
:math:`x < 0` . Default is 0.01. :math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -643,11 +692,11 @@ class LeakyReLU(layers.Layer): ...@@ -643,11 +692,11 @@ class LeakyReLU(layers.Layer):
class Sigmoid(layers.Layer): class Sigmoid(layers.Layer):
""" """
this interface is used to construct a callable object of the ``Sigmoid`` class. This layer calcluate the `sigmoid` of input x. this interface is used to construct a callable object of the ``Sigmoid`` class. This layer calcluate the `sigmoid` of input x.
.. math:: .. math::
Sigmoid(x) = \frac{1}{1 + e^{-x}} Sigmoid(x) = \frac{1}{1 + e^{-x}}
Parameters: Parameters:
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -656,7 +705,7 @@ class Sigmoid(layers.Layer): ...@@ -656,7 +705,7 @@ class Sigmoid(layers.Layer):
Returns: Returns:
A callable object of Sigmoid. A callable object of Sigmoid.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -680,6 +729,53 @@ class Sigmoid(layers.Layer): ...@@ -680,6 +729,53 @@ class Sigmoid(layers.Layer):
return F.sigmoid(x, self.name) return F.sigmoid(x, self.name)
class Hardsigmoid(layers.Layer):
"""
This interface is used to construct a callable object of the ``Hardsigmoid`` class.
This layer calcluate the `hardsigmoid` of input x.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
.. math::
Hardsigmoid(x)=
\\left\\{
\\begin{aligned}
&0, & & \\text{if } x \\leq -3 \\\\
&1, & & \\text{if } x \\geq 3 \\\\
&x/6 + 1/2, & & \\text{otherwise}
\\end{aligned}
\\right.
Parameters:
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
x: N-D tensor, available dtype is float32, float64.
Returns:
A callable object of Hardsigmoid.
Examples:
.. code-block:: python
import paddle
m = paddle.nn.Sigmoid()
x = paddle.to_tensor([-4., 5., 1.])
out = m(x) # [0., 1, 0.666667]
"""
def __init__(self, name=None):
super(Hardsigmoid, self).__init__()
self.name = name
def forward(self, x):
return F.hardsigmoid(x, self.name)
class Softplus(layers.Layer): class Softplus(layers.Layer):
""" """
Softplus Activation Softplus Activation
...@@ -842,7 +938,7 @@ class Tanhshrink(layers.Layer): ...@@ -842,7 +938,7 @@ class Tanhshrink(layers.Layer):
class LogSigmoid(layers.Layer): class LogSigmoid(layers.Layer):
""" """
LogSigmoid Activation. LogSigmoid Activation.
.. math:: .. math::
LogSigmoid(x) = log \\frac{1}{1 + e^{-x}} LogSigmoid(x) = log \\frac{1}{1 + e^{-x}}
...@@ -851,11 +947,11 @@ class LogSigmoid(layers.Layer): ...@@ -851,11 +947,11 @@ class LogSigmoid(layers.Layer):
x (Tensor): The input Tensor with data type float32, or float64. x (Tensor): The input Tensor with data type float32, or float64.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -961,7 +1057,7 @@ class Softmax(layers.Layer): ...@@ -961,7 +1057,7 @@ class Softmax(layers.Layer):
:math:`axis + D` . Default is -1. :math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64. preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x. If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None. Default is None.
...@@ -1013,7 +1109,7 @@ class LogSoftmax(layers.Layer): ...@@ -1013,7 +1109,7 @@ class LogSoftmax(layers.Layer):
.. math:: .. math::
Out[i, j] = log(softmax(x)) Out[i, j] = log(softmax(x))
= log(\\frac{\exp(X[i, j])}{\\sum_j(exp(X[i, j])}) = log(\\frac{\exp(X[i, j])}{\\sum_j(exp(X[i, j])})
Parameters: Parameters:
...@@ -1023,7 +1119,7 @@ class LogSoftmax(layers.Layer): ...@@ -1023,7 +1119,7 @@ class LogSoftmax(layers.Layer):
same way as :math:`axis + D` . Default is -1. same way as :math:`axis + D` . Default is -1.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
...@@ -1060,3 +1156,64 @@ class LogSoftmax(layers.Layer): ...@@ -1060,3 +1156,64 @@ class LogSoftmax(layers.Layer):
def forward(self, x): def forward(self, x):
return F.log_softmax(x, self._axis) return F.log_softmax(x, self._axis)
class Maxout(layers.Layer):
"""
Maxout Activation.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then Co = Ci/groups and the operator formula is as follows:
.. math::
&out_{si+j} = \max_{k} x_{gsi + sk + j} \\\\
&g = groups \\\\
&s = \\frac{input.size}{num\\_channels} \\\\
&0 \\le i < \\frac{num\\_channels}{groups} \\\\
&0 \\le j < s \\\\
&0 \\le k < groups
Parameters:
groups (int, optional): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be
a factor of number of features. Default is 1.
axis (int, optional): The axis along which to perform maxout calculations.
It should be 1 when data format is NCHW, be -1 or 3 when data format
is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
where D is the dimensions of ``x`` . Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: :math:`(N, C_{in}, H_{in}, W_{in})`
- output: :math:`(N, C_{out}, H_{out}, W_{out})`
Examples:
.. code-block:: python
import paddle
x = paddle.rand([1, 2, 3, 4])
# [[[[0.5002636 0.22272532 0.17402348 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.02879342 0.88725346 0.61093384 0.38833922]]
# [[0.5231306 0.03807496 0.91661984 0.15602879]
# [0.666127 0.616567 0.30741522 0.24044901]
# [0.7142536 0.7351477 0.31588817 0.23782359]]]]
m = paddle.nn.Maxout(groups=2)
out = m(x)
# [[[[0.5231306 0.22272532 0.91661984 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.7142536 0.88725346 0.61093384 0.38833922]]]]
"""
def __init__(self, groups, axis=1, name=None):
super(Maxout, self).__init__()
self._groups = groups
self._axis = axis
self._name = name
def forward(self, x):
return F.maxout(x, self._groups, self._axis, self._name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册