swicher.py 3.9 KB
Newer Older
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
from array import array


class Swichter:
    def __init__(self):
        pass

    def nhwc2nchw_one_slice(self, from_file_name, to_file_name, batch, channel, height, width):
        from_file = open(from_file_name, "rb")
        to_file = open(to_file_name, "wb")

        float_array = array("f")
        float_array.fromfile(from_file, width * height * batch * channel)
        float_write_array = array("f")

        for b in range(batch):
            for c in range(channel):
                for h in range(height):
                    for w in range(width):
                        float_value = float_array[b * channel * width * height
                                                  + channel * (h * width + w) + c]

                        float_write_array.append(float_value)

        float_write_array.tofile(to_file)
        from_file.close()
        to_file.close()

    def copy(self, from_file_name, to_file_name):
        from_file = open(from_file_name, "rb")
        to_file = open(to_file_name, "wb")

        to_file.write(from_file.read())
        from_file.close()
        to_file.close()

    def nhwc2nchw_one_slice_add_head(self, from_file_name, to_file_name, tmp_file_name, batch, channel, height, width):
        from_file = open(from_file_name, "rb")
        tmp_file = open(tmp_file_name, "wb+")
        float_array = array("f")
        float_array.fromfile(from_file, width * height * batch * channel)
        float_write_array = array("f")

        for b in range(batch):
            for c in range(channel):
                for h in range(height):
                    for w in range(width):
                        float_value = float_array[b * channel * width * height
                                                  + channel * (h * width + w) + c]

                        float_write_array.append(float_value)

        float_write_array.tofile(tmp_file)
        tmp_file.close()
        from_file.close()

        tmp_file = open(tmp_file_name, "rb")
        to_file = open(to_file_name, "wb")

        tmp = tmp_file.read()
xiebaiyuan's avatar
xiebaiyuan 已提交
61
        head = self.read_head('yolo/datas/yolo/conv1_biases')
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
62 63 64 65 66 67 68
        to_file.write(head)
        to_file.write(tmp)
        tmp_file.close()
        to_file.close()

    def read_head(self, head_file):
        from_file = open(head_file, "rb")
xiebaiyuan's avatar
xiebaiyuan 已提交
69
        read = from_file.read(24)
xiebaiyuan's avatar
xiebaiyuan 已提交
70
        # print read
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
71
        from_file.close()
xiebaiyuan's avatar
xiebaiyuan 已提交
72
        # print read
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
73 74 75 76 77 78 79
        return read

    def copy_add_head(self, from_file_name, to_file_name, tmp_file_name):
        from_file = open(from_file_name, "rb")
        to_file = open(to_file_name, "wb")
        # tmp_file = open(tmp_file_name, "wb")

xiebaiyuan's avatar
xiebaiyuan 已提交
80
        head = self.read_head('yolo/datas/yolo/conv1_biases')
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
81 82 83 84 85 86
        to_file.write(head)
        to_file.write(from_file.read())
        from_file.close()
        to_file.close()
        pass

xiebaiyuan's avatar
xiebaiyuan 已提交
87 88 89 90 91 92 93 94 95 96 97 98
    def copy_padding_add_head(self, from_file_name, to_file_name, tmp_file_name, padding):
        print'padding  = %d' % padding
        from_file = open(from_file_name, "rb")
        # print len(from_file.read())
        from_file.seek(padding, 0)

        read = from_file.read()
        print len(read)

        to_file = open(to_file_name, "wb")
        # tmp_file = open(tmp_file_name, "wb")

xiebaiyuan's avatar
xiebaiyuan 已提交
99
        head = self.read_head('yolo/datas/yolo/conv1_biases')
xiebaiyuan's avatar
xiebaiyuan 已提交
100 101 102 103 104 105 106
        to_file.write(head)
        to_file.write(read)
        from_file.close()
        to_file.close()
        pass

# Swichter().nhwc2nchw_one_slice_add_head(
xiebaiyuan's avatar
xiebaiyuan 已提交
107 108 109
#     '/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/multiobjects/float32s_nhwc/conv1_0.bin',
#     '/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/multiobjects/float32s_nchw_with_head/conv1_0',
#     '/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/multiobjects/float32s_nchw/.tmp',
xiebaiyuan's avatar
xiebaiyuan 已提交
110 111 112
#     32,
#     3, 3, 3)

xiebaiyuan's avatar
xiebaiyuan 已提交
113
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/conv1_biases')
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
114

xiebaiyuan's avatar
xiebaiyuan 已提交
115
# Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '')