提交 53075cd3 编写于 作者: M Megvii Engine Team

feat(mge/experimental): add visualization and net stats for python graph

GitOrigin-RevId: a1ab77c20aff8b9205fb3b34532e8f86a2733d69
上级 ae3123b3
# MegEngine Tools
This directory contains executable python files.
Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`):
```
python -m megengine.tools.xxx
```
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -7,12 +8,55 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import os
import struct
import textwrap
from pathlib import Path
import numpy as np
from megengine.utils import plugin
def load_tensor_binary(fobj):
"""
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``.
:param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)``.
"""
if isinstance(fobj, str):
with open(fobj, "rb") as fin:
return load_tensor_binary(fin)
DTYPE_LIST = {
0: np.float32,
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
# 5: _mgb.intb1,
# 6: _mgb.intb2,
# 7: _mgb.intb4,
8: None,
9: np.float16,
# quantized dtype start from 100000
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
# dnn/include/megdnn/dtype.h
100000: np.uint8,
100001: np.int32,
100002: np.int8,
}
header_fmt = struct.Struct("III")
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size))
assert (
DTYPE_LIST[dtype] is not None
), "Cannot load this tensor: dtype Byte is unsupported."
shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4)))
while shape[-1] == 0:
shape.pop(-1)
name = fobj.read(name_len).decode("ascii")
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name
def check(v0, v1, name, max_err):
......@@ -26,9 +70,9 @@ def check(v0, v1, name, max_err):
)
vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0)
err = np.abs(v0 - v1) / vdiv
check = err > max_err
if check.sum():
idx = tuple(i[0] for i in np.nonzero(check))
rst = err > max_err
if rst.sum():
idx = tuple(i[0] for i in np.nonzero(rst))
raise AssertionError(
"{} not equal: "
"shape={} nonequal_idx={} v0={} v1={} err={}".format(
......@@ -79,8 +123,8 @@ def main():
files1 = sorted(files1)
for i, j in zip(files0, files1):
val0, name0 = plugin.load_tensor_binary(i)
val1, name1 = plugin.load_tensor_binary(j)
val0, name0 = load_tensor_binary(i)
val1, name1 = load_tensor_binary(j)
name = "{}: \n{}\n{}\n".format(
i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1))
)
......
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import numpy as np
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import get_logger
from megengine.utils.module_stats import (
print_flops_stats,
print_params_stats,
sizeof_fmt,
)
from megengine.utils.network import Network
logger = get_logger(__name__)
def visualize(
model_path: str,
log_path: str,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
):
r"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats`
:param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
try:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import (
AllocatorMemoryUsed,
DeviceStepStats,
NodeExecStats,
StepStats,
)
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboardX import SummaryWriter
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True
)
return
graph = Network.load(model_path)
writer = SummaryWriter(log_path)
def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
node_list = []
flops_list = []
params_list = []
for node in graph.all_oprs:
if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx]
else:
if len(node.outputs) != 1:
logger.warning(
"OpNode {} has more than one output and not has 'output_idx' attr.".format(
node
)
)
node_oup = node.outputs[0]
inp_list = [process_name(var.owner.name) for var in node.inputs]
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape]
)
]
)
),
}
if hasattr(node, "calc_flops"):
flops_num = node.calc_flops()
# add op flops attr
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8"))
flops_list.append(
dict(
name=node.name,
class_name=node.type,
input_shapes=[i.shape for i in node.inputs],
output_shapes=[o.shape for o in node.outputs],
flops_num=flops_num,
flops_cum=0,
)
)
if node.type == "ImmutableTensor":
param_dim = np.prod(node_oup.shape)
# TODO: consider other quantize dtypes
param_bytes = 1 if is_quantize(node_oup.dtype) else 4
# add tensor size attr
attr["size"] = AttrValue(
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
)
params_list.append(
dict(
name=node.name,
shape=node_oup.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(node.numpy().mean()),
std="{:.2g}".format(node.numpy().std()),
)
)
node_list.append(
NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
)
)
total_flops, total_params = 0, 0
if log_params:
total_params = print_params_stats(params_list, bar_length_max)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
device = "/device:CPU:0"
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer._get_file_writer().add_graph((graph_def, stepstats))
return total_params, total_flops
def main():
parser = argparse.ArgumentParser(
description="load a megengine dumped model and export log file for tensorboard visualization.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("model_path", help="dumped model path.")
parser.add_argument("log_path", help="tensorboard log path.")
parser.add_argument(
"--bar_length_max",
type=int,
default=20,
help="size of bar indicating max flops or parameter size in net stats.",
)
parser.add_argument(
"--log_params",
action="store_true",
help="whether print and record params size.",
)
parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.",
)
visualize(**vars(parser.parse_args()))
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......
......@@ -84,8 +84,7 @@ hook_modules = (
)
def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True):
def dict2table(list_of_dict, header):
def dict2table(list_of_dict, header):
table_data = [header]
for d in list_of_dict:
row = []
......@@ -97,7 +96,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
table_data.append(row)
return table_data
def sizeof_fmt(num, suffix="B"):
def sizeof_fmt(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
return "{:3.3f} {}{}".format(num, unit, suffix)
......@@ -105,15 +105,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
sign_str = "-" if num < 0 else ""
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
def get_byteswidth(tensor):
if dtype.is_quantize(tensor.dtype):
return 1
# elif dtype.is_bfloat16(tensor.dtype):
# return 2
else:
return 4
def print_flops_stats(flops):
def print_flops_stats(flops, bar_length_max=20):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
......@@ -147,13 +140,12 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)
logger.info(
"flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))
)
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))
return total_flops_num
def print_params_stats(params):
def print_params_stats(params, bar_length_max=20):
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
......@@ -194,6 +186,32 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
return total_param_size
def net_stats(
model: m.Module,
input_size: int,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
):
r"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info.
:param input_size: size of input for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
def get_byteswidth(tensor):
if dtype.is_quantize(tensor.dtype):
return 1
# elif dtype.is_bfloat16(tensor.dtype):
# return 2
else:
return 4
def net_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
......@@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
total_flops, total_params = 0, 0
if log_params:
total_params = print_params_stats(params)
total_params = print_params_stats(params, bar_length_max)
if log_flops:
total_flops = print_flops_stats(flops)
total_flops = print_flops_stats(flops, bar_length_max)
return total_params, total_flops
......@@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph
from ..core.tensor import megbrain_graph as G
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import (
NetworkNode,
Host2DeviceCopy,
ImmutableTensor,
NetworkNode,
OpNode,
VarNode,
str_to_mge_class,
......@@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter):
_node_type = None
def __init__(self, node_iter, node_type):
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(
node_type
)
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type)
super().__init__(node_iter)
self._node_type = node_type
......
......@@ -10,6 +10,8 @@ import json
import sys
from typing import Callable
import numpy as np
from ..core import _imperative_rt as rt
from ..core._wrap import Device
from ..core.ops import builtin
......@@ -52,7 +54,7 @@ class VarNode(NetworkNode):
return self.var.dtype if self.var else None
def set_owner_opr(self, owner_opr):
self.owner_opr = owner_opr
self.owner = owner_opr
class OpNode(NetworkNode):
......@@ -223,6 +225,9 @@ class Elemwise(OpNode):
type = "Elemwise"
opdef = builtin.Elemwise
def calc_flops(self):
return np.prod(self.outputs[0].shape)
class Reduce(OpNode):
type = "Reduce"
......@@ -250,11 +255,21 @@ class MatrixMul(OpNode):
type = "MatrixMul"
opdef = builtin.MatrixMul
def calc_flops(self):
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2
mid_shape = self.inputs[0].shape[1]
return np.prod(self.outputs[0].shape) * mid_shape
class BatchedMatrixMul(OpNode):
type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul
def calc_flops(self):
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3
mid_shape = self.inputs[0].shape[2]
return np.prod(self.outputs[0].shape) * mid_shape
class Dot(OpNode):
type = "Dot"
......@@ -270,6 +285,18 @@ class ConvolutionForward(OpNode):
type = "Convolution"
opdef = builtin.Convolution
def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh)
class ConvolutionBackwardData(OpNode):
type = "ConvTranspose"
......@@ -316,6 +343,18 @@ class ConvBiasForward(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype
return obj
def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return NCHW * (num_input * kw * kh + 1)
class BatchConvBiasForward(OpNode):
type = "BatchConvBias"
......@@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode):
class BatchNormForward(OpNode):
type = "BatchNorm"
opdef = builtin.BatchNorm
output_idx = -1
class ROIAlignForward(OpNode):
......@@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype
return obj
def calc_flops(self):
return np.prod(self.outputs[0].shape)
class CvtColorForward(OpNode):
type = "CvtColor"
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import struct
import numpy as np
def load_tensor_binary(fobj):
"""
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``.
Multiple values can be compared by ``tools/compare_binary_iodump.py``.
:param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)``.
"""
if isinstance(fobj, str):
with open(fobj, "rb") as fin:
return load_tensor_binary(fin)
DTYPE_LIST = {
0: np.float32,
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
# 5: _mgb.intb1,
# 6: _mgb.intb2,
# 7: _mgb.intb4,
8: None,
9: np.float16,
# quantized dtype start from 100000
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
# dnn/include/megdnn/dtype.h
100000: np.uint8,
100001: np.int32,
100002: np.int8,
}
header_fmt = struct.Struct("III")
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size))
assert (
DTYPE_LIST[dtype] is not None
), "Cannot load this tensor: dtype Byte is unsupported."
shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4)))
while shape[-1] == 0:
shape.pop(-1)
name = fobj.read(name_len).decode("ascii")
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册