train.py 5.1 KB
Newer Older
Q
qiaolongfei 已提交
1
import sys
Q
qiaolongfei 已提交
2 3 4
import paddle.v2 as paddle


Q
qiaolongfei 已提交
5 6 7 8 9 10 11 12
def seqToseq_net(source_dict_dim, target_dict_dim):
    ### Network Architecture
    word_vector_dim = 512  # dimension of word vector
    decoder_size = 512  # dimension of hidden unit in GRU Decoder network
    encoder_size = 512  # dimension of hidden unit in GRU Encoder network

    #### Encoder
    src_word_id = paddle.layer.data(
13 14
        name='source_language_word',
        type=paddle.data_type.integer_value_sequence(source_dict_dim))
Q
qiaolongfei 已提交
15
    src_embedding = paddle.layer.embedding(
16 17 18
        input=src_word_id,
        size=word_vector_dim,
        param_attr=paddle.attr.ParamAttr(name='_source_language_embedding'))
Q
qiaolongfei 已提交
19
    src_forward = paddle.networks.simple_gru(
20
        input=src_embedding, size=encoder_size)
Q
qiaolongfei 已提交
21
    src_backward = paddle.networks.simple_gru(
22
        input=src_embedding, size=encoder_size, reverse=True)
Q
qiaolongfei 已提交
23 24 25 26 27
    encoded_vector = paddle.layer.concat(input=[src_forward, src_backward])

    #### Decoder
    with paddle.layer.mixed(size=decoder_size) as encoded_proj:
        encoded_proj += paddle.layer.full_matrix_projection(
28
            input=encoded_vector)
Q
qiaolongfei 已提交
29 30 31 32 33 34

    backward_first = paddle.layer.first_seq(input=src_backward)

    with paddle.layer.mixed(
            size=decoder_size, act=paddle.activation.Tanh()) as decoder_boot:
        decoder_boot += paddle.layer.full_matrix_projection(
35
            input=backward_first)
Q
qiaolongfei 已提交
36 37 38 39

    def gru_decoder_with_attention(enc_vec, enc_proj, current_word):

        decoder_mem = paddle.layer.memory(
40
            name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
Q
qiaolongfei 已提交
41 42

        context = paddle.networks.simple_attention(
43 44 45
            encoded_sequence=enc_vec,
            encoded_proj=enc_proj,
            decoder_state=decoder_mem)
Q
qiaolongfei 已提交
46 47 48 49

        with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
            decoder_inputs += paddle.layer.full_matrix_projection(input=context)
            decoder_inputs += paddle.layer.full_matrix_projection(
50
                input=current_word)
Q
qiaolongfei 已提交
51 52

        gru_step = paddle.layer.gru_step(
53 54 55 56
            name='gru_decoder',
            input=decoder_inputs,
            output_mem=decoder_mem,
            size=decoder_size)
Q
qiaolongfei 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70

        with paddle.layer.mixed(
                size=target_dict_dim,
                bias_attr=True,
                act=paddle.activation.Softmax()) as out:
            out += paddle.layer.full_matrix_projection(input=gru_step)
        return out

    decoder_group_name = "decoder_group"
    group_input1 = paddle.layer.StaticInputV2(input=encoded_vector, is_seq=True)
    group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True)
    group_inputs = [group_input1, group_input2]

    trg_embedding = paddle.layer.embedding(
71 72 73 74 75
        input=paddle.layer.data(
            name='target_language_word',
            type=paddle.data_type.integer_value_sequence(target_dict_dim)),
        size=word_vector_dim,
        param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
Q
qiaolongfei 已提交
76 77 78 79 80 81 82 83
    group_inputs.append(trg_embedding)

    # For decoder equipped with attention mechanism, in training,
    # target embeding (the groudtruth) is the data input,
    # while encoded source sequence is accessed to as an unbounded memory.
    # Here, the StaticInput defines a read-only memory
    # for the recurrent_group.
    decoder = paddle.layer.recurrent_group(
84 85 86
        name=decoder_group_name,
        step=gru_decoder_with_attention,
        input=group_inputs)
Q
qiaolongfei 已提交
87 88

    lbl = paddle.layer.data(
89 90
        name='target_language_next_word',
        type=paddle.data_type.integer_value_sequence(target_dict_dim))
Q
qiaolongfei 已提交
91 92 93
    cost = paddle.layer.classification_cost(input=decoder, label=lbl)

    return cost
Q
qiaolongfei 已提交
94 95 96 97 98


def main():
    paddle.init(use_gpu=False, trainer_count=1)

Q
qiaolongfei 已提交
99 100 101 102
    # source and target dict dim.
    dict_size = 30000
    source_dict_dim = target_dict_dim = dict_size

Q
qiaolongfei 已提交
103
    # define network topology
Q
qiaolongfei 已提交
104
    cost = seqToseq_net(source_dict_dim, target_dict_dim)
Q
qiaolongfei 已提交
105 106
    parameters = paddle.parameters.create(cost)

Q
qiaolongfei 已提交
107
    # define optimize method and trainer
Q
qiaolongfei 已提交
108 109
    optimizer = paddle.optimizer.Adam(
        learning_rate=5e-5,
Q
qiaolongfei 已提交
110
        regularization=paddle.optimizer.L2Regularization(rate=8e-4))
111 112
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)
Q
qiaolongfei 已提交
113

Q
qiaolongfei 已提交
114
    # define data reader
115
    feeding = {
Q
qiaolongfei 已提交
116 117 118 119
        'source_language_word': 0,
        'target_language_word': 1,
        'target_language_next_word': 2
    }
120

Q
qiaolongfei 已提交
121
    wmt14_reader = paddle.batch(
122 123 124
        paddle.reader.shuffle(
            paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192),
        batch_size=5)
Q
qiaolongfei 已提交
125

Q
qiaolongfei 已提交
126 127 128 129
    # define event_handler callback
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 10 == 0:
Q
qiaolongfei 已提交
130
                print "\nPass %d, Batch %d, Cost %f, %s" % (
Q
qiaolongfei 已提交
131
                    event.pass_id, event.batch_id, event.cost, event.metrics)
Q
qiaolongfei 已提交
132 133 134
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
Q
qiaolongfei 已提交
135

Q
qiaolongfei 已提交
136
    # start to train
Q
qiaolongfei 已提交
137
    trainer.train(
138 139
        reader=wmt14_reader,
        event_handler=event_handler,
Q
qiaolongfei 已提交
140
        num_passes=2,
141
        feeding=feeding)
Q
qiaolongfei 已提交
142 143 144 145


if __name__ == '__main__':
    main()