pyDataProvider.py 4.6 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
import struct
import traceback

19

Z
zhangjinchao01 已提交
20 21
def header_creator():
    ret = ""
22 23 24 25 26 27 28 29
    ret += struct.pack('i', 3)  # slot num
    ret += struct.pack('i', 1)  # sequence flag
    ret += struct.pack('i', 0)  # slot0 dense type
    ret += struct.pack('i', 3)  # slot0 dim
    ret += struct.pack('i', 1)  # slot1 sparse non value type
    ret += struct.pack('i', 7)  # slot1 dim
    ret += struct.pack('i', 3)  # slot2 index type
    ret += struct.pack('i', 2)  # slot2 dim
Z
zhangjinchao01 已提交
30 31
    return ret

32

Z
zhangjinchao01 已提交
33 34
def dense_value_creator(sample_num):
    ret = ""
35 36
    ret += struct.pack('i', sample_num)  # slot0 sample num
    for i in range(sample_num):  # slot0 value
Z
zhangjinchao01 已提交
37 38 39 40 41
        ret += struct.pack('f', 1.0)
        ret += struct.pack('f', 2.0)
        ret += struct.pack('f', 3.0)
    return ret

42

Z
zhangjinchao01 已提交
43 44
def sparse_value_creator(sample_num):
    ret = ""
45 46
    ret += struct.pack('i', sample_num)  # slot1 sample num
    for i in range(sample_num):  # slot1 index
Z
zhangjinchao01 已提交
47
        ret += struct.pack('i', i * 2)
48 49
    ret += struct.pack('i', sample_num * 2)  #slot1 length
    for i in range(sample_num):  # slot1 value
Z
zhangjinchao01 已提交
50 51 52 53
        ret += struct.pack('i', 1)
        ret += struct.pack('i', 2)
    return ret

54

Z
zhangjinchao01 已提交
55 56
def index_value_creator(sample_num):
    ret = ""
57 58
    ret += struct.pack('i', sample_num)  # slot2 sample num
    for i in range(sample_num):  # slot2 value
Z
zhangjinchao01 已提交
59 60 61
        ret += struct.pack('i', 0)
    return ret

62

Z
zhangjinchao01 已提交
63 64
def sequenceStartPositions_creator():
    ret = ""
65 66 67 68 69 70 71 72
    ret += struct.pack('i', 2)  # slot0 sequence num
    ret += struct.pack('i', 0)  # slot0 sequence value1
    ret += struct.pack('i', 1)  # slot0 sequence value2
    ret += struct.pack('i', 1)  # slot1 sequence num
    ret += struct.pack('i', 0)  # slot1 sequence value1
    ret += struct.pack('i', 2)  # slot2 sequence num
    ret += struct.pack('i', 0)  # slot2 sequence value1
    ret += struct.pack('i', 1)  # slot2 sequence value2
Z
zhangjinchao01 已提交
73 74
    return ret

75

Z
zhangjinchao01 已提交
76 77
def subSequenceStartPositions_creator():
    ret = ""
78 79 80 81 82 83 84 85 86 87 88
    ret += struct.pack('i', 3)  # slot0 subsequence num
    ret += struct.pack('i', 0)  # slot0 subsequence value1
    ret += struct.pack('i', 1)  # slot0 subsequence value2
    ret += struct.pack('i', 2)  # slot0 subsequence value3
    ret += struct.pack('i', 2)  # slot1 subsequence num
    ret += struct.pack('i', 0)  # slot1 subsequence value1
    ret += struct.pack('i', 1)  # slot1 subsequence value2
    ret += struct.pack('i', 3)  # slot2 subsequence num
    ret += struct.pack('i', 0)  # slot2 subsequence value1
    ret += struct.pack('i', 1)  # slot2 subsequence value2
    ret += struct.pack('i', 2)  # slot2 subsequence value3
Z
zhangjinchao01 已提交
89 90
    return ret

91

Z
zhangjinchao01 已提交
92 93 94 95 96 97 98 99 100 101 102
class SimpleDataProvider:
    def __init__(self, *file_list):
        self.file_list = file_list

    def shuffle(self):
        pass

    def reset(self):
        pass

    def getHeader(self):
103
        return header_creator()
Z
zhangjinchao01 已提交
104 105 106

    def getNextBatch(self, batch_size):
        ret = ""
107 108 109 110
        ret += struct.pack('i', 2)  # batch size
        ret += dense_value_creator(2)  # slot0
        ret += sparse_value_creator(2)  # slot1
        ret += index_value_creator(2)  # slot2
Z
zhangjinchao01 已提交
111 112 113
        ret += sequenceStartPositions_creator()
        return ret

114

Z
zhangjinchao01 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
class SimpleNestDataProvider:
    def __init__(self, *file_list):
        self.file_list = file_list

    def shuffle(self):
        pass

    def reset(self):
        pass

    def getHeader(self):
        return header_creator()

    def getNextBatch(self, batch_size):
        ret = ""
130 131 132 133
        ret += struct.pack('i', 2)  # batch size
        ret += dense_value_creator(4)  # slot0
        ret += sparse_value_creator(4)  # slot1
        ret += index_value_creator(4)  # slot2
Z
zhangjinchao01 已提交
134 135 136 137
        ret += sequenceStartPositions_creator()
        ret += subSequenceStartPositions_creator()
        return ret

138

Z
zhangjinchao01 已提交
139 140 141 142 143 144 145 146 147
if __name__ == "__main__":
    # test code
    data_provider = SimpleDataProvider('./test_batch')
    print len(data_provider.getHeader())
    print len(data_provider.getNextBatch(2))

    data_provider = SimpleNestDataProvider('./test_batch')
    print len(data_provider.getHeader())
    print len(data_provider.getNextBatch(2))