From 53075cd3dadd9aad4da70ef7227fa9c3a5391f7b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 25 Feb 2021 11:17:34 +0800 Subject: [PATCH] feat(mge/experimental): add visualization and net stats for python graph GitOrigin-RevId: a1ab77c20aff8b9205fb3b34532e8f86a2733d69 --- imperative/python/megengine/tools/README.md | 8 + imperative/python/megengine/tools/__init__.py | 0 .../{utils => tools}/compare_binary_iodump.py | 56 ++++- .../megengine/tools/network_visualize.py | 176 ++++++++++++++ .../{utils => tools}/profile_analyze.py | 2 +- .../utils/{net_stats.py => module_stats.py} | 224 ++++++++++-------- imperative/python/megengine/utils/network.py | 6 +- .../python/megengine/utils/network_node.py | 45 +++- imperative/python/megengine/utils/plugin.py | 57 ----- 9 files changed, 402 insertions(+), 172 deletions(-) create mode 100644 imperative/python/megengine/tools/README.md create mode 100644 imperative/python/megengine/tools/__init__.py rename imperative/python/megengine/{utils => tools}/compare_binary_iodump.py (64%) create mode 100755 imperative/python/megengine/tools/network_visualize.py rename imperative/python/megengine/{utils => tools}/profile_analyze.py (99%) rename imperative/python/megengine/utils/{net_stats.py => module_stats.py} (58%) delete mode 100644 imperative/python/megengine/utils/plugin.py diff --git a/imperative/python/megengine/tools/README.md b/imperative/python/megengine/tools/README.md new file mode 100644 index 000000000..4e38fe4c0 --- /dev/null +++ b/imperative/python/megengine/tools/README.md @@ -0,0 +1,8 @@ +# MegEngine Tools + +This directory contains executable python files. +Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): + +``` +python -m megengine.tools.xxx +``` diff --git a/imperative/python/megengine/tools/__init__.py b/imperative/python/megengine/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/imperative/python/megengine/utils/compare_binary_iodump.py b/imperative/python/megengine/tools/compare_binary_iodump.py similarity index 64% rename from imperative/python/megengine/utils/compare_binary_iodump.py rename to imperative/python/megengine/tools/compare_binary_iodump.py index b6be0ced0..9a4ef87a5 100755 --- a/imperative/python/megengine/utils/compare_binary_iodump.py +++ b/imperative/python/megengine/tools/compare_binary_iodump.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -7,12 +8,55 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse import os +import struct import textwrap from pathlib import Path import numpy as np -from megengine.utils import plugin + +def load_tensor_binary(fobj): + """ + Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual + tensor value dump is implemented by ``mgb::debug::dump_tensor``. + + :param fobj: file object, or a string that contains the file name. + :return: tuple ``(tensor_value, tensor_name)``. + """ + if isinstance(fobj, str): + with open(fobj, "rb") as fin: + return load_tensor_binary(fin) + + DTYPE_LIST = { + 0: np.float32, + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + # 5: _mgb.intb1, + # 6: _mgb.intb2, + # 7: _mgb.intb4, + 8: None, + 9: np.float16, + # quantized dtype start from 100000 + # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in + # dnn/include/megdnn/dtype.h + 100000: np.uint8, + 100001: np.int32, + 100002: np.int8, + } + + header_fmt = struct.Struct("III") + name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) + assert ( + DTYPE_LIST[dtype] is not None + ), "Cannot load this tensor: dtype Byte is unsupported." + + shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) + while shape[-1] == 0: + shape.pop(-1) + name = fobj.read(name_len).decode("ascii") + return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name def check(v0, v1, name, max_err): @@ -26,9 +70,9 @@ def check(v0, v1, name, max_err): ) vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0) err = np.abs(v0 - v1) / vdiv - check = err > max_err - if check.sum(): - idx = tuple(i[0] for i in np.nonzero(check)) + rst = err > max_err + if rst.sum(): + idx = tuple(i[0] for i in np.nonzero(rst)) raise AssertionError( "{} not equal: " "shape={} nonequal_idx={} v0={} v1={} err={}".format( @@ -79,8 +123,8 @@ def main(): files1 = sorted(files1) for i, j in zip(files0, files1): - val0, name0 = plugin.load_tensor_binary(i) - val1, name1 = plugin.load_tensor_binary(j) + val0, name0 = load_tensor_binary(i) + val1, name1 = load_tensor_binary(j) name = "{}: \n{}\n{}\n".format( i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1)) ) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py new file mode 100755 index 000000000..181a30272 --- /dev/null +++ b/imperative/python/megengine/tools/network_visualize.py @@ -0,0 +1,176 @@ +#! /usr/bin/env python3 +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse + +import numpy as np + +from megengine.core.tensor.dtype import is_quantize +from megengine.logger import get_logger +from megengine.utils.module_stats import ( + print_flops_stats, + print_params_stats, + sizeof_fmt, +) +from megengine.utils.network import Network + +logger = get_logger(__name__) + + +def visualize( + model_path: str, + log_path: str, + bar_length_max: int = 20, + log_params: bool = True, + log_flops: bool = True, +): + r""" + Load megengine dumped model and visualize graph structure with tensorboard log files. + Can also record and print model's statistics like :func:`~.net_stats` + + :param model_path: dir path for megengine dumped model. + :param log_path: dir path for tensorboard graph log. + :param bar_length_max: size of bar indicating max flops or parameter size in net stats. + :param log_params: whether print and record params size. + :param log_flops: whether print and record op flops. + """ + try: + from tensorboard.compat.proto.attr_value_pb2 import AttrValue + from tensorboard.compat.proto.config_pb2 import RunMetadata + from tensorboard.compat.proto.graph_pb2 import GraphDef + from tensorboard.compat.proto.node_def_pb2 import NodeDef + from tensorboard.compat.proto.step_stats_pb2 import ( + AllocatorMemoryUsed, + DeviceStepStats, + NodeExecStats, + StepStats, + ) + from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto + from tensorboard.compat.proto.versions_pb2 import VersionDef + from tensorboardX import SummaryWriter + except ImportError: + logger.error( + "TensorBoard and TensorboardX are required for visualize.", exc_info=True + ) + return + + graph = Network.load(model_path) + writer = SummaryWriter(log_path) + + def process_name(name): + return name.replace(".", "/").encode(encoding="utf-8") + + node_list = [] + flops_list = [] + params_list = [] + for node in graph.all_oprs: + if hasattr(node, "output_idx"): + node_oup = node.outputs[node.output_idx] + else: + if len(node.outputs) != 1: + logger.warning( + "OpNode {} has more than one output and not has 'output_idx' attr.".format( + node + ) + ) + node_oup = node.outputs[0] + + inp_list = [process_name(var.owner.name) for var in node.inputs] + attr = { + "_output_shapes": AttrValue( + list=AttrValue.ListValue( + shape=[ + TensorShapeProto( + dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] + ) + ] + ) + ), + } + if hasattr(node, "calc_flops"): + flops_num = node.calc_flops() + # add op flops attr + attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) + flops_list.append( + dict( + name=node.name, + class_name=node.type, + input_shapes=[i.shape for i in node.inputs], + output_shapes=[o.shape for o in node.outputs], + flops_num=flops_num, + flops_cum=0, + ) + ) + if node.type == "ImmutableTensor": + param_dim = np.prod(node_oup.shape) + # TODO: consider other quantize dtypes + param_bytes = 1 if is_quantize(node_oup.dtype) else 4 + # add tensor size attr + attr["size"] = AttrValue( + s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") + ) + params_list.append( + dict( + name=node.name, + shape=node_oup.shape, + param_dim=param_dim, + bits=param_bytes * 8, + size=param_dim * param_bytes, + size_cum=0, + mean="{:.2g}".format(node.numpy().mean()), + std="{:.2g}".format(node.numpy().std()), + ) + ) + node_list.append( + NodeDef( + name=process_name(node.name), op=node.type, input=inp_list, attr=attr, + ) + ) + + total_flops, total_params = 0, 0 + if log_params: + total_params = print_params_stats(params_list, bar_length_max) + if log_flops: + total_flops = print_flops_stats(flops_list, bar_length_max) + + graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) + + device = "/device:CPU:0" + stepstats = RunMetadata( + step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) + ) + writer._get_file_writer().add_graph((graph_def, stepstats)) + return total_params, total_flops + + +def main(): + parser = argparse.ArgumentParser( + description="load a megengine dumped model and export log file for tensorboard visualization.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("model_path", help="dumped model path.") + parser.add_argument("log_path", help="tensorboard log path.") + parser.add_argument( + "--bar_length_max", + type=int, + default=20, + help="size of bar indicating max flops or parameter size in net stats.", + ) + parser.add_argument( + "--log_params", + action="store_true", + help="whether print and record params size.", + ) + parser.add_argument( + "--log_flops", action="store_true", help="whether print and record op flops.", + ) + visualize(**vars(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/imperative/python/megengine/utils/profile_analyze.py b/imperative/python/megengine/tools/profile_analyze.py similarity index 99% rename from imperative/python/megengine/utils/profile_analyze.py rename to imperative/python/megengine/tools/profile_analyze.py index 6722670bf..071e06606 100755 --- a/imperative/python/megengine/utils/profile_analyze.py +++ b/imperative/python/megengine/tools/profile_analyze.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +#! /usr/bin/env python3 # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. diff --git a/imperative/python/megengine/utils/net_stats.py b/imperative/python/megengine/utils/module_stats.py similarity index 58% rename from imperative/python/megengine/utils/net_stats.py rename to imperative/python/megengine/utils/module_stats.py index 68eced78e..46753cbe9 100644 --- a/imperative/python/megengine/utils/net_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -84,26 +84,125 @@ hook_modules = ( ) -def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True): - def dict2table(list_of_dict, header): - table_data = [header] - for d in list_of_dict: - row = [] - for h in header: - v = "" - if h in d: - v = d[h] - row.append(v) - table_data.append(row) - return table_data - - def sizeof_fmt(num, suffix="B"): - for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: - if abs(num) < 1024.0: - return "{:3.3f} {}{}".format(num, unit, suffix) - num /= 1024.0 - sign_str = "-" if num < 0 else "" - return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) +def dict2table(list_of_dict, header): + table_data = [header] + for d in list_of_dict: + row = [] + for h in header: + v = "" + if h in d: + v = d[h] + row.append(v) + table_data.append(row) + return table_data + + +def sizeof_fmt(num, suffix="B"): + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return "{:3.3f} {}{}".format(num, unit, suffix) + num /= 1024.0 + sign_str = "-" if num < 0 else "" + return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) + + +def print_flops_stats(flops, bar_length_max=20): + flops_list = [i["flops_num"] for i in flops] + max_flops_num = max(flops_list + [0]) + # calc total flops and set flops_cum + total_flops_num = 0 + for d in flops: + total_flops_num += int(d["flops_num"]) + d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") + + for i in flops: + f = i["flops_num"] + i["flops"] = sizeof_fmt(f, suffix="OPs") + r = i["ratio"] = f / total_flops_num + i["percentage"] = "{:.2f}%".format(r * 100) + bar_length = int(f / max_flops_num * bar_length_max) + i["bar"] = "#" * bar_length + + header = [ + "name", + "class_name", + "input_shapes", + "output_shapes", + "flops", + "flops_cum", + "percentage", + "bar", + ] + + total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") + total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) + flops.append( + dict(name="total", flops=total_flops_str, output_shapes=total_var_size) + ) + + logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) + + return total_flops_num + + +def print_params_stats(params, bar_length_max=20): + total_param_dims, total_param_size = 0, 0 + for d in params: + total_param_dims += int(d["param_dim"]) + total_param_size += int(d["size"]) + d["size"] = sizeof_fmt(d["size"]) + d["size_cum"] = sizeof_fmt(total_param_size) + + for d in params: + ratio = d["param_dim"] / total_param_dims + d["ratio"] = ratio + d["percentage"] = "{:.2f}%".format(ratio * 100) + + # construct bar + max_ratio = max([d["ratio"] for d in params]) + for d in params: + bar_length = int(d["ratio"] / max_ratio * bar_length_max) + d["size_bar"] = "#" * bar_length + + param_size = sizeof_fmt(total_param_size) + params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) + + header = [ + "name", + "shape", + "mean", + "std", + "param_dim", + "bits", + "size", + "size_cum", + "percentage", + "size_bar", + ] + + logger.info( + "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) + ) + + return total_param_size + + +def net_stats( + model: m.Module, + input_size: int, + bar_length_max: int = 20, + log_params: bool = True, + log_flops: bool = True, +): + r""" + Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. + + :param model: model that need to get stats info. + :param input_size: size of input for running model and calculating stats. + :param bar_length_max: size of bar indicating max flops or parameter size in net stats. + :param log_params: whether print and record params size. + :param log_flops: whether print and record op flops. + """ def get_byteswidth(tensor): if dtype.is_quantize(tensor.dtype): @@ -113,87 +212,6 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T else: return 4 - def print_flops_stats(flops): - flops_list = [i["flops_num"] for i in flops] - max_flops_num = max(flops_list + [0]) - # calc total flops and set flops_cum - total_flops_num = 0 - for d in flops: - total_flops_num += int(d["flops_num"]) - d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") - - for i in flops: - f = i["flops_num"] - i["flops"] = sizeof_fmt(f, suffix="OPs") - r = i["ratio"] = f / total_flops_num - i["percentage"] = "{:.2f}%".format(r * 100) - bar_length = int(f / max_flops_num * bar_length_max) - i["bar"] = "#" * bar_length - - header = [ - "name", - "class_name", - "input_shapes", - "output_shapes", - "flops", - "flops_cum", - "percentage", - "bar", - ] - - total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") - total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) - flops.append( - dict(name="total", flops=total_flops_str, output_shapes=total_var_size) - ) - - logger.info( - "flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)) - ) - - return total_flops_num - - def print_params_stats(params): - total_param_dims, total_param_size = 0, 0 - for d in params: - total_param_dims += int(d["param_dim"]) - total_param_size += int(d["size"]) - d["size"] = sizeof_fmt(d["size"]) - d["size_cum"] = sizeof_fmt(total_param_size) - - for d in params: - ratio = d["param_dim"] / total_param_dims - d["ratio"] = ratio - d["percentage"] = "{:.2f}%".format(ratio * 100) - - # construct bar - max_ratio = max([d["ratio"] for d in params]) - for d in params: - bar_length = int(d["ratio"] / max_ratio * bar_length_max) - d["size_bar"] = "#" * bar_length - - param_size = sizeof_fmt(total_param_size) - params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) - - header = [ - "name", - "shape", - "mean", - "std", - "param_dim", - "bits", - "size", - "size_cum", - "percentage", - "size_bar", - ] - - logger.info( - "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) - ) - - return total_param_size - def net_stats_hook(module, input, output, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] @@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T total_flops, total_params = 0, 0 if log_params: - total_params = print_params_stats(params) + total_params = print_params_stats(params, bar_length_max) if log_flops: - total_flops = print_flops_stats(flops) + total_flops = print_flops_stats(flops, bar_length_max) return total_params, total_flops diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index f26e8086b..8bcf4d892 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph from ..core.tensor import megbrain_graph as G from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .network_node import ( - NetworkNode, Host2DeviceCopy, ImmutableTensor, + NetworkNode, OpNode, VarNode, str_to_mge_class, @@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter): _node_type = None def __init__(self, node_iter, node_type): - assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( - node_type - ) + assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type) super().__init__(node_iter) self._node_type = node_type diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index ada2df466..58a5982ea 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -10,6 +10,8 @@ import json import sys from typing import Callable +import numpy as np + from ..core import _imperative_rt as rt from ..core._wrap import Device from ..core.ops import builtin @@ -52,7 +54,7 @@ class VarNode(NetworkNode): return self.var.dtype if self.var else None def set_owner_opr(self, owner_opr): - self.owner_opr = owner_opr + self.owner = owner_opr class OpNode(NetworkNode): @@ -223,6 +225,9 @@ class Elemwise(OpNode): type = "Elemwise" opdef = builtin.Elemwise + def calc_flops(self): + return np.prod(self.outputs[0].shape) + class Reduce(OpNode): type = "Reduce" @@ -250,11 +255,21 @@ class MatrixMul(OpNode): type = "MatrixMul" opdef = builtin.MatrixMul + def calc_flops(self): + assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 + mid_shape = self.inputs[0].shape[1] + return np.prod(self.outputs[0].shape) * mid_shape + class BatchedMatrixMul(OpNode): type = "BatchedMatmul" opdef = builtin.BatchedMatrixMul + def calc_flops(self): + assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 + mid_shape = self.inputs[0].shape[2] + return np.prod(self.outputs[0].shape) * mid_shape + class Dot(OpNode): type = "Dot" @@ -270,6 +285,18 @@ class ConvolutionForward(OpNode): type = "Convolution" opdef = builtin.Convolution + def calc_flops(self): + param_W_shape = self.inputs[1].shape + kh = param_W_shape[-2] + kw = param_W_shape[-1] + if len(param_W_shape) == 5: + num_input = param_W_shape[2] + else: + num_input = param_W_shape[1] + NCHW = np.prod(self.outputs[0].shape) + # N x Cout x H x W x (Cin x Kw x Kh) + return NCHW * (num_input * kw * kh) + class ConvolutionBackwardData(OpNode): type = "ConvTranspose" @@ -316,6 +343,18 @@ class ConvBiasForward(OpNode): obj.params["dtype"] = opr.outputs[0].dtype return obj + def calc_flops(self): + param_W_shape = self.inputs[1].shape + kh = param_W_shape[-2] + kw = param_W_shape[-1] + if len(param_W_shape) == 5: + num_input = param_W_shape[2] + else: + num_input = param_W_shape[1] + NCHW = np.prod(self.outputs[0].shape) + # N x Cout x H x W x (Cin x Kw x Kh + bias) + return NCHW * (num_input * kw * kh + 1) + class BatchConvBiasForward(OpNode): type = "BatchConvBias" @@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode): class BatchNormForward(OpNode): type = "BatchNorm" opdef = builtin.BatchNorm + output_idx = -1 class ROIAlignForward(OpNode): @@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode): obj.params["dtype"] = opr.outputs[0].dtype return obj + def calc_flops(self): + return np.prod(self.outputs[0].shape) + class CvtColorForward(OpNode): type = "CvtColor" diff --git a/imperative/python/megengine/utils/plugin.py b/imperative/python/megengine/utils/plugin.py deleted file mode 100644 index a50634c23..000000000 --- a/imperative/python/megengine/utils/plugin.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import struct - -import numpy as np - - -def load_tensor_binary(fobj): - """ - Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual - tensor value dump is implemented by ``mgb::debug::dump_tensor``. - - Multiple values can be compared by ``tools/compare_binary_iodump.py``. - - :param fobj: file object, or a string that contains the file name. - :return: tuple ``(tensor_value, tensor_name)``. - """ - if isinstance(fobj, str): - with open(fobj, "rb") as fin: - return load_tensor_binary(fin) - - DTYPE_LIST = { - 0: np.float32, - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - # 5: _mgb.intb1, - # 6: _mgb.intb2, - # 7: _mgb.intb4, - 8: None, - 9: np.float16, - # quantized dtype start from 100000 - # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in - # dnn/include/megdnn/dtype.h - 100000: np.uint8, - 100001: np.int32, - 100002: np.int8, - } - - header_fmt = struct.Struct("III") - name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) - assert ( - DTYPE_LIST[dtype] is not None - ), "Cannot load this tensor: dtype Byte is unsupported." - - shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) - while shape[-1] == 0: - shape.pop(-1) - name = fobj.read(name_len).decode("ascii") - return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name -- GitLab