sequenceGen.py 2.3 KB
Newer Older
1
#  Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zhangjinchao01 已提交
2
#
Y
ying 已提交
3 4 5
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
Z
zhangjinchao01 已提交
6
#
Y
ying 已提交
7
#    http://www.apache.org/licenses/LICENSE-2.0
Z
zhangjinchao01 已提交
8
#
Y
ying 已提交
9 10 11 12 13
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
Z
zhangjinchao01 已提交
14 15 16
import os
import sys

17
from paddle.trainer.PyDataProvider2 import *
Z
zhangjinchao01 已提交
18

19

20 21
def hook(settings, dict_file, **kwargs):
    settings.word_dict = dict_file
22 23 24
    settings.input_types = [
        integer_value_sequence(len(settings.word_dict)), integer_value(3)
    ]
25
    settings.logger.info('dict len : %d' % (len(settings.word_dict)))
Z
zhangjinchao01 已提交
26

27 28

@provider(init_hook=hook, should_shuffle=False)
29
def process(settings, file_name):
Z
zhangjinchao01 已提交
30 31 32 33 34
    with open(file_name, 'r') as fdata:
        for line in fdata:
            label, comment = line.strip().split('\t')
            label = int(''.join(label.split()))
            words = comment.split()
Y
Yu Yang 已提交
35
            words = [
36 37
                settings.word_dict[w] for w in words if w in settings.word_dict
            ]
Y
Yu Yang 已提交
38
            yield words, label
Z
zhangjinchao01 已提交
39

40

Z
zhangjinchao01 已提交
41
## for hierarchical sequence network
42 43
def hook2(settings, dict_file, **kwargs):
    settings.word_dict = dict_file
44 45 46 47
    settings.input_types = [
        integer_value_sub_sequence(len(settings.word_dict)),
        integer_value_sequence(3)
    ]
48 49
    settings.logger.info('dict len : %d' % (len(settings.word_dict)))

50 51

@provider(init_hook=hook2, should_shuffle=False)
52
def process2(settings, file_name):
Z
zhangjinchao01 已提交
53
    with open(file_name) as fdata:
Y
Yu Yang 已提交
54 55
        labels = []
        sentences = []
Z
zhangjinchao01 已提交
56 57
        for line in fdata:
            if (len(line)) > 1:
58
                label, comment = line.strip().split('\t')
Z
zhangjinchao01 已提交
59 60
                label = int(''.join(label.split()))
                words = comment.split()
Y
Yu Yang 已提交
61
                words = [
62 63 64
                    settings.word_dict[w] for w in words
                    if w in settings.word_dict
                ]
Y
Yu Yang 已提交
65 66
                labels.append(label)
                sentences.append(words)
Z
zhangjinchao01 已提交
67
            else:
Y
Yu Yang 已提交
68 69 70
                yield sentences, labels
                labels = []
                sentences = []