seqToseq_net.py 7.4 KB
Newer Older
Z
zhangjinchao01 已提交
1 2
# edit-mode: -*- python -*-

3
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
from paddle.trainer_config_helpers import *


def seq_to_seq_data(data_dir,
                    is_generating,
                    dict_size=30000,
                    train_list='train.list',
                    test_list='test.list',
                    gen_list='gen.list',
                    gen_result='gen_result'):
    """
    Predefined seqToseq train data provider for application
    is_generating: whether this config is used for generating
    dict_size: word count of dictionary
    train_list: a text file containing a list of training data
    test_list: a text file containing a list of testing data
    gen_list: a text file containing a list of generating data
    gen_result: a text file containing generating result
    """
    src_lang_dict = os.path.join(data_dir, 'src.dict')
    trg_lang_dict = os.path.join(data_dir, 'trg.dict')

    if is_generating:
        train_list = None
        test_list = os.path.join(data_dir, gen_list)
    else:
        train_list = os.path.join(data_dir, train_list)
46
        test_list = os.path.join(data_dir, test_list)
Z
zhangjinchao01 已提交
47

48 49 50 51 52
    define_py_data_sources2(
        train_list,
        test_list,
        module="dataprovider",
        obj="process",
L
Luo Tao 已提交
53 54 55 56 57
        args={
            "src_dict_path": src_lang_dict,
            "trg_dict_path": trg_lang_dict,
            "is_generating": is_generating
        })
Z
zhangjinchao01 已提交
58

59 60 61 62 63
    return {
        "src_dict_path": src_lang_dict,
        "trg_dict_path": trg_lang_dict,
        "gen_result": gen_result
    }
Z
zhangjinchao01 已提交
64 65 66 67 68 69 70 71


def gru_encoder_decoder(data_conf,
                        is_generating,
                        word_vector_dim=512,
                        encoder_size=512,
                        decoder_size=512,
                        beam_size=3,
Y
Yu Yang 已提交
72 73
                        max_length=250,
                        error_clipping=50):
Z
zhangjinchao01 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    """
    A wrapper for an attention version of GRU Encoder-Decoder network
    is_generating: whether this config is used for generating
    encoder_size: dimension of hidden unit in GRU Encoder network
    decoder_size: dimension of hidden unit in GRU Decoder network
    word_vector_dim: dimension of word vector
    beam_size: expand width in beam search
    max_length: a stop condition of sequence generation
    """
    for k, v in data_conf.iteritems():
        globals()[k] = v
    source_dict_dim = len(open(src_dict_path, "r").readlines())
    target_dict_dim = len(open(trg_dict_path, "r").readlines())
    gen_trans_file = gen_result

    src_word_id = data_layer(name='source_language_word', size=source_dict_dim)
    src_embedding = embedding_layer(
        input=src_word_id,
        size=word_vector_dim,
W
wangjiang03 已提交
93
        param_attr=ParamAttr(name='_source_language_embedding'))
Y
Yu Yang 已提交
94 95 96 97 98 99
    src_forward = simple_gru(
        input=src_embedding,
        size=encoder_size,
        naive=True,
        gru_layer_attr=ExtraLayerAttribute(
            error_clipping_threshold=error_clipping))
100
    src_backward = simple_gru(
Y
Yu Yang 已提交
101 102 103 104 105 106
        input=src_embedding,
        size=encoder_size,
        reverse=True,
        naive=True,
        gru_layer_attr=ExtraLayerAttribute(
            error_clipping_threshold=error_clipping))
Z
zhangjinchao01 已提交
107 108 109
    encoded_vector = concat_layer(input=[src_forward, src_backward])

    with mixed_layer(size=decoder_size) as encoded_proj:
110
        encoded_proj += full_matrix_projection(input=encoded_vector)
Z
zhangjinchao01 已提交
111 112

    backward_first = first_seq(input=src_backward)
113 114 115
    with mixed_layer(
            size=decoder_size,
            act=TanhActivation(), ) as decoder_boot:
116
        decoder_boot += full_matrix_projection(input=backward_first)
Z
zhangjinchao01 已提交
117 118

    def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
119 120
        decoder_mem = memory(
            name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
Z
zhangjinchao01 已提交
121

122 123 124 125
        context = simple_attention(
            encoded_sequence=enc_vec,
            encoded_proj=enc_proj,
            decoder_state=decoder_mem, )
Z
zhangjinchao01 已提交
126 127

        with mixed_layer(size=decoder_size * 3) as decoder_inputs:
128 129
            decoder_inputs += full_matrix_projection(input=context)
            decoder_inputs += full_matrix_projection(input=current_word)
Z
zhangjinchao01 已提交
130

Y
Yu Yang 已提交
131
        gru_step = gru_step_naive_layer(
132 133 134
            name='gru_decoder',
            input=decoder_inputs,
            output_mem=decoder_mem,
Y
Yu Yang 已提交
135 136 137
            size=decoder_size,
            layer_attr=ExtraLayerAttribute(
                error_clipping_threshold=error_clipping))
Z
zhangjinchao01 已提交
138

139 140 141
        with mixed_layer(
                size=target_dict_dim, bias_attr=True,
                act=SoftmaxActivation()) as out:
Z
zhangjinchao01 已提交
142 143 144 145
            out += full_matrix_projection(input=gru_step)
        return out

    decoder_group_name = "decoder_group"
146 147 148 149 150
    group_inputs = [
        StaticInput(
            input=encoded_vector, is_seq=True), StaticInput(
                input=encoded_proj, is_seq=True)
    ]
151

Z
zhangjinchao01 已提交
152 153
    if not is_generating:
        trg_embedding = embedding_layer(
154 155
            input=data_layer(
                name='target_language_word', size=target_dict_dim),
Z
zhangjinchao01 已提交
156 157
            size=word_vector_dim,
            param_attr=ParamAttr(name='_target_language_embedding'))
158
        group_inputs.append(trg_embedding)
Z
zhangjinchao01 已提交
159 160 161 162 163 164

        # For decoder equipped with attention mechanism, in training,
        # target embeding (the groudtruth) is the data input,
        # while encoded source sequence is accessed to as an unbounded memory.
        # Here, the StaticInput defines a read-only memory
        # for the recurrent_group.
165 166 167 168
        decoder = recurrent_group(
            name=decoder_group_name,
            step=gru_decoder_with_attention,
            input=group_inputs)
Z
zhangjinchao01 已提交
169

170
        lbl = data_layer(name='target_language_next_word', size=target_dict_dim)
171
        cost = classification_cost(input=decoder, label=lbl)
Z
zhangjinchao01 已提交
172 173
        outputs(cost)
    else:
174
        # In generation, the decoder predicts a next target word based on
Z
zhangjinchao01 已提交
175
        # the encoded source sequence and the last generated target word.
176

Z
zhangjinchao01 已提交
177
        # The encoded source sequence (encoder's output) must be specified by
178 179 180 181 182
        # StaticInput, which is a read-only memory.
        # Embedding of the last generated word is automatically gotten by
        # GeneratedInputs, which is initialized by a start mark, such as <s>,
        # and must be included in generation.

Z
zhangjinchao01 已提交
183 184 185 186
        trg_embedding = GeneratedInput(
            size=target_dict_dim,
            embedding_name='_target_language_embedding',
            embedding_size=word_vector_dim)
187 188
        group_inputs.append(trg_embedding)

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        beam_gen = beam_search(
            name=decoder_group_name,
            step=gru_decoder_with_attention,
            input=group_inputs,
            bos_id=0,
            eos_id=1,
            beam_size=beam_size,
            max_length=max_length)

        seqtext_printer_evaluator(
            input=beam_gen,
            id_input=data_layer(
                name="sent_id", size=1),
            dict_file=trg_dict_path,
            result_file=gen_trans_file)
Z
zhangjinchao01 已提交
204
        outputs(beam_gen)