From 6dd9901b3c3cc54bf053d63fed6ff8024e008e7e Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Sat, 11 Apr 2020 08:37:13 +0800 Subject: [PATCH] add activation ops under paddle.nn and paddle.nn.functional: ReLU, LogSoftmax (#23258) --- .../tests/unittests/test_activation_op.py | 74 ++++++- .../fluid/tests/unittests/test_log_softmax.py | 107 ++++++++++ python/paddle/nn/__init__.py | 8 +- python/paddle/nn/functional/__init__.py | 5 +- python/paddle/nn/functional/activation.py | 188 +++++++++++++++--- python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/activation.py | 105 +++++++++- python/setup.py.in | 2 - 8 files changed, 454 insertions(+), 37 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_log_softmax.py diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 5dc89591e77..cb498ce94d4 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -21,6 +21,8 @@ from op_test import OpTest from scipy.special import expit, erf import paddle import paddle.fluid as fluid +import paddle.nn as nn +import paddle.nn.functional as functional from paddle.fluid import compiler, Program, program_guard @@ -759,9 +761,6 @@ class TestPow_factor_tensor(TestActivation): self.check_grad(['X'], 'Out') def test_api(self): - import paddle - import paddle.fluid as fluid - input = np.random.uniform(1, 2, [11, 17]).astype("float32") x = fluid.layers.data( name="x", shape=[11, 17], append_batch_size=False, dtype="float32") @@ -1003,5 +1002,74 @@ create_test_act_fp16_class(TestHardSigmoid) create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestHardSwish) + +class TestNNReluAPI(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.x_shape = [10, 12] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.y = self.ref_forward(self.x) + + def ref_forward(self, x): + return np.maximum(x, 0) + + def ref_backward(self, y, dy): + y_t = y.copy() + y_t[y_t > 0] = 1 + return y_t * dy + + def check_api(self, place=fluid.CPUPlace(), inplace=False): + main_program = Program() + myrelu = nn.ReLU(inplace) + with fluid.program_guard(main_program): + x = fluid.data(name='x', shape=self.x_shape) + x.stop_gradient = False + y = myrelu(x) + fluid.backward.append_backward(fluid.layers.mean(y)) + exe = fluid.Executor(place) + out = exe.run(main_program, + feed={'x': self.x}, + fetch_list=[y, y.grad_name, x.grad_name]) + self.assertTrue(np.allclose(out[0], self.y)) + self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1]))) + + with fluid.dygraph.guard(place): + x = fluid.dygraph.to_variable(self.x) + y = myrelu(x) + self.assertTrue(np.allclose(y.numpy(), self.y)) + + def test_check_api(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + for inplace in [True, False]: + self.check_api(place, inplace) + + +class TestNNFunctionalReluAPI(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.x_shape = [10, 12] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.y = self.ref_forward(self.x) + + def ref_forward(self, x): + return np.maximum(x, 0) + + def test_check_api(self): + main_program = Program() + with fluid.program_guard(main_program): + x = fluid.data(name='x', shape=self.x_shape) + y = functional.relu(x) + exe = fluid.Executor(fluid.CPUPlace()) + out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], self.y)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py new file mode 100644 index 00000000000..2b77624734d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.nn as nn +import paddle.nn.functional as functional + + +def stable_softmax(x): + shiftx = (x - np.max(x)) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +def ref_log_softmax(x, axis=None, dtype=None): + x_t = x.copy() + if dtype is not None: + x_t = x_t.astype(dtype) + if axis is None: + axis = -1 + out = np.apply_along_axis(stable_softmax, axis, x_t) + return np.log(out) + + +class TestNNLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + + def check_api(self, place=fluid.CPUPlace(), axis=None): + ref_out = ref_log_softmax(self.x, axis) + + main_program = fluid.Program() + mylogsoftmax = nn.LogSoftmax(axis) + with fluid.program_guard(main_program): + x = fluid.data(name='x', shape=self.x_shape) + y = mylogsoftmax(x) + exe = fluid.Executor(place) + out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], ref_out)) + + with fluid.dygraph.guard(place): + x = fluid.dygraph.to_variable(self.x) + y = mylogsoftmax(x) + self.assertTrue(np.allclose(y.numpy(), ref_out)) + + def test_check_api(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + for axis in [None, 2]: + self.check_api(place, axis) + + +class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + + def check_api(self, place=fluid.CPUPlace(), axis=None, dtype=None): + ref_out = ref_log_softmax(self.x, axis, dtype) + main_program = fluid.Program() + mylogsoftmax = nn.LogSoftmax(axis) + with fluid.program_guard(main_program): + x = fluid.data(name='x', shape=self.x_shape) + y = functional.log_softmax(x, axis, dtype) + exe = fluid.Executor(place) + out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], ref_out)) + + with fluid.dygraph.guard(place): + x = fluid.dygraph.to_variable(self.x) + y = functional.log_softmax(x, axis, dtype) + self.assertTrue(np.allclose(y.numpy(), ref_out)) + + def test_check_api(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.check_api(place, None, None) + self.check_api(place, None, np.float64) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index ee43e8633a3..e5cfd360780 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -81,10 +81,10 @@ from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFIN from .layer.norm import InstanceNorm #DEFINE_ALIAS # from .layer.norm import SpectralNorm #DEFINE_ALIAS # from .layer.activation import PReLU #DEFINE_ALIAS -# from .layer.activation import ReLU #DEFINE_ALIAS +from .layer.activation import ReLU #DEFINE_ALIAS # from .layer.activation import Sigmoid #DEFINE_ALIAS # from .layer.activation import Softmax #DEFINE_ALIAS -# from .layer.activation import LogSoftmax #DEFINE_ALIAS +from .layer.activation import LogSoftmax #DEFINE_ALIAS # from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS @@ -189,7 +189,7 @@ from .functional.conv import conv3d_transpose #DEFINE_ALIAS # from .functional.activation import logsigmoid #DEFINE_ALIAS # from .functional.activation import maxout #DEFINE_ALIAS # from .functional.activation import prelu #DEFINE_ALIAS -# from .functional.activation import relu #DEFINE_ALIAS +from .functional.activation import relu #DEFINE_ALIAS # from .functional.activation import relu6 #DEFINE_ALIAS # from .functional.activation import selu #DEFINE_ALIAS # from .functional.activation import sigmoid #DEFINE_ALIAS @@ -201,7 +201,7 @@ from .functional.conv import conv3d_transpose #DEFINE_ALIAS # from .functional.activation import swish #DEFINE_ALIAS # from .functional.activation import tanh_shrink #DEFINE_ALIAS # from .functional.activation import thresholded_relu #DEFINE_ALIAS -# from .functional.activation import log_softmax #DEFINE_ALIAS +from .functional.activation import log_softmax #DEFINE_ALIAS # from .functional.extension import add_position_encoding #DEFINE_ALIAS # from .functional.extension import autoincreased_step_counter #DEFINE_ALIAS # from .functional.extension import continuous_value_model #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 347ccc0e376..784a2ec9d23 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -102,6 +102,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS # from .vision import space_to_depth #DEFINE_ALIAS # from .vision import yolo_box #DEFINE_ALIAS # from .vision import yolov3_loss #DEFINE_ALIAS +from . import activation # from .activation import brelu #DEFINE_ALIAS # from .activation import elu #DEFINE_ALIAS # from .activation import erf #DEFINE_ALIAS @@ -114,7 +115,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS # from .activation import logsigmoid #DEFINE_ALIAS # from .activation import maxout #DEFINE_ALIAS # from .activation import prelu #DEFINE_ALIAS -# from .activation import relu #DEFINE_ALIAS +from .activation import relu #DEFINE_ALIAS # from .activation import relu6 #DEFINE_ALIAS # from .activation import selu #DEFINE_ALIAS # from .activation import sigmoid #DEFINE_ALIAS @@ -126,7 +127,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS # from .activation import swish #DEFINE_ALIAS # from .activation import tanh_shrink #DEFINE_ALIAS # from .activation import thresholded_relu #DEFINE_ALIAS -# from .activation import log_softmax #DEFINE_ALIAS +from .activation import log_softmax #DEFINE_ALIAS # from .extension import add_position_encoding #DEFINE_ALIAS # from .extension import autoincreased_step_counter #DEFINE_ALIAS # from .extension import continuous_value_model #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 4a505c22576..900f1aa33c1 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -12,29 +12,167 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ +from ...fluid import core + # TODO: define activation functions of neural network -# __all__ = ['brelu', -# 'elu', -# 'erf', -# 'gelu', -# 'hard_shrink', -# 'hard_sigmoid', -# 'hard_swish', -# 'hsigmoid', -# 'leaky_relu', -# 'logsigmoid', -# 'maxout', -# 'prelu', -# 'relu', -# 'relu6', -# 'selu', -# 'sigmoid', -# 'soft_relu', -# 'softmax', -# 'softplus', -# 'softshrink', -# 'softsign', -# 'swish', -# 'tanh_shrink', -# 'thresholded_relu', -# 'log_softmax'] +__all__ = [ + # 'brelu', + # 'elu', + # 'erf', + # 'gelu', + # 'hard_shrink', + # 'hard_sigmoid', + # 'hard_swish', + # 'hsigmoid', + # 'leaky_relu', + # 'logsigmoid', + # 'maxout', + # 'prelu', + 'relu', + # 'relu6', + # 'selu', + # 'sigmoid', + # 'soft_relu', + # 'softmax', + # 'softplus', + # 'softshrink', + # 'softsign', + # 'swish', + # 'tanh_shrink', + # 'thresholded_relu', + 'log_softmax', +] + + +def relu(input, inplace=False, name=None): + """ + ReLU Activation. + + .. math: + + out = max(x, 0) + + Parameters: + input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64. + inplace (bool, optional): If inplace is True, the input and output of ``ReLU`` are the same variable. + Otherwise, the input and output of ``ReLU`` are different variables. Default: False. Note that if x is + more than one OPs' input, inplace must be False. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Output of relu operator, a Tensor with shape same as input + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn.functional as functional + import numpy as np + + data = np.array([-2, 0, 1]).astype('float32') + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + res = functional.relu(data) # [0, 0, 1] + """ + + if in_dygraph_mode(): + if inplace: + warnings.warn( + "Inplace on ReLU is not allowed and will be discarded in dygraph mode currently." + ) + return core.ops.relu(input) + + helper = LayerHelper('relu', **locals()) + + outs = input if inplace else helper.create_variable_for_type_inference( + input.dtype) + helper.append_op(type='relu', inputs={'X': [input]}, outputs={'Out': outs}) + return outs + + +def log_softmax(input, axis=None, dtype=None, name=None): + """ + This operator implements the log_softmax layer. The calculation process is as follows: + + .. math:: + + Out[i, j] = log(softmax(x)) + = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) + + Parameters: + input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64. + axis (int, optional): The index of dimension to perform softmax calculations, it should be in + range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. + None and -1 means the last dimension. + dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, + the input tensor is casted to dtype before the operation is performed. This is useful for + preventing data type overflows. Default: None. Supported dtype: float32 or float64 + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn.functional as F + import numpy as np + + data = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]).astype('float32') + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + res = F.log_softmax(data, -1) + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + """ + + axis = -1 if axis is None else axis + dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype + + if in_dygraph_mode(): + outs_cast = input if dtype is None \ + else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype) + outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', + False) + return core.ops.log(outs_softmax) + + helper = LayerHelper("log_softmax", **locals()) + + outs_cast = input + if dtype is not None: + outs_cast = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='cast', + inputs={'X': input}, + outputs={'Out': outs_cast}, + attrs={'in_dtype': input.dtype, + 'out_dtype': dtype}) + + outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype) + helper.append_op( + type='softmax', + inputs={'X': outs_cast}, + outputs={'Out': outs_softmax}, + attrs={'axis': axis, + 'use_cudnn': False}) + + outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype) + helper.append_op( + type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log}) + + return outs_log diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index c61fd56a931..689cc857ef6 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -14,10 +14,12 @@ # TODO: define activation functions of neural network +from . import activation from . import loss from . import conv from . import norm +from .activation import * from .loss import * from .conv import * from .norm import * diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index aec08e6de76..f94c8c7f980 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -12,5 +12,108 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ...fluid.dygraph import layers +from ...fluid import core +from ...fluid.framework import in_dygraph_mode +from .. import functional + # TODO: define activation functions of neural network -# __all__ = ['PReLU', 'ReLU', 'Sigmoid', 'Softmax', 'LogSoftmax'] +__all__ = [ + # 'PReLU', + 'ReLU', + # 'Sigmoid', + # 'Softmax', + 'LogSoftmax', +] + + +class ReLU(layers.Layer): + """ + ReLU Activation. + + .. math: + + out = max(x, 0) + + Parameters: + inplace (bool, optional): If inplace is True, the input and output of + ``ReLU`` are the same variable. Otherwise, the input and output of + ``ReLU`` are different variables. Default False. Note that if x is + more than one OPs' input, inplace must be False. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + + data = np.array([-2, 0, 1]).astype('float32') + my_relu = nn.ReLU() + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + res = my_relu(data) # [0, 0, 1] + """ + + def __init__(self, inplace=False): + super(ReLU, self).__init__() + self._inplace = inplace + + def forward(self, input): + return functional.relu(input, self._inplace) + + +class LogSoftmax(layers.Layer): + """ + This operator implements the log_softmax layer. The calculation process is as follows: + + .. math:: + + Out[i, j] = log(softmax(x)) + = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) + + Parameters: + axis (int, optional): The index of dimension to perform softmax calculations, it should be in + range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. + None and -1 means the last dimension. + dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, + the input tensor is casted to dtype before the operation is performed. This is useful for + preventing data type overflows. Default: None. Supported dtype: float32 or float64 + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + + data = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]).astype('float32') + my_log_softnmax = nn.LogSoftmax() + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + res = my_log_softnmax(data) + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + """ + + def __init__(self, axis=None): + super(LogSoftmax, self).__init__() + self._axis = axis + + def forward(self, input): + return functional.log_softmax(input, self._axis) diff --git a/python/setup.py.in b/python/setup.py.in index 23799e6189f..24ca37a8daa 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -105,8 +105,6 @@ write_version_py(filename='@PADDLE_BINARY_DIR@/python/paddle/version.py') packages=['paddle', - 'paddle.nn', - 'paddle.nn.layer', 'paddle.libs', 'paddle.utils', 'paddle.dataset', -- GitLab