From 99fda6f431969159881c2eb54a74beaca429f61b Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 6 May 2020 21:59:23 -0300 Subject: [PATCH] Add logsigmoid and reduce_logsumexp --- mindspore/nn/layer/__init__.py | 4 +- mindspore/nn/layer/activation.py | 45 ++++++++++++++++++++ mindspore/nn/layer/math.py | 68 ++++++++++++++++++++++++++++++ tests/ut/python/ops/test_nn_ops.py | 24 +++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 mindspore/nn/layer/math.py diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 20ec8e17e..9999142a4 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -17,7 +17,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ -from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant +from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math from .activation import * from .normalization import * from .container import * @@ -28,6 +28,7 @@ from .embedding import * from .pooling import * from .image import * from .quant import * +from .math import * __all__ = [] __all__.extend(activation.__all__) @@ -40,3 +41,4 @@ __all__.extend(embedding.__all__) __all__.extend(pooling.__all__) __all__.extend(image.__all__) __all__.extend(quant.__all__) +__all__.extend(math.__all__) diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index ef217ea11..b3ebe3af1 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -35,6 +35,7 @@ __all__ = ['Softmax', 'HSigmoid', 'HSwish', 'ELU', + 'LogSigmoid', ] @@ -476,6 +477,49 @@ class HSigmoid(Cell): return self.hsigmoid(x) +class LogSigmoid(Cell): + r""" + Logsigmoid activation function. + + Applies logsigmoid activation element-wise. The input is a Tensor with any valid shape. + + Logsigmoid is defined as: + + .. math:: + \text{logsigmoid}(x_{i}) = log(\frac{1}{1 + \exp(-x_i)}), + + where :math:`x_{i}` is the element of the input. + + Inputs: + - **input_data** (Tensor) - The input of LogSigmoid. + + Outputs: + Tensor, with the same type and shape as the `input_data`. + + Examples: + >>> net = nn.LogSigmoid() + >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) + >>> logsigmoid = net(input_x) + [-3.1326166e-01, -1.2692806e-01, -4.8587345e-02] + + """ + def __init__(self): + super(LogSigmoid, self).__init__() + self.mul = P.Mul() + self.exp = P.Exp() + self.add = P.TensorAdd() + self.rec = P.Reciprocal() + self.log = P.Log() + + def construct(self, input_x): + neg_input = self.mul(input_x, -1) + exp_neg_input = self.exp(neg_input) + exp_neg_input_1 = self.add(exp_neg_input, 1) + rec_exp_neg_input_1 = self.rec(exp_neg_input_1) + ret = self.log(rec_exp_neg_input_1) + return ret + + _activation = { 'softmax': Softmax, 'logsoftmax': LogSoftmax, @@ -488,6 +532,7 @@ _activation = { 'leakyrelu': LeakyReLU, 'hswish': HSwish, 'hsigmoid': HSigmoid, + 'logsigmoid': LogSigmoid, } diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py new file mode 100644 index 000000000..8e44a80f5 --- /dev/null +++ b/mindspore/nn/layer/math.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""math""" +from mindspore.ops import operations as P +from ..cell import Cell +from ..._checkparam import Validator as validator + +__all__ = ['ReduceLogSumExp'] + +class ReduceLogSumExp(Cell): + r""" + Reduce a dimension of a tensor by calculating exponential for all elements in the dimension, + then calculate logarithm of the sum. + + The dtype of the tensor to be reduced is number. + + Args: + keep_dims (bool): If True, keep these reduced dimensions and the length is 1. + If False, don't keep these dimensions. + Default : False. + + Inputs: + - **input_x** (Tensor[Number]) - The input tensor. + - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. + Only constant value is allowed. + + Outputs: + Tensor, has the same dtype as the 'input_x'. + + - If axis is (), and keep_dims is false, + the output is a 0-D tensor representing the sum of all elements in the input tensor. + - If axis is int, set as 2, and keep_dims is false, + the shape of output is :math:`(x_1, x_3, ..., x_R)`. + - If axis is tuple(int), set as (2, 3), and keep_dims is false, + the shape of output is :math:`(x_1, x_4, ..., x_R)`. + + Examples: + >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32)) + >>> op = P.ReduceLogSumExp(keep_dims=True) + >>> output = op(input_x, 1) + """ + + def __init__(self, axis, keep_dims=False): + super(ReduceLogSumExp, self).__init__() + validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name) + validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name) + self.axis = axis + self.exp = P.Exp() + self.sum = P.ReduceSum(keep_dims) + self.log = P.Log() + + def construct(self, input_x): + exp = self.exp(input_x) + sumexp = self.sum(exp, self.axis) + logsumexp = self.log(sumexp) + return logsumexp diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 0e9b8d2af..992d7957a 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -522,6 +522,16 @@ test_cases = [ 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))], 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))], 'skip': ['backward']}), + ('LogSigmoid', { + 'block': nn.LogSigmoid(), + 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], + 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], + 'skip': ['backward']}), + ('ReduceLogSumExp', { + 'block': nn.ReduceLogSumExp((0, ), False), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], + 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], + 'skip': ['backward']}), ] test_cases_for_verify_exception = [ @@ -621,6 +631,20 @@ test_cases_for_verify_exception = [ ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), + ('ReduceLogsumexp_TypeError_1', { + 'block': ( + lambda _: nn.ReduceLogSumExp(axis=(0,), keep_dims=2), + {'exception': TypeError}, + ), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], + }), + ('ReduceLogsumexp_TypeError_2', { + 'block': ( + lambda _: nn.ReduceLogSumExp(axis=1.2, keep_dims=True), + {'exception': TypeError}, + ), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], + }), ] -- GitLab