dataprovider.py 3.3 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *

UNK_IDX = 2
START = "<s>"
END = "<e>"


L
Luo Tao 已提交
22 23
def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list,
         **kwargs):
Z
zhangjinchao01 已提交
24 25
    # job_mode = 1: training mode
    # job_mode = 0: generating mode
L
Luo Tao 已提交
26
    settings.job_mode = not is_generating
L
Luo Tao 已提交
27 28 29 30 31 32 33 34 35 36 37 38

    def fun(dict_path):
        out_dict = dict()
        with open(dict_path, "r") as fin:
            out_dict = {
                line.strip(): line_count
                for line_count, line in enumerate(fin)
            }
        return out_dict

    settings.src_dict = fun(src_dict_path)
    settings.trg_dict = fun(trg_dict_path)
L
Luo Tao 已提交
39

Z
zhangjinchao01 已提交
40 41 42
    settings.logger.info("src dict len : %d" % (len(settings.src_dict)))

    if settings.job_mode:
L
Luo Tao 已提交
43 44
        settings.slots = {
            'source_language_word':
45
            integer_value_sequence(len(settings.src_dict)),
L
Luo Tao 已提交
46
            'target_language_word':
47
            integer_value_sequence(len(settings.trg_dict)),
L
Luo Tao 已提交
48
            'target_language_next_word':
49
            integer_value_sequence(len(settings.trg_dict))
L
Luo Tao 已提交
50
        }
Z
zhangjinchao01 已提交
51 52
        settings.logger.info("trg dict len : %d" % (len(settings.trg_dict)))
    else:
L
Luo Tao 已提交
53 54
        settings.slots = {
            'source_language_word':
55
            integer_value_sequence(len(settings.src_dict)),
L
Luo Tao 已提交
56
            'sent_id':
57
            integer_value_sequence(len(open(file_list[0], "r").readlines()))
L
Luo Tao 已提交
58
        }
Z
zhangjinchao01 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80


def _get_ids(s, dictionary):
    words = s.strip().split()
    return [dictionary[START]] + \
           [dictionary.get(w, UNK_IDX) for w in words] + \
           [dictionary[END]]


@provider(init_hook=hook, pool_size=50000)
def process(settings, file_name):
    with open(file_name, 'r') as f:
        for line_count, line in enumerate(f):
            line_split = line.strip().split('\t')
            if settings.job_mode and len(line_split) != 2:
                continue
            src_seq = line_split[0]  # one source sequence
            src_ids = _get_ids(src_seq, settings.src_dict)

            if settings.job_mode:
                trg_seq = line_split[1]  # one target sequence
                trg_words = trg_seq.split()
81
                trg_ids = [settings.trg_dict.get(w, UNK_IDX) for w in trg_words]
Z
zhangjinchao01 已提交
82 83 84 85 86 87

                # remove sequence whose length > 80 in training mode
                if len(src_ids) > 80 or len(trg_ids) > 80:
                    continue
                trg_ids_next = trg_ids + [settings.trg_dict[END]]
                trg_ids = [settings.trg_dict[START]] + trg_ids
L
Luo Tao 已提交
88 89 90 91 92
                yield {
                    'source_language_word': src_ids,
                    'target_language_word': trg_ids,
                    'target_language_next_word': trg_ids_next
                }
Z
zhangjinchao01 已提交
93
            else:
L
Luo Tao 已提交
94
                yield {'source_language_word': src_ids, 'sent_id': [line_count]}