提交 43f7d7b7 编写于 作者: L luotao1 提交者: qingqing01

add interface and unittest for nce layer (#180)

* add interface and unittest for nce layer

* follow comments
上级 e26f220d
...@@ -371,6 +371,12 @@ ctc_layer ...@@ -371,6 +371,12 @@ ctc_layer
:members: ctc_layer :members: ctc_layer
:noindex: :noindex:
nce_layer
-----------
.. automodule:: paddle.trainer_config_helpers.layers
:members: nce_layer
:noindex:
hsigmoid hsigmoid
--------- ---------
.. automodule:: paddle.trainer_config_helpers.layers .. automodule:: paddle.trainer_config_helpers.layers
......
...@@ -21,14 +21,18 @@ limitations under the License. */ ...@@ -21,14 +21,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
/** /**
* Noise-contrastive estimation * Noise-contrastive estimation.
* Implements the method in the following paper: * Implements the method in the following paper:
* A fast and simple algorithm for training neural probabilistic language models * A fast and simple algorithm for training neural probabilistic language models.
*
* The config file api is nce_layer.
*/ */
class NCELayer : public Layer { class NCELayer : public Layer {
int numClasses_; int numClasses_;
int numInputs_; // number of input layer besides labelLayer and weightLayer /// number of input layer besides labelLayer and weightLayer
int numInputs_;
LayerPtr labelLayer_; LayerPtr labelLayer_;
/// weight layer, can be None
LayerPtr weightLayer_; LayerPtr weightLayer_;
WeightList weights_; WeightList weights_;
std::unique_ptr<Weight> biases_; std::unique_ptr<Weight> biases_;
...@@ -43,7 +47,8 @@ class NCELayer : public Layer { ...@@ -43,7 +47,8 @@ class NCELayer : public Layer {
real weight; real weight;
}; };
std::vector<Sample> samples_; std::vector<Sample> samples_;
bool prepared_; // whether samples_ is prepared /// whether samples_ is prepared
bool prepared_;
Argument sampleOut_; Argument sampleOut_;
IVectorPtr labelIds_; IVectorPtr labelIds_;
......
...@@ -13,157 +13,71 @@ ...@@ -13,157 +13,71 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later. from paddle.trainer_config_helpers import *
default_initial_std(0.5) TrainData(ProtoData(
files = "dummy_list",
model_type("nn") constant_slots = [1.0],
async_load_data = True))
DataLayer(
name = "input", TestData(SimpleData(
size = 3, files = "trainer/tests/sample_filelist.txt",
) feat_dim = 3,
context_len = 0,
DataLayer( buffer_capacity = 1000000,
name = "weight", async_load_data = False))
size = 1,
) settings(batch_size = 100)
Layer( data = data_layer(name='input', size=3)
name = "layer1_1",
type = "fc", wt = data_layer(name='weight', size=1)
size = 5,
active_type = "sigmoid", fc1 = fc_layer(input=data, size=5,
inputs = "input", bias_attr=True,
) act=SigmoidActivation())
Layer( fc2 = fc_layer(input=data, size=12,
name = "layer1_2", bias_attr=True,
type = "fc", param_attr=ParamAttr(name='sharew'),
size = 12, act=LinearActivation())
active_type = "linear",
inputs = Input("input", parameter_name='sharew'), fc3 = fc_layer(input=data, size=3,
) bias_attr=True,
act=TanhActivation())
Layer(
name = "layer1_3", fc4 = fc_layer(input=data, size=5,
type = "fc", bias_attr=True,
size = 3, layer_attr=ExtraAttr(drop_rate=0.5),
active_type = "tanh", act=SquareActivation())
inputs = "input",
) pool = img_pool_layer(input=fc2,
pool_size=2,
Layer( pool_size_y=3,
name = "layer1_5", num_channels=1,
type = "fc", padding=1,
size = 3, padding_y=2,
active_type = "tanh", stride=2,
inputs = Input("input", stride_y=3,
learning_rate=0.01, img_width=3,
momentum=0.9, pool_type=CudnnAvgPooling())
decay_rate=0.05,
initial_mean=0.0, concat = concat_layer(input=[fc3, fc4])
initial_std=0.01,
format = "csc", with mixed_layer(size=3, act=SoftmaxActivation()) as output:
nnz = 4) output += full_matrix_projection(input=fc1)
) output += trans_full_matrix_projection(input=fc2,
param_attr=ParamAttr(name='sharew'))
FCLayer( output += full_matrix_projection(input=concat)
name = "layer1_4", output += identity_projection(input=fc3)
size = 5,
active_type = "square", lbl = data_layer(name='label', size=1)
inputs = "input",
drop_rate = 0.5, cost = classification_cost(input=output, label=lbl, weight=wt,
) layer_attr=ExtraAttr(device=-1))
Layer( nce = nce_layer(input=fc2, label=lbl, weight=wt,
name = "pool", num_classes=3,
type = "pool", neg_distribution=[0.1, 0.3, 0.6])
inputs = Input("layer1_2",
pool = Pool(pool_type="cudnn-avg-pool", outputs(cost, nce)
channels = 1,
size_x = 2,
size_y = 3,
img_width = 3,
padding = 1,
padding_y = 2,
stride = 2,
stride_y = 3))
)
Layer(
name = "concat",
type = "concat",
inputs = ["layer1_3", "layer1_4"],
)
MixedLayer(
name = "output",
size = 3,
active_type = "softmax",
inputs = [
FullMatrixProjection("layer1_1",
learning_rate=0.1),
TransposedFullMatrixProjection("layer1_2", parameter_name='sharew'),
FullMatrixProjection("concat"),
IdentityProjection("layer1_3"),
],
)
Layer(
name = "label",
type = "data",
size = 1,
)
Layer(
name = "cost",
type = "multi-class-cross-entropy",
inputs = ["output", "label", "weight"],
)
Layer(
name = "cost2",
type = "nce",
num_classes = 3,
active_type = "sigmoid",
neg_sampling_dist = [0.1, 0.3, 0.6],
inputs = ["layer1_2", "label", "weight"],
)
Evaluator(
name = "error",
type = "classification_error",
inputs = ["output", "label", "weight"]
)
Inputs("input", "label", "weight")
Outputs("cost", "cost2")
TrainData(
ProtoData(
files = "dummy_list",
constant_slots = [1.0],
async_load_data = True,
)
)
TestData(
SimpleData(
files = "trainer/tests/sample_filelist.txt",
feat_dim = 3,
context_len = 0,
buffer_capacity = 1000000,
async_load_data = False,
),
)
Settings(
algorithm = "sgd",
num_batches_per_send_parameter = 1,
num_batches_per_get_parameter = 1,
batch_size = 100,
learning_rate = 0.001,
learning_rate_decay_a = 1e-5,
learning_rate_decay_b = 0.5,
)
...@@ -50,6 +50,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", ...@@ -50,6 +50,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'slope_intercept_layer', 'trans_full_matrix_projection', 'slope_intercept_layer', 'trans_full_matrix_projection',
'linear_comb_layer', 'linear_comb_layer',
'convex_comb_layer', 'ctc_layer', 'crf_layer', 'crf_decoding_layer', 'convex_comb_layer', 'ctc_layer', 'crf_layer', 'crf_decoding_layer',
'nce_layer',
'cross_entropy_with_selfnorm', 'cross_entropy', 'cross_entropy_with_selfnorm', 'cross_entropy',
'multi_binary_label_cross_entropy', 'multi_binary_label_cross_entropy',
'rank_cost', 'lambda_cost', 'huber_cost', 'rank_cost', 'lambda_cost', 'huber_cost',
...@@ -115,6 +116,7 @@ class LayerType(object): ...@@ -115,6 +116,7 @@ class LayerType(object):
CTC_LAYER = "ctc" CTC_LAYER = "ctc"
CRF_LAYER = "crf" CRF_LAYER = "crf"
CRF_DECODING_LAYER = "crf_decoding" CRF_DECODING_LAYER = "crf_decoding"
NCE_LAYER = 'nce'
RANK_COST = "rank-cost" RANK_COST = "rank-cost"
LAMBDA_COST = "lambda_cost" LAMBDA_COST = "lambda_cost"
...@@ -168,7 +170,7 @@ class LayerOutput(object): ...@@ -168,7 +170,7 @@ class LayerOutput(object):
:param activation: Layer Activation. :param activation: Layer Activation.
:type activation: BaseActivation. :type activation: BaseActivation.
:param parents: Layer's parents. :param parents: Layer's parents.
:type parents: list|tuple|collection.Sequence :type parents: list|tuple|collections.Sequence
""" """
def __init__(self, name, layer_type, parents=None, activation=None, def __init__(self, name, layer_type, parents=None, activation=None,
...@@ -1988,10 +1990,16 @@ def concat_layer(input, act=None, name=None, layer_attr=None): ...@@ -1988,10 +1990,16 @@ def concat_layer(input, act=None, name=None, layer_attr=None):
Concat all input vector into one huge vector. Concat all input vector into one huge vector.
Inputs can be list of LayerOutput or list of projection. Inputs can be list of LayerOutput or list of projection.
The example usage is:
.. code-block:: python
concat = concat_layer(input=[layer1, layer2])
:param name: Layer name. :param name: Layer name.
:type name: basestring :type name: basestring
:param input: input layers or projections :param input: input layers or projections
:type input: list|tuple|collection.Sequence :type input: list|tuple|collections.Sequence
:param act: Activation type. :param act: Activation type.
:type act: BaseActivation :type act: BaseActivation
:param layer_attr: Extra Layer Attribute. :param layer_attr: Extra Layer Attribute.
...@@ -3488,6 +3496,83 @@ def crf_decoding_layer(input, size, label=None, param_attr=None, name=None): ...@@ -3488,6 +3496,83 @@ def crf_decoding_layer(input, size, label=None, param_attr=None, name=None):
parents.append(label) parents.append(label)
return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=size) return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=size)
@wrap_bias_attr_default(has_bias=True)
@wrap_name_default()
@layer_support()
def nce_layer(input, label, num_classes, weight=None,
num_neg_samples=10, neg_distribution=None,
name=None, bias_attr=None, layer_attr=None):
"""
Noise-contrastive estimation.
Implements the method in the following paper:
A fast and simple algorithm for training neural probabilistic language models.
The example usage is:
.. code-block:: python
cost = nce_layer(input=layer1, label=layer2, weight=layer3,
num_classes=3, neg_distribution=[0.1,0.3,0.6])
:param name: layer name
:type name: basestring
:param input: input layers. It could be a LayerOutput of list/tuple of LayerOutput.
:type input: LayerOutput|list|tuple|collections.Sequence
:param label: label layer
:type label: LayerOutput
:param weight: weight layer, can be None(default)
:type weight: LayerOutput
:param num_classes: number of classes.
:type num_classes: int
:param num_neg_samples: number of negative samples. Default is 10.
:type num_neg_samples: int
:param neg_distribution: The distribution for generating the random negative labels.
A uniform distribution will be used if not provided.
If not None, its length must be equal to num_classes.
:type neg_distribution: list|tuple|collections.Sequence|None
:param bias_attr: Bias parameter attribute. True if no bias.
:type bias_attr: ParameterAttribute|None|False
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: layer name.
:rtype: LayerOutput
"""
if isinstance(input, LayerOutput):
input = [input]
assert isinstance(input, collections.Sequence)
assert isinstance(label, LayerOutput)
assert label.layer_type == LayerType.DATA
if neg_distribution is not None:
assert isinstance(neg_distribution, collections.Sequence)
assert len(neg_distribution) == num_classes
assert sum(neg_distribution) == 1
ipts_for_layer = []
parents = []
for each_input in input:
assert isinstance(each_input, LayerOutput)
ipts_for_layer.append(each_input.name)
parents.append(each_input)
ipts_for_layer.append(label.name)
parents.append(label)
if weight is not None:
assert isinstance(weight, LayerOutput)
assert weight.layer_type == LayerType.DATA
ipts_for_layer.append(weight.name)
parents.append(weight)
Layer(
name=name,
type=LayerType.NCE_LAYER,
num_classes=num_classes,
neg_sampling_dist=neg_distribution,
num_neg_samples=num_neg_samples,
inputs=ipts_for_layer,
bias=ParamAttr.to_bias(bias_attr),
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.NCE_LAYER, parents=parents)
""" """
following are cost Layers. following are cost Layers.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册