未验证 提交 8f659d43 编写于 作者: T Tao Luo 提交者: GitHub

Split some APIs from nn.py to loss.py (#21117)

* Split some APIs from nn.py to loss.py

test=develop

* fix test_detection unit-test

test=develop
上级 4a544762
...@@ -28,6 +28,8 @@ from . import device ...@@ -28,6 +28,8 @@ from . import device
from .device import * from .device import *
from . import math_op_patch from . import math_op_patch
from .math_op_patch import * from .math_op_patch import *
from . import loss
from .loss import *
from . import detection from . import detection
from .detection import * from .detection import *
from . import metric_op from . import metric_op
...@@ -50,6 +52,7 @@ __all__ += metric_op.__all__ ...@@ -50,6 +52,7 @@ __all__ += metric_op.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
__all__ += distributions.__all__ __all__ += distributions.__all__
__all__ += sequence_lod.__all__ __all__ += sequence_lod.__all__
__all__ += loss.__all__
__all__ += rnn.__all__ __all__ += rnn.__all__
from .rnn import * from .rnn import *
...@@ -21,6 +21,7 @@ from .layer_function_generator import generate_layer_fn ...@@ -21,6 +21,7 @@ from .layer_function_generator import generate_layer_fn
from .layer_function_generator import autodoc, templatedoc from .layer_function_generator import autodoc, templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable
from .loss import softmax_with_cross_entropy
from . import tensor from . import tensor
from . import nn from . import nn
from . import ops from . import ops
...@@ -1540,7 +1541,7 @@ def ssd_loss(location, ...@@ -1540,7 +1541,7 @@ def ssd_loss(location,
target_label = tensor.cast(x=target_label, dtype='int64') target_label = tensor.cast(x=target_label, dtype='int64')
target_label = __reshape_to_2d(target_label) target_label = __reshape_to_2d(target_label)
target_label.stop_gradient = True target_label.stop_gradient = True
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = softmax_with_cross_entropy(confidence, target_label)
# 3. Mining hard examples # 3. Mining hard examples
actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2]) actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2])
actual_shape.stop_gradient = True actual_shape.stop_gradient = True
...@@ -1594,7 +1595,7 @@ def ssd_loss(location, ...@@ -1594,7 +1595,7 @@ def ssd_loss(location,
target_label = __reshape_to_2d(target_label) target_label = __reshape_to_2d(target_label)
target_label = tensor.cast(x=target_label, dtype='int64') target_label = tensor.cast(x=target_label, dtype='int64')
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = __reshape_to_2d(target_conf_weight) target_conf_weight = __reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight conf_loss = conf_loss * target_conf_weight
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
from functools import partial, reduce
from . import nn
from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper
from ..framework import Variable
from ..data_feeder import convert_dtype
from ..param_attr import ParamAttr
from ..initializer import NumpyArrayInitializer
__all__ = [
'center_loss',
'bpr_loss',
'cross_entropy',
'square_error_cost',
'edit_distance',
'warpctc',
'nce',
'hsigmoid',
'sampled_softmax_with_cross_entropy',
'softmax_with_cross_entropy',
'rank_loss',
'margin_rank_loss',
'sigmoid_cross_entropy_with_logits',
'teacher_student_sigmoid_loss',
'huber_loss',
'kldiv_loss',
'npair_loss',
'mse_loss',
]
kIgnoreIndex = -100
def center_loss(input,
label,
num_classes,
alpha,
param_attr,
update_center=True):
"""
**Center loss Cost layer**
This OP accepts input (deep features,the output of the last hidden layer)
and target label and return the center loss cost. The average of the
distances of each sample in the mini-batch from the center of the
corresponding category is calculated as the center loss.
For deep features, :math:`X`, and target labels, :math:`Y`, the equation is:
.. math::
Out = \\frac{1}{2}(X - Y)^2
Args:
input (Variable): a 2-D tensor with shape[N x M]. Its dtype should be float32 or float64.
label (Variable): the groud truth which is a 2-D tensor
with shape[N x 1],where N is the batch size. Its dtype should be int32.
num_classes (int): the number of classification categories.
alpha (float|Variable): learning rate of centers.
param_attr (ParamAttr): Attribute initializer of centers.
update_center (bool): whether to update value of center.
Returns:
Variable: 2-D tensor with shape [N * 1]
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(name='x',shape=[20,30],dtype='float32')
label = fluid.data(name='y',shape=[20,1],dtype='int64')
num_classes = 1000
alpha = 0.01
param_attr = fluid.initializer.Xavier(uniform=False)
center_loss=fluid.layers.center_loss(input=input,
label=label,
num_classes=1000,
alpha=alpha,
param_attr=fluid.initializer.Xavier(uniform=False),
update_center=True)
"""
helper = LayerHelper('center_loss', **locals())
dtype = helper.input_dtype()
centers_shape = [num_classes, input.shape[1]]
centers_param = helper.create_parameter(
attr=param_attr, shape=centers_shape, dtype=dtype)
centers_param.stop_gradient = True
if isinstance(alpha, Variable):
alpha_param = alpha
else:
assert isinstance(alpha, float)
alpha_param = helper.create_variable(
name="centerloss_alpha",
shape=[1],
dtype="float32",
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=True,
stop_gradient=True,
initializer=Constant(alpha))
centersdiff = helper.create_variable_for_type_inference(dtype=input.dtype)
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='center_loss',
inputs={
'X': [input],
'Label': [label],
'Centers': [centers_param],
'CenterUpdateRate': [alpha_param]
},
outputs={
'SampleCenterDiff': [centersdiff],
'Loss': [loss],
'CentersOut': [centers_param]
},
attrs={'cluster_num': num_classes,
'need_update': update_center})
return loss
def bpr_loss(input, label, name=None):
"""
**Bayesian Personalized Ranking Loss Operator**
This operator belongs to pairwise ranking loss. Label is the desired item.
The loss at a given point in one session is defined as:
.. math::
Y[i] = 1/(N[i] - 1) * \sum_j{\log(\sigma(X[i, Label[i]]-X[i, j]))}
Learn more details by reading paper <session-based recommendations with recurrent
neural networks>.
Args:
input (Variable|list): a 2-D tensor with shape [N x D], where N is the
batch size and D is the number of positive classes and negative classes
This input is not probability but logits.
label (Variable|list): the ground truth which is a 2-D tensor. `label`
is a tensor<int64> with shape [N x 1].
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically. Default: None.
Returns:
A 2-D tensor with shape [N x 1], the bpr loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
neg_size = 10
label = fluid.data(
name="label", shape=[3, 1], dtype="int64")
predict = fluid.data(
name="predict", shape=[3, neg_size + 1], dtype="float32")
cost = fluid.layers.bpr_loss(input=predict, label=label)
"""
helper = LayerHelper('bpr_loss', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='bpr_loss',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]})
return out
def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
"""
This operator computes the cross entropy between input and label. It
supports both hard-label and and soft-label cross entropy computation.
1. Hard-label cross entropy: if soft_label=False, :math:`label[i_1, i_2, ..., i_k]`
is the hard label of each sample.
.. math::
output[i_1, i_2, ..., i_k]=-log(input[i_1, i_2, ..., i_k, j]), label[i_1, i_2, ..., i_k] = j, j != ignore\_index
2. Soft-label cross entropy: if soft_label=True, :math:`label[i_1, i_2, ..., i_k, j]`
is the soft label of each sample corresponding to the j-th class.
.. math::
output[i_1, i_2, ..., i_k]= -\sum_{j}label[i_1,i_2,...,i_k,j]*log(input[i_1, i_2, ..., i_k,j])
Args:
input (Variable): a multidimensional Tensor with shape
:math:`[N_1, N_2, ..., N_k, D]`, where the last dimension D is
the class number. The data type should be float32 or float64.
label (Variable): label value corresponding to input. If
soft_label=False, the dimension of label should be :math:`[N_1, N_2, ..., N_k]`
or :math:`[N_1, N_2, ..., N_k, 1]` , and its data type should be int64,
and the value must be inside [0, D). If soft_label=True, the shape,
data type of label should be the same with input, and the sum of
soft label value of each sample should be 1.
soft_label (bool): indicate whether label is soft. Default False, meaning that
the label is hard. If soft_label=True, the label is soft.
ignore_index (int): specify an ignorable label value. The ignored label would be
omitted when computing. If it is a negative integer, no label would
be ignored. Only valid when soft_label=False. Default -100.
Returns:
A Variable holding Tensor representing the cross entropy, whose data type is the same with input.
If soft_label=False, the shape of output is the same with label.
If soft_label=True, the shape of output is :math:`[N_1, N_2, ..., N_k, 1]` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
class_num = 7
x = fluid.data(name='x', shape=[None, 3, 10], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
"""
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in cross_entropy must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in cross_entropy only support float16 on GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in cross_entropy must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if not soft_label:
return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs={"soft_label": soft_label,
"ignore_index": ignore_index})
return out
def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy2',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out],
'MatchX': [match_x],
'XShape': [xshape]},
attrs={'ignore_index': ignore_index})
return out
def square_error_cost(input, label):
"""
This op accepts input predictions and target label and returns the
squared error cost.
For predictions label, and target label, the equation is:
.. math::
Out = (input - label)^2
Parameters:
input (Variable): Input tensor, the data type should be float32.
label (Variable): Label tensor, the data type should be float32.
Returns:
The tensor variable storing the element-wise squared error \
difference between input and label.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
output = fluid.layers.square_error_cost(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([0.04000002], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = fluid.layers.square_error_cost(input, label)
print(output.numpy())
# [0.04000002]
"""
helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
def edit_distance(input,
label,
normalized=True,
ignored_tokens=None,
input_length=None,
label_length=None):
"""
This op computes the edit distances between a batch of
hypothesis strings and their references. Edit distance, also called
Levenshtein distance, measures how dissimilar two strings are by counting
the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution.
For example, given hypothesis string A = "kitten" and reference
B = "sitting", the edit distance is 3 for A will be transformed into B
at least after two substitutions and one insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
The input is a LoDTensor/Tensor consisting of all the hypothesis strings with
the total number denoted by `batch_size`, and the separation is specified
by the LoD information or input_length. And the `batch_size` reference strings are arranged
in order in the same way as `input`.
The output contains the `batch_size` results and each stands for the edit
distance for a pair of strings respectively. If Attr(normalized) is true,
the edit distance will be divided by the length of reference string.
Parameters:
input(Variable): The indices for hypothesis strings, its rank should equals to 2 and its data type should be int64.
label(Variable): The indices for reference strings, its rank should equals to 2 and its data type should be int64.
normalized(bool, default True): Indicated whether to normalize the edit distance by
the length of reference string.
ignored_tokens(list<int>, default None): Tokens that should be removed before
calculating edit distance.
input_length(Variable): The length for each sequence in `input` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each sequence in `label` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns:
Tuple:
edit_distance_out(Variable): edit distance result in shape [batch_size, 1].
sequence_num(Variable): sequence number in shape [].
Examples:
.. code-block:: python
import paddle.fluid as fluid
# using LoDTensor
x_lod = fluid.data(name='x_lod', shape=[None,1], dtype='int64', lod_level=1)
y_lod = fluid.data(name='y_lod', shape=[None,1], dtype='int64', lod_level=1)
distance_lod, seq_num_lod = fluid.layers.edit_distance(input=x_lod, label=y_lod)
# using Tensor
x_seq_len = 5
y_seq_len = 6
x_pad = fluid.data(name='x_pad', shape=[None,x_seq_len], dtype='int64')
y_pad = fluid.data(name='y_pad', shape=[None,y_seq_len], dtype='int64')
x_len = fluid.data(name='x_len', shape=[None], dtype='int64')
y_len = fluid.data(name='y_len', shape=[None], dtype='int64')
distance_pad, seq_num_pad = fluid.layers.edit_distance(input=x_pad, label=y_pad, input_length=x_len, label_length=y_len)
"""
helper = LayerHelper("edit_distance", **locals())
# remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_variable_for_type_inference(dtype="int64")
erased_label = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="sequence_erase",
inputs={"X": [input]},
outputs={"Out": [erased_input]},
attrs={"tokens": ignored_tokens})
input = erased_input
helper.append_op(
type="sequence_erase",
inputs={"X": [label]},
outputs={"Out": [erased_label]},
attrs={"tokens": ignored_tokens})
label = erased_label
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length and label_length:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="edit_distance",
inputs=this_inputs,
outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]},
attrs={"normalized": normalized})
return edit_distance_out, sequence_num
def warpctc(input,
label,
blank=0,
norm_by_times=False,
input_length=None,
label_length=None):
"""
An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation is
interated to the Warp-CTC library to normlize values for each row of the
input tensor.
Args:
input (Variable): The unscaled probabilities of variable-length sequences,
which is a 2-D Tensor with LoD information, or a 3-D Tensor without Lod
information. When it is a 2-D LodTensor, it's shape is
[Lp, num_classes + 1], where Lp is the sum of all input
sequences' length and num_classes is the true number of classes.
(not including the blank label). When it is a 3-D Tensor, it's shape
is [max_logit_length, batch_size, num_classes + 1],
where max_logit_length is the length of the longest
input logit sequence. The data type must be float32.
label (Variable): The ground truth of variable-length sequence,
which is a 2-D Tensor with LoD information or a 2-D Tensor without
LoD information. When it is a 2-D LoDTensor or 2-D Tensor,
it is of the shape [Lg, 1], where Lg is th sum of all labels' length.
The data type must be int32.
blank (int, default 0): The blank label index of Connectionist
Temporal Classification (CTC) loss, which is in the
half-opened interval [0, num_classes + 1). The data type must be int32.
norm_by_times(bool, default false): Whether to normalize the gradients
by the number of time-step, which is also the sequence's length.
There is no need to normalize the gradients if warpctc layer was
follewed by a mean_op.
input_length(Variable): The length for each input sequence if it is
of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each label sequence if it is
of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns:
Variable: The Connectionist Temporal Classification (CTC) loss,
which is a 2-D Tensor with the shape [batch_size, 1].
The date type is the same as input.
Examples:
.. code-block:: python
# using LoDTensor
import paddle.fluid as fluid
import numpy as np
predict = fluid.data(name='predict',
shape=[None, 5],
dtype='float32',lod_level=1)
label = fluid.data(name='label', shape=[None, 1],
dtype='int32', lod_level=1)
cost = fluid.layers.warpctc(input=predict, label=label)
place = fluid.CPUPlace()
x=fluid.LoDTensor()
data = np.random.rand(8, 5).astype("float32")
x.set(data, place)
x.set_lod([[0,4,8]])
y=fluid.LoDTensor()
data = np.random.randint(0, 5, [4, 1]).astype("int32")
y.set(data, place)
y.set_lod([[0,2,4]])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output= exe.run(feed={"predict": x,"label": y},
fetch_list=[cost.name])
print output
.. code-block:: python
# using Tensor
import paddle.fluid as fluid
import numpy as np
# length of the longest logit sequence
max_seq_length = 5
# number of logit sequences
batch_size = None
logits = fluid.data(name='logits',
shape=[max_seq_length, batch_size, 5],
dtype='float32')
logits_length = fluid.data(name='logits_length', shape=[None],
dtype='int64')
label = fluid.layers.data(name='label', shape=[None, 1],
dtype='int32')
label_length = fluid.layers.data(name='labels_length', shape=[None],
dtype='int64')
cost = fluid.layers.warpctc(input=logits, label=label,
input_length=logits_length,
label_length=label_length)
place = fluid.CPUPlace()
batch_size = 2
x = np.random.rand(max_seq_length, batch_size, 5).astype("float32")
y = np.random.randint(0, 5, [max_seq_length * batch_size, 1]).astype("int32")
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output= exe.run(feed={"logits": x,
"label": y,
"logits_length": np.array([5, 4]).astype("int64"),
"labels_length": np.array([3, 2]).astype("int64")},
fetch_list=[cost.name])
print(output)
"""
helper = LayerHelper('warpctc', **locals())
this_inputs = {'Logits': [input], 'Label': [label]}
if input_length and label_length:
this_inputs['LogitsLength'] = [input_length]
this_inputs['LabelLength'] = [label_length]
loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='warpctc',
inputs=this_inputs,
outputs={'WarpCTCGrad': [grad_out],
'Loss': [loss_out]},
attrs={
'blank': blank,
'norm_by_times': norm_by_times,
})
return loss_out
# FIXME(wuyi): let docstring_checker.py understand @autodoc.
# For now, the comments in c++ use types like Tensor, but in python side
# the type is often "Variable", and arguments may vary.
@templatedoc(op_type="nce")
def nce(input,
label,
num_total_classes,
sample_weight=None,
param_attr=None,
bias_attr=None,
num_neg_samples=None,
name=None,
sampler="uniform",
custom_dist=None,
seed=0,
is_sparse=False):
"""
${comment}
Args:
input (Variable): Input variable, 2-D tensor with shape [batch_size, dim],
and data type is float32 or float64.
label (Variable): Input label, 2-D tensor with shape [batch_size, num_true_class],
and data type is int64.
num_total_classes (int):${num_total_classes_comment}.
sample_weight (Variable|None): A Variable of shape [batch_size, 1]
storing a weight for each sample. The default weight for each
sample is 1.0.
param_attr (ParamAttr|None): To specify the weight parameter attribute.
Default: None, which means the default weight parameter property is
used. See usage for details in :ref:`api_fluid_ParamAttr` .
bias_attr (ParamAttr|None): To specify the bias parameter attribute.
Default: None, which means the default bias parameter property is
used. See usage for details in :ref:`api_fluid_ParamAttr` .
num_neg_samples (int): ${num_neg_samples_comment}.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
sampler (str, optional): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
custom_dist (nd.array|None): A numpy ndarray with size=num_total_classes.
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
seed (int, optional): The seed used in sampler. Default 0, means no random seed.
is_sparse(bool, optional): The flag indicating whether to use sparse update,
the weight@GRAD and bias@GRAD will be changed to SelectedRows. Default False.
Returns:
Variable: The output nce loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
window_size = 5
words = []
for i in xrange(window_size):
words.append(fluid.data(
name='word_{0}'.format(i), shape=[-1, 1], dtype='int64'))
dict_size = 10000
label_word = int(window_size / 2) + 1
embs = []
for i in xrange(window_size):
if i == label_word:
continue
emb = fluid.layers.embedding(input=words[i], size=[dict_size, 32],
param_attr='embed', is_sparse=True)
embs.append(emb)
embs = fluid.layers.concat(input=embs, axis=1)
loss = fluid.layers.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w_0',
bias_attr='nce.b_0')
#or use custom distribution
dist = np.array([0.05,0.5,0.1,0.3,0.05])
loss = fluid.layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w_1',
bias_attr='nce.b_1',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
"""
helper = LayerHelper('nce', **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in nce layer must be Variable, but received %s"
% (type(input)))
if not isinstance(label, Variable):
raise TypeError(
"The type of 'label' in nce layer must be Variable, but received %s"
% (type(label)))
if convert_dtype(input.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'input' in nce layer must be float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(label.dtype) not in ['int64']:
raise TypeError(
"The data type of 'label' in nce layer must be int64, but received %s."
% (convert_dtype(label.dtype)))
dim = input.shape[1]
num_true_class = label.shape[1]
w = helper.create_parameter(
attr=helper.param_attr,
shape=[num_total_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {}
if helper.bias_attr:
b = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_total_classes, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = b
cost = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)
inputs['Input'] = input
inputs['Label'] = label
inputs['Weight'] = w
inputs['SampleWeight'] = sample_weight if sample_weight is not None else []
if sampler == "uniform":
sampler = 0
elif sampler == "log_uniform":
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
# assert isinstance(custom_dist, Variable)
custom_dist_len = num_total_classes
alias_probs_ = [0] * custom_dist_len
alias_ = [0] * custom_dist_len
bigs = []
littles = []
for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 0:
bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 0:
littles.append((i, normal_prob))
else:
alias_probs_[i] = normal_prob
alias_[i] = -1
while len(bigs) and len(littles):
big = bigs.pop(0)
little = littles.pop(0)
big_idx = big[0]
big_prob = big[1]
alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1
if big_left - 1.0 > 0:
bigs.append((big_idx, big_left))
elif 1.0 - big_left > 0:
littles.append((big_idx, big_left))
else:
alias_probs_[big_idx] = big_left
alias_[big_idx] = -1
if len(bigs):
big = bigs.pop(0)
alias_probs_[big[0]] = 1.0
alias_[big[0]] = -1
if len(littles):
little = littles.pop(0)
alias_probs_[little[0]] = 1.0
alias_[little[0]] = -1
def _init_by_numpy_array(numpy_array):
ret = helper.create_parameter(
attr=ParamAttr(),
shape=numpy_array.shape,
dtype=numpy_array.dtype,
default_initializer=NumpyArrayInitializer(numpy_array))
ret.stop_gradient = True
return ret
inputs['CustomDistProbs'] = _init_by_numpy_array(
np.array(custom_dist).astype('float32'))
inputs['CustomDistAlias'] = _init_by_numpy_array(
np.array(alias_).astype('int32'))
inputs['CustomDistAliasProbs'] = _init_by_numpy_array(
np.array(alias_probs_).astype('float32'))
sampler = 2
else:
raise Exception("Unsupported sampler type.")
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
remote_prefetch = is_sparse
print(
"With sparse mode, if your models has only small parameter prefetch may cause speed down"
)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler,
'is_sparse': is_sparse,
'remote_prefetch': remote_prefetch
}
helper.append_op(
type='nce',
inputs=inputs,
outputs={
'Cost': cost,
'SampleLogits': sample_logits,
'SampleLabels': sample_labels
},
attrs=attrs)
return cost / (num_neg_samples + 1)
def hsigmoid(input,
label,
num_classes,
param_attr=None,
bias_attr=None,
name=None,
path_table=None,
path_code=None,
is_custom=False,
is_sparse=False):
"""
The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
and speed up the model training, especially the training of language model.
Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
the path, and sum them to get a total cost.
Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the number of classes or the size of word dict.
The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`. For the custom
tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):
1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
to the same batch of inputs.
Parameters:
input (Variable): A tensor with the shape [N, D], where N is the size of mini-batch,
and D is the feature size. Its data type supports float32 and float64.
label (Variable): A tensor contains the labels of training data. Its shape is [N, 1]
and data type is int64.
num_classes (int): The number of classes or the size of word dict, must be greater than 2.
If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
:attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
classes using by the binary classifier.
param_attr (ParamAttr, optional): The parameter attribute for the learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
set, the bias is initialized zero. Default: None.
name (str, optional): Normally there is no need for user to set this property. For more information,
please refer to :ref:`api_guide_Name`. Default: None.
path_table (Variable, optional): A tensor that stores each batch of samples' path from leaf to root
node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i,
path_table[i] is a np.array like structure and each element in this array is the indexes in parent
nodes' weight matrix. Default: None.
path_code (Variable, optional): A tensor that stores each batch of samples' code of path from leaf
to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`.
Each code of path is consisted with the code of nodes from leaf to root node. Default: None.
is_custom (bool, optional): Whether use custom binary tree. If it's True, :attr:`path_table`,
:attr:`path_code` and :attr:`num_classes` should be set, otherwise :attr:`num_classes` should
be set. Default: False.
is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the
gradient of W and input will be sparse. Default: False.
Returns:
Variable: A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as :attr:`input`.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.fill_constant(shape=[4, 3], value=0.9, dtype='float32')
# x = [[0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9]]
y = fluid.layers.fill_constant(
shape=[4, 1], value=1, dtype='int64')
# y = [[1], [1], [1], [1]]
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=2, param_attr=fluid.initializer.Constant(
value=0.05), bias_attr=fluid.initializer.Constant(value=.0))
# out = [[0.62792355], [0.62792355], [0.62792355], [0.62792355]]
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
pre_out = helper.create_variable_for_type_inference(dtype)
dim = input.shape[1]
if ((num_classes is None) or (num_classes < 2)) and (not is_custom):
raise ValueError(
"num_classes must not be less than 2 with default tree")
if (not is_custom) and (is_sparse):
print("Sparse mode should not be used without custom tree")
is_sparse = False
if (not is_custom) and ((path_table is not None) or
(path_code is not None)):
raise ValueError(
"only num_classes should be passed without custom tree")
if (is_custom) and (path_code is None):
raise ValueError("path_code should not be None with custom tree")
elif (is_custom) and (path_table is None):
raise ValueError("path_table should not be None with custom tree")
elif (is_custom) and (num_classes is None):
raise ValueError("num_classes should not be None with custom tree")
else:
pass
weights = None
remote_prefetch = is_sparse
print(
"With sparse mode, if your models has only small parameter prefetch may cause speed down"
)
if not is_custom:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
else:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {
"X": input,
"W": weights,
"PathTable": path_table,
"PathCode": path_code,
"Label": label
}
if helper.bias_attr:
if not is_custom:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_classes - 1, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
else:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_classes, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out,
"W_Out": weights},
attrs={
"num_classes": num_classes,
"is_sparse": is_sparse,
"remote_prefetch": remote_prefetch
})
return out
def sampled_softmax_with_cross_entropy(logits,
label,
num_samples,
num_true=1,
remove_accidental_hits=True,
use_customized_samples=False,
customized_samples=None,
customized_probabilities=None,
seed=0):
"""
**Sampled Softmax With Cross Entropy Operator.**
Cross entropy loss with sampled softmax is used as the output layer for
larger output classes extensively. This operator samples a number of samples
for all examples, and computes the softmax normalized values for each
row of the sampled tensor, after which cross-entropy loss is computed.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
For examples with T true labels (T >= 1), we assume that each true label has
a probability of 1/T. For each sample, S samples are generated using a
log uniform distribution. True labels are concatenated with these samples to
form T + S samples for each example. So, assume the shape of logits is
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
probability is calculated, which corresponds to the Q(y|x) in
[Jean et al., 2014](http://arxiv.org/abs/1412.2007).
Logits are sampled according to the sampled labels. Then if
remove_accidental_hits is True, if a sample[i, j] accidentally hits true
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
make its softmax result close to zero. Then sampled logits are subtracted by
logQ(y|x), these sampled logits and re-indexed labels are used to compute
a softmax with cross entropy.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. Label is a
Tensor<int64> with shape [N x T], where T is the number of true
labels per example.
num_samples (int): The number for each example, num_samples should be
less than the number of class.
num_true(int): The number of target classes per training example.
remove_accidental_hits (bool): A flag indicating whether to remove
accidental hits when sampling. If True and if a sample[i, j]
accidentally hits true labels, then the corresponding
sampled_logits[i, j] is minus by 1e20 to make its softmax result
close to zero. Default is True.
use_customized_samples (bool): Whether to use custom samples and probabities to sample
logits.
customized_samples (Variable): User defined samples, which is a 2-D tensor
with shape [N, T + S]. S is the num_samples, and T is the number of true
labels per example.
customized_probabilities (Variable): User defined probabilities of samples,
a 2-D tensor which has the same shape with customized_samples.
seed (int): The random seed for generating random number, which is used
in the process of sampling. Default is 0.
Returns:
Variable: Return the cross entropy loss which is a 2-D tensor with shape
[N x 1].
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name='data', shape=[256], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc = fluid.layers.fc(input=input, size=100)
out = fluid.layers.sampled_softmax_with_cross_entropy(
logits=fc, label=label, num_samples=25)
"""
helper = LayerHelper('sample_logits', **locals())
samples = helper.create_variable_for_type_inference(dtype='int64')
probabilities = helper.create_variable_for_type_inference(
dtype=logits.dtype)
sampled_logits \
= helper.create_variable_for_type_inference(dtype=logits.dtype)
sampled_label = helper.create_variable_for_type_inference(dtype='int64')
sampled_softlabel = helper.create_variable_for_type_inference(
dtype=logits.dtype)
logits_dim = helper.create_variable_for_type_inference(dtype=logits.dtype)
labels_dim = helper.create_variable_for_type_inference(dtype=label.type)
helper.append_op(
type='sample_logits',
inputs={
'Logits': logits,
'Labels': label,
'CustomizedSamples': customized_samples,
'CustomizedProbabilities': customized_probabilities
},
outputs={
'Samples': samples,
'Probabilities': probabilities,
'SampledLabels': sampled_label,
'SampledLogits': sampled_logits,
'LogitsDim': logits_dim,
'LabelsDim': labels_dim
},
attrs={
'use_customized_samples': use_customized_samples,
'uniq': True,
'remove_accidental_hits': remove_accidental_hits,
'num_samples': num_samples,
'seed': seed
})
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='one_hot',
inputs={'X': sampled_label},
attrs={'depth': num_samples + 1},
outputs={'Out': sampled_softlabel})
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': sampled_logits,
'Label': sampled_softlabel},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': True,
'ignore_index': False,
'numeric_stable_mode': False
})
return loss / num_true
def softmax_with_cross_entropy(logits,
label,
soft_label=False,
ignore_index=kIgnoreIndex,
numeric_stable_mode=True,
return_softmax=False,
axis=-1):
"""
This operator implements the cross entropy loss function with softmax. This function
combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute :attr:`soft_label` is set :attr:`False`, this operators
expects mutually exclusive hard labels, each sample in a batch is in exactly
one class with a probability of 1.0. Each sample in the batch will have a
single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math::
loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes)
.. math::
loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K
3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:
.. math::
max_j &= \\max_{i=0}^{K}{\\text{logits}_i}
log\\_max\\_sum_j &= \\log\\sum_{i=0}^{K}\\exp(logits_i - max_j)
softmax_j &= \\exp(logits_j - max_j - {log\\_max\\_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
label (Variable): The ground truth ``Tensor`` , data type is the same
as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`,
Label is a ``Tensor`` in the same shape with :attr:`logits`.
If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor``
in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
soft_label (bool, optional): A flag to indicate whether to interpretate the given
labels as soft labels. Default False.
ignore_index (int, optional): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if :attr:`soft_label` is set to :attr:`False`.
Default: kIgnoreIndex(-100).
numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when :attr:`soft_label` is :attr:`False`
and GPU is used. When :attr:`soft_label`
is :attr:`True` or CPU is used, the
algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: True.
return_softmax (bool, optional): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns:
``Variable`` or Tuple of two ``Variable`` : Return the cross entropy loss if \
`return_softmax` is False, otherwise the tuple \
(loss, softmax), softmax is in the same shape \
with input logits and cross entropy loss is in \
the same shape with input logits except shape \
in dimension :attr:`axis` as 1.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.data(name='data', shape=[-1, 128], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.softmax_with_cross_entropy(
logits=fc, label=label)
"""
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode,
'axis': axis
})
if return_softmax:
return loss, softmax
return loss
def rank_loss(label, left, right, name=None):
"""
This operator implements the sort loss layer in the RankNet model. RankNet is a pairwise ranking model
with a training sample consisting of a pair of documents (A and B), The label (P)
indicates whether A is ranked higher than B or not. Please refer to more details:
`RankNet <http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf>`_
Rank loss layer takes three inputs: left ( :math:`o_i` ), right ( :math:`o_j` ) and
label ( :math:`P_{i,j}` ). The inputs respectively represent RankNet's output scores
for documents A and B and the value of label P. Rank loss layer takes batch inputs
with size batch_size (batch_size >= 1), P = {0, 1} or {0, 0.5, 1},
where 0.5 means that there is no information about the rank of the input pair.
The following equation computes rank loss C_{i,j} from the inputs:
.. math::
C_{i,j} &= -\\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\\\
.. math::
o_{i,j} &= o_i - o_j \\\\
.. math::
\\tilde{P_{i,j}} &= \\left \{0, 0.5, 1 \\right \} \ or \ \\left \{0, 1 \\right \}
Parameters:
label (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32, batch indicates the size of the data. Indicats whether A ranked higher than B or not.
left (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32. RankNet's output score for doc A.
right (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32. RankNet's output score for doc B.
name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: ``Tensor`` indicating the output value of the sort loss layer, the data type is float32, and the return value's shape is :math:`[batch,1]` .
Raises:
ValueError: Any of label, left, and right is not a ``Variable`` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
label = fluid.data(name="label", shape=[-1, 1], dtype="float32")
left = fluid.data(name="left", shape=[-1, 1], dtype="float32")
right = fluid.data(name="right", shape=[-1, 1], dtype="float32")
out = fluid.layers.rank_loss(label, left, right)
"""
helper = LayerHelper('rank_loss', **locals())
if not (isinstance(label, Variable)):
raise ValueError("The label should be a Variable")
if not (isinstance(left, Variable)):
raise ValueError("The left should be a Variable")
if not (isinstance(right, Variable)):
raise ValueError("The right should be a Variable")
out = helper.create_variable_for_type_inference("float32")
helper.append_op(
type='rank_loss',
inputs={"Label": label,
"Left": left,
"Right": right},
outputs={'Out': out})
return out
def margin_rank_loss(label, left, right, margin=0.1, name=None):
"""
Margin Ranking Loss Layer for ranking problem,
which compares left score and right score passed in.
The ranking loss can be defined as following equation:
.. math::
rank\_loss = max(0, -label * (left - right) + margin)
Args:
label (Variable): Indicates whether the left is ranked higher than the right or not.
Data type is float32.
left (Variable): Ranking score for left. Data type float32.
right (Variable): Ranking score for right. Data type float32.
margin (float): Indicates the given margin.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Returns:
Variable: The ranking loss.
Raises:
ValueError: Any of label, left, and right is not a Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
label = fluid.data(name="label", shape=[-1, 1], dtype="float32")
left = fluid.data(name="left", shape=[-1, 1], dtype="float32")
right = fluid.data(name="right", shape=[-1, 1], dtype="float32")
out = fluid.layers.margin_rank_loss(label, left, right)
"""
helper = LayerHelper('margin_rank_loss', **locals())
if not isinstance(label, Variable):
raise ValueError("The label should be a Variable.")
if not isinstance(left, Variable):
raise ValueError("The left should be a Variable.")
if not isinstance(right, Variable):
raise ValueError("The right should be a Variable.")
out = helper.create_variable_for_type_inference(left.dtype)
act = helper.create_variable_for_type_inference(left.dtype)
helper.append_op(
type='margin_rank_loss',
inputs={"Label": label,
"X1": left,
"X2": right},
outputs={'Out': out,
'Activated': act},
attrs={'margin': margin})
return out
@templatedoc()
def sigmoid_cross_entropy_with_logits(x,
label,
ignore_index=kIgnoreIndex,
name=None,
normalize=False):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
label(${label_type}): ${label_comment}
ignore_index(int): ${ignore_index_comment}
name(str|None): The default value is None. Normally there is
no need for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
normalize(bool): If true, divide the output by the number of
targets != ignore_index.
Returns:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(
name='data', shape=[10], dtype='float32')
label = fluid.data(
name='data', shape=[10], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=input,
label=label,
ignore_index=-1,
normalize=True) # or False
# loss = fluid.layers.reduce_sum(loss) # summation of loss
"""
helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="sigmoid_cross_entropy_with_logits",
inputs={"X": x,
"Label": label},
attrs={"ignore_index": ignore_index,
'normalize': normalize},
outputs={"Out": out})
return out
def teacher_student_sigmoid_loss(input,
label,
soft_max_up_bound=15.0,
soft_max_lower_bound=-15.0):
"""
**Teacher Student Log Loss Layer**
This layer accepts input predictions and target label and returns the
teacher_student loss. Z is click or not, z' is value of teacher loss, label = {-2, -1, [0, 2]}
when z' is not exist, clk = 0 : label = -2; when z' is not exist, clk = 1 : label = -1;
when z' is exist , clk = 0 : label = 0 + z'; when z' is exist , clk = 1 : label = 1 + z'
.. math::
loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x)))
Args:
input (Variable|list): a 2-D tensor with shape [N x 1], where N is the
batch size. This input is a probability computed
by the previous operator.
label (Variable|list): the ground truth which is a 2-D tensor with
shape [N x 1], where N is the batch size.
soft_max_up_bound (float): if input > soft_max_up_bound, will be bound
soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound
Returns:
Variable: A 2-D tensor with shape [N x 1], the teacher_student_sigmoid_loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
batch_size = 64
label = fluid.data(
name="label", shape=[batch_size, 1], dtype="int64")
similarity = fluid.data(
name="similarity", shape=[batch_size, 1], dtype="float32")
cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label)
"""
helper = LayerHelper('teacher_student_sigmoid_loss', **locals())
out = helper.create_variable(dtype=input.dtype)
helper.append_op(
type='teacher_student_sigmoid_loss',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs={"soft_max_lower_bound": float(soft_max_lower_bound), \
"soft_max_up_bound": float(soft_max_up_bound)})
return out
def huber_loss(input, label, delta):
"""
This operator computes the Huber loss between input and label.
Huber loss is commonly used in regression tasks. Compared to square_error_cost, Huber loss is more robust and less sensitivity to outliers.
When the absolute difference between input and label is greater than delta, the linear error is calculated:
.. math::
huber\_loss = delta * (label - input) - 0.5 * delta * delta
When the absolute difference between input and label is greater than delta, the square error is calculated:
.. math::
huber\_loss = 0.5 * (label - input) * (label - input)
Args:
input (Variable): Predicted data, 2D-Tensor with the shape of [batch_size, 1]. The data type should be float32 or float64.
label (Variable): Ground truth label, 2D-Tensor with the shape of [batch_size, 1]. The data type should be float32 or float64.
delta (float): The threshold for Huber loss, which is used to control the balance between the linear error and square error. The data type should be float32.
Returns:
Variable: The huber loss, a tensor with the same shape and data type as input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
DATATYPE='float32'
input_data = np.array([[1.],[2.],[3.],[4.]]).astype(DATATYPE)
label_data = np.array([[3.],[3.],[4.],[4.]]).astype(DATATYPE)
x = fluid.data(name='input', shape=[None, 1], dtype=DATATYPE)
y = fluid.data(name='label', shape=[None, 1], dtype=DATATYPE)
loss = fluid.layers.huber_loss(input=x, label=y, delta=1.0)
place = fluid.CPUPlace()
#place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
HuberLoss, = exe.run(feed={'input':input_data ,'label':label_data}, fetch_list=[loss.name])
print(HuberLoss) #[[1.5], [0.5], [0.5], [0. ]], dtype=float32
"""
helper = LayerHelper('huber_loss', **locals())
residual = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='huber_loss',
inputs={'X': input,
'Y': label},
outputs={'Out': out,
'Residual': residual},
attrs={'delta': delta})
return out
@templatedoc()
def kldiv_loss(x, target, reduction='mean', name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
target (Variable): ${target_comment}
reduction (Variable): ${reduction_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable(Tensor): The KL divergence loss. The data type is same as input tensor
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32')
target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean')
"""
helper = LayerHelper('kldiv_loss', **locals())
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='kldiv_loss',
inputs={'X': x,
'Target': target},
outputs={'Loss': loss},
attrs={'reduction': reduction})
return loss
from .ops import square
from .control_flow import equal
def npair_loss(anchor, positive, labels, l2_reg=0.002):
'''
**Npair Loss Layer**
Read `Improved Deep Metric Learning with Multi class N pair Loss Objective\
<http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/\
papers/nips16_npairmetriclearning.pdf>`_ .
Npair loss requires paired data. Npair loss has two parts: the first part is L2
regularizer on the embedding vector; the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
Args:
anchor(Variable): embedding vector for the anchor image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
positive(Variable): embedding vector for the positive image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
labels(Variable): 1-D tensor. shape=[batch_size], the data type is float32 or float64 or int64.
l2_reg(float32): L2 regularization term on embedding vector, default: 0.002.
Returns:
A Variable holding Tensor representing the npair loss, the data type is the same as
anchor, the shape is [1].
Examples:
.. code-block:: python
import paddle.fluid as fluid
anchor = fluid.data(
name = 'anchor', shape = [18, 6], dtype = 'float32')
positive = fluid.data(
name = 'positive', shape = [18, 6], dtype = 'float32')
labels = fluid.data(
name = 'labels', shape = [18], dtype = 'float32')
npair_loss = fluid.layers.npair_loss(anchor, positive, labels, l2_reg = 0.002)
'''
Beta = 0.25
batch_size = labels.shape[0]
labels = nn.reshape(labels, shape=[batch_size, 1], inplace=True)
labels = nn.expand(labels, expand_times=[1, batch_size])
labels = equal(labels, nn.transpose(labels, perm=[1, 0])).astype('float32')
labels = labels / nn.reduce_sum(labels, dim=1, keep_dim=True)
l2loss = nn.reduce_mean(nn.reduce_sum(square(anchor), 1)) \
+ nn.reduce_mean(nn.reduce_sum(square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = nn.matmul(
anchor, positive, transpose_x=False, transpose_y=True)
softmax_ce = softmax_with_cross_entropy(
logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = nn.reduce_sum(labels * softmax_ce, 0)
celoss = nn.reduce_mean(cross_entropy)
return l2loss + celoss
def mse_loss(input, label):
"""
This op accepts input predications and target label and returns the mean square error.
The loss can be described as:
.. math::
Out = MEAN((input - label)^2)
Parameters:
input (Variable): Input tensor, the data type should be float32.
label (Variable): Label tensor, the data type shoulf be float32.
Returns:
Variable: The tensor variable storing the mean square error difference of input and label.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
output = fluid.layers.mse_loss(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([0.04000002], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = fluid.layers.mse_loss(input, label)
print(output.numpy())
# [0.04000002]
"""
return nn.reduce_mean(square_error_cost(input, label))
...@@ -38,14 +38,10 @@ from ..data_feeder import convert_dtype ...@@ -38,14 +38,10 @@ from ..data_feeder import convert_dtype
__all__ = [ __all__ = [
'fc', 'fc',
'center_loss',
'embedding', 'embedding',
'linear_chain_crf', 'linear_chain_crf',
'crf_decoding', 'crf_decoding',
'cos_sim', 'cos_sim',
'cross_entropy',
'bpr_loss',
'square_error_cost',
'chunk_eval', 'chunk_eval',
'conv2d', 'conv2d',
'conv3d', 'conv3d',
...@@ -69,22 +65,16 @@ __all__ = [ ...@@ -69,22 +65,16 @@ __all__ = [
'dropout', 'dropout',
'split', 'split',
'ctc_greedy_decoder', 'ctc_greedy_decoder',
'edit_distance',
'l2_normalize', 'l2_normalize',
'matmul', 'matmul',
'topk', 'topk',
'warpctc',
'transpose', 'transpose',
'im2sequence', 'im2sequence',
'nce',
'sampled_softmax_with_cross_entropy',
'hsigmoid',
'row_conv', 'row_conv',
'multiplex', 'multiplex',
'layer_norm', 'layer_norm',
'group_norm', 'group_norm',
'spectral_norm', 'spectral_norm',
'softmax_with_cross_entropy',
'smooth_l1', 'smooth_l1',
'one_hot', 'one_hot',
'autoincreased_step_counter', 'autoincreased_step_counter',
...@@ -117,8 +107,6 @@ __all__ = [ ...@@ -117,8 +107,6 @@ __all__ = [
'log', 'log',
'crop', 'crop',
'crop_tensor', 'crop_tensor',
'rank_loss',
'margin_rank_loss',
'elu', 'elu',
'relu6', 'relu6',
'pow', 'pow',
...@@ -165,7 +153,6 @@ __all__ = [ ...@@ -165,7 +153,6 @@ __all__ = [
'clip_by_norm', 'clip_by_norm',
'mean', 'mean',
'mul', 'mul',
'sigmoid_cross_entropy_with_logits',
'maxout', 'maxout',
'space_to_depth', 'space_to_depth',
'affine_grid', 'affine_grid',
...@@ -183,10 +170,6 @@ __all__ = [ ...@@ -183,10 +170,6 @@ __all__ = [
'py_func', 'py_func',
'psroi_pool', 'psroi_pool',
'prroi_pool', 'prroi_pool',
'teacher_student_sigmoid_loss',
'huber_loss',
'kldiv_loss',
'npair_loss',
'pixel_shuffle', 'pixel_shuffle',
'fsp_matrix', 'fsp_matrix',
'continuous_value_model', 'continuous_value_model',
...@@ -199,12 +182,9 @@ __all__ = [ ...@@ -199,12 +182,9 @@ __all__ = [
'shard_index', 'shard_index',
'hard_swish', 'hard_swish',
'gather_tree', 'gather_tree',
'mse_loss',
'uniform_random', 'uniform_random',
] ]
kIgnoreIndex = -100
def fc(input, def fc(input,
size, size,
...@@ -375,95 +355,6 @@ def fc(input, ...@@ -375,95 +355,6 @@ def fc(input,
return helper.append_activation(pre_activation) return helper.append_activation(pre_activation)
def center_loss(input,
label,
num_classes,
alpha,
param_attr,
update_center=True):
"""
**Center loss Cost layer**
This OP accepts input (deep features,the output of the last hidden layer)
and target label and return the center loss cost. The average of the
distances of each sample in the mini-batch from the center of the
corresponding category is calculated as the center loss.
For deep features, :math:`X`, and target labels, :math:`Y`, the equation is:
.. math::
Out = \\frac{1}{2}(X - Y)^2
Args:
input (Variable): a 2-D tensor with shape[N x M]. Its dtype should be float32 or float64.
label (Variable): the groud truth which is a 2-D tensor
with shape[N x 1],where N is the batch size. Its dtype should be int32.
num_classes (int): the number of classification categories.
alpha (float|Variable): learning rate of centers.
param_attr (ParamAttr): Attribute initializer of centers.
update_center (bool): whether to update value of center.
Returns:
Variable: 2-D tensor with shape [N * 1]
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(name='x',shape=[20,30],dtype='float32')
label = fluid.data(name='y',shape=[20,1],dtype='int64')
num_classes = 1000
alpha = 0.01
param_attr = fluid.initializer.Xavier(uniform=False)
center_loss=fluid.layers.center_loss(input=input,
label=label,
num_classes=1000,
alpha=alpha,
param_attr=fluid.initializer.Xavier(uniform=False),
update_center=True)
"""
helper = LayerHelper('center_loss', **locals())
dtype = helper.input_dtype()
centers_shape = [num_classes, input.shape[1]]
centers_param = helper.create_parameter(
attr=param_attr, shape=centers_shape, dtype=dtype)
centers_param.stop_gradient = True
if isinstance(alpha, Variable):
alpha_param = alpha
else:
assert isinstance(alpha, float)
alpha_param = helper.create_variable(
name="centerloss_alpha",
shape=[1],
dtype="float32",
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=True,
stop_gradient=True,
initializer=Constant(alpha))
centersdiff = helper.create_variable_for_type_inference(dtype=input.dtype)
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='center_loss',
inputs={
'X': [input],
'Label': [label],
'Centers': [centers_param],
'CenterUpdateRate': [alpha_param]
},
outputs={
'SampleCenterDiff': [centersdiff],
'Loss': [loss],
'CentersOut': [centers_param]
},
attrs={'cluster_num': num_classes,
'need_update': update_center})
return loss
def embedding(input, def embedding(input,
size, size,
is_sparse=False, is_sparse=False,
...@@ -975,216 +866,6 @@ def dropout(x, ...@@ -975,216 +866,6 @@ def dropout(x,
return out return out
def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
"""
This operator computes the cross entropy between input and label. It
supports both hard-label and and soft-label cross entropy computation.
1. Hard-label cross entropy: if soft_label=False, :math:`label[i_1, i_2, ..., i_k]`
is the hard label of each sample.
.. math::
output[i_1, i_2, ..., i_k]=-log(input[i_1, i_2, ..., i_k, j]), label[i_1, i_2, ..., i_k] = j, j != ignore\_index
2. Soft-label cross entropy: if soft_label=True, :math:`label[i_1, i_2, ..., i_k, j]`
is the soft label of each sample corresponding to the j-th class.
.. math::
output[i_1, i_2, ..., i_k]= -\sum_{j}label[i_1,i_2,...,i_k,j]*log(input[i_1, i_2, ..., i_k,j])
Args:
input (Variable): a multidimensional Tensor with shape
:math:`[N_1, N_2, ..., N_k, D]`, where the last dimension D is
the class number. The data type should be float32 or float64.
label (Variable): label value corresponding to input. If
soft_label=False, the dimension of label should be :math:`[N_1, N_2, ..., N_k]`
or :math:`[N_1, N_2, ..., N_k, 1]` , and its data type should be int64,
and the value must be inside [0, D). If soft_label=True, the shape,
data type of label should be the same with input, and the sum of
soft label value of each sample should be 1.
soft_label (bool): indicate whether label is soft. Default False, meaning that
the label is hard. If soft_label=True, the label is soft.
ignore_index (int): specify an ignorable label value. The ignored label would be
omitted when computing. If it is a negative integer, no label would
be ignored. Only valid when soft_label=False. Default -100.
Returns:
A Variable holding Tensor representing the cross entropy, whose data type is the same with input.
If soft_label=False, the shape of output is the same with label.
If soft_label=True, the shape of output is :math:`[N_1, N_2, ..., N_k, 1]` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
class_num = 7
x = fluid.data(name='x', shape=[None, 3, 10], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
"""
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in cross_entropy must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in cross_entropy only support float16 on GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in cross_entropy must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if not soft_label:
return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs={"soft_label": soft_label,
"ignore_index": ignore_index})
return out
def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy2',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out],
'MatchX': [match_x],
'XShape': [xshape]},
attrs={'ignore_index': ignore_index})
return out
def bpr_loss(input, label, name=None):
"""
**Bayesian Personalized Ranking Loss Operator**
This operator belongs to pairwise ranking loss. Label is the desired item.
The loss at a given point in one session is defined as:
.. math::
Y[i] = 1/(N[i] - 1) * \sum_j{\log(\sigma(X[i, Label[i]]-X[i, j]))}
Learn more details by reading paper <session-based recommendations with recurrent
neural networks>.
Args:
input (Variable|list): a 2-D tensor with shape [N x D], where N is the
batch size and D is the number of positive classes and negative classes
This input is not probability but logits.
label (Variable|list): the ground truth which is a 2-D tensor. `label`
is a tensor<int64> with shape [N x 1].
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically. Default: None.
Returns:
A 2-D tensor with shape [N x 1], the bpr loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
neg_size = 10
label = fluid.data(
name="label", shape=[3, 1], dtype="int64")
predict = fluid.data(
name="predict", shape=[3, neg_size + 1], dtype="float32")
cost = fluid.layers.bpr_loss(input=predict, label=label)
"""
helper = LayerHelper('bpr_loss', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='bpr_loss',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]})
return out
def square_error_cost(input, label):
"""
This op accepts input predictions and target label and returns the
squared error cost.
For predictions label, and target label, the equation is:
.. math::
Out = (input - label)^2
Parameters:
input (Variable): Input tensor, the data type should be float32.
label (Variable): Label tensor, the data type should be float32.
Returns:
The tensor variable storing the element-wise squared error \
difference between input and label.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
output = fluid.layers.square_error_cost(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([0.04000002], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = fluid.layers.square_error_cost(input, label)
print(output.numpy())
# [0.04000002]
"""
helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
@templatedoc() @templatedoc()
def chunk_eval(input, def chunk_eval(input,
label, label,
...@@ -4931,111 +4612,6 @@ def topk(input, k, name=None): ...@@ -4931,111 +4612,6 @@ def topk(input, k, name=None):
return values, indices return values, indices
def edit_distance(input,
label,
normalized=True,
ignored_tokens=None,
input_length=None,
label_length=None):
"""
This op computes the edit distances between a batch of
hypothesis strings and their references. Edit distance, also called
Levenshtein distance, measures how dissimilar two strings are by counting
the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution.
For example, given hypothesis string A = "kitten" and reference
B = "sitting", the edit distance is 3 for A will be transformed into B
at least after two substitutions and one insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
The input is a LoDTensor/Tensor consisting of all the hypothesis strings with
the total number denoted by `batch_size`, and the separation is specified
by the LoD information or input_length. And the `batch_size` reference strings are arranged
in order in the same way as `input`.
The output contains the `batch_size` results and each stands for the edit
distance for a pair of strings respectively. If Attr(normalized) is true,
the edit distance will be divided by the length of reference string.
Parameters:
input(Variable): The indices for hypothesis strings, its rank should equals to 2 and its data type should be int64.
label(Variable): The indices for reference strings, its rank should equals to 2 and its data type should be int64.
normalized(bool, default True): Indicated whether to normalize the edit distance by
the length of reference string.
ignored_tokens(list<int>, default None): Tokens that should be removed before
calculating edit distance.
input_length(Variable): The length for each sequence in `input` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each sequence in `label` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns:
Tuple:
edit_distance_out(Variable): edit distance result in shape [batch_size, 1].
sequence_num(Variable): sequence number in shape [].
Examples:
.. code-block:: python
import paddle.fluid as fluid
# using LoDTensor
x_lod = fluid.data(name='x_lod', shape=[None,1], dtype='int64', lod_level=1)
y_lod = fluid.data(name='y_lod', shape=[None,1], dtype='int64', lod_level=1)
distance_lod, seq_num_lod = fluid.layers.edit_distance(input=x_lod, label=y_lod)
# using Tensor
x_seq_len = 5
y_seq_len = 6
x_pad = fluid.data(name='x_pad', shape=[None,x_seq_len], dtype='int64')
y_pad = fluid.data(name='y_pad', shape=[None,y_seq_len], dtype='int64')
x_len = fluid.data(name='x_len', shape=[None], dtype='int64')
y_len = fluid.data(name='y_len', shape=[None], dtype='int64')
distance_pad, seq_num_pad = fluid.layers.edit_distance(input=x_pad, label=y_pad, input_length=x_len, label_length=y_len)
"""
helper = LayerHelper("edit_distance", **locals())
# remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_variable_for_type_inference(dtype="int64")
erased_label = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="sequence_erase",
inputs={"X": [input]},
outputs={"Out": [erased_input]},
attrs={"tokens": ignored_tokens})
input = erased_input
helper.append_op(
type="sequence_erase",
inputs={"X": [label]},
outputs={"Out": [erased_label]},
attrs={"tokens": ignored_tokens})
label = erased_label
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length and label_length:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="edit_distance",
inputs=this_inputs,
outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]},
attrs={"normalized": normalized})
return edit_distance_out, sequence_num
def ctc_greedy_decoder(input, def ctc_greedy_decoder(input,
blank, blank,
input_length=None, input_length=None,
...@@ -5198,531 +4774,6 @@ def ctc_greedy_decoder(input, ...@@ -5198,531 +4774,6 @@ def ctc_greedy_decoder(input,
return ctc_out, ctc_out_len return ctc_out, ctc_out_len
def warpctc(input,
label,
blank=0,
norm_by_times=False,
input_length=None,
label_length=None):
"""
An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation is
interated to the Warp-CTC library to normlize values for each row of the
input tensor.
Args:
input (Variable): The unscaled probabilities of variable-length sequences,
which is a 2-D Tensor with LoD information, or a 3-D Tensor without Lod
information. When it is a 2-D LodTensor, it's shape is
[Lp, num_classes + 1], where Lp is the sum of all input
sequences' length and num_classes is the true number of classes.
(not including the blank label). When it is a 3-D Tensor, it's shape
is [max_logit_length, batch_size, num_classes + 1],
where max_logit_length is the length of the longest
input logit sequence. The data type must be float32.
label (Variable): The ground truth of variable-length sequence,
which is a 2-D Tensor with LoD information or a 2-D Tensor without
LoD information. When it is a 2-D LoDTensor or 2-D Tensor,
it is of the shape [Lg, 1], where Lg is th sum of all labels' length.
The data type must be int32.
blank (int, default 0): The blank label index of Connectionist
Temporal Classification (CTC) loss, which is in the
half-opened interval [0, num_classes + 1). The data type must be int32.
norm_by_times(bool, default false): Whether to normalize the gradients
by the number of time-step, which is also the sequence's length.
There is no need to normalize the gradients if warpctc layer was
follewed by a mean_op.
input_length(Variable): The length for each input sequence if it is
of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each label sequence if it is
of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns:
Variable: The Connectionist Temporal Classification (CTC) loss,
which is a 2-D Tensor with the shape [batch_size, 1].
The date type is the same as input.
Examples:
.. code-block:: python
# using LoDTensor
import paddle.fluid as fluid
import numpy as np
predict = fluid.data(name='predict',
shape=[None, 5],
dtype='float32',lod_level=1)
label = fluid.data(name='label', shape=[None, 1],
dtype='int32', lod_level=1)
cost = fluid.layers.warpctc(input=predict, label=label)
place = fluid.CPUPlace()
x=fluid.LoDTensor()
data = np.random.rand(8, 5).astype("float32")
x.set(data, place)
x.set_lod([[0,4,8]])
y=fluid.LoDTensor()
data = np.random.randint(0, 5, [4, 1]).astype("int32")
y.set(data, place)
y.set_lod([[0,2,4]])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output= exe.run(feed={"predict": x,"label": y},
fetch_list=[cost.name])
print output
.. code-block:: python
# using Tensor
import paddle.fluid as fluid
import numpy as np
# length of the longest logit sequence
max_seq_length = 5
# number of logit sequences
batch_size = None
logits = fluid.data(name='logits',
shape=[max_seq_length, batch_size, 5],
dtype='float32')
logits_length = fluid.data(name='logits_length', shape=[None],
dtype='int64')
label = fluid.layers.data(name='label', shape=[None, 1],
dtype='int32')
label_length = fluid.layers.data(name='labels_length', shape=[None],
dtype='int64')
cost = fluid.layers.warpctc(input=logits, label=label,
input_length=logits_length,
label_length=label_length)
place = fluid.CPUPlace()
batch_size = 2
x = np.random.rand(max_seq_length, batch_size, 5).astype("float32")
y = np.random.randint(0, 5, [max_seq_length * batch_size, 1]).astype("int32")
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output= exe.run(feed={"logits": x,
"label": y,
"logits_length": np.array([5, 4]).astype("int64"),
"labels_length": np.array([3, 2]).astype("int64")},
fetch_list=[cost.name])
print(output)
"""
helper = LayerHelper('warpctc', **locals())
this_inputs = {'Logits': [input], 'Label': [label]}
if input_length and label_length:
this_inputs['LogitsLength'] = [input_length]
this_inputs['LabelLength'] = [label_length]
loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='warpctc',
inputs=this_inputs,
outputs={'WarpCTCGrad': [grad_out],
'Loss': [loss_out]},
attrs={
'blank': blank,
'norm_by_times': norm_by_times,
})
return loss_out
# FIXME(wuyi): let docstring_checker.py understand @autodoc.
# For now, the comments in c++ use types like Tensor, but in python side
# the type is often "Variable", and arguments may vary.
@templatedoc(op_type="nce")
def nce(input,
label,
num_total_classes,
sample_weight=None,
param_attr=None,
bias_attr=None,
num_neg_samples=None,
name=None,
sampler="uniform",
custom_dist=None,
seed=0,
is_sparse=False):
"""
${comment}
Args:
input (Variable): Input variable, 2-D tensor with shape [batch_size, dim],
and data type is float32 or float64.
label (Variable): Input label, 2-D tensor with shape [batch_size, num_true_class],
and data type is int64.
num_total_classes (int):${num_total_classes_comment}.
sample_weight (Variable|None): A Variable of shape [batch_size, 1]
storing a weight for each sample. The default weight for each
sample is 1.0.
param_attr (ParamAttr|None): To specify the weight parameter attribute.
Default: None, which means the default weight parameter property is
used. See usage for details in :ref:`api_fluid_ParamAttr` .
bias_attr (ParamAttr|None): To specify the bias parameter attribute.
Default: None, which means the default bias parameter property is
used. See usage for details in :ref:`api_fluid_ParamAttr` .
num_neg_samples (int): ${num_neg_samples_comment}.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
sampler (str, optional): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
custom_dist (nd.array|None): A numpy ndarray with size=num_total_classes.
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
seed (int, optional): The seed used in sampler. Default 0, means no random seed.
is_sparse(bool, optional): The flag indicating whether to use sparse update,
the weight@GRAD and bias@GRAD will be changed to SelectedRows. Default False.
Returns:
Variable: The output nce loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
window_size = 5
words = []
for i in xrange(window_size):
words.append(fluid.data(
name='word_{0}'.format(i), shape=[-1, 1], dtype='int64'))
dict_size = 10000
label_word = int(window_size / 2) + 1
embs = []
for i in xrange(window_size):
if i == label_word:
continue
emb = fluid.layers.embedding(input=words[i], size=[dict_size, 32],
param_attr='embed', is_sparse=True)
embs.append(emb)
embs = fluid.layers.concat(input=embs, axis=1)
loss = fluid.layers.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w_0',
bias_attr='nce.b_0')
#or use custom distribution
dist = np.array([0.05,0.5,0.1,0.3,0.05])
loss = fluid.layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w_1',
bias_attr='nce.b_1',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
"""
helper = LayerHelper('nce', **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in nce layer must be Variable, but received %s"
% (type(input)))
if not isinstance(label, Variable):
raise TypeError(
"The type of 'label' in nce layer must be Variable, but received %s"
% (type(label)))
if convert_dtype(input.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'input' in nce layer must be float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(label.dtype) not in ['int64']:
raise TypeError(
"The data type of 'label' in nce layer must be int64, but received %s."
% (convert_dtype(label.dtype)))
dim = input.shape[1]
num_true_class = label.shape[1]
w = helper.create_parameter(
attr=helper.param_attr,
shape=[num_total_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {}
if helper.bias_attr:
b = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_total_classes, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = b
cost = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)
inputs['Input'] = input
inputs['Label'] = label
inputs['Weight'] = w
inputs['SampleWeight'] = sample_weight if sample_weight is not None else []
if sampler == "uniform":
sampler = 0
elif sampler == "log_uniform":
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
# assert isinstance(custom_dist, Variable)
custom_dist_len = num_total_classes
alias_probs_ = [0] * custom_dist_len
alias_ = [0] * custom_dist_len
bigs = []
littles = []
for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 0:
bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 0:
littles.append((i, normal_prob))
else:
alias_probs_[i] = normal_prob
alias_[i] = -1
while len(bigs) and len(littles):
big = bigs.pop(0)
little = littles.pop(0)
big_idx = big[0]
big_prob = big[1]
alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1
if big_left - 1.0 > 0:
bigs.append((big_idx, big_left))
elif 1.0 - big_left > 0:
littles.append((big_idx, big_left))
else:
alias_probs_[big_idx] = big_left
alias_[big_idx] = -1
if len(bigs):
big = bigs.pop(0)
alias_probs_[big[0]] = 1.0
alias_[big[0]] = -1
if len(littles):
little = littles.pop(0)
alias_probs_[little[0]] = 1.0
alias_[little[0]] = -1
def _init_by_numpy_array(numpy_array):
ret = helper.create_parameter(
attr=ParamAttr(),
shape=numpy_array.shape,
dtype=numpy_array.dtype,
default_initializer=NumpyArrayInitializer(numpy_array))
ret.stop_gradient = True
return ret
inputs['CustomDistProbs'] = _init_by_numpy_array(
np.array(custom_dist).astype('float32'))
inputs['CustomDistAlias'] = _init_by_numpy_array(
np.array(alias_).astype('int32'))
inputs['CustomDistAliasProbs'] = _init_by_numpy_array(
np.array(alias_probs_).astype('float32'))
sampler = 2
else:
raise Exception("Unsupported sampler type.")
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
remote_prefetch = is_sparse
print(
"With sparse mode, if your models has only small parameter prefetch may cause speed down"
)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler,
'is_sparse': is_sparse,
'remote_prefetch': remote_prefetch
}
helper.append_op(
type='nce',
inputs=inputs,
outputs={
'Cost': cost,
'SampleLogits': sample_logits,
'SampleLabels': sample_labels
},
attrs=attrs)
return cost / (num_neg_samples + 1)
def hsigmoid(input,
label,
num_classes,
param_attr=None,
bias_attr=None,
name=None,
path_table=None,
path_code=None,
is_custom=False,
is_sparse=False):
"""
The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
and speed up the model training, especially the training of language model.
Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
the path, and sum them to get a total cost.
Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the number of classes or the size of word dict.
The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`. For the custom
tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):
1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
to the same batch of inputs.
Parameters:
input (Variable): A tensor with the shape [N, D], where N is the size of mini-batch,
and D is the feature size. Its data type supports float32 and float64.
label (Variable): A tensor contains the labels of training data. Its shape is [N, 1]
and data type is int64.
num_classes (int): The number of classes or the size of word dict, must be greater than 2.
If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
:attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
classes using by the binary classifier.
param_attr (ParamAttr, optional): The parameter attribute for the learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
set, the bias is initialized zero. Default: None.
name (str, optional): Normally there is no need for user to set this property. For more information,
please refer to :ref:`api_guide_Name`. Default: None.
path_table (Variable, optional): A tensor that stores each batch of samples' path from leaf to root
node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i,
path_table[i] is a np.array like structure and each element in this array is the indexes in parent
nodes' weight matrix. Default: None.
path_code (Variable, optional): A tensor that stores each batch of samples' code of path from leaf
to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`.
Each code of path is consisted with the code of nodes from leaf to root node. Default: None.
is_custom (bool, optional): Whether use custom binary tree. If it's True, :attr:`path_table`,
:attr:`path_code` and :attr:`num_classes` should be set, otherwise :attr:`num_classes` should
be set. Default: False.
is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the
gradient of W and input will be sparse. Default: False.
Returns:
Variable: A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as :attr:`input`.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.fill_constant(shape=[4, 3], value=0.9, dtype='float32')
# x = [[0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9]]
y = fluid.layers.fill_constant(
shape=[4, 1], value=1, dtype='int64')
# y = [[1], [1], [1], [1]]
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=2, param_attr=fluid.initializer.Constant(
value=0.05), bias_attr=fluid.initializer.Constant(value=.0))
# out = [[0.62792355], [0.62792355], [0.62792355], [0.62792355]]
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
pre_out = helper.create_variable_for_type_inference(dtype)
dim = input.shape[1]
if ((num_classes is None) or (num_classes < 2)) and (not is_custom):
raise ValueError(
"num_classes must not be less than 2 with default tree")
if (not is_custom) and (is_sparse):
print("Sparse mode should not be used without custom tree")
is_sparse = False
if (not is_custom) and ((path_table is not None) or
(path_code is not None)):
raise ValueError(
"only num_classes should be passed without custom tree")
if (is_custom) and (path_code is None):
raise ValueError("path_code should not be None with custom tree")
elif (is_custom) and (path_table is None):
raise ValueError("path_table should not be None with custom tree")
elif (is_custom) and (num_classes is None):
raise ValueError("num_classes should not be None with custom tree")
else:
pass
weights = None
remote_prefetch = is_sparse
print(
"With sparse mode, if your models has only small parameter prefetch may cause speed down"
)
if not is_custom:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
else:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {
"X": input,
"W": weights,
"PathTable": path_table,
"PathCode": path_code,
"Label": label
}
if helper.bias_attr:
if not is_custom:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_classes - 1, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
else:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_classes, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out,
"W_Out": weights},
attrs={
"num_classes": num_classes,
"is_sparse": is_sparse,
"remote_prefetch": remote_prefetch
})
return out
def transpose(x, perm, name=None): def transpose(x, perm, name=None):
""" """
Permute the data dimensions of `input` according to `perm`. Permute the data dimensions of `input` according to `perm`.
...@@ -6054,270 +5105,20 @@ def multiplex(inputs, index): ...@@ -6054,270 +5105,20 @@ def multiplex(inputs, index):
res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out]) res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out])
print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)] print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)]
""" """
helper = LayerHelper('multiplex', **locals()) helper = LayerHelper('multiplex', **locals())
if not isinstance(inputs, list) and len(inputs) < 2:
raise ValueError("inputs should be a list object and contains at least "
"2 elements.")
out = helper.create_variable_for_type_inference(inputs[0].dtype)
helper.append_op(
type='multiplex',
inputs={'X': inputs,
'Ids': index},
outputs={'Out': [out]})
return out
def softmax_with_cross_entropy(logits,
label,
soft_label=False,
ignore_index=kIgnoreIndex,
numeric_stable_mode=True,
return_softmax=False,
axis=-1):
"""
This operator implements the cross entropy loss function with softmax. This function
combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute :attr:`soft_label` is set :attr:`False`, this operators
expects mutually exclusive hard labels, each sample in a batch is in exactly
one class with a probability of 1.0. Each sample in the batch will have a
single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math::
loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes)
.. math::
loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K
3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:
.. math::
max_j &= \\max_{i=0}^{K}{\\text{logits}_i}
log\\_max\\_sum_j &= \\log\\sum_{i=0}^{K}\\exp(logits_i - max_j)
softmax_j &= \\exp(logits_j - max_j - {log\\_max\\_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
label (Variable): The ground truth ``Tensor`` , data type is the same
as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`,
Label is a ``Tensor`` in the same shape with :attr:`logits`.
If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor``
in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
soft_label (bool, optional): A flag to indicate whether to interpretate the given
labels as soft labels. Default False.
ignore_index (int, optional): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if :attr:`soft_label` is set to :attr:`False`.
Default: kIgnoreIndex(-100).
numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when :attr:`soft_label` is :attr:`False`
and GPU is used. When :attr:`soft_label`
is :attr:`True` or CPU is used, the
algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: True.
return_softmax (bool, optional): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns:
``Variable`` or Tuple of two ``Variable`` : Return the cross entropy loss if \
`return_softmax` is False, otherwise the tuple \
(loss, softmax), softmax is in the same shape \
with input logits and cross entropy loss is in \
the same shape with input logits except shape \
in dimension :attr:`axis` as 1.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.data(name='data', shape=[-1, 128], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.softmax_with_cross_entropy(
logits=fc, label=label)
"""
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode,
'axis': axis
})
if return_softmax:
return loss, softmax
return loss
def sampled_softmax_with_cross_entropy(logits,
label,
num_samples,
num_true=1,
remove_accidental_hits=True,
use_customized_samples=False,
customized_samples=None,
customized_probabilities=None,
seed=0):
"""
**Sampled Softmax With Cross Entropy Operator.**
Cross entropy loss with sampled softmax is used as the output layer for
larger output classes extensively. This operator samples a number of samples
for all examples, and computes the softmax normalized values for each
row of the sampled tensor, after which cross-entropy loss is computed.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
For examples with T true labels (T >= 1), we assume that each true label has
a probability of 1/T. For each sample, S samples are generated using a
log uniform distribution. True labels are concatenated with these samples to
form T + S samples for each example. So, assume the shape of logits is
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
probability is calculated, which corresponds to the Q(y|x) in
[Jean et al., 2014](http://arxiv.org/abs/1412.2007).
Logits are sampled according to the sampled labels. Then if
remove_accidental_hits is True, if a sample[i, j] accidentally hits true
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
make its softmax result close to zero. Then sampled logits are subtracted by
logQ(y|x), these sampled logits and re-indexed labels are used to compute
a softmax with cross entropy.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. Label is a
Tensor<int64> with shape [N x T], where T is the number of true
labels per example.
num_samples (int): The number for each example, num_samples should be
less than the number of class.
num_true(int): The number of target classes per training example.
remove_accidental_hits (bool): A flag indicating whether to remove
accidental hits when sampling. If True and if a sample[i, j]
accidentally hits true labels, then the corresponding
sampled_logits[i, j] is minus by 1e20 to make its softmax result
close to zero. Default is True.
use_customized_samples (bool): Whether to use custom samples and probabities to sample
logits.
customized_samples (Variable): User defined samples, which is a 2-D tensor
with shape [N, T + S]. S is the num_samples, and T is the number of true
labels per example.
customized_probabilities (Variable): User defined probabilities of samples,
a 2-D tensor which has the same shape with customized_samples.
seed (int): The random seed for generating random number, which is used
in the process of sampling. Default is 0.
Returns:
Variable: Return the cross entropy loss which is a 2-D tensor with shape
[N x 1].
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name='data', shape=[256], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc = fluid.layers.fc(input=input, size=100)
out = fluid.layers.sampled_softmax_with_cross_entropy(
logits=fc, label=label, num_samples=25)
"""
helper = LayerHelper('sample_logits', **locals())
samples = helper.create_variable_for_type_inference(dtype='int64')
probabilities = helper.create_variable_for_type_inference(
dtype=logits.dtype)
sampled_logits \
= helper.create_variable_for_type_inference(dtype=logits.dtype)
sampled_label = helper.create_variable_for_type_inference(dtype='int64')
sampled_softlabel = helper.create_variable_for_type_inference(
dtype=logits.dtype)
logits_dim = helper.create_variable_for_type_inference(dtype=logits.dtype)
labels_dim = helper.create_variable_for_type_inference(dtype=label.type)
helper.append_op( if not isinstance(inputs, list) and len(inputs) < 2:
type='sample_logits', raise ValueError("inputs should be a list object and contains at least "
inputs={ "2 elements.")
'Logits': logits,
'Labels': label,
'CustomizedSamples': customized_samples,
'CustomizedProbabilities': customized_probabilities
},
outputs={
'Samples': samples,
'Probabilities': probabilities,
'SampledLabels': sampled_label,
'SampledLogits': sampled_logits,
'LogitsDim': logits_dim,
'LabelsDim': labels_dim
},
attrs={
'use_customized_samples': use_customized_samples,
'uniq': True,
'remove_accidental_hits': remove_accidental_hits,
'num_samples': num_samples,
'seed': seed
})
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='one_hot',
inputs={'X': sampled_label},
attrs={'depth': num_samples + 1},
outputs={'Out': sampled_softlabel})
out = helper.create_variable_for_type_inference(inputs[0].dtype)
helper.append_op( helper.append_op(
type='softmax_with_cross_entropy', type='multiplex',
inputs={'Logits': sampled_logits, inputs={'X': inputs,
'Label': sampled_softlabel}, 'Ids': index},
outputs={'Softmax': softmax, outputs={'Out': [out]})
'Loss': loss}, return out
attrs={
'soft_label': True,
'ignore_index': False,
'numeric_stable_mode': False
})
return loss / num_true
def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
...@@ -9468,127 +8269,6 @@ def affine_grid(theta, out_shape, name=None): ...@@ -9468,127 +8269,6 @@ def affine_grid(theta, out_shape, name=None):
return out return out
def rank_loss(label, left, right, name=None):
"""
This operator implements the sort loss layer in the RankNet model. RankNet is a pairwise ranking model
with a training sample consisting of a pair of documents (A and B), The label (P)
indicates whether A is ranked higher than B or not. Please refer to more details:
`RankNet <http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf>`_
Rank loss layer takes three inputs: left ( :math:`o_i` ), right ( :math:`o_j` ) and
label ( :math:`P_{i,j}` ). The inputs respectively represent RankNet's output scores
for documents A and B and the value of label P. Rank loss layer takes batch inputs
with size batch_size (batch_size >= 1), P = {0, 1} or {0, 0.5, 1},
where 0.5 means that there is no information about the rank of the input pair.
The following equation computes rank loss C_{i,j} from the inputs:
.. math::
C_{i,j} &= -\\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\\\
.. math::
o_{i,j} &= o_i - o_j \\\\
.. math::
\\tilde{P_{i,j}} &= \\left \{0, 0.5, 1 \\right \} \ or \ \\left \{0, 1 \\right \}
Parameters:
label (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32, batch indicates the size of the data. Indicats whether A ranked higher than B or not.
left (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32. RankNet's output score for doc A.
right (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32. RankNet's output score for doc B.
name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: ``Tensor`` indicating the output value of the sort loss layer, the data type is float32, and the return value's shape is :math:`[batch,1]` .
Raises:
ValueError: Any of label, left, and right is not a ``Variable`` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
label = fluid.data(name="label", shape=[-1, 1], dtype="float32")
left = fluid.data(name="left", shape=[-1, 1], dtype="float32")
right = fluid.data(name="right", shape=[-1, 1], dtype="float32")
out = fluid.layers.rank_loss(label, left, right)
"""
helper = LayerHelper('rank_loss', **locals())
if not (isinstance(label, Variable)):
raise ValueError("The label should be a Variable")
if not (isinstance(left, Variable)):
raise ValueError("The left should be a Variable")
if not (isinstance(right, Variable)):
raise ValueError("The right should be a Variable")
out = helper.create_variable_for_type_inference("float32")
helper.append_op(
type='rank_loss',
inputs={"Label": label,
"Left": left,
"Right": right},
outputs={'Out': out})
return out
def margin_rank_loss(label, left, right, margin=0.1, name=None):
"""
Margin Ranking Loss Layer for ranking problem,
which compares left score and right score passed in.
The ranking loss can be defined as following equation:
.. math::
rank\_loss = max(0, -label * (left - right) + margin)
Args:
label (Variable): Indicates whether the left is ranked higher than the right or not.
Data type is float32.
left (Variable): Ranking score for left. Data type float32.
right (Variable): Ranking score for right. Data type float32.
margin (float): Indicates the given margin.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Returns:
Variable: The ranking loss.
Raises:
ValueError: Any of label, left, and right is not a Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
label = fluid.data(name="label", shape=[-1, 1], dtype="float32")
left = fluid.data(name="left", shape=[-1, 1], dtype="float32")
right = fluid.data(name="right", shape=[-1, 1], dtype="float32")
out = fluid.layers.margin_rank_loss(label, left, right)
"""
helper = LayerHelper('margin_rank_loss', **locals())
if not isinstance(label, Variable):
raise ValueError("The label should be a Variable.")
if not isinstance(left, Variable):
raise ValueError("The left should be a Variable.")
if not isinstance(right, Variable):
raise ValueError("The right should be a Variable.")
out = helper.create_variable_for_type_inference(left.dtype)
act = helper.create_variable_for_type_inference(left.dtype)
helper.append_op(
type='margin_rank_loss',
inputs={"Label": label,
"X1": left,
"X2": right},
outputs={'Out': out,
'Activated': act},
attrs={'margin': margin})
return out
def pad2d(input, def pad2d(input,
paddings=[0, 0, 0, 0], paddings=[0, 0, 0, 0],
mode='constant', mode='constant',
...@@ -12627,62 +11307,6 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -12627,62 +11307,6 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
return out return out
@templatedoc()
def sigmoid_cross_entropy_with_logits(x,
label,
ignore_index=kIgnoreIndex,
name=None,
normalize=False):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
label(${label_type}): ${label_comment}
ignore_index(int): ${ignore_index_comment}
name(str|None): The default value is None. Normally there is
no need for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
normalize(bool): If true, divide the output by the number of
targets != ignore_index.
Returns:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(
name='data', shape=[10], dtype='float32')
label = fluid.data(
name='data', shape=[10], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=input,
label=label,
ignore_index=-1,
normalize=True) # or False
# loss = fluid.layers.reduce_sum(loss) # summation of loss
"""
helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="sigmoid_cross_entropy_with_logits",
inputs={"X": x,
"Label": label},
attrs={"ignore_index": ignore_index,
'normalize': normalize},
outputs={"Out": out})
return out
@templatedoc() @templatedoc()
def maxout(x, groups, name=None, axis=1): def maxout(x, groups, name=None, axis=1):
""" """
...@@ -13239,58 +11863,6 @@ def log_loss(input, label, epsilon=1e-4, name=None): ...@@ -13239,58 +11863,6 @@ def log_loss(input, label, epsilon=1e-4, name=None):
return loss return loss
def teacher_student_sigmoid_loss(input,
label,
soft_max_up_bound=15.0,
soft_max_lower_bound=-15.0):
"""
**Teacher Student Log Loss Layer**
This layer accepts input predictions and target label and returns the
teacher_student loss. Z is click or not, z' is value of teacher loss, label = {-2, -1, [0, 2]}
when z' is not exist, clk = 0 : label = -2; when z' is not exist, clk = 1 : label = -1;
when z' is exist , clk = 0 : label = 0 + z'; when z' is exist , clk = 1 : label = 1 + z'
.. math::
loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x)))
Args:
input (Variable|list): a 2-D tensor with shape [N x 1], where N is the
batch size. This input is a probability computed
by the previous operator.
label (Variable|list): the ground truth which is a 2-D tensor with
shape [N x 1], where N is the batch size.
soft_max_up_bound (float): if input > soft_max_up_bound, will be bound
soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound
Returns:
Variable: A 2-D tensor with shape [N x 1], the teacher_student_sigmoid_loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
batch_size = 64
label = fluid.data(
name="label", shape=[batch_size, 1], dtype="int64")
similarity = fluid.data(
name="similarity", shape=[batch_size, 1], dtype="float32")
cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label)
"""
helper = LayerHelper('teacher_student_sigmoid_loss', **locals())
out = helper.create_variable(dtype=input.dtype)
helper.append_op(
type='teacher_student_sigmoid_loss',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs={"soft_max_lower_bound": float(soft_max_lower_bound), \
"soft_max_up_bound": float(soft_max_up_bound)})
return out
def add_position_encoding(input, alpha, beta, name=None): def add_position_encoding(input, alpha, beta, name=None):
""" """
This operator performs weighted sum of input feature at each position This operator performs weighted sum of input feature at each position
...@@ -13945,165 +12517,6 @@ def prroi_pool(input, ...@@ -13945,165 +12517,6 @@ def prroi_pool(input,
return out return out
def huber_loss(input, label, delta):
"""
This operator computes the Huber loss between input and label.
Huber loss is commonly used in regression tasks. Compared to square_error_cost, Huber loss is more robust and less sensitivity to outliers.
When the absolute difference between input and label is greater than delta, the linear error is calculated:
.. math::
huber\_loss = delta * (label - input) - 0.5 * delta * delta
When the absolute difference between input and label is greater than delta, the square error is calculated:
.. math::
huber\_loss = 0.5 * (label - input) * (label - input)
Args:
input (Variable): Predicted data, 2D-Tensor with the shape of [batch_size, 1]. The data type should be float32 or float64.
label (Variable): Ground truth label, 2D-Tensor with the shape of [batch_size, 1]. The data type should be float32 or float64.
delta (float): The threshold for Huber loss, which is used to control the balance between the linear error and square error. The data type should be float32.
Returns:
Variable: The huber loss, a tensor with the same shape and data type as input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
DATATYPE='float32'
input_data = np.array([[1.],[2.],[3.],[4.]]).astype(DATATYPE)
label_data = np.array([[3.],[3.],[4.],[4.]]).astype(DATATYPE)
x = fluid.data(name='input', shape=[None, 1], dtype=DATATYPE)
y = fluid.data(name='label', shape=[None, 1], dtype=DATATYPE)
loss = fluid.layers.huber_loss(input=x, label=y, delta=1.0)
place = fluid.CPUPlace()
#place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
HuberLoss, = exe.run(feed={'input':input_data ,'label':label_data}, fetch_list=[loss.name])
print(HuberLoss) #[[1.5], [0.5], [0.5], [0. ]], dtype=float32
"""
helper = LayerHelper('huber_loss', **locals())
residual = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='huber_loss',
inputs={'X': input,
'Y': label},
outputs={'Out': out,
'Residual': residual},
attrs={'delta': delta})
return out
@templatedoc()
def kldiv_loss(x, target, reduction='mean', name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
target (Variable): ${target_comment}
reduction (Variable): ${reduction_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable(Tensor): The KL divergence loss. The data type is same as input tensor
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32')
target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean')
"""
helper = LayerHelper('kldiv_loss', **locals())
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='kldiv_loss',
inputs={'X': x,
'Target': target},
outputs={'Loss': loss},
attrs={'reduction': reduction})
return loss
from .ops import square
from .control_flow import equal
def npair_loss(anchor, positive, labels, l2_reg=0.002):
'''
**Npair Loss Layer**
Read `Improved Deep Metric Learning with Multi class N pair Loss Objective\
<http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/\
papers/nips16_npairmetriclearning.pdf>`_ .
Npair loss requires paired data. Npair loss has two parts: the first part is L2
regularizer on the embedding vector; the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
Args:
anchor(Variable): embedding vector for the anchor image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
positive(Variable): embedding vector for the positive image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
labels(Variable): 1-D tensor. shape=[batch_size], the data type is float32 or float64 or int64.
l2_reg(float32): L2 regularization term on embedding vector, default: 0.002.
Returns:
A Variable holding Tensor representing the npair loss, the data type is the same as
anchor, the shape is [1].
Examples:
.. code-block:: python
import paddle.fluid as fluid
anchor = fluid.data(
name = 'anchor', shape = [18, 6], dtype = 'float32')
positive = fluid.data(
name = 'positive', shape = [18, 6], dtype = 'float32')
labels = fluid.data(
name = 'labels', shape = [18], dtype = 'float32')
npair_loss = fluid.layers.npair_loss(anchor, positive, labels, l2_reg = 0.002)
'''
Beta = 0.25
batch_size = labels.shape[0]
labels = reshape(labels, shape=[batch_size, 1], inplace=True)
labels = expand(labels, expand_times=[1, batch_size])
labels = equal(labels, transpose(labels, perm=[1, 0])).astype('float32')
labels = labels / reduce_sum(labels, dim=1, keep_dim=True)
l2loss = reduce_mean(reduce_sum(square(anchor), 1)) \
+ reduce_mean(reduce_sum(square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = matmul(
anchor, positive, transpose_x=False, transpose_y=True)
softmax_ce = softmax_with_cross_entropy(
logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = reduce_sum(labels * softmax_ce, 0)
celoss = reduce_mean(cross_entropy)
return l2loss + celoss
def pixel_shuffle(x, upscale_factor): def pixel_shuffle(x, upscale_factor):
""" """
...@@ -15160,62 +13573,6 @@ def gather_tree(ids, parents): ...@@ -15160,62 +13573,6 @@ def gather_tree(ids, parents):
return out return out
def mse_loss(input, label):
"""
This op accepts input predications and target label and returns the mean square error.
The loss can be described as:
.. math::
Out = MEAN((input - label)^2)
Parameters:
input (Variable): Input tensor, the data type should be float32.
label (Variable): Label tensor, the data type shoulf be float32.
Returns:
Variable: The tensor variable storing the mean square error difference of input and label.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
output = fluid.layers.mse_loss(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([0.04000002], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = fluid.layers.mse_loss(input, label)
print(output.numpy())
# [0.04000002]
"""
return reduce_mean(square_error_cost(input, label))
@templatedoc() @templatedoc()
def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册