ckpt_inspector.py 3.9 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import struct
import logging
import argparse
import numpy as np
import collections
from distutils import dir_util
import pickle

#from utils import print_arguments 
import paddle.fluid as F
from paddle.fluid.proto import framework_pb2

log = logging.getLogger(__name__)
formatter = logging.Formatter(
    fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
console = logging.StreamHandler()
console.setFormatter(formatter)
log.addHandler(console)
log.setLevel(logging.DEBUG)


def gen_arr(data, dtype):
    num = len(data) // struct.calcsize(dtype)
    arr = struct.unpack('%d%s' % (num, dtype), data)
    return arr


def parse(filename):
    with open(filename, 'rb') as f:
        read = lambda fmt: struct.unpack(fmt, f.read(struct.calcsize(fmt)))
        _, = read('I')  # version
        lodsize, = read('Q')
        if lodsize != 0:
            log.warning('shit, it is LOD tensor!!! skipped!!')
            return None
        _, = read('I')  # version
        pbsize, = read('i')
        data = f.read(pbsize)
        proto = framework_pb2.VarType.TensorDesc()
        proto.ParseFromString(data)
        log.info('type: [%s] dim %s' % (proto.data_type, proto.dims))
        if proto.data_type == framework_pb2.VarType.FP32:
            arr = np.array(
                gen_arr(f.read(), 'f'), dtype=np.float32).reshape(proto.dims)
        elif proto.data_type == framework_pb2.VarType.INT64:
            arr = np.array(
                gen_arr(f.read(), 'q'), dtype=np.int64).reshape(proto.dims)
        elif proto.data_type == framework_pb2.VarType.INT32:
            arr = np.array(
                gen_arr(f.read(), 'i'), dtype=np.int32).reshape(proto.dims)
        elif proto.data_type == framework_pb2.VarType.INT8:
            arr = np.array(
                gen_arr(f.read(), 'B'), dtype=np.int8).reshape(proto.dims)
        else:
            raise RuntimeError('Unknown dtype %s' % proto.data_type)

        return arr


def show(arr):
    print(repr(arr))


def dump(arr, path):
    path = os.path.join(args.to, path)
    log.info('dump to %s' % path)
    try:
        os.makedirs(os.path.dirname(path))
    except FileExistsError:
        pass
    pickle.dump(arr, open(path, 'wb'), protocol=4)


def list_dir(dir_or_file):
    if os.path.isfile(dir_or_file):
        return [dir_or_file]
    else:
        return [
            os.path.join(i, kk) for i, _, k in os.walk(dir_or_file) for kk in k
        ]


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', choices=['show', 'dump'], type=str)
    parser.add_argument('file_or_dir', type=str)
    parser.add_argument('-t', "--to", type=str, default=None)
    parser.add_argument('-v', "--verbose", action='store_true')
    args = parser.parse_args()

    files = list_dir(args.file_or_dir)
    parsed_arr = map(parse, files)
    if args.mode == 'show':
        for arr in parsed_arr:
            if arr is not None:
                show(arr)
    elif args.mode == 'dump':
        if args.to is None:
            raise ValueError('--to dir_name not specified')
        for arr, path in zip(parsed_arr, files):
            if arr is not None:
                dump(arr, path.replace(args.file_or_dir, ''))