swicher.py 4.0 KB
Newer Older
xiebaiyuan's avatar
xiebaiyuan 已提交
1 2
import os
import shutil
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
from array import array


class Swichter:
    def __init__(self):
        pass

    def nhwc2nchw_one_slice(self, from_file_name, to_file_name, batch, channel, height, width):
        from_file = open(from_file_name, "rb")
        to_file = open(to_file_name, "wb")

        float_array = array("f")
        float_array.fromfile(from_file, width * height * batch * channel)
        float_write_array = array("f")

        for b in range(batch):
            for c in range(channel):
                for h in range(height):
                    for w in range(width):
                        float_value = float_array[b * channel * width * height
                                                  + channel * (h * width + w) + c]

                        float_write_array.append(float_value)

        float_write_array.tofile(to_file)
        from_file.close()
        to_file.close()

    def copy(self, from_file_name, to_file_name):
        from_file = open(from_file_name, "rb")
        to_file = open(to_file_name, "wb")

        to_file.write(from_file.read())
        from_file.close()
        to_file.close()

    def nhwc2nchw_one_slice_add_head(self, from_file_name, to_file_name, tmp_file_name, batch, channel, height, width):
        from_file = open(from_file_name, "rb")
        tmp_file = open(tmp_file_name, "wb+")
        float_array = array("f")
        float_array.fromfile(from_file, width * height * batch * channel)
        float_write_array = array("f")

        for b in range(batch):
            for c in range(channel):
                for h in range(height):
                    for w in range(width):
                        float_value = float_array[b * channel * width * height
                                                  + channel * (h * width + w) + c]

                        float_write_array.append(float_value)

        float_write_array.tofile(tmp_file)
        tmp_file.close()
        from_file.close()

        tmp_file = open(tmp_file_name, "rb")
        to_file = open(to_file_name, "wb")

        tmp = tmp_file.read()
xiebaiyuan's avatar
xiebaiyuan 已提交
63
        head = self.read_head('yolo/datas/yolo/head')
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
64 65 66 67 68 69 70
        to_file.write(head)
        to_file.write(tmp)
        tmp_file.close()
        to_file.close()

    def read_head(self, head_file):
        from_file = open(head_file, "rb")
xiebaiyuan's avatar
xiebaiyuan 已提交
71
        read = from_file.read(24)
xiebaiyuan's avatar
xiebaiyuan 已提交
72
        # print read
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
73
        from_file.close()
xiebaiyuan's avatar
xiebaiyuan 已提交
74
        # print read
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
75 76
        return read

xiebaiyuan's avatar
xiebaiyuan 已提交
77 78
    def copy_add_head(self, from_file_name, to_file_name):

xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
79 80 81 82
        from_file = open(from_file_name, "rb")
        to_file = open(to_file_name, "wb")
        # tmp_file = open(tmp_file_name, "wb")

xiebaiyuan's avatar
xiebaiyuan 已提交
83 84
        head = self.read_head(
            '/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/modeltools/mobilenet/datas/sourcemodels/head/head')
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
85 86 87 88 89 90
        to_file.write(head)
        to_file.write(from_file.read())
        from_file.close()
        to_file.close()
        pass

xiebaiyuan's avatar
xiebaiyuan 已提交
91 92 93 94 95 96 97 98 99 100 101 102
    def copy_padding_add_head(self, from_file_name, to_file_name, tmp_file_name, padding):
        print'padding  = %d' % padding
        from_file = open(from_file_name, "rb")
        # print len(from_file.read())
        from_file.seek(padding, 0)

        read = from_file.read()
        print len(read)

        to_file = open(to_file_name, "wb")
        # tmp_file = open(tmp_file_name, "wb")

xiebaiyuan's avatar
xiebaiyuan 已提交
103
        head = self.read_head('yolo/datas/yolo/head')
xiebaiyuan's avatar
xiebaiyuan 已提交
104 105 106 107 108 109 110
        to_file.write(head)
        to_file.write(read)
        from_file.close()
        to_file.close()
        pass

# Swichter().nhwc2nchw_one_slice_add_head(
xiebaiyuan's avatar
xiebaiyuan 已提交
111 112 113
#     '/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/multiobjects/float32s_nhwc/conv1_0.bin',
#     '/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/multiobjects/float32s_nchw_with_head/conv1_0',
#     '/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/multiobjects/float32s_nchw/.tmp',
xiebaiyuan's avatar
xiebaiyuan 已提交
114 115 116
#     32,
#     3, 3, 3)

xiebaiyuan's avatar
xiebaiyuan 已提交
117
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/head')
xiebaiyuan's avatar
convert  
xiebaiyuan 已提交
118

xiebaiyuan's avatar
xiebaiyuan 已提交
119
# Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '')