From 1171014d3ffc31aae5c4572c495c110975220444 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 15 May 2017 11:32:06 +0800 Subject: [PATCH] add param_attr to nce_layer and enable multiple inputs. --- python/paddle/trainer_config_helpers/layers.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 31652613fb3..52c7a57a1f8 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -4921,12 +4921,14 @@ def crf_decoding_layer(input, @wrap_act_default(act=SigmoidActivation()) @wrap_bias_attr_default(has_bias=True) +@wrap_param_attr_default() @wrap_name_default() @layer_support() def nce_layer(input, label, num_classes, act=None, + param_attr=None, weight=None, num_neg_samples=10, neg_distribution=None, @@ -4957,6 +4959,8 @@ def nce_layer(input, :type num_classes: int :param act: Activation, default is Sigmoid. :type act: BaseActivation + :param param_attr: The Parameter Attribute|list. + :type param_attr: ParameterAttribute :param num_neg_samples: number of negative samples. Default is 10. :type num_neg_samples: int :param neg_distribution: The distribution for generating the random negative labels. @@ -4972,7 +4976,16 @@ def nce_layer(input, """ if isinstance(input, LayerOutput): input = [input] + assert not isinstance(param_attr, collections.Sequence) + param_attr = [param_attr] + else: + if isinstance(param_attr, collections.Sequence): + assert len(input) == len(param_attr) + else: + param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] + assert isinstance(input, collections.Sequence) + assert isinstance(label, LayerOutput) assert label.layer_type == LayerType.DATA if neg_distribution is not None: @@ -4984,9 +4997,9 @@ def nce_layer(input, ipts_for_layer = [] parents = [] - for each_input in input: + for each_input, attr in zip(input, param_attr): assert isinstance(each_input, LayerOutput) - ipts_for_layer.append(each_input.name) + ipts_for_layer.append(Input(each_input.name, **attr.attr)) parents.append(each_input) ipts_for_layer.append(label.name) parents.append(label) -- GitLab