未验证 提交 19534daf 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15215 from velconia/local_release_1_2_x_add_huber_regression_loss_op

Add python interface for huber loss
...@@ -197,6 +197,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act ...@@ -197,6 +197,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......
...@@ -124,8 +124,9 @@ REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>, ...@@ -124,8 +124,9 @@ REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp); REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
huber_loss, huber_loss, ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>); ops::HuberLossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
huber_loss_grad, huber_loss_grad,
ops::HuberLossGradKernel<paddle::platform::CPUDeviceContext, float>); ops::HuberLossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::HuberLossGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -172,6 +172,7 @@ __all__ = [ ...@@ -172,6 +172,7 @@ __all__ = [
'merge_selected_rows', 'merge_selected_rows',
'get_tensor_from_selected_rows', 'get_tensor_from_selected_rows',
'lstm', 'lstm',
'huber_loss',
] ]
...@@ -491,7 +492,7 @@ def lstm(input, ...@@ -491,7 +492,7 @@ def lstm(input,
If Device is GPU, This op will use cudnn LSTM implementation If Device is GPU, This op will use cudnn LSTM implementation
A four-gate Long Short-Term Memory network with no peephole connections. A four-gate Long Short-Term Memory network with no peephole connections.
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
$$ i_t = \\sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$ $$ i_t = \\sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$
...@@ -518,19 +519,19 @@ def lstm(input, ...@@ -518,19 +519,19 @@ def lstm(input,
- $\tilde{c_t}$ is also called candidate hidden state, - $\tilde{c_t}$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state. which is computed based on the current input and the previous hidden state.
Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication,
X represensts a matrix multiplication X represensts a matrix multiplication
Args: Args:
input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size ) input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size )
init_h(Variable): The initial hidden state of the LSTM init_h(Variable): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size) This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
init_c(Variable): The initial cell state of the LSTM. init_c(Variable): The initial cell state of the LSTM.
This is a tensor with shape ( num_layers x batch_size x hidden_size ) This is a tensor with shape ( num_layers x batch_size x hidden_size )
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len
hidden_size (int): hidden size of the LSTM hidden_size (int): hidden size of the LSTM
num_layers (int): total layers number of the LSTM num_layers (int): total layers number of the LSTM
dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
...@@ -549,10 +550,10 @@ def lstm(input, ...@@ -549,10 +550,10 @@ def lstm(input,
if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2) if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2)
last_h(Tensor): the hidden state of the last step of LSTM last_h(Tensor): the hidden state of the last step of LSTM
shape is ( num_layers x batch_size x hidden_size ) shape is ( num_layers x batch_size x hidden_size )
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
last_c(Tensor): the cell state of the last step of LSTM last_c(Tensor): the cell state of the last step of LSTM
shape is ( num_layers x batch_size x hidden_size ) shape is ( num_layers x batch_size x hidden_size )
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
Examples: Examples:
...@@ -4390,7 +4391,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4390,7 +4391,7 @@ def ctc_greedy_decoder(input, blank, name=None):
[0.5, 0.1, 0.3, 0.1]] [0.5, 0.1, 0.3, 0.1]]
input.lod = [[4, 4]] input.lod = [[4, 4]]
Computation: Computation:
step1: Apply argmax to first input sequence which is input.data[0:4]. Then we get: step1: Apply argmax to first input sequence which is input.data[0:4]. Then we get:
...@@ -4423,7 +4424,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4423,7 +4424,7 @@ def ctc_greedy_decoder(input, blank, name=None):
Returns: Returns:
Variable: CTC greedy decode result which is a 2-D tensor with shape [Lp, 1]. Variable: CTC greedy decode result which is a 2-D tensor with shape [Lp, 1].
'Lp' is the sum if all output sequences' length. If all the sequences 'Lp' is the sum if all output sequences' length. If all the sequences
in result were empty, the result LoDTensor will be [-1] with in result were empty, the result LoDTensor will be [-1] with
LoD [[]] and dims [1, 1]. LoD [[]] and dims [1, 1].
Examples: Examples:
...@@ -4777,7 +4778,7 @@ def hsigmoid(input, ...@@ -4777,7 +4778,7 @@ def hsigmoid(input,
""" """
The hierarchical sigmoid operator is used to accelerate the training The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a process of language model. This operator organizes the classes into a
complete binary tree, or you can use is_custom to pass your own tree to complete binary tree, or you can use is_custom to pass your own tree to
implement hierarchical. Each leaf node represents a class(a word) and each implement hierarchical. Each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each path from root to it's leaf node, hsigmoid calculate the cost for each
...@@ -4793,7 +4794,7 @@ def hsigmoid(input, ...@@ -4793,7 +4794,7 @@ def hsigmoid(input,
2. build a dict to store word_id -> word's leaf to root path, we call it path_table. 2. build a dict to store word_id -> word's leaf to root path, we call it path_table.
3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code 3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code
means label of each binary classification, using 1 indicate true, 0 indicate false. means label of each binary classification, using 1 indicate true, 0 indicate false.
4. now, each word should has its path and code along the path, you can pass a batch of path and code 4. now, each word should has its path and code along the path, you can pass a batch of path and code
related to the same batch of inputs. related to the same batch of inputs.
...@@ -4803,8 +4804,8 @@ def hsigmoid(input, ...@@ -4803,8 +4804,8 @@ def hsigmoid(input,
and :math:`D` is the feature size. and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data. label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`. It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set, num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set,
it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num
which indicates the num of classes using by binary classify. which indicates the num of classes using by binary classify.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
...@@ -4817,15 +4818,15 @@ def hsigmoid(input, ...@@ -4817,15 +4818,15 @@ def hsigmoid(input,
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
path_table: (Variable|None) this variable can store each batch of samples' path to root, path_table: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order it should be in leaf -> root order
path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like
structure and each element in this array is indexes in parent nodes' Weight Matrix. structure and each element in this array is indexes in parent nodes' Weight Matrix.
path_code: (Variable|None) this variable can store each batch of samples' code, path_code: (Variable|None) this variable can store each batch of samples' code,
each code consist with every code of parent nodes. it should be in leaf -> root order each code consist with every code of parent nodes. it should be in leaf -> root order
is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is
set you need to set path_table/path_code/num_classes, otherwise num_classes should be set set you need to set path_table/path_code/num_classes, otherwise num_classes should be set
is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient
of W and input will be sparse. of W and input will be sparse.
Returns: Returns:
...@@ -9049,3 +9050,42 @@ def get_tensor_from_selected_rows(x, name=None): ...@@ -9049,3 +9050,42 @@ def get_tensor_from_selected_rows(x, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs={}) attrs={})
return out return out
def huber_loss(input, label, delta):
"""
Huber loss is a loss function used in robust.
Huber loss can evaluate the fitness of input to label.
Different from MSE loss, Huber loss is more robust for outliers.
When the difference between input and label is large than delta
.. math::
huber\_loss = delta * (label - input) - 0.5 * delta * delta
When the difference between input and label is less than delta
.. math::
huber\_loss = 0.5 * (label - input) * (label - input)
Args:
input (Variable): This input is a probability computed by the previous operator.
The first dimension is batch size, and the last dimension is 1.
label (Variable): The groud truth whose first dimension is batch size
and last dimension is 1.
delta (float): The parameter of huber loss, which controls
the range of outliers
Returns:
huber\_loss (Variable): The huber loss with shape [batch_size, 1].
Examples:
.. code-block:: python
predictions = fluid.layers.softmax(x)
loss = fluid.layers.huber_loss(input=predictions, label=label, 1.0)
"""
helper = LayerHelper('huber_loss', **locals())
residual = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='huber_loss',
inputs={'X': input,
'Y': label},
outputs={'Out': out,
'Residual': residual},
attrs={'delta': delta})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册