dataprovider.py 3.4 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *

UNK_IDX = 2
START = "<s>"
END = "<e>"


L
Luo Tao 已提交
22 23
def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list,
         **kwargs):
Z
zhangjinchao01 已提交
24 25
    # job_mode = 1: training mode
    # job_mode = 0: generating mode
L
Luo Tao 已提交
26 27
    settings.job_mode = not is_generating
    settings.src_dict = dict()
28 29 30 31 32
    with open(src_dict_path, "r") as fin:
        settings.src_dict = {
            line.strip(): line_count
            for line_count, line in enumerate(fin)
        }
L
Luo Tao 已提交
33
    settings.trg_dict = dict()
34 35 36 37 38
    with open(trg_dict_path, "r") as fin:
        settings.trg_dict = {
            line.strip(): line_count
            for line_count, line in enumerate(fin)
        }
L
Luo Tao 已提交
39

Z
zhangjinchao01 已提交
40 41 42 43
    settings.logger.info("src dict len : %d" % (len(settings.src_dict)))
    settings.sample_count = 0

    if settings.job_mode:
L
Luo Tao 已提交
44 45
        settings.slots = {
            'source_language_word':
46
            integer_value_sequence(len(settings.src_dict)),
L
Luo Tao 已提交
47
            'target_language_word':
48
            integer_value_sequence(len(settings.trg_dict)),
L
Luo Tao 已提交
49
            'target_language_next_word':
50
            integer_value_sequence(len(settings.trg_dict))
L
Luo Tao 已提交
51
        }
Z
zhangjinchao01 已提交
52 53
        settings.logger.info("trg dict len : %d" % (len(settings.trg_dict)))
    else:
L
Luo Tao 已提交
54 55
        settings.slots = {
            'source_language_word':
56
            integer_value_sequence(len(settings.src_dict)),
L
Luo Tao 已提交
57
            'sent_id':
58
            integer_value_sequence(len(open(file_list[0], "r").readlines()))
L
Luo Tao 已提交
59
        }
Z
zhangjinchao01 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81


def _get_ids(s, dictionary):
    words = s.strip().split()
    return [dictionary[START]] + \
           [dictionary.get(w, UNK_IDX) for w in words] + \
           [dictionary[END]]


@provider(init_hook=hook, pool_size=50000)
def process(settings, file_name):
    with open(file_name, 'r') as f:
        for line_count, line in enumerate(f):
            line_split = line.strip().split('\t')
            if settings.job_mode and len(line_split) != 2:
                continue
            src_seq = line_split[0]  # one source sequence
            src_ids = _get_ids(src_seq, settings.src_dict)

            if settings.job_mode:
                trg_seq = line_split[1]  # one target sequence
                trg_words = trg_seq.split()
82
                trg_ids = [settings.trg_dict.get(w, UNK_IDX) for w in trg_words]
Z
zhangjinchao01 已提交
83 84 85 86 87 88

                # remove sequence whose length > 80 in training mode
                if len(src_ids) > 80 or len(trg_ids) > 80:
                    continue
                trg_ids_next = trg_ids + [settings.trg_dict[END]]
                trg_ids = [settings.trg_dict[START]] + trg_ids
L
Luo Tao 已提交
89 90 91 92 93
                yield {
                    'source_language_word': src_ids,
                    'target_language_word': trg_ids,
                    'target_language_next_word': trg_ids_next
                }
Z
zhangjinchao01 已提交
94
            else:
L
Luo Tao 已提交
95
                yield {'source_language_word': src_ids, 'sent_id': [line_count]}