提交 8a0dd240 编写于 作者: Q qingqing01 提交者: Yi Wang

Expose softmax_with_cross_entropy and smooth_l1 into Python API. (#8375)

* Add softmax_with_cross_entropy and smooth_l1 in Python API.

* Fix doc format.
上级 51912a7a
......@@ -44,7 +44,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,10 +72,10 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out",
"(Tensor, default Tensor<float>) A tensor with rank be 2. "
"The output smooth l1 loss with shape [batch_size, 1].");
AddAttr<AttrType>("sigma",
"Hyper parameter of smooth l1 loss op."
"A float scalar with default value 3.0.")
.SetDefault(3.0);
AddAttr<float>("sigma",
"Hyper parameter of smooth l1 loss op."
"A float scalar with default value 3.0.")
.SetDefault(1.0);
AddComment(R"DOC(
Smooth L1 Loss Operator.
......@@ -133,9 +132,8 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp,
ops::SmoothL1LossOpMaker<float>, smooth_l1_loss_grad,
ops::SmoothL1LossGradOp);
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, ops::SmoothL1LossOpMaker,
smooth_l1_loss_grad, ops::SmoothL1LossGradOp);
REGISTER_OP_CPU_KERNEL(
smooth_l1_loss,
ops::SmoothL1LossKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -66,6 +66,8 @@ __all__ = [
'row_conv',
'multiplex',
'layer_norm',
'softmax_with_cross_entropy',
'smooth_l1',
]
......@@ -3091,3 +3093,122 @@ def multiplex(inputs, index):
'Ids': index},
outputs={'Out': [out]})
return out
def softmax_with_cross_entropy(logits, label, soft_label=False):
"""
**Softmax With Cross Entropy Operator.**
Cross entropy loss with softmax is used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math::
loss_j = -\\text{logit}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logit}_i)\\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes)
.. math::
loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. If soft_label
is set to false, Label is a Tensor<int64> with shape [N x 1]. If
soft_label is set to true, Label is a Tensor<float/double> with
soft_label (bool): A flag to indicate whether to interpretate the given
labels as soft labels. By default, `soft_label` is set to False.
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label)
"""
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_tmp_variable(dtype=logits.dtype)
loss = helper.create_tmp_variable(dtype=logits.dtype)
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={'soft_label': soft_label})
return loss
def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
"""
**Smooth L1 Loss Operator. **
This operator computes the smooth l1 loss for X and Y.
The operator takes the first dimension of X and Y as batch size.
For each instance, it computes the smooth l1 loss element by element first
and then sums all the losses. So the shape of Out is [batch_size, 1].
Args:
x (Variable): A tensor with rank at least 2. The input value of smooth
l1 loss op with shape [batch_size, dim1, ..., dimN].
y (Variable): A tensor with rank at least 2. The target value of smooth
l1 loss op with same shape as x.
inside_weight (Variable|None): A tensor with rank at least 2. This
input is optional and should have same shape with x. If provided,
the result of (x - y) will be multiplied by this tensor element by
element.
outside_weight (Variable|None): A tensor with rank at least 2. This
input is optional and should have same shape with x. If provided,
the out smooth l1 loss will be multiplied by this tensor element
by element.
sigma (float|None): Hyper parameter of smooth l1 loss op. A float scalar
with default value 1.0.
Returns:
Variable: A tensor with rank be 2. The output smooth l1 loss with
shape [batch_size, 1].
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[100], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.smooth_l1(logits=fc, label=label)
"""
helper = LayerHelper('smooth_l1_loss', **locals())
diff = helper.create_tmp_variable(dtype=x.dtype)
loss = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='smooth_l1_loss',
inputs={
'X': x,
'Y': y,
'InsideWeight': inside_weight,
'OutsideWeight': outside_weight
},
outputs={'Diff': diff,
'Out': loss},
attrs={'sigma': sigma})
return loss
......@@ -309,6 +309,24 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_softmax_with_cross_entropy(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[16], dtype='float32')
y = layers.data(name='label', shape=[1], dtype='int64')
loss = layers.softmax_with_cross_entropy(x, y)
self.assertIsNotNone(loss)
print(str(program))
def test_smooth_l1(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[4], dtype='float32')
y = layers.data(name='label', shape=[4], dtype='float32')
loss = layers.smooth_l1(x, y)
self.assertIsNotNone(loss)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册