diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f242fd6a97a1d13f759f9f0d7d687201e4c851ee..0df6ae5be5f8ffd9c6c8e8f38873abc08238e0ea 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -33,6 +33,7 @@ import paddle.compat import paddle.distributed batch = batch.batch import paddle.sysconfig +import paddle.nn #TODO: define alias in tensor and framework directory # from .tensor.creation import create_.tensor #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_l1_loss.py b/python/paddle/fluid/tests/unittests/test_l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e801a666f110471180d704a6fda6dc0f9aeb1e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_l1_loss.py @@ -0,0 +1,131 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest + + +class TestL1Loss(unittest.TestCase): + def test_L1Loss_mean(self): + input_np = np.random.random(size=(10, 1)).astype(np.float32) + label_np = np.random.random(size=(10, 1)).astype(np.float32) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.layers.data( + name='input', shape=[10, 1], dtype='float32') + label = fluid.layers.data( + name='label', shape=[10, 1], dtype='float32') + l1_loss = paddle.nn.loss.L1Loss() + ret = l1_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[ret]) + + with fluid.dygraph.guard(): + l1_loss = paddle.nn.loss.L1Loss() + dy_ret = l1_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_ret.numpy() + + expected = np.mean(np.abs(input_np - label_np)) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + self.assertTrue(dy_result.shape, [1]) + + def test_L1Loss_sum(self): + input_np = np.random.random(size=(10, 10, 5)).astype(np.float32) + label_np = np.random.random(size=(10, 10, 5)).astype(np.float32) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.layers.data( + name='input', shape=[10, 10, 5], dtype='float32') + label = fluid.layers.data( + name='label', shape=[10, 10, 5], dtype='float32') + l1_loss = paddle.nn.loss.L1Loss(reduction='sum') + ret = l1_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[ret]) + + with fluid.dygraph.guard(): + l1_loss = paddle.nn.loss.L1Loss(reduction='sum') + dy_ret = l1_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_ret.numpy() + + expected = np.sum(np.abs(input_np - label_np)) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + self.assertTrue(dy_result.shape, [1]) + + def test_L1Loss_none(self): + input_np = np.random.random(size=(10, 5)).astype(np.float32) + label_np = np.random.random(size=(10, 5)).astype(np.float32) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.layers.data( + name='input', shape=[10, 5], dtype='float32') + label = fluid.layers.data( + name='label', shape=[10, 5], dtype='float32') + l1_loss = paddle.nn.loss.L1Loss(reduction='none') + ret = l1_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[ret]) + + with fluid.dygraph.guard(): + l1_loss = paddle.nn.loss.L1Loss(reduction='none') + dy_ret = l1_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_ret.numpy() + + expected = np.abs(input_np - label_np) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + self.assertTrue(dy_result.shape, input.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c2267bb632e4e2123d29d3628c6337b7d65b02a9..4e6bfded788fec4da51db254ffe72e5a5031f9ac 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -14,7 +14,7 @@ # TODO: import all neural network related api under this directory, # including layers, linear, conv, rnn etc. -# __all__ = [ ] +__all__ = [] # TODO: define alias in nn directory # from .clip import ErrorClipByValue #DEFINE_ALIAS @@ -56,7 +56,8 @@ # from .layer.loss import NCELoss #DEFINE_ALIAS # from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS # from .layer.loss import MSELoss #DEFINE_ALIAS -# from .layer.loss import L1Loss #DEFINE_ALIAS +from .layer.loss import L1Loss #DEFINE_ALIAS +from .layer import loss #DEFINE_ALIAS # from .layer.loss import NLLLoss #DEFINE_ALIAS # from .layer.loss import BCELoss #DEFINE_ALIAS # from .layer.learning_rate import CosineDecay #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54e5c8e980360f1c287eac2a945048b18d680199 --- /dev/null +++ b/python/paddle/nn/layer/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: define activation functions of neural network + +from . import loss +__all__ = [loss] diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 4e573703d0e2f99d3c5390ca41dc856cc40834c7..db1ff750ecf6605c521245ee9b1636a48f46b349 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -13,9 +13,99 @@ # limitations under the License. # TODO: define loss functions of neural network -# __all__ = ['NCELoss', -# 'CrossEntropyLoss', -# 'MSELoss', -# 'L1Loss', -# 'NLLLoss', -# 'BCELoss'] +import paddle.fluid as fluid +__all__ = [ + #'NCELoss', + # 'CrossEntropyLoss', + # 'MSELoss', + 'L1Loss', + # 'NLLLoss', + # 'BCELoss' +] + + +class L1Loss(fluid.dygraph.Layer): + """ + This interface is used to construct a callable object of the ``L1Loss`` class. + The L1Loss layer calculates the L1 Loss of input predictions and target + labels as follows. + + If :attr:`reduction` set to ``'none'``, the unreduced loss is: + .. math:: + Out = |input - label| + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + .. math:: + Out = MEAN(|input - label|) + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + .. math:: + Out = SUM(|input - label|) + + The shape of input predictions and target labels are [N, *], where N is batch_size and `*` + means any number of additional dimensions. + If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as input. + If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar. + + Parameters: + reduction (str, optional): Indicate the reduction to apply to the loss, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. + If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. + Default is ``'mean'``. + Returns: + A callable object of L1Loss. + Examples: + .. code-block:: python + # declarative mode + import paddle.fluid as fluid + import numpy as np + import paddle + input = fluid.data(name="input", shape=[1]) + label = fluid.data(name="label", shape=[1]) + l1_loss = paddle.nn.loss.L1Loss(reduction='mean') + output = l1_loss(input,label) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + input_data = np.array([1.5]).astype("float32") + label_data = np.array([1.7]).astype("float32") + output_data = exe.run(fluid.default_main_program(), + feed={"input":input_data, "label":label_data}, + fetch_list=[output], + return_numpy=True) + + print(output_data) # [array([0.2], dtype=float32)] + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_data) + label = dg.to_variable(label_data) + l1_loss = paddle.nn.loss.L1Loss(reduction='mean') + output = l1_loss(input,label) + print(output.numpy()) # [0.2] + """ + + def __init__(self, reduction='mean'): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + super(L1Loss, self).__init__() + self.reduction = reduction + + def forward(self, input, label): + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') + + unreduced = fluid.layers.elementwise_sub(input, label, act='abs') + + if self.reduction == 'sum': + return fluid.layers.reduce_sum(unreduced) + elif self.reduction == 'mean': + return fluid.layers.reduce_mean(unreduced) + else: + return unreduced diff --git a/python/setup.py.in b/python/setup.py.in index 5e8ed5544acd465abd55a3a68f5ba929000be925..8500cfc50f22b47e4b27c44f46b4299c82c78c21 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -105,6 +105,8 @@ write_version_py(filename='@PADDLE_BINARY_DIR@/python/paddle/version.py') packages=['paddle', + 'paddle.nn', + 'paddle.nn.layer', 'paddle.libs', 'paddle.utils', 'paddle.dataset',