You need to sign in or sign up before continuing.
rnn_data_provider.py 3.2 KB
Newer Older
Y
ying 已提交
1
#  Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2
#
Y
ying 已提交
3 4 5
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
6
#
Y
ying 已提交
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
Y
ying 已提交
9 10 11 12 13
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14 15
from paddle.trainer.PyDataProvider2 import *

16 17 18
# Note that each config should has an independent provider
# in current design of PyDataProvider2.
#######################################################
19 20 21 22 23
data = [
    [[[1, 3, 2], [4, 5, 2]], 0],
    [[[0, 2], [2, 5], [0, 1, 2]], 1],
]

24

25
# Used for sequence_nest_rnn.conf
26 27 28
@provider(
    input_types=[integer_value_sub_sequence(10), integer_value(3)],
    should_shuffle=False)
29 30 31 32
def process_subseq(settings, file_name):
    for d in data:
        yield d

33

34
# Used for sequence_rnn.conf
35 36 37
@provider(
    input_types=[integer_value_sequence(10), integer_value(3)],
    should_shuffle=False)
38 39 40 41 42 43
def process_seq(settings, file_name):
    for d in data:
        seq = []
        for subseq in d[0]:
            seq += subseq
        yield seq, d[1]
44

45

46
# Used for sequence_nest_rnn_multi_input.conf
47 48 49
@provider(
    input_types=[integer_value_sub_sequence(10), integer_value(3)],
    should_shuffle=False)
50 51 52 53
def process_subseq2(settings, file_name):
    for d in data:
        yield d

54

55
# Used for sequence_rnn_multi_input.conf
56 57 58
@provider(
    input_types=[integer_value_sequence(10), integer_value(3)],
    should_shuffle=False)
59 60 61 62 63 64 65
def process_seq2(settings, file_name):
    for d in data:
        seq = []
        for subseq in d[0]:
            seq += subseq
        yield seq, d[1]

66

67
###########################################################
68
data2 = [
69 70
    [[[1, 2], [4, 5, 2]], [[5, 4, 1], [3, 1]], 0],
    [[[0, 2], [2, 5], [0, 1, 2]], [[1, 5], [4], [2, 3, 6, 1]], 1],
71 72
]

73

74
# Used for sequence_nest_rnn_multi_unequalength_inputs.conf
75 76 77 78 79 80
@provider(
    input_types=[
        integer_value_sub_sequence(10), integer_value_sub_sequence(10),
        integer_value(2)
    ],
    should_shuffle=False)
81 82 83 84 85
def process_unequalength_subseq(settings, file_name):
    for d in data2:
        yield d


86
# Used for sequence_rnn_multi_unequalength_inputs.conf
87 88 89 90 91
@provider(
    input_types=[
        integer_value_sequence(10), integer_value_sequence(10), integer_value(2)
    ],
    should_shuffle=False)
92 93
def process_unequalength_seq(settings, file_name):
    for d in data2:
94 95
        words1 = reduce(lambda x, y: x + y, d[0])
        words2 = reduce(lambda x, y: x + y, d[1])
96
        yield words1, words2, d[2]
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115


###########################################################
data3 = [
    [[[1, 2], [4, 5, 2]], [1, 2], 0],
    [[[0, 2], [2, 5], [0, 1, 2]], [2, 3, 0], 1],
]


# Used for sequence_nest_mixed_inputs.conf
@provider(
    input_types=[
        integer_value_sub_sequence(10), integer_value_sequence(10),
        integer_value(2)
    ],
    should_shuffle=False)
def process_mixed(settings, file_name):
    for d in data3:
        yield d