未验证 提交 1d870c44 编写于 作者: L LutaoChu 提交者: GitHub

fix paddle.nn.loss.L1Loss OP, add paddle.nn.functional.l1_loss OP for API2.0

fix paddle.nn.loss.L1Loss OP, add paddle.nn.functional.l1_loss OP for API2.0
上级 faf83a7a
......@@ -20,111 +20,165 @@ import numpy as np
import unittest
class TestL1Loss(unittest.TestCase):
def test_L1Loss_mean(self):
input_np = np.random.random(size=(10, 1)).astype(np.float32)
label_np = np.random.random(size=(10, 1)).astype(np.float32)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.layers.data(
name='input', shape=[10, 1], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 1], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss()
ret = l1_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[ret])
with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss()
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_ret.numpy()
expected = np.mean(np.abs(input_np - label_np))
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
class TestFunctionalL1Loss(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
self.label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
def run_imperative(self):
input = paddle.to_variable(self.input_np)
label = paddle.to_variable(self.label_np)
dy_result = paddle.nn.functional.l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])
def test_L1Loss_sum(self):
input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.layers.data(
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum')
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none')
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False):
input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.data(name='label', shape=[10, 10, 5], dtype='float32')
result0 = paddle.nn.functional.l1_loss(input, label)
result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum')
result2 = paddle.nn.functional.l1_loss(input, label, reduction='none')
y = paddle.nn.functional.l1_loss(input, label, name='aaa')
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
static_result = exe.run(
feed={"input": self.input_np,
"label": self.label_np},
fetch_list=[result0, result1, result2])
expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[0], expected))
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[1], expected))
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(static_result[2], expected))
self.assertTrue('aaa' in y.name)
def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
self.run_static(use_gpu=True)
# test case the raise message
def test_errors(self):
def test_value_error():
input = paddle.data(
name='input', shape=[10, 10, 5], dtype='float32')
label = fluid.layers.data(
label = paddle.data(
name='label', shape=[10, 10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
ret = l1_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[ret])
with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_ret.numpy()
expected = np.sum(np.abs(input_np - label_np))
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
loss = paddle.nn.functional.l1_loss(
input, label, reduction='reduce_mean')
self.assertRaises(ValueError, test_value_error)
class TestClassL1Loss(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
self.label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
def run_imperative(self):
input = paddle.to_variable(self.input_np)
label = paddle.to_variable(self.label_np)
l1_loss = paddle.nn.loss.L1Loss()
dy_result = l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_result = l1_loss(input, label)
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])
def test_L1Loss_none(self):
input_np = np.random.random(size=(10, 5)).astype(np.float32)
label_np = np.random.random(size=(10, 5)).astype(np.float32)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.layers.data(
name='input', shape=[10, 5], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
ret = l1_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[ret])
with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_ret.numpy()
expected = np.abs(input_np - label_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(dy_result.shape, input.shape)
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_result = l1_loss(input, label)
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False):
input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.data(name='label', shape=[10, 10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss()
result0 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
result1 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
result2 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(name='aaa')
result3 = l1_loss(input, label)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
static_result = exe.run(
feed={"input": self.input_np,
"label": self.label_np},
fetch_list=[result0, result1, result2])
expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[0], expected))
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[1], expected))
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(static_result[2], expected))
self.assertTrue('aaa' in result3.name)
def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
self.run_static(use_gpu=True)
# test case the raise message
def test_errors(self):
def test_value_error():
loss = paddle.nn.loss.L1Loss(reduction="reduce_mean")
self.assertRaises(ValueError, test_value_error)
if __name__ == "__main__":
......
......@@ -127,6 +127,7 @@ from .loss import edit_distance #DEFINE_ALIAS
from .loss import huber_loss #DEFINE_ALIAS
from .loss import iou_similarity #DEFINE_ALIAS
from .loss import kldiv_loss #DEFINE_ALIAS
from .loss import l1_loss #DEFINE_ALIAS
from .loss import log_loss #DEFINE_ALIAS
from .loss import margin_rank_loss #DEFINE_ALIAS
from .loss import mse_loss #DEFINE_ALIAS
......
......@@ -13,6 +13,10 @@
# limitations under the License.
# TODO: define loss functions of neural network
import paddle
import paddle.fluid as fluid
from ...fluid.framework import core, in_dygraph_mode
from ...fluid.layers.nn import _elementwise_op_in_dygraph
from ...fluid.layers import bpr_loss #DEFINE_ALIAS
from ...fluid.layers import center_loss #DEFINE_ALIAS
from ...fluid.layers import cross_entropy #DEFINE_ALIAS
......@@ -45,6 +49,7 @@ __all__ = [
'huber_loss',
'iou_similarity',
'kldiv_loss',
'l1_loss',
'log_loss',
'margin_rank_loss',
'mse_loss',
......@@ -60,3 +65,92 @@ __all__ = [
'ssd_loss',
'teacher_student_sigmoid_loss'
]
def l1_loss(x, label, reduction='mean', name=None):
"""
This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows.
If :attr:`reduction` set to ``'none'``, the loss is:
.. math::
Out = \lvert x - label\rvert
If :attr:`reduction` set to ``'mean'``, the loss is:
.. math::
Out = MEAN(\lvert x - label\rvert)
If :attr:`reduction` set to ``'sum'``, the loss is:
.. math::
Out = SUM(\lvert x - label\rvert)
Parameters:
x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64.
reduction (str, optional): Indicate the reduction to apply to the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the L1 Loss of Tensor ``x`` and ``label``.
If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` .
If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
x = paddle.to_variable(x_data)
label = paddle.to_variable(label_data)
l1_loss = paddle.nn.functional.l1_loss(x, label)
print(l1_loss.numpy())
# [0.35]
l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='none')
print(l1_loss.numpy())
# [[0.20000005 0.19999999]
# [0.2 0.79999995]]
l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='sum')
print(l1_loss.numpy())
# [1.4]
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
if in_dygraph_mode():
unreduced = _elementwise_op_in_dygraph(
x, label, axis=-1, act='abs', op_name='elementwise_sub')
if reduction == 'mean':
return core.ops.mean(unreduced)
elif reduction == 'sum':
return core.ops.reduce_sum(unreduced, 'dim', [0], 'keep_dim', False,
'reduce_all', True)
else:
return unreduced
fluid.data_feeder.check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
if reduction == 'sum':
unreduced = paddle.elementwise_sub(x, label, act='abs')
return paddle.sum(unreduced, name=name)
elif reduction == 'mean':
unreduced = paddle.elementwise_sub(x, label, act='abs')
return paddle.mean(unreduced, name=name)
else:
return paddle.elementwise_sub(x, label, act='abs', name=name)
......@@ -250,27 +250,24 @@ class MSELoss(fluid.dygraph.layers.Layer):
class L1Loss(fluid.dygraph.Layer):
"""
:alias_main: paddle.nn.L1Loss
:alias: paddle.nn.L1Loss,paddle.nn.layer.L1Loss,paddle.nn.layer.loss.L1Loss
This interface is used to construct a callable object of the ``L1Loss`` class.
The L1Loss layer calculates the L1 Loss of input predictions and target
labels as follows.
The L1Loss layer calculates the L1 Loss of ``x`` and ``label`` as follows.
If :attr:`reduction` set to ``'none'``, the loss is:
If :attr:`reduction` set to ``'none'``, the unreduced loss is:
.. math::
Out = |input - label|
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
Out = \lvert x - label\rvert
If :attr:`reduction` set to ``'mean'``, the loss is:
.. math::
Out = MEAN(|input - label|)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
Out = MEAN(\lvert x - label\rvert)
If :attr:`reduction` set to ``'sum'``, the loss is:
.. math::
Out = SUM(|input - label|)
Out = SUM(\lvert x - label\rvert)
The shape of input predictions and target labels are [N, *], where N is batch_size and `*`
means any number of additional dimensions.
If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as input.
If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar.
Parameters:
reduction (str, optional): Indicate the reduction to apply to the loss,
......@@ -279,63 +276,55 @@ class L1Loss(fluid.dygraph.Layer):
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
Default is ``'mean'``.
Returns:
A callable object of L1Loss.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64.
output (Tensor): The L1 Loss of ``x`` and ``label``.
If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` .
If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
l1_loss = paddle.nn.loss.L1Loss(reduction='mean')
output = l1_loss(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data) # [array([0.2], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
l1_loss = paddle.nn.loss.L1Loss(reduction='mean')
output = l1_loss(input,label)
print(output.numpy()) # [0.2]
import numpy as np
paddle.disable_static()
x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
x = paddle.to_variable(x_data)
label = paddle.to_variable(label_data)
l1_loss = paddle.nn.loss.L1Loss()
output = l1_loss(x, label)
print(output.numpy())
# [0.35]
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
output = l1_loss(x, label)
print(output.numpy())
# [1.4]
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
output = l1_loss(x, label)
print(output.numpy())
# [[0.20000005 0.19999999]
# [0.2 0.79999995]]
"""
def __init__(self, reduction='mean'):
def __init__(self, reduction='mean', name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(L1Loss, self).__init__()
self.reduction = reduction
self.name = name
def forward(self, input, label):
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
unreduced = fluid.layers.elementwise_sub(input, label, act='abs')
if self.reduction == 'sum':
return fluid.layers.reduce_sum(unreduced)
elif self.reduction == 'mean':
return fluid.layers.reduce_mean(unreduced)
else:
return unreduced
def forward(self, x, label):
return paddle.nn.functional.l1_loss(
x, label, self.reduction, name=self.name)
class BCELoss(fluid.dygraph.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册