提交 4b1b599a 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1762 from reyoung/feature/refine_nce_layer

Refine NCE Layer
...@@ -18,7 +18,7 @@ import inspect ...@@ -18,7 +18,7 @@ import inspect
from paddle.trainer.config_parser import * from paddle.trainer.config_parser import *
from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ from .activations import LinearActivation, SigmoidActivation, TanhActivation, \
ReluActivation, IdentityActivation, SoftmaxActivation ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation
from .evaluators import * from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType from .poolings import MaxPooling, AvgPooling, BasePoolingType
from .attrs import * from .attrs import *
...@@ -2253,8 +2253,9 @@ def img_pool_layer(input, ...@@ -2253,8 +2253,9 @@ def img_pool_layer(input,
pool_type.name = 'avg' pool_type.name = 'avg'
type_name = pool_type.name + '-projection' \ type_name = pool_type.name + '-projection' \
if (isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \ if (
else pool_type.name isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \
else pool_type.name
pool_size_y = pool_size if pool_size_y is None else pool_size_y pool_size_y = pool_size if pool_size_y is None else pool_size_y
stride_y = stride if stride_y is None else stride_y stride_y = stride if stride_y is None else stride_y
...@@ -3294,8 +3295,8 @@ def recurrent_group(step, ...@@ -3294,8 +3295,8 @@ def recurrent_group(step,
assert (targetInlink == None or targetInlink_in_inlinks()) assert (targetInlink == None or targetInlink_in_inlinks())
targetInlinkName = None if targetInlink == None \ targetInlinkName = None if targetInlink == None \
else targetInlink.name if isinstance(targetInlink, LayerOutput) \ else targetInlink.name if isinstance(targetInlink, LayerOutput) \
else targetInlink.input.name else targetInlink.input.name
contains_sub_seq = [False] contains_sub_seq = [False]
...@@ -4807,12 +4808,14 @@ def crf_decoding_layer(input, ...@@ -4807,12 +4808,14 @@ def crf_decoding_layer(input,
return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=1) return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=1)
@wrap_act_default(act=SigmoidActivation())
@wrap_bias_attr_default(has_bias=True) @wrap_bias_attr_default(has_bias=True)
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def nce_layer(input, def nce_layer(input,
label, label,
num_classes, num_classes,
act=None,
weight=None, weight=None,
num_neg_samples=10, num_neg_samples=10,
neg_distribution=None, neg_distribution=None,
...@@ -4841,6 +4844,8 @@ def nce_layer(input, ...@@ -4841,6 +4844,8 @@ def nce_layer(input,
:type weight: LayerOutput :type weight: LayerOutput
:param num_classes: number of classes. :param num_classes: number of classes.
:type num_classes: int :type num_classes: int
:param act: Activation, default is Sigmoid.
:type act: BaseActivation
:param num_neg_samples: number of negative samples. Default is 10. :param num_neg_samples: number of negative samples. Default is 10.
:type num_neg_samples: int :type num_neg_samples: int
:param neg_distribution: The distribution for generating the random negative labels. :param neg_distribution: The distribution for generating the random negative labels.
...@@ -4863,6 +4868,8 @@ def nce_layer(input, ...@@ -4863,6 +4868,8 @@ def nce_layer(input,
assert isinstance(neg_distribution, collections.Sequence) assert isinstance(neg_distribution, collections.Sequence)
assert len(neg_distribution) == num_classes assert len(neg_distribution) == num_classes
assert sum(neg_distribution) == 1 assert sum(neg_distribution) == 1
if not isinstance(act, BaseActivation):
raise TypeError()
ipts_for_layer = [] ipts_for_layer = []
parents = [] parents = []
...@@ -4884,12 +4891,17 @@ def nce_layer(input, ...@@ -4884,12 +4891,17 @@ def nce_layer(input,
type=LayerType.NCE_LAYER, type=LayerType.NCE_LAYER,
num_classes=num_classes, num_classes=num_classes,
neg_sampling_dist=neg_distribution, neg_sampling_dist=neg_distribution,
active_type=act.name,
num_neg_samples=num_neg_samples, num_neg_samples=num_neg_samples,
inputs=ipts_for_layer, inputs=ipts_for_layer,
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput( return LayerOutput(
name, LayerType.NCE_LAYER, parents=parents, size=l.config.size) name,
LayerType.NCE_LAYER,
parents=parents,
size=l.config.size,
activation=act)
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册