From 199a6a4b5c62583d05d4c3199a13891dcba576c5 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Mon, 10 Oct 2016 11:21:44 +0800 Subject: [PATCH] add weight for cost layer interface (#177) --- .../paddle/trainer_config_helpers/layers.py | 44 +++++++++++++++---- .../tests/configs/check.md5 | 1 + .../tests/configs/generate_protostr.sh | 2 +- .../configs/test_cost_layers_with_weight.py | 14 ++++++ 4 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5e7e66a908..7699c90db7 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2777,29 +2777,49 @@ def beam_search(step, input, bos_id, eos_id, beam_size, return tmp +def __cost_input__(input, label, weight=None): + """ + inputs and parents for cost layers. + """ + ipts = [Input(input.name), Input(label.name)] + parents = [input, label] + if weight is not None: + assert weight.layer_type == LayerType.DATA + ipts.append(Input(weight.name)) + parents.append(weight) + return ipts, parents + @wrap_name_default() -def regression_cost(input, label, cost='square_error', name=None): +def regression_cost(input, label, weight=None, cost='square_error', name=None): """ Regression Layer. TODO(yuyang18): Complete this method. :param name: layer name. + :type name: basestring :param input: Network prediction. + :type input: LayerOutput :param label: Data label. + :type label: LayerOutput + :param weight: The weight affects the cost, namely the scale of cost. + It is an optional argument. + :type weight: LayerOutput :param cost: Cost method. + :type cost: basestring :return: LayerOutput object. + :rtype: LayerOutput """ - Layer(inputs=[Input(input.name), Input(label.name)], type=cost, name=name) - return LayerOutput( - name, LayerType.COST, parents=[input, label] - ) + ipts, parents = __cost_input__(input, label, weight) + + Layer(inputs=ipts, type=cost, name=name) + return LayerOutput(name, LayerType.COST, parents=parents) @wrap_name_default("cost") @layer_support() -def classification_cost(input, label, name=None, +def classification_cost(input, label, weight=None, name=None, cost="multi-class-cross-entropy", evaluator=classification_error_evaluator, layer_attr=None): @@ -2812,6 +2832,9 @@ def classification_cost(input, label, name=None, :type input: LayerOutput :param label: label layer name. data_layer often. :type label: LayerOutput + :param weight: The weight affects the cost, namely the scale of cost. + It is an optional argument. + :type weight: LayerOutput :param cost: cost method. :type cost: basestring :param evaluator: Evaluator method. @@ -2823,7 +2846,10 @@ def classification_cost(input, label, name=None, assert input.layer_type != LayerType.DATA assert isinstance(input.activation, SoftmaxActivation) assert label.layer_type == LayerType.DATA - Layer(name=name, type=cost, inputs=[Input(input.name), Input(label.name)], + + ipts, parents = __cost_input__(input, label, weight) + + Layer(name=name, type=cost, inputs=ipts, **ExtraLayerAttribute.to_kwargs(layer_attr)) def __add_evaluator__(e): @@ -2835,7 +2861,7 @@ def classification_cost(input, label, name=None, assert isinstance(e.for_classification, bool) assert e.for_classification - e(name=e.__name__, input=input, label=label) + e(name=e.__name__, input=input, label=label, weight=weight) if not isinstance(evaluator, collections.Sequence): evaluator = [evaluator] @@ -2843,7 +2869,7 @@ def classification_cost(input, label, name=None, for each_evaluator in evaluator: __add_evaluator__(each_evaluator) - return LayerOutput(name, LayerType.COST, parents=[input, label]) + return LayerOutput(name, LayerType.COST, parents=parents) def conv_operator(img, filter, filter_size, num_filters, diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5 index 359652f3d0..3ecfff2071 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/check.md5 +++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5 @@ -4,6 +4,7 @@ a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr 5913f87b39cee3b2701fa158270aca26 projections.protostr 6b39e34beea8dfb782bee9bd3dea9eb5 simple_rnn_layers.protostr 0fc1409600f1a3301da994ab9d28b0bf test_cost_layers.protostr +6cd5f28a3416344f20120698470e0a4c test_cost_layers_with_weight.protostr 144bc6d3a509de74115fa623741797ed test_expand_layer.protostr 2378518bdb71e8c6e888b1842923df58 test_fc.protostr 8bb44e1e5072d0c261572307e7672bda test_grumemory_layer.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index fc2acbd41e..5514ee65e5 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -8,7 +8,7 @@ configs=(test_fc layer_activations projections test_print_layer test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers util_layers simple_rnn_layers unused_layers test_cost_layers -test_rnn_group) +test_cost_layers_with_weight test_rnn_group) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py new file mode 100644 index 0000000000..29749cbb66 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py @@ -0,0 +1,14 @@ +from paddle.trainer_config_helpers import * + +settings( + learning_rate=1e-4, + batch_size=1000 +) + +data = data_layer(name='input', size=300) +lbl = data_layer(name='label', size=1) +wt = data_layer(name='weight', size=1) +fc = fc_layer(input=data, size=10, act=SoftmaxActivation()) + +outputs(classification_cost(input=fc, label=lbl, weight=wt), + regression_cost(input=fc, label=lbl, weight=wt)) -- GitLab