pyDataProvider.py 4.6 KB
Newer Older
Y
ying 已提交
1
#  Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2
#
Y
ying 已提交
3 4 5
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
Z
zhangjinchao01 已提交
6
#
Y
ying 已提交
7
#    http://www.apache.org/licenses/LICENSE-2.0
Z
zhangjinchao01 已提交
8
#
Y
ying 已提交
9 10 11 12 13
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
Z
zhangjinchao01 已提交
14 15 16 17
import numpy
import struct
import traceback

18

Z
zhangjinchao01 已提交
19 20
def header_creator():
    ret = ""
21 22 23 24 25 26 27 28
    ret += struct.pack('i', 3)  # slot num
    ret += struct.pack('i', 1)  # sequence flag
    ret += struct.pack('i', 0)  # slot0 dense type
    ret += struct.pack('i', 3)  # slot0 dim
    ret += struct.pack('i', 1)  # slot1 sparse non value type
    ret += struct.pack('i', 7)  # slot1 dim
    ret += struct.pack('i', 3)  # slot2 index type
    ret += struct.pack('i', 2)  # slot2 dim
Z
zhangjinchao01 已提交
29 30
    return ret

31

Z
zhangjinchao01 已提交
32 33
def dense_value_creator(sample_num):
    ret = ""
34 35
    ret += struct.pack('i', sample_num)  # slot0 sample num
    for i in range(sample_num):  # slot0 value
Z
zhangjinchao01 已提交
36 37 38 39 40
        ret += struct.pack('f', 1.0)
        ret += struct.pack('f', 2.0)
        ret += struct.pack('f', 3.0)
    return ret

41

Z
zhangjinchao01 已提交
42 43
def sparse_value_creator(sample_num):
    ret = ""
44 45
    ret += struct.pack('i', sample_num)  # slot1 sample num
    for i in range(sample_num):  # slot1 index
Z
zhangjinchao01 已提交
46
        ret += struct.pack('i', i * 2)
47 48
    ret += struct.pack('i', sample_num * 2)  #slot1 length
    for i in range(sample_num):  # slot1 value
Z
zhangjinchao01 已提交
49 50 51 52
        ret += struct.pack('i', 1)
        ret += struct.pack('i', 2)
    return ret

53

Z
zhangjinchao01 已提交
54 55
def index_value_creator(sample_num):
    ret = ""
56 57
    ret += struct.pack('i', sample_num)  # slot2 sample num
    for i in range(sample_num):  # slot2 value
Z
zhangjinchao01 已提交
58 59 60
        ret += struct.pack('i', 0)
    return ret

61

Z
zhangjinchao01 已提交
62 63
def sequenceStartPositions_creator():
    ret = ""
64 65 66 67 68 69 70 71
    ret += struct.pack('i', 2)  # slot0 sequence num
    ret += struct.pack('i', 0)  # slot0 sequence value1
    ret += struct.pack('i', 1)  # slot0 sequence value2
    ret += struct.pack('i', 1)  # slot1 sequence num
    ret += struct.pack('i', 0)  # slot1 sequence value1
    ret += struct.pack('i', 2)  # slot2 sequence num
    ret += struct.pack('i', 0)  # slot2 sequence value1
    ret += struct.pack('i', 1)  # slot2 sequence value2
Z
zhangjinchao01 已提交
72 73
    return ret

74

Z
zhangjinchao01 已提交
75 76
def subSequenceStartPositions_creator():
    ret = ""
77 78 79 80 81 82 83 84 85 86 87
    ret += struct.pack('i', 3)  # slot0 subsequence num
    ret += struct.pack('i', 0)  # slot0 subsequence value1
    ret += struct.pack('i', 1)  # slot0 subsequence value2
    ret += struct.pack('i', 2)  # slot0 subsequence value3
    ret += struct.pack('i', 2)  # slot1 subsequence num
    ret += struct.pack('i', 0)  # slot1 subsequence value1
    ret += struct.pack('i', 1)  # slot1 subsequence value2
    ret += struct.pack('i', 3)  # slot2 subsequence num
    ret += struct.pack('i', 0)  # slot2 subsequence value1
    ret += struct.pack('i', 1)  # slot2 subsequence value2
    ret += struct.pack('i', 2)  # slot2 subsequence value3
Z
zhangjinchao01 已提交
88 89
    return ret

90

Z
zhangjinchao01 已提交
91 92 93 94 95 96 97 98 99 100 101
class SimpleDataProvider:
    def __init__(self, *file_list):
        self.file_list = file_list

    def shuffle(self):
        pass

    def reset(self):
        pass

    def getHeader(self):
102
        return header_creator()
Z
zhangjinchao01 已提交
103 104 105

    def getNextBatch(self, batch_size):
        ret = ""
106 107 108 109
        ret += struct.pack('i', 2)  # batch size
        ret += dense_value_creator(2)  # slot0
        ret += sparse_value_creator(2)  # slot1
        ret += index_value_creator(2)  # slot2
Z
zhangjinchao01 已提交
110 111 112
        ret += sequenceStartPositions_creator()
        return ret

113

Z
zhangjinchao01 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
class SimpleNestDataProvider:
    def __init__(self, *file_list):
        self.file_list = file_list

    def shuffle(self):
        pass

    def reset(self):
        pass

    def getHeader(self):
        return header_creator()

    def getNextBatch(self, batch_size):
        ret = ""
129 130 131 132
        ret += struct.pack('i', 2)  # batch size
        ret += dense_value_creator(4)  # slot0
        ret += sparse_value_creator(4)  # slot1
        ret += index_value_creator(4)  # slot2
Z
zhangjinchao01 已提交
133 134 135 136
        ret += sequenceStartPositions_creator()
        ret += subSequenceStartPositions_creator()
        return ret

137

Z
zhangjinchao01 已提交
138 139 140 141 142 143 144 145 146
if __name__ == "__main__":
    # test code
    data_provider = SimpleDataProvider('./test_batch')
    print len(data_provider.getHeader())
    print len(data_provider.getNextBatch(2))

    data_provider = SimpleNestDataProvider('./test_batch')
    print len(data_provider.getHeader())
    print len(data_provider.getNextBatch(2))