train.py 2.2 KB
Newer Older
C
caoying03 已提交
1
#!/usr/bin/env python
2 3 4
import os
import logging
import paddle.v2 as paddle
C
caoying03 已提交
5

6
from network_conf import seq2seq_net
C
caoying03 已提交
7

8 9
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
C
caoying03 已提交
10

11 12

def train(save_dir_path, source_dict_dim, target_dict_dim):
C
caoying03 已提交
13 14 15
    '''
    Training function for NMT

16 17
    :param save_dir_path: path of the directory to save the trained models.
    :param save_dir_path: str
C
caoying03 已提交
18 19 20 21 22
    :param source_dict_dim: size of source dictionary
    :type source_dict_dim: int
    :param target_dict_dim: size of target dictionary
    :type target_dict_dim: int
    '''
23 24 25 26
    if not os.path.exists(save_dir_path):
        os.mkdir(save_dir_path)

    # initialize PaddlePaddle
C
caoying03 已提交
27 28 29 30 31
    paddle.init(use_gpu=False, trainer_count=1)

    cost = seq2seq_net(source_dict_dim, target_dict_dim)
    parameters = paddle.parameters.create(cost)

32
    # define optimization method and the trainer instance
C
caoying03 已提交
33 34 35 36 37 38
    optimizer = paddle.optimizer.RMSProp(
        learning_rate=1e-3,
        gradient_clipping_threshold=10.0,
        regularization=paddle.optimizer.L2Regularization(rate=8e-4))
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)
39

C
caoying03 已提交
40 41 42 43 44 45
    # define data reader
    wmt14_reader = paddle.batch(
        paddle.reader.shuffle(
            paddle.dataset.wmt14.train(source_dict_dim), buf_size=8192),
        batch_size=8)

46
    # define the event_handler callback
C
caoying03 已提交
47 48
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
49 50 51 52 53
            if not event.batch_id % 100 and event.batch_id:
                with gzip.open(
                        os.path.join(save_path,
                                     "nmt_without_att_%05d_batch_%05d.tar.gz" %
                                     event.pass_id, event.batch_id), "w") as f:
C
caoying03 已提交
54 55 56
                    parameters.to_tar(f)

            if event.batch_id and not event.batch_id % 10:
57 58 59 60
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

    # start training
C
caoying03 已提交
61 62 63 64 65
    trainer.train(
        reader=wmt14_reader, event_handler=event_handler, num_passes=2)


if __name__ == '__main__':
66
    train(save_dir_path="models", source_dict_dim=3000, target_dict_dim=3000)