From edea528b40cf3c21dc9cb86d0492cc29651c1d9c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 23 Mar 2021 16:57:26 +0800 Subject: [PATCH] feat(mge/tools): set network_visualize's log_path as optional flag GitOrigin-RevId: a74bdc08ba86d431a1a0cc9d1fc665d897ecd16f --- .../megengine/tools/network_visualize.py | 125 +++++++++++------- 1 file changed, 77 insertions(+), 48 deletions(-) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index 210343f81..d155755d4 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -40,30 +40,31 @@ def visualize( :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ - try: - from tensorboard.compat.proto.attr_value_pb2 import AttrValue - from tensorboard.compat.proto.config_pb2 import RunMetadata - from tensorboard.compat.proto.graph_pb2 import GraphDef - from tensorboard.compat.proto.node_def_pb2 import NodeDef - from tensorboard.compat.proto.step_stats_pb2 import ( - AllocatorMemoryUsed, - DeviceStepStats, - NodeExecStats, - StepStats, - ) - from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto - from tensorboard.compat.proto.versions_pb2 import VersionDef - from tensorboardX import SummaryWriter - except ImportError: - logger.error( - "TensorBoard and TensorboardX are required for visualize.", exc_info=True - ) - return + if log_path: + try: + from tensorboard.compat.proto.attr_value_pb2 import AttrValue + from tensorboard.compat.proto.config_pb2 import RunMetadata + from tensorboard.compat.proto.graph_pb2 import GraphDef + from tensorboard.compat.proto.node_def_pb2 import NodeDef + from tensorboard.compat.proto.step_stats_pb2 import ( + AllocatorMemoryUsed, + DeviceStepStats, + NodeExecStats, + StepStats, + ) + from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto + from tensorboard.compat.proto.versions_pb2 import VersionDef + from tensorboardX import SummaryWriter + except ImportError: + logger.error( + "TensorBoard and TensorboardX are required for visualize.", + exc_info=True, + ) + return # FIXME: remove this after resolving "span dist too large" warning old_level = set_mgb_log_level(logging.ERROR) graph = Network.load(model_path) - writer = SummaryWriter(log_path) def process_name(name): return name.replace(".", "/").encode(encoding="utf-8") @@ -84,21 +85,27 @@ def visualize( node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] - attr = { - "_output_shapes": AttrValue( - list=AttrValue.ListValue( - shape=[ - TensorShapeProto( - dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] - ) - ] - ) - ), - } + if log_path: + attr = { + "_output_shapes": AttrValue( + list=AttrValue.ListValue( + shape=[ + TensorShapeProto( + dim=[ + TensorShapeProto.Dim(size=d) for d in node_oup.shape + ] + ) + ] + ) + ), + } if hasattr(node, "calc_flops"): flops_num = node.calc_flops() # add op flops attr - attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) + if log_path: + attr["flops"] = AttrValue( + s=sizeof_fmt(flops_num).encode(encoding="utf-8") + ) flops_list.append( dict( name=node.name, @@ -114,9 +121,10 @@ def visualize( # TODO: consider other quantize dtypes param_bytes = 1 if is_quantize(node_oup.dtype) else 4 # add tensor size attr - attr["size"] = AttrValue( - s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") - ) + if log_path: + attr["size"] = AttrValue( + s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") + ) params_list.append( dict( name=node.name, @@ -132,25 +140,33 @@ def visualize( # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug if not len(node.name.split(".")) > 2 and not node in graph.input_vars: continue - node_list.append( - NodeDef( - name=process_name(node.name), op=node.type, input=inp_list, attr=attr, + if log_path: + node_list.append( + NodeDef( + name=process_name(node.name), + op=node.type, + input=inp_list, + attr=attr, + ) ) - ) - total_flops, total_params = 0, 0 + total_flops, total_params = None, None if log_params: total_params = print_params_stats(params_list, bar_length_max) if log_flops: total_flops = print_flops_stats(flops_list, bar_length_max) - graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) + if log_path: + graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) - device = "/device:CPU:0" - stepstats = RunMetadata( - step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) - ) - writer._get_file_writer().add_graph((graph_def, stepstats)) + device = "/device:CPU:0" + stepstats = RunMetadata( + step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) + ) + writer = SummaryWriter(log_path) + writer._get_file_writer().add_graph((graph_def, stepstats)) + + # summary # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) @@ -164,7 +180,7 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("model_path", help="dumped model path.") - parser.add_argument("log_path", help="tensorboard log path.") + parser.add_argument("--log_path", help="tensorboard log path.") parser.add_argument( "--bar_length_max", type=int, @@ -179,7 +195,20 @@ def main(): parser.add_argument( "--log_flops", action="store_true", help="whether print and record op flops.", ) - visualize(**vars(parser.parse_args())) + parser.add_argument( + "--all", + action="store_true", + help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", + ) + args = parser.parse_args() + if args.all: + args.log_params = True + args.log_flops = True + if not args.log_path: + args.log_path = "./log" + kwargs = vars(args) + kwargs.pop("all") + visualize(**kwargs) if __name__ == "__main__": -- GitLab