seqToseq_net.py 7.0 KB
Newer Older
Z
zhangjinchao01 已提交
1 2
# edit-mode: -*- python -*-

3
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
from paddle.trainer_config_helpers import *


def seq_to_seq_data(data_dir,
                    is_generating,
                    dict_size=30000,
                    train_list='train.list',
                    test_list='test.list',
                    gen_list='gen.list',
                    gen_result='gen_result'):
    """
    Predefined seqToseq train data provider for application
    is_generating: whether this config is used for generating
    dict_size: word count of dictionary
    train_list: a text file containing a list of training data
    test_list: a text file containing a list of testing data
    gen_list: a text file containing a list of generating data
    gen_result: a text file containing generating result
    """
    src_lang_dict = os.path.join(data_dir, 'src.dict')
    trg_lang_dict = os.path.join(data_dir, 'trg.dict')

    if is_generating:
        train_list = None
        test_list = os.path.join(data_dir, gen_list)
    else:
        train_list = os.path.join(data_dir, train_list)
46
        test_list = os.path.join(data_dir, test_list)
Z
zhangjinchao01 已提交
47

48 49 50 51 52
    define_py_data_sources2(
        train_list,
        test_list,
        module="dataprovider",
        obj="process",
L
Luo Tao 已提交
53 54 55 56 57
        args={
            "src_dict_path": src_lang_dict,
            "trg_dict_path": trg_lang_dict,
            "is_generating": is_generating
        })
Z
zhangjinchao01 已提交
58

59 60 61 62 63
    return {
        "src_dict_path": src_lang_dict,
        "trg_dict_path": trg_lang_dict,
        "gen_result": gen_result
    }
Z
zhangjinchao01 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83


def gru_encoder_decoder(data_conf,
                        is_generating,
                        word_vector_dim=512,
                        encoder_size=512,
                        decoder_size=512,
                        beam_size=3,
                        max_length=250):
    """
    A wrapper for an attention version of GRU Encoder-Decoder network
    is_generating: whether this config is used for generating
    encoder_size: dimension of hidden unit in GRU Encoder network
    decoder_size: dimension of hidden unit in GRU Decoder network
    word_vector_dim: dimension of word vector
    beam_size: expand width in beam search
    max_length: a stop condition of sequence generation
    """
    for k, v in data_conf.iteritems():
        globals()[k] = v
Q
qiaolongfei 已提交
84 85
    source_dict_dim = len(open(src_dict_path, "r").readlines())
    target_dict_dim = len(open(trg_dict_path, "r").readlines())
Z
zhangjinchao01 已提交
86 87 88 89 90 91
    gen_trans_file = gen_result

    src_word_id = data_layer(name='source_language_word', size=source_dict_dim)
    src_embedding = embedding_layer(
        input=src_word_id,
        size=word_vector_dim,
W
wangjiang03 已提交
92 93
        param_attr=ParamAttr(name='_source_language_embedding'))
    src_forward = simple_gru(input=src_embedding, size=encoder_size)
94 95
    src_backward = simple_gru(
        input=src_embedding, size=encoder_size, reverse=True)
Z
zhangjinchao01 已提交
96 97 98
    encoded_vector = concat_layer(input=[src_forward, src_backward])

    with mixed_layer(size=decoder_size) as encoded_proj:
99
        encoded_proj += full_matrix_projection(input=encoded_vector)
Z
zhangjinchao01 已提交
100 101

    backward_first = first_seq(input=src_backward)
102 103 104
    with mixed_layer(
            size=decoder_size,
            act=TanhActivation(), ) as decoder_boot:
105
        decoder_boot += full_matrix_projection(input=backward_first)
Z
zhangjinchao01 已提交
106 107

    def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
108 109
        decoder_mem = memory(
            name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
Z
zhangjinchao01 已提交
110

111 112 113 114
        context = simple_attention(
            encoded_sequence=enc_vec,
            encoded_proj=enc_proj,
            decoder_state=decoder_mem, )
Z
zhangjinchao01 已提交
115 116

        with mixed_layer(size=decoder_size * 3) as decoder_inputs:
117 118
            decoder_inputs += full_matrix_projection(input=context)
            decoder_inputs += full_matrix_projection(input=current_word)
Z
zhangjinchao01 已提交
119

120 121 122 123 124
        gru_step = gru_step_layer(
            name='gru_decoder',
            input=decoder_inputs,
            output_mem=decoder_mem,
            size=decoder_size)
Z
zhangjinchao01 已提交
125

126 127 128
        with mixed_layer(
                size=target_dict_dim, bias_attr=True,
                act=SoftmaxActivation()) as out:
Z
zhangjinchao01 已提交
129 130 131 132
            out += full_matrix_projection(input=gru_step)
        return out

    decoder_group_name = "decoder_group"
133
    group_inputs = [
Q
qiaolongfei 已提交
134 135
        StaticInput(input=encoded_vector, is_seq=True),
        StaticInput(input=encoded_proj, is_seq=True)
136
    ]
137

Z
zhangjinchao01 已提交
138 139
    if not is_generating:
        trg_embedding = embedding_layer(
140 141
            input=data_layer(
                name='target_language_word', size=target_dict_dim),
Z
zhangjinchao01 已提交
142 143
            size=word_vector_dim,
            param_attr=ParamAttr(name='_target_language_embedding'))
144
        group_inputs.append(trg_embedding)
Z
zhangjinchao01 已提交
145 146 147 148 149 150

        # For decoder equipped with attention mechanism, in training,
        # target embeding (the groudtruth) is the data input,
        # while encoded source sequence is accessed to as an unbounded memory.
        # Here, the StaticInput defines a read-only memory
        # for the recurrent_group.
151 152 153 154
        decoder = recurrent_group(
            name=decoder_group_name,
            step=gru_decoder_with_attention,
            input=group_inputs)
Z
zhangjinchao01 已提交
155

156
        lbl = data_layer(name='target_language_next_word', size=target_dict_dim)
157
        cost = classification_cost(input=decoder, label=lbl)
Z
zhangjinchao01 已提交
158 159
        outputs(cost)
    else:
160
        # In generation, the decoder predicts a next target word based on
Z
zhangjinchao01 已提交
161
        # the encoded source sequence and the last generated target word.
162

Z
zhangjinchao01 已提交
163
        # The encoded source sequence (encoder's output) must be specified by
164 165 166 167 168
        # StaticInput, which is a read-only memory.
        # Embedding of the last generated word is automatically gotten by
        # GeneratedInputs, which is initialized by a start mark, such as <s>,
        # and must be included in generation.

Z
zhangjinchao01 已提交
169 170 171 172
        trg_embedding = GeneratedInput(
            size=target_dict_dim,
            embedding_name='_target_language_embedding',
            embedding_size=word_vector_dim)
173 174
        group_inputs.append(trg_embedding)

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        beam_gen = beam_search(
            name=decoder_group_name,
            step=gru_decoder_with_attention,
            input=group_inputs,
            bos_id=0,
            eos_id=1,
            beam_size=beam_size,
            max_length=max_length)

        seqtext_printer_evaluator(
            input=beam_gen,
            id_input=data_layer(
                name="sent_id", size=1),
            dict_file=trg_dict_path,
            result_file=gen_trans_file)
Z
zhangjinchao01 已提交
190
        outputs(beam_gen)