提交 1171014d 编写于 作者: C caoying03

add param_attr to nce_layer and enable multiple inputs.

上级 1ba82069
...@@ -4921,12 +4921,14 @@ def crf_decoding_layer(input, ...@@ -4921,12 +4921,14 @@ def crf_decoding_layer(input,
@wrap_act_default(act=SigmoidActivation()) @wrap_act_default(act=SigmoidActivation())
@wrap_bias_attr_default(has_bias=True) @wrap_bias_attr_default(has_bias=True)
@wrap_param_attr_default()
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def nce_layer(input, def nce_layer(input,
label, label,
num_classes, num_classes,
act=None, act=None,
param_attr=None,
weight=None, weight=None,
num_neg_samples=10, num_neg_samples=10,
neg_distribution=None, neg_distribution=None,
...@@ -4957,6 +4959,8 @@ def nce_layer(input, ...@@ -4957,6 +4959,8 @@ def nce_layer(input,
:type num_classes: int :type num_classes: int
:param act: Activation, default is Sigmoid. :param act: Activation, default is Sigmoid.
:type act: BaseActivation :type act: BaseActivation
:param param_attr: The Parameter Attribute|list.
:type param_attr: ParameterAttribute
:param num_neg_samples: number of negative samples. Default is 10. :param num_neg_samples: number of negative samples. Default is 10.
:type num_neg_samples: int :type num_neg_samples: int
:param neg_distribution: The distribution for generating the random negative labels. :param neg_distribution: The distribution for generating the random negative labels.
...@@ -4972,7 +4976,16 @@ def nce_layer(input, ...@@ -4972,7 +4976,16 @@ def nce_layer(input,
""" """
if isinstance(input, LayerOutput): if isinstance(input, LayerOutput):
input = [input] input = [input]
assert not isinstance(param_attr, collections.Sequence)
param_attr = [param_attr]
else:
if isinstance(param_attr, collections.Sequence):
assert len(input) == len(param_attr)
else:
param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))]
assert isinstance(input, collections.Sequence) assert isinstance(input, collections.Sequence)
assert isinstance(label, LayerOutput) assert isinstance(label, LayerOutput)
assert label.layer_type == LayerType.DATA assert label.layer_type == LayerType.DATA
if neg_distribution is not None: if neg_distribution is not None:
...@@ -4984,9 +4997,9 @@ def nce_layer(input, ...@@ -4984,9 +4997,9 @@ def nce_layer(input,
ipts_for_layer = [] ipts_for_layer = []
parents = [] parents = []
for each_input in input: for each_input, attr in zip(input, param_attr):
assert isinstance(each_input, LayerOutput) assert isinstance(each_input, LayerOutput)
ipts_for_layer.append(each_input.name) ipts_for_layer.append(Input(each_input.name, **attr.attr))
parents.append(each_input) parents.append(each_input)
ipts_for_layer.append(label.name) ipts_for_layer.append(label.name)
parents.append(label) parents.append(label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册