提交 612cd3f3 编写于 作者: H Helin Wang

replace old code with new code

上级 97fee5d0
#!/bin/bash
set -e
wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar -zxf simple-examples.tgz
echo `pwd`/simple-examples/data/ptb.train.txt > train.list
echo `pwd`/simple-examples/data/ptb.valid.txt > test.list
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.PyDataProvider2 import *
import collections
import logging
import pdb
logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s', )
logger = logging.getLogger('paddle')
logger.setLevel(logging.INFO)
N = 5 # Ngram
cutoff = 50 # select words with frequency > cutoff to dictionary
def build_dict(ftrain, fdict):
sentences = []
with open(ftrain) as fin:
for line in fin:
line = ['<s>'] + line.strip().split() + ['<e>']
sentences += line
wordfreq = collections.Counter(sentences)
wordfreq = filter(lambda x: x[1] > cutoff, wordfreq.items())
dictionary = sorted(wordfreq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
for word in words:
print >> fdict, word
word_idx = dict(zip(words, xrange(len(words))))
logger.info("Dictionary size=%s" % len(words))
return word_idx
def initializer(settings, srcText, dictfile, **xargs):
with open(dictfile, 'w') as fdict:
settings.dicts = build_dict(srcText, fdict)
input_types = []
for i in xrange(N):
input_types.append(integer_value(len(settings.dicts)))
settings.input_types = input_types
@provider(init_hook=initializer)
def process(settings, filename):
UNKID = settings.dicts['<unk>']
with open(filename) as fin:
for line in fin:
line = ['<s>'] * (N - 1) + line.strip().split() + ['<e>']
line = [settings.dicts.get(w, UNKID) for w in line]
for i in range(N, len(line) + 1):
yield line[i - N:i]
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
import math
#################### Data Configure ####################
args = {
'srcText': 'data/simple-examples/data/ptb.train.txt',
'dictfile': 'data/vocabulary.txt'
}
define_py_data_sources2(
train_list="data/train.list",
test_list="data/test.list",
module="dataprovider",
obj="process",
args=args)
settings(
batch_size=100, regularization=L2Regularization(8e-4), learning_rate=3e-3)
dictsize = 1953
embsize = 32
hiddensize = 256
firstword = data_layer(name="firstw", size=dictsize)
secondword = data_layer(name="secondw", size=dictsize)
thirdword = data_layer(name="thirdw", size=dictsize)
fourthword = data_layer(name="fourthw", size=dictsize)
nextword = data_layer(name="fifthw", size=dictsize)
# construct word embedding for each datalayer
def wordemb(inlayer):
wordemb = table_projection(
input=inlayer,
size=embsize,
param_attr=ParamAttr(
name="_proj",
initial_std=0.001,
learning_rate=1,
l2_rate=0, ))
return wordemb
Efirst = wordemb(firstword)
Esecond = wordemb(secondword)
Ethird = wordemb(thirdword)
Efourth = wordemb(fourthword)
# concatentate Ngram embeddings into context embedding
contextemb = concat_layer(input=[Efirst, Esecond, Ethird, Efourth])
hidden1 = fc_layer(
input=contextemb,
size=hiddensize,
act=SigmoidActivation(),
layer_attr=ExtraAttr(drop_rate=0.5),
bias_attr=ParamAttr(learning_rate=2),
param_attr=ParamAttr(
initial_std=1. / math.sqrt(embsize * 8), learning_rate=1))
# use context embedding to predict nextword
predictword = fc_layer(
input=hidden1,
size=dictsize,
bias_attr=ParamAttr(learning_rate=2),
act=SoftmaxActivation())
cost = classification_cost(input=predictword, label=nextword)
# network input and output
outputs(cost)
import math
import paddle.v2 as paddle
embsize = 32
hiddensize = 256
N = 5
def wordemb(inlayer):
wordemb = paddle.layer.table_projection(
input=inlayer,
size=embsize,
param_attr=paddle.attr.Param(
name="_proj",
initial_std=0.001,
learning_rate=1,
l2_rate=0, ))
return wordemb
def main():
paddle.init(use_gpu=False, trainer_count=1)
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
firstword = paddle.layer.data(
name="firstw", type=paddle.data_type.integer_value(dict_size))
secondword = paddle.layer.data(
name="secondw", type=paddle.data_type.integer_value(dict_size))
thirdword = paddle.layer.data(
name="thirdw", type=paddle.data_type.integer_value(dict_size))
fourthword = paddle.layer.data(
name="fourthw", type=paddle.data_type.integer_value(dict_size))
nextword = paddle.layer.data(
name="fifthw", type=paddle.data_type.integer_value(dict_size))
Efirst = wordemb(firstword)
Esecond = wordemb(secondword)
Ethird = wordemb(thirdword)
Efourth = wordemb(fourthword)
contextemb = paddle.layer.concat(input=[Efirst, Esecond, Ethird, Efourth])
hidden1 = paddle.layer.fc(input=contextemb,
size=hiddensize,
act=paddle.activation.Sigmoid(),
layer_attr=paddle.attr.Extra(drop_rate=0.5),
bias_attr=paddle.attr.Param(learning_rate=2),
param_attr=paddle.attr.Param(
initial_std=1. / math.sqrt(embsize * 8),
learning_rate=1))
predictword = paddle.layer.fc(input=hidden1,
size=dict_size,
bias_attr=paddle.attr.Param(learning_rate=2),
act=paddle.activation.Softmax())
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
result = trainer.test(
paddle.batch(
paddle.dataset.imikolov.test(word_dict, N), 32))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
cost = paddle.layer.classification_cost(input=predictword, label=nextword)
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
num_passes=30,
event_handler=event_handler)
if __name__ == '__main__':
main()
#!/bin/bash
set -e
paddle train \
--config ngram.py \
--use_gpu=1 \
--dot_period=100 \
--log_period=3000 \
--test_period=0 \
--save_dir=model \
--num_passes=30
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册