提交 6d21ecef 编写于 作者: Z Zrachel 提交者: luotao1

add cost_type constraint to weighted_cost interface (#206)

上级 c13bdb15
......@@ -1715,7 +1715,6 @@ def define_cost(class_name, cost_type):
g_cost_map[cost_type] = cls
define_cost('MultiClassCrossEntropy', 'multi-class-cross-entropy')
define_cost('ClassificationErrorLayer', 'classification_error')
define_cost('RankingCost', 'rank-cost')
define_cost('AucValidation', 'auc-validation')
define_cost('PnpairValidation', 'pnpair-validation')
......
......@@ -2799,7 +2799,7 @@ def __cost_input__(input, label, weight=None):
@wrap_name_default()
def regression_cost(input, label, weight=None, cost='square_error', name=None):
def regression_cost(input, label, weight=None, name=None):
"""
Regression Layer.
......@@ -2814,21 +2814,18 @@ def regression_cost(input, label, weight=None, cost='square_error', name=None):
:param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument.
:type weight: LayerOutput
:param cost: Cost method.
:type cost: basestring
:return: LayerOutput object.
:rtype: LayerOutput
"""
ipts, parents = __cost_input__(input, label, weight)
Layer(inputs=ipts, type=cost, name=name)
Layer(inputs=ipts, type="square_error", name=name)
return LayerOutput(name, LayerType.COST, parents=parents)
@wrap_name_default("cost")
@layer_support()
def classification_cost(input, label, weight=None, name=None,
cost="multi-class-cross-entropy",
evaluator=classification_error_evaluator,
layer_attr=None):
"""
......@@ -2843,8 +2840,6 @@ def classification_cost(input, label, weight=None, name=None,
:param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument.
:type weight: LayerOutput
:param cost: cost method.
:type cost: basestring
:param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute
......@@ -2857,7 +2852,7 @@ def classification_cost(input, label, weight=None, name=None,
ipts, parents = __cost_input__(input, label, weight)
Layer(name=name, type=cost, inputs=ipts,
Layer(name=name, type="multi-class-cross-entropy", inputs=ipts,
**ExtraLayerAttribute.to_kwargs(layer_attr))
def __add_evaluator__(e):
......@@ -3819,8 +3814,8 @@ def multi_binary_label_cross_entropy(input, label, name=None, coeff=1.0):
if input.activation is None or \
not isinstance(input.activation, SigmoidActivation):
logger.log(logging.WARN,
"%s is not recommend for batch normalization's activation, "
"maybe the relu is better" % repr(input.activation))
"%s is not recommend for multi_binary_label_cross_entropy's activation, "
"maybe the sigmoid is better" % repr(input.activation))
Layer(name=name,
type=LayerType.MULTI_BIN_LABEL_CROSS_ENTROPY,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册