提交 47eb8691 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #3571 from luotao1/huber_loss

refine Huber loss, add huber_regression_cost
......@@ -419,9 +419,14 @@ multi_binary_label_cross_entropy_cost
.. autoclass:: paddle.v2.layer.multi_binary_label_cross_entropy_cost
:noindex:
huber_cost
----------
.. autoclass:: paddle.v2.layer.huber_cost
huber_regression_cost
-------------------------
.. autoclass:: paddle.v2.layer.huber_regression_cost
:noindex:
huber_classification_cost
-------------------------
.. autoclass:: paddle.v2.layer.huber_classification_cost
:noindex:
lambda_cost
......
......@@ -572,13 +572,8 @@ void MultiBinaryLabelCrossEntropy::backwardImp(Matrix& output,
}
}
//
// Huber loss for robust 2-classes classification
//
REGISTER_LAYER(huber, HuberTwoClass);
bool HuberTwoClass::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool HuberCost::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
CostLayer::init(layerMap, parameterMap);
if (useGpu_) {
tmpCpuInput_.reserve(inputLayers_.size());
......@@ -589,7 +584,7 @@ bool HuberTwoClass::init(const LayerMap& layerMap,
return true;
}
void HuberTwoClass::forwardImp(Matrix& output, Argument& label, Matrix& cost) {
void HuberCost::forwardImp(Matrix& output, Argument& label, Matrix& cost) {
if (useGpu_) {
for (size_t i = 0; i < inputLayers_.size(); i++) {
tmpCpuInput_[i].resizeAndCopyFrom(
......@@ -597,13 +592,87 @@ void HuberTwoClass::forwardImp(Matrix& output, Argument& label, Matrix& cost) {
}
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
}
forwardImpIn(output, label, cost);
}
void HuberTwoClass::forwardImpIn(Matrix& output,
Argument& label,
Matrix& target) {
//
// Huber loss for robust regression.
//
REGISTER_LAYER(huber_regression, HuberRegressionLoss);
bool HuberRegressionLoss::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
HuberCost::init(layerMap, parameterMap);
delta_ = config_.delta();
return true;
}
void HuberRegressionLoss::forwardImp(Matrix& output,
Argument& label,
Matrix& target) {
HuberCost::forwardImp(output, label, target);
size_t numSamples = target.getHeight();
size_t dim = output.getWidth();
CHECK(label.value);
CHECK_EQ((*label.value).getHeight(), numSamples);
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(dim, (*label.value).getWidth());
CHECK_EQ(target.getWidth(), (size_t)1);
real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData();
real* lbl =
useGpu_ ? tmpCpuInput_[1].value->getData() : (*label.value).getData();
std::vector<real> cost(numSamples, 0);
for (size_t i = 0; i < numSamples; ++i) {
for (size_t j = 0; j < dim; ++j) {
int index = i * dim + j;
real a = std::abs(lbl[index] - out[index]);
if (a <= delta_)
cost[i] += a * a / 2;
else
cost[i] += delta_ * (a - delta_ / 2);
}
}
target.copyFrom(cost.data(), numSamples);
}
void HuberRegressionLoss::backwardImp(Matrix& output,
Argument& label,
Matrix& outputG) {
size_t numSamples = output.getHeight();
size_t dim = output.getWidth();
real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData();
real* lbl =
useGpu_ ? tmpCpuInput_[1].value->getData() : (*label.value).getData();
real* grad = useGpu_ ? tmpCpuInput_[0].grad->getData() : outputG.getData();
for (size_t i = 0; i < numSamples; ++i) {
for (size_t j = 0; j < dim; ++j) {
int index = i * dim + j;
real a = lbl[index] - out[index];
if (std::abs(a) <= delta_)
grad[index] += -a;
else
grad[index] += a > 0 ? -delta_ : delta_;
}
}
if (useGpu_) outputG.copyFrom(grad, numSamples * dim);
}
//
// Huber loss for robust 2-classes classification
//
REGISTER_LAYER(huber_classification, HuberTwoClassification);
bool HuberTwoClassification::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
return HuberCost::init(layerMap, parameterMap);
}
void HuberTwoClassification::forwardImp(Matrix& output,
Argument& label,
Matrix& target) {
HuberCost::forwardImp(output, label, target);
size_t numSamples = target.getHeight();
CHECK(label.ids);
CHECK_EQ((*label.ids).getSize(), numSamples);
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(output.getWidth(), (size_t)1);
......@@ -611,47 +680,35 @@ void HuberTwoClass::forwardImpIn(Matrix& output,
real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData();
int* lbl = useGpu_ ? tmpCpuInput_[1].ids->getData() : (*label.ids).getData();
std::vector<real> cost(numSamples);
std::vector<real> cost(numSamples, 0);
for (size_t i = 0; i < numSamples; ++i) {
int y = 2 * lbl[i] - 1;
if (out[i] * y < -1)
cost[i] = -4 * out[i] * y;
else if (out[i] * y < 1)
cost[i] = (1 - out[i] * y) * (1 - out[i] * y);
else
cost[i] = 0;
real a = out[i] * y;
if (a < -1)
cost[i] = -4 * a;
else if (a < 1)
cost[i] = (1 - a) * (1 - a);
}
target.copyFrom(cost.data(), numSamples);
}
void HuberTwoClass::backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) {
if (useGpu_) {
backwardImpIn(
*tmpCpuInput_[0].value, tmpCpuInput_[1], *tmpCpuInput_[0].grad);
outputGrad.copyFrom(*tmpCpuInput_[0].grad);
} else {
backwardImpIn(outputValue, label, outputGrad);
}
}
void HuberTwoClass::backwardImpIn(Matrix& output,
Argument& label,
Matrix& outputG) {
void HuberTwoClassification::backwardImp(Matrix& output,
Argument& label,
Matrix& outputG) {
size_t numSamples = output.getHeight();
real* out = output.getData();
real* grad = outputG.getData();
int* lbl = (*label.ids).getData();
real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData();
int* lbl = useGpu_ ? tmpCpuInput_[1].ids->getData() : (*label.ids).getData();
real* grad = useGpu_ ? tmpCpuInput_[0].grad->getData() : outputG.getData();
for (size_t i = 0; i < numSamples; ++i) {
int y = 2 * lbl[i] - 1;
if (y * out[i] < -1)
real a = out[i] * y;
if (a < -1)
grad[i] += -4 * y;
else if (y * out[i] < 1)
grad[i] += -2 * (1 - y * out[i]) * y;
else if (a < 1)
grad[i] += -2 * (1 - a) * y;
}
if (useGpu_) outputG.copyFrom(grad, numSamples);
}
/**
* This cost layer compute the sum of its input as loss.
* \f[
......
......@@ -304,37 +304,68 @@ public:
Matrix& outputGrad) override;
};
/**
* Huber loss for robust 2-classes classification.
*
* For label={0, 1}, let y=2*label-1. Given output f, the loss is:
* \f[
* Loss =
* \left\{\begin{matrix}
* 4 * y * f & \textit{if} \ \ y* f < -1 \\
* (1 - y * f)^2 & \textit{if} \ \ -1 < y * f < 1 \\
* 0 & \textit{otherwise}
* \end{matrix}\right.
* \f]
/*
* A base layer for HuberRegressionLoss and HuberTwoClassification.
*/
class HuberTwoClass : public CostLayer {
class HuberCost : public CostLayer {
public:
std::vector<Argument> tmpCpuInput_;
public:
explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {}
explicit HuberCost(const LayerConfig& config) : CostLayer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void forwardImpIn(Matrix& output, Argument& label, Matrix& cost);
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad) {}
};
/**
* Huber loss for robust regression.
*
* Given output f(x), label y and delta, the loss is:
* Loss = 0.5 * (1 - y * f)^2, if abs(y - f) <= delta \\
* Loss = delta * abs(y - f) - 0.5 * delta^2, otherwise
*/
class HuberRegressionLoss : public HuberCost {
public:
explicit HuberRegressionLoss(const LayerConfig& config) : HuberCost(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad);
protected:
real delta_;
};
/**
* Huber loss for robust 2-classes classification.
*
* For label={0, 1}, let y=2*label-1. Given output f(x), the loss is:
* Loss = 4 * y * f, if y* f < -1 \\
* Loss = (1 - y * f)^2, if -1 < y * f < 1 \\
* Loss = 0, otherwise
*/
class HuberTwoClassification : public HuberCost {
public:
explicit HuberTwoClassification(const LayerConfig& config)
: HuberCost(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
};
typedef std::shared_ptr<CostLayer> CostLayerPtr;
......
......@@ -850,9 +850,27 @@ TEST(Layer, square_error_weighted) {
}
}
TEST(Layer, huber_regression_loss) {
TestConfig config;
config.layerConfig.set_type("huber_regression");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 10, 0});
config.inputDefs.push_back({INPUT_DATA_TARGET, "layer_1", 10, 0});
config.layerConfig.add_inputs();
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
for (auto delta : {1, 3, 5}) {
config.layerConfig.set_delta(delta);
testLayerGrad(config, "huber_regression", 100, /* trans */ false, useGpu);
}
}
}
TEST(Layer, huber_two_class) {
TestConfig config;
config.layerConfig.set_type("huber");
config.layerConfig.set_type("huber_classification");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1, 0});
......@@ -861,7 +879,7 @@ TEST(Layer, huber_two_class) {
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "huber", 100, /* trans */ false, useGpu);
testLayerGrad(config, "huber_two_class", 100, /* trans */ false, useGpu);
}
}
......
......@@ -499,6 +499,9 @@ message LayerConfig {
optional int32 axis = 54 [ default = 2 ];
repeated uint32 offset = 55;
repeated uint32 shape = 56;
// for HuberRegressionLoss
optional double delta = 57 [ default = 1.0 ];
}
message EvaluatorConfig {
......
......@@ -2274,7 +2274,7 @@ define_cost('PnpairValidation', 'pnpair-validation')
define_cost('SumOfSquaresCostLayer', 'square_error')
define_cost('MultiBinaryLabelCrossEntropy', 'multi_binary_label_cross_entropy')
define_cost('SoftBinaryClassCrossEntropy', 'soft_binary_class_cross_entropy')
define_cost('HuberTwoClass', 'huber')
define_cost('HuberTwoClassification', 'huber_classification')
define_cost('SumCost', 'sum_cost')
define_cost('SmoothL1Cost', 'smooth_l1')
......@@ -2336,6 +2336,17 @@ class LambdaCost(LayerBase):
self.config.max_sort_size = max_sort_size
@config_layer('huber_regression')
class HuberRegressionLoss(LayerBase):
def __init__(self, name, inputs, delta=1., coeff=1., device=None):
super(HuberRegressionLoss, self).__init__(
name, 'huber_regression', 1, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 2, 'HuberRegression must have 2 inputs')
self.config.delta = delta
self.config.coeff = coeff
@config_layer('nce')
class NCELayer(LayerBase):
def __init__(self,
......
......@@ -110,7 +110,8 @@ __all__ = [
'sum_cost',
'rank_cost',
'lambda_cost',
'huber_cost',
'huber_regression_cost',
'huber_classification_cost',
'block_expand_layer',
'maxout_layer',
'out_prod_layer',
......@@ -220,7 +221,8 @@ class LayerType(object):
RANK_COST = 'rank-cost'
LAMBDA_COST = 'lambda_cost'
HUBER = 'huber'
HUBER_REGRESSION = 'huber_regression'
HUBER_CLASSIFICATION = 'huber_classification'
CROSS_ENTROPY = 'multi-class-cross-entropy'
CROSS_ENTROPY_WITH_SELFNORM = 'multi_class_cross_entropy_with_selfnorm'
SOFT_BIN_CLASS_CROSS_ENTROPY = 'soft_binary_class_cross_entropy'
......@@ -5644,16 +5646,77 @@ def sum_cost(input, name=None, layer_attr=None):
@wrap_name_default()
@layer_support()
def huber_cost(input, label, name=None, coeff=1.0, layer_attr=None):
def huber_regression_cost(input,
label,
name=None,
delta=1.0,
coeff=1.0,
layer_attr=None):
"""
In statistics, the Huber loss is a loss function used in robust regression,
that is less sensitive to outliers in data than the squared error loss.
Given a prediction f(x), a label y and :math:`\delta`, the loss function
is defined as:
.. math:
loss = 0.5*\left ( y-f(x) \right )^2, \left | y-f(x) \right |\leq \delta
loss = \delta \left | y-f(x) \right |-0.5\delta ^2, otherwise
The example usage is:
.. code-block:: python
cost = huber_regression_cost(input=input_layer, label=label_layer)
:param input: The first input layer.
:type input: LayerOutput.
:param label: The input label.
:type input: LayerOutput.
:param name: The name of this layers. It is not necessary.
:type name: None|basestring.
:param delta: The difference between the observed and predicted values.
:type delta: float.
:param coeff: The coefficient affects the gradient in the backward.
:type coeff: float.
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput.
"""
assert isinstance(input, LayerOutput)
Layer(
name=name,
type=LayerType.HUBER_REGRESSION,
inputs=[input.name, label.name],
delta=delta,
coeff=coeff,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.HUBER_REGRESSION, parents=[input, label], size=1)
@wrap_name_default()
@layer_support()
def huber_classification_cost(input,
label,
name=None,
coeff=1.0,
layer_attr=None):
"""
A loss layer for huber loss.
For classification purposes, a variant of the Huber loss called modified Huber
is sometimes used. Given a prediction f(x) (a real-valued classifier score) and
a true binary class label :math:`y\in \left \{-1, 1 \right \}`, the modified Huber
loss is defined as:
.. math:
loss = \max \left ( 0, 1-yf(x) \right )^2, yf(x)\geq 1
loss = -4yf(x), \text{otherwise}
The example usage is:
.. code-block:: python
cost = huber_cost(input=input_layer,
label=label_layer)
cost = huber_classification_cost(input=input_layer, label=label_layer)
:param input: The first input layer.
:type input: LayerOutput.
......@@ -5673,11 +5736,12 @@ def huber_cost(input, label, name=None, coeff=1.0, layer_attr=None):
assert input.size == 1
Layer(
name=name,
type=LayerType.HUBER,
type=LayerType.HUBER_CLASSIFICATION,
inputs=[input.name, label.name],
coeff=coeff,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.HUBER, parents=[input, label], size=1)
return LayerOutput(
name, LayerType.HUBER_CLASSIFICATION, parents=[input, label], size=1)
@wrap_name_default()
......
......@@ -167,6 +167,20 @@ layers {
softmax_selfnorm_alpha: 0.1
coeff: 1.0
}
layers {
name: "__huber_regression_cost_0__"
type: "huber_regression"
size: 1
active_type: ""
inputs {
input_layer_name: "input"
}
inputs {
input_layer_name: "labels"
}
coeff: 1.0
delta: 1.0
}
layers {
name: "huber_probs"
type: "data"
......@@ -180,8 +194,8 @@ layers {
active_type: ""
}
layers {
name: "__huber_cost_0__"
type: "huber"
name: "__huber_classification_cost_0__"
type: "huber_classification"
size: 1
active_type: ""
inputs {
......@@ -300,7 +314,8 @@ output_layer_names: "__rank_cost_0__"
output_layer_names: "__lambda_cost_0__"
output_layer_names: "__cross_entropy_0__"
output_layer_names: "__cross_entropy_with_selfnorm_0__"
output_layer_names: "__huber_cost_0__"
output_layer_names: "__huber_regression_cost_0__"
output_layer_names: "__huber_classification_cost_0__"
output_layer_names: "__multi_binary_label_cross_entropy_0__"
output_layer_names: "__sum_cost_0__"
output_layer_names: "__nce_layer_0__"
......@@ -324,9 +339,10 @@ sub_models {
layer_names: "__lambda_cost_0__"
layer_names: "__cross_entropy_0__"
layer_names: "__cross_entropy_with_selfnorm_0__"
layer_names: "__huber_regression_cost_0__"
layer_names: "huber_probs"
layer_names: "huber_label"
layer_names: "__huber_cost_0__"
layer_names: "__huber_classification_cost_0__"
layer_names: "__multi_binary_label_cross_entropy_0__"
layer_names: "__sum_cost_0__"
layer_names: "__nce_layer_0__"
......@@ -349,7 +365,8 @@ sub_models {
output_layer_names: "__lambda_cost_0__"
output_layer_names: "__cross_entropy_0__"
output_layer_names: "__cross_entropy_with_selfnorm_0__"
output_layer_names: "__huber_cost_0__"
output_layer_names: "__huber_regression_cost_0__"
output_layer_names: "__huber_classification_cost_0__"
output_layer_names: "__multi_binary_label_cross_entropy_0__"
output_layer_names: "__sum_cost_0__"
output_layer_names: "__nce_layer_0__"
......
......@@ -33,7 +33,9 @@ outputs(
input=probs, label=xe_label),
cross_entropy_with_selfnorm(
input=probs, label=xe_label),
huber_cost(
huber_regression_cost(
input=seq_in, label=labels),
huber_classification_cost(
input=data_layer(
name='huber_probs', size=1),
label=data_layer(
......
......@@ -141,12 +141,13 @@ class CostLayerTest(unittest.TestCase):
cost8 = layer.rank_cost(left=score, right=score, label=score)
cost9 = layer.lambda_cost(input=inference, score=score)
cost10 = layer.sum_cost(input=inference)
cost11 = layer.huber_cost(input=score, label=label)
cost11 = layer.huber_regression_cost(input=score, label=label)
cost12 = layer.huber_classification_cost(input=score, label=label)
print layer.parse_network([cost1, cost2])
print layer.parse_network([cost3, cost4])
print layer.parse_network([cost5, cost6])
print layer.parse_network([cost7, cost8, cost9, cost10, cost11])
print layer.parse_network([cost7, cost8, cost9, cost10, cost11, cost12])
crf = layer.crf(input=inference, label=label)
crf_decoding = layer.crf_decoding(input=inference, size=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册