From 43f7d7b7684b8c4cee4e396f57c4c841f41b2dbe Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 13 Oct 2016 15:11:52 +0800 Subject: [PATCH] add interface and unittest for nce layer (#180) * add interface and unittest for nce layer * follow comments --- doc/ui/api/trainer_config_helpers/layers.rst | 6 + paddle/gserver/layers/NCELayer.cpp | 13 +- paddle/trainer/tests/test_config.conf | 222 ++++++------------ .../paddle/trainer_config_helpers/layers.py | 89 ++++++- 4 files changed, 170 insertions(+), 160 deletions(-) diff --git a/doc/ui/api/trainer_config_helpers/layers.rst b/doc/ui/api/trainer_config_helpers/layers.rst index c1d7a7ce81..5271262d20 100644 --- a/doc/ui/api/trainer_config_helpers/layers.rst +++ b/doc/ui/api/trainer_config_helpers/layers.rst @@ -371,6 +371,12 @@ ctc_layer :members: ctc_layer :noindex: +nce_layer +----------- +.. automodule:: paddle.trainer_config_helpers.layers + :members: nce_layer + :noindex: + hsigmoid --------- .. automodule:: paddle.trainer_config_helpers.layers diff --git a/paddle/gserver/layers/NCELayer.cpp b/paddle/gserver/layers/NCELayer.cpp index a896e16a60..4faebe5d2a 100644 --- a/paddle/gserver/layers/NCELayer.cpp +++ b/paddle/gserver/layers/NCELayer.cpp @@ -21,14 +21,18 @@ limitations under the License. */ namespace paddle { /** - * Noise-contrastive estimation + * Noise-contrastive estimation. * Implements the method in the following paper: - * A fast and simple algorithm for training neural probabilistic language models + * A fast and simple algorithm for training neural probabilistic language models. + * + * The config file api is nce_layer. */ class NCELayer : public Layer { int numClasses_; - int numInputs_; // number of input layer besides labelLayer and weightLayer + /// number of input layer besides labelLayer and weightLayer + int numInputs_; LayerPtr labelLayer_; + /// weight layer, can be None LayerPtr weightLayer_; WeightList weights_; std::unique_ptr biases_; @@ -43,7 +47,8 @@ class NCELayer : public Layer { real weight; }; std::vector samples_; - bool prepared_; // whether samples_ is prepared + /// whether samples_ is prepared + bool prepared_; Argument sampleOut_; IVectorPtr labelIds_; diff --git a/paddle/trainer/tests/test_config.conf b/paddle/trainer/tests/test_config.conf index 5d2e2ba9df..664e18cb98 100644 --- a/paddle/trainer/tests/test_config.conf +++ b/paddle/trainer/tests/test_config.conf @@ -13,157 +13,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later. - -default_initial_std(0.5) - -model_type("nn") - -DataLayer( - name = "input", - size = 3, -) - -DataLayer( - name = "weight", - size = 1, -) - -Layer( - name = "layer1_1", - type = "fc", - size = 5, - active_type = "sigmoid", - inputs = "input", -) - -Layer( - name = "layer1_2", - type = "fc", - size = 12, - active_type = "linear", - inputs = Input("input", parameter_name='sharew'), -) - -Layer( - name = "layer1_3", - type = "fc", - size = 3, - active_type = "tanh", - inputs = "input", -) - -Layer( - name = "layer1_5", - type = "fc", - size = 3, - active_type = "tanh", - inputs = Input("input", - learning_rate=0.01, - momentum=0.9, - decay_rate=0.05, - initial_mean=0.0, - initial_std=0.01, - format = "csc", - nnz = 4) -) - -FCLayer( - name = "layer1_4", - size = 5, - active_type = "square", - inputs = "input", - drop_rate = 0.5, -) - -Layer( - name = "pool", - type = "pool", - inputs = Input("layer1_2", - pool = Pool(pool_type="cudnn-avg-pool", - channels = 1, - size_x = 2, - size_y = 3, - img_width = 3, - padding = 1, - padding_y = 2, - stride = 2, - stride_y = 3)) -) - -Layer( - name = "concat", - type = "concat", - inputs = ["layer1_3", "layer1_4"], -) - -MixedLayer( - name = "output", - size = 3, - active_type = "softmax", - inputs = [ - FullMatrixProjection("layer1_1", - learning_rate=0.1), - TransposedFullMatrixProjection("layer1_2", parameter_name='sharew'), - FullMatrixProjection("concat"), - IdentityProjection("layer1_3"), - ], -) - -Layer( - name = "label", - type = "data", - size = 1, -) - -Layer( - name = "cost", - type = "multi-class-cross-entropy", - inputs = ["output", "label", "weight"], -) - -Layer( - name = "cost2", - type = "nce", - num_classes = 3, - active_type = "sigmoid", - neg_sampling_dist = [0.1, 0.3, 0.6], - inputs = ["layer1_2", "label", "weight"], -) - -Evaluator( - name = "error", - type = "classification_error", - inputs = ["output", "label", "weight"] -) - -Inputs("input", "label", "weight") -Outputs("cost", "cost2") - -TrainData( - ProtoData( - files = "dummy_list", - constant_slots = [1.0], - async_load_data = True, - ) -) - -TestData( - SimpleData( - files = "trainer/tests/sample_filelist.txt", - feat_dim = 3, - context_len = 0, - buffer_capacity = 1000000, - async_load_data = False, - ), -) - -Settings( - algorithm = "sgd", - num_batches_per_send_parameter = 1, - num_batches_per_get_parameter = 1, - batch_size = 100, - learning_rate = 0.001, - learning_rate_decay_a = 1e-5, - learning_rate_decay_b = 0.5, -) +from paddle.trainer_config_helpers import * + +TrainData(ProtoData( + files = "dummy_list", + constant_slots = [1.0], + async_load_data = True)) + +TestData(SimpleData( + files = "trainer/tests/sample_filelist.txt", + feat_dim = 3, + context_len = 0, + buffer_capacity = 1000000, + async_load_data = False)) + +settings(batch_size = 100) + +data = data_layer(name='input', size=3) + +wt = data_layer(name='weight', size=1) + +fc1 = fc_layer(input=data, size=5, + bias_attr=True, + act=SigmoidActivation()) + +fc2 = fc_layer(input=data, size=12, + bias_attr=True, + param_attr=ParamAttr(name='sharew'), + act=LinearActivation()) + +fc3 = fc_layer(input=data, size=3, + bias_attr=True, + act=TanhActivation()) + +fc4 = fc_layer(input=data, size=5, + bias_attr=True, + layer_attr=ExtraAttr(drop_rate=0.5), + act=SquareActivation()) + +pool = img_pool_layer(input=fc2, + pool_size=2, + pool_size_y=3, + num_channels=1, + padding=1, + padding_y=2, + stride=2, + stride_y=3, + img_width=3, + pool_type=CudnnAvgPooling()) + +concat = concat_layer(input=[fc3, fc4]) + +with mixed_layer(size=3, act=SoftmaxActivation()) as output: + output += full_matrix_projection(input=fc1) + output += trans_full_matrix_projection(input=fc2, + param_attr=ParamAttr(name='sharew')) + output += full_matrix_projection(input=concat) + output += identity_projection(input=fc3) + +lbl = data_layer(name='label', size=1) + +cost = classification_cost(input=output, label=lbl, weight=wt, + layer_attr=ExtraAttr(device=-1)) + +nce = nce_layer(input=fc2, label=lbl, weight=wt, + num_classes=3, + neg_distribution=[0.1, 0.3, 0.6]) + +outputs(cost, nce) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 7699c90db7..745e61b2eb 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -50,6 +50,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", 'slope_intercept_layer', 'trans_full_matrix_projection', 'linear_comb_layer', 'convex_comb_layer', 'ctc_layer', 'crf_layer', 'crf_decoding_layer', + 'nce_layer', 'cross_entropy_with_selfnorm', 'cross_entropy', 'multi_binary_label_cross_entropy', 'rank_cost', 'lambda_cost', 'huber_cost', @@ -115,6 +116,7 @@ class LayerType(object): CTC_LAYER = "ctc" CRF_LAYER = "crf" CRF_DECODING_LAYER = "crf_decoding" + NCE_LAYER = 'nce' RANK_COST = "rank-cost" LAMBDA_COST = "lambda_cost" @@ -168,7 +170,7 @@ class LayerOutput(object): :param activation: Layer Activation. :type activation: BaseActivation. :param parents: Layer's parents. - :type parents: list|tuple|collection.Sequence + :type parents: list|tuple|collections.Sequence """ def __init__(self, name, layer_type, parents=None, activation=None, @@ -1988,10 +1990,16 @@ def concat_layer(input, act=None, name=None, layer_attr=None): Concat all input vector into one huge vector. Inputs can be list of LayerOutput or list of projection. + The example usage is: + + .. code-block:: python + + concat = concat_layer(input=[layer1, layer2]) + :param name: Layer name. :type name: basestring :param input: input layers or projections - :type input: list|tuple|collection.Sequence + :type input: list|tuple|collections.Sequence :param act: Activation type. :type act: BaseActivation :param layer_attr: Extra Layer Attribute. @@ -3488,6 +3496,83 @@ def crf_decoding_layer(input, size, label=None, param_attr=None, name=None): parents.append(label) return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=size) +@wrap_bias_attr_default(has_bias=True) +@wrap_name_default() +@layer_support() +def nce_layer(input, label, num_classes, weight=None, + num_neg_samples=10, neg_distribution=None, + name=None, bias_attr=None, layer_attr=None): + """ + Noise-contrastive estimation. + Implements the method in the following paper: + A fast and simple algorithm for training neural probabilistic language models. + + The example usage is: + + .. code-block:: python + + cost = nce_layer(input=layer1, label=layer2, weight=layer3, + num_classes=3, neg_distribution=[0.1,0.3,0.6]) + + :param name: layer name + :type name: basestring + :param input: input layers. It could be a LayerOutput of list/tuple of LayerOutput. + :type input: LayerOutput|list|tuple|collections.Sequence + :param label: label layer + :type label: LayerOutput + :param weight: weight layer, can be None(default) + :type weight: LayerOutput + :param num_classes: number of classes. + :type num_classes: int + :param num_neg_samples: number of negative samples. Default is 10. + :type num_neg_samples: int + :param neg_distribution: The distribution for generating the random negative labels. + A uniform distribution will be used if not provided. + If not None, its length must be equal to num_classes. + :type neg_distribution: list|tuple|collections.Sequence|None + :param bias_attr: Bias parameter attribute. True if no bias. + :type bias_attr: ParameterAttribute|None|False + :param layer_attr: Extra Layer Attribute. + :type layer_attr: ExtraLayerAttribute + :return: layer name. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + assert isinstance(input, collections.Sequence) + assert isinstance(label, LayerOutput) + assert label.layer_type == LayerType.DATA + if neg_distribution is not None: + assert isinstance(neg_distribution, collections.Sequence) + assert len(neg_distribution) == num_classes + assert sum(neg_distribution) == 1 + + ipts_for_layer = [] + parents = [] + for each_input in input: + assert isinstance(each_input, LayerOutput) + ipts_for_layer.append(each_input.name) + parents.append(each_input) + ipts_for_layer.append(label.name) + parents.append(label) + + if weight is not None: + assert isinstance(weight, LayerOutput) + assert weight.layer_type == LayerType.DATA + ipts_for_layer.append(weight.name) + parents.append(weight) + + Layer( + name=name, + type=LayerType.NCE_LAYER, + num_classes=num_classes, + neg_sampling_dist=neg_distribution, + num_neg_samples=num_neg_samples, + inputs=ipts_for_layer, + bias=ParamAttr.to_bias(bias_attr), + **ExtraLayerAttribute.to_kwargs(layer_attr) + ) + return LayerOutput(name, LayerType.NCE_LAYER, parents=parents) """ following are cost Layers. -- GitLab