diff --git a/imperative/python/megengine/utils/tensorboard.py b/imperative/python/megengine/utils/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..c99ee041c42b7163fff91e7ce22dd945a63a6af8 --- /dev/null +++ b/imperative/python/megengine/utils/tensorboard.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python +# -*-coding=utf-8-*- + +from megengine.logger import get_logger + +logger = get_logger(__name__) + +try: + from tensorboardX import SummaryWriter + from tensorboardX.proto.attr_value_pb2 import AttrValue + from tensorboardX.proto.graph_pb2 import GraphDef + from tensorboardX.proto.node_def_pb2 import NodeDef + from tensorboardX.proto.plugin_text_pb2 import TextPluginData + from tensorboardX.proto.step_stats_pb2 import ( + DeviceStepStats, + RunMetadata, + StepStats, + ) + from tensorboardX.proto.summary_pb2 import Summary, SummaryMetadata + from tensorboardX.proto.tensor_pb2 import TensorProto + from tensorboardX.proto.tensor_shape_pb2 import TensorShapeProto + from tensorboardX.proto.versions_pb2 import VersionDef +except ImportError: + logger.error( + "TensorBoard and TensorboardX are required for visualize.", exc_info=True, + ) + + +def tensor_shape_proto(shape): + """Creates an object matching + https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto + """ + return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in shape]) + + +def attr_value_proto(shape, dtype, attr): + """Creates a dict of objects matching + https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto + specifically designed for a NodeDef. The values have been + reverse engineered from standard TensorBoard logged data. + """ + attr_proto = {} + if shape is not None: + shapeproto = tensor_shape_proto(shape) + attr_proto["_output_shapes"] = AttrValue( + list=AttrValue.ListValue(shape=[shapeproto]) + ) + if dtype is not None: + attr_proto["dtype"] = AttrValue(s=dtype.encode(encoding="utf-8")) + if attr is not None: + for key in attr.keys(): + attr_proto[key] = AttrValue(s=attr[key].encode(encoding="utf-8")) + + return attr_proto + + +def node_proto( + name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={} +): + """Creates an object matching + https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto + """ + if input is None: + input = [] + if not isinstance(input, list): + input = [input] + return NodeDef( + name=name.encode(encoding="utf_8"), + op=op, + input=input, + attr=attr_value_proto(outputshape, dtype, attributes), + ) + + +def node( + name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={} +): + return node_proto(name, op, input, outputshape, dtype, attributes) + + +def graph(node_list): + graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) + stepstats = RunMetadata( + step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]) + ) + return graph_def, stepstats + + +def text(tag, text): + plugin_data = SummaryMetadata.PluginData( + plugin_name="text", content=TextPluginData(version=0).SerializeToString() + ) + smd = SummaryMetadata(plugin_data=plugin_data) + string_val = [] + for item in text: + string_val.append(item.encode(encoding="utf_8")) + tensor = TensorProto( + dtype="DT_STRING", + string_val=string_val, + tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=len(text))]), + ) + + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) + + +class NodeRaw: + def __init__(self, name, op, input, outputshape, dtype, attributes): + self.name = name + self.op = op + self.input = input + self.outputshape = outputshape + self.dtype = dtype + self.attributes = attributes + + +class SummaryWriterExtend(SummaryWriter): + def __init__( + self, + logdir=None, + comment="", + purge_step=None, + max_queue=10, + flush_secs=120, + filename_suffix="", + write_to_disk=True, + log_dir=None, + **kwargs + ): + self.node_raw_dict = {} + super().__init__( + logdir, + comment, + purge_step, + max_queue, + flush_secs, + filename_suffix, + write_to_disk, + log_dir, + **kwargs, + ) + + def add_text(self, tag, text_string_list, global_step=None, walltime=None): + """Add text data to summary. + Args: + tag (string): Data identifier + text_string_list (string list): String to save + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Examples:: + # text can be divided into three levels by tag and global_step + from writer import SummaryWriterExtend + writer = SummaryWriterExtend() + + writer.add_text('level1.0/level2.0', ['text0'], 0) + writer.add_text('level1.0/level2.0', ['text1'], 1) + writer.add_text('level1.0/level2.1', ['text2']) + writer.add_text('level1.1', ['text3']) + """ + + self._get_file_writer().add_summary( + text(tag, text_string_list), global_step, walltime + ) + + def add_node_raw( + self, + name, + op="UnSpecified", + input=[], + outputshape=None, + dtype=None, + attributes={}, + ): + """Add node raw datas that can help build graph.After add all nodes, call + add_graph_by_node_raw_list() to build graph and add graph data to summary. + Args: + name (string): opr name. + op (string): opr class name. + input (string list): input opr name. + outputshape (list): output shape. + dtype (string): output data dtype. + attributes (dict): attributes info. + Examples:: + from writer import SummaryWriterExtend + writer = SummaryWriterExtend() + + writer.add_node_raw('node1', 'opr1', outputshape=[6, 2, 3], dtype="float32", attributes={ + "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"}) + writer.add_node_raw('node2', 'opr2', outputshape=[6, 2, 3], dtype="float32", input="node1", attributes={ + "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"}) + writer.add_graph_by_node_raw_list() + + """ + # self.node_raw_list.append( + # node(name, op, input, outputshape, dtype, attributes)) + self.node_raw_dict[name] = NodeRaw( + name, op, input, outputshape, dtype, dict(attributes) + ) + + def add_node_raw_name_suffix(self, name, suffix): + """Give node name suffix in order to finding this node by 'search nodes' + Args: + name (string): opr name. + suffix (string): nam suffix. + """ + old_name = self.node_raw_dict[name].name + new_name = old_name + suffix + # self.node_raw_dict[new_name] = self.node_raw_dict.pop(name) + self.node_raw_dict[name].name = new_name + for node_name, node in self.node_raw_dict.items(): + node.input = [new_name if x == old_name else x for x in node.input] + + def add_node_raw_attributes(self, name, attributes): + """ + Args: + name (string): opr name. + attributes (dict): attributes info that need to be added. + """ + for key, value in attributes.items(): + self.node_raw_dict[name].attributes[key] = value + + def add_graph_by_node_raw_list(self): + """Build graph and add graph data to summary.""" + node_raw_list = [] + for key, value in self.node_raw_dict.items(): + node_raw_list.append( + node( + value.name, + value.op, + value.input, + value.outputshape, + value.dtype, + value.attributes, + ) + ) + self._get_file_writer().add_graph(graph(node_raw_list))