api_train_v2.py 8.7 KB
Newer Older
1
import sys
Q
qiaolongfei 已提交
2

3 4 5
import paddle.v2 as paddle


Q
qiaolongfei 已提交
6
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
Q
qiaolongfei 已提交
7 8 9 10 11
    ### Network Architecture
    word_vector_dim = 512  # dimension of word vector
    decoder_size = 512  # dimension of hidden unit in GRU Decoder network
    encoder_size = 512  # dimension of hidden unit in GRU Encoder network

Q
qiaolongfei 已提交
12 13 14
    beam_size = 3
    max_length = 250

Q
qiaolongfei 已提交
15 16 17 18 19 20 21 22 23
    #### Encoder
    src_word_id = paddle.layer.data(
        name='source_language_word',
        type=paddle.data_type.integer_value_sequence(source_dict_dim))
    src_embedding = paddle.layer.embedding(
        input=src_word_id,
        size=word_vector_dim,
        param_attr=paddle.attr.ParamAttr(name='_source_language_embedding'))
    src_forward = paddle.networks.simple_gru(
24
        name='src_forward_gru', input=src_embedding, size=encoder_size)
Q
qiaolongfei 已提交
25
    src_backward = paddle.networks.simple_gru(
26 27 28 29
        name='src_backward_gru',
        input=src_embedding,
        size=encoder_size,
        reverse=True)
Q
qiaolongfei 已提交
30 31 32 33 34 35 36 37 38 39
    encoded_vector = paddle.layer.concat(input=[src_forward, src_backward])

    #### Decoder
    with paddle.layer.mixed(size=decoder_size) as encoded_proj:
        encoded_proj += paddle.layer.full_matrix_projection(
            input=encoded_vector)

    backward_first = paddle.layer.first_seq(input=src_backward)

    with paddle.layer.mixed(
40 41 42
            name="decoder_boot_mixed",
            size=decoder_size,
            act=paddle.activation.Tanh()) as decoder_boot:
Q
qiaolongfei 已提交
43 44 45 46 47 48 49 50 51
        decoder_boot += paddle.layer.full_matrix_projection(
            input=backward_first)

    def gru_decoder_with_attention(enc_vec, enc_proj, current_word):

        decoder_mem = paddle.layer.memory(
            name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)

        context = paddle.networks.simple_attention(
52
            name="simple_attention",
Q
qiaolongfei 已提交
53 54 55 56
            encoded_sequence=enc_vec,
            encoded_proj=enc_proj,
            decoder_state=decoder_mem)

57 58 59 60 61 62
        with paddle.layer.mixed(
                name="input_recurrent",
                size=decoder_size * 3,
                # enable error clipping 
                layer_attr=paddle.attr.ExtraAttr(
                    error_clipping_threshold=100.0)) as decoder_inputs:
Q
qiaolongfei 已提交
63 64 65 66 67 68 69 70
            decoder_inputs += paddle.layer.full_matrix_projection(input=context)
            decoder_inputs += paddle.layer.full_matrix_projection(
                input=current_word)

        gru_step = paddle.layer.gru_step(
            name='gru_decoder',
            input=decoder_inputs,
            output_mem=decoder_mem,
71 72
            # uncomment to enable local threshold for gradient clipping
            # param_attr=paddle.attr.ParamAttr(gradient_clipping_threshold=9.9),
Q
qiaolongfei 已提交
73 74 75
            size=decoder_size)

        with paddle.layer.mixed(
76
                name="gru_step_output",
Q
qiaolongfei 已提交
77 78 79 80 81 82 83 84 85 86 87
                size=target_dict_dim,
                bias_attr=True,
                act=paddle.activation.Softmax()) as out:
            out += paddle.layer.full_matrix_projection(input=gru_step)
        return out

    decoder_group_name = "decoder_group"
    group_input1 = paddle.layer.StaticInputV2(input=encoded_vector, is_seq=True)
    group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True)
    group_inputs = [group_input1, group_input2]

Q
qiaolongfei 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    if not is_generating:
        trg_embedding = paddle.layer.embedding(
            input=paddle.layer.data(
                name='target_language_word',
                type=paddle.data_type.integer_value_sequence(target_dict_dim)),
            size=word_vector_dim,
            param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
        group_inputs.append(trg_embedding)

        # For decoder equipped with attention mechanism, in training,
        # target embeding (the groudtruth) is the data input,
        # while encoded source sequence is accessed to as an unbounded memory.
        # Here, the StaticInput defines a read-only memory
        # for the recurrent_group.
        decoder = paddle.layer.recurrent_group(
            name=decoder_group_name,
            step=gru_decoder_with_attention,
            input=group_inputs)

        lbl = paddle.layer.data(
            name='target_language_next_word',
            type=paddle.data_type.integer_value_sequence(target_dict_dim))
        cost = paddle.layer.classification_cost(input=decoder, label=lbl)

        return cost
    else:
        # In generation, the decoder predicts a next target word based on
        # the encoded source sequence and the last generated target word.

        # The encoded source sequence (encoder's output) must be specified by
        # StaticInput, which is a read-only memory.
        # Embedding of the last generated word is automatically gotten by
        # GeneratedInputs, which is initialized by a start mark, such as <s>,
        # and must be included in generation.

Q
qiaolongfei 已提交
123
        trg_embedding = paddle.layer.GeneratedInputV2(
Q
qiaolongfei 已提交
124 125 126 127 128
            size=target_dict_dim,
            embedding_name='_target_language_embedding',
            embedding_size=word_vector_dim)
        group_inputs.append(trg_embedding)

Q
qiaolongfei 已提交
129
        beam_gen = paddle.layer.beam_search(
Q
qiaolongfei 已提交
130 131 132 133 134 135 136
            name=decoder_group_name,
            step=gru_decoder_with_attention,
            input=group_inputs,
            bos_id=0,
            eos_id=1,
            beam_size=beam_size,
            max_length=max_length)
Q
qiaolongfei 已提交
137

Q
qiaolongfei 已提交
138
        return beam_gen
139 140 141


def main():
142 143 144 145 146 147 148
    paddle.init(
        use_gpu=False,
        trainer_count=1,
        # log gradient clipping info
        log_clipping=True,
        # log error clipping info
        log_error_clipping=True)
L
Luo Tao 已提交
149
    is_generating = False
150

Q
qiaolongfei 已提交
151 152 153 154
    # source and target dict dim.
    dict_size = 30000
    source_dict_dim = target_dict_dim = dict_size

L
Luo Tao 已提交
155 156 157 158 159 160 161 162
    # train the network
    if not is_generating:
        cost = seqToseq_net(source_dict_dim, target_dict_dim)
        parameters = paddle.parameters.create(cost)

        # define optimize method and trainer
        optimizer = paddle.optimizer.Adam(
            learning_rate=5e-5,
163 164
            # uncomment to enable global threshold for gradient clipping
            # gradient_clipping_threshold=10.0,
L
Luo Tao 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
            regularization=paddle.optimizer.L2Regularization(rate=8e-4))
        trainer = paddle.trainer.SGD(cost=cost,
                                     parameters=parameters,
                                     update_equation=optimizer)
        # define data reader
        wmt14_reader = paddle.batch(
            paddle.reader.shuffle(
                paddle.dataset.wmt14.train(dict_size), buf_size=8192),
            batch_size=5)

        # define event_handler callback
        def event_handler(event):
            if isinstance(event, paddle.event.EndIteration):
                if event.batch_id % 10 == 0:
                    print "\nPass %d, Batch %d, Cost %f, %s" % (
                        event.pass_id, event.batch_id, event.cost,
                        event.metrics)
                else:
                    sys.stdout.write('.')
                    sys.stdout.flush()

        # start to train
        trainer.train(
            reader=wmt14_reader, event_handler=event_handler, num_passes=2)

    # generate a english sequence to french
    else:
L
Luo Tao 已提交
192 193
        # use the first 3 samples for generation
        gen_creator = paddle.dataset.wmt14.gen(dict_size)
L
Luo Tao 已提交
194
        gen_data = []
L
Luo Tao 已提交
195
        gen_num = 3
L
Luo Tao 已提交
196 197
        for item in gen_creator():
            gen_data.append((item[0], ))
L
Luo Tao 已提交
198
            if len(gen_data) == gen_num:
L
Luo Tao 已提交
199 200 201
                break

        beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
L
Luo Tao 已提交
202
        # get the pretrained model, whose bleu = 26.92
L
Luo Tao 已提交
203
        parameters = paddle.dataset.wmt14.model()
L
Luo Tao 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
        # prob is the prediction probabilities, and id is the prediction word. 
        beam_result = paddle.infer(
            output_layer=beam_gen,
            parameters=parameters,
            input=gen_data,
            field=['prob', 'id'])

        # get the dictionary
        src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)

        # the delimited element of generated sequences is -1,
        # the first element of each generated sequence is the sequence length
        seq_list = []
        seq = []
        for w in beam_result[1]:
            if w != -1:
                seq.append(w)
            else:
                seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
                seq = []

        prob = beam_result[0]
        beam_size = 3
        for i in xrange(gen_num):
            print "\n*******************************************************\n"
            print "src:", ' '.join(
                [src_dict.get(w) for w in gen_data[i][0]]), "\n"
            for j in xrange(beam_size):
                print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]
233 234 235 236


if __name__ == '__main__':
    main()