diff --git a/README.md b/README.md index d5ad0ab5049cc2d21c3ed49450feb067abc8c6b8..c604c6ad99c246464f6b4f71faa993604474f68f 100644 --- a/README.md +++ b/README.md @@ -315,10 +315,18 @@ Developers can compare multiple experiments by specifying and uploading the path ### Graph -**Graph** enables developers to visualize model structures by only one click. Moreover, **Graph** allows developers to explore model attributes, node information, node input and output. aiding them analyze model structures quickly and understand the direction of data flow easily. +**Graph** enables developers to visualize model structures by only one click. Moreover, **Graph** allows developers to explore model attributes, node information, node input and output. aiding them analyze model structures quickly and understand the direction of data flow easily. Additionally, Graph supports the visualization of dynamic and static model graph respectively. + +- dynamic graph + +

+ +

+ +- static graph

- +

diff --git a/README_CN.md b/README_CN.md index fdb3ad5d780dfaf4b49af85df68bbdb9f8361a13..04e3b92764ff83357c3f5be8515a524ca0820522 100644 --- a/README_CN.md +++ b/README_CN.md @@ -337,12 +337,18 @@ value: 3.1297709941864014 ### Graph -一键可视化模型的网络结构。可查看模型属性、节点信息、节点输入输出等,并支持节点搜索,辅助用户快速分析模型结构与了解数据流向。 +一键可视化模型的网络结构。可查看模型属性、节点信息、节点输入输出等,并支持节点搜索,辅助用户快速分析模型结构与了解数据流向,覆盖动态图与静态图两种格式。 +- 动态图

- +

+- 静态图 + +

+ +

### Histogram diff --git a/docs/README.md b/docs/README.md index e8527a859485a1c330fd7d6e66968c82c35c6348..71fcbb412cdd86aad463ef4db721248f5150c0a4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -239,12 +239,21 @@ Developers can compare with multiple experiments by specifying and uploading the ### Graph -**Graph** enables developers to visualize model structures by only one click. Moreover, **Graph** allows Developers to explore model attributes, node information, node input and output. aiding them analyze model structure quickly and understand the direction of data flow easily. +**Graph** enables developers to visualize model structures by only one click. Moreover, **Graph** allows developers to explore model attributes, node information, node input and output. aiding them analyze model structures quickly and understand the direction of data flow easily. Additionally, Graph supports the visualization of dynamic and static model graph respectively. + +- dynamic graph + +

+ +

+ +- static graph

- +

+ ### Histogram **Histogram** displays how the trend of tensors (weight, bias, gradient, etc.) changes during the training process in the form of histogram. Developers can adjust the model structures accurately by having an in-depth understanding of the effect of each layer. diff --git a/docs/README_CN.md b/docs/README_CN.md index 38a659e25078e097bccd23280b2fa1d609ff7c63..9a3741338bb22c9a3e9b38260ff3d09e03b6242d 100644 --- a/docs/README_CN.md +++ b/docs/README_CN.md @@ -250,9 +250,17 @@ app.run(logdir="./log") ### Graph -一键可视化模型的网络结构。可查看模型属性、节点信息、节点输入输出等,并支持节点搜索,辅助用户快速分析模型结构与了解数据流向。 +一键可视化模型的网络结构。可查看模型属性、节点信息、节点输入输出等,并支持节点搜索,辅助用户快速分析模型结构与了解数据流向,覆盖动态图与静态图两种格式。 +- 动态图 + +

+ +

+ +- 静态图 +

- +

### Histogram diff --git a/docs/components/README.md b/docs/components/README.md index 9ed73fc395d0b0847f61975d2b874ab0866ff2ca..29c509e5321be553b5709d38e82aaccf196e334f 100644 --- a/docs/components/README.md +++ b/docs/components/README.md @@ -454,44 +454,114 @@ Then, open the browser and enter the address`http://127.0.0.1:8080` to view: Graph can visualize the network structure of the model by one click. It enables developers to view the model attributes, node information, searching node and so on. These functions help developers analyze model structures and understand the directions of data flow quickly. +### Record Interface + +The interface of the Graph is shown as follows: + +```python +add_graph(model, input_spec, verbose=False): +``` + +The interface parameters are described as follows: + +| parameter | format | meaning | +| -------------- | --------------------- | ------------------------------------------- | +| model | paddle.nn.Layer | Dynamic model of paddle | +| input_spec | list\[paddle.static.InputSpec\|Tensor\] | Describes the input of the saved model's [forward arguments](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/static/InputSpec_cn.html) | +| verbose | bool | Whether to print graph statistic information in console. | + +**Note** + +If you want to use add_graph interface, paddle package is required. Please refer to website of [PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/en/install/pip/linux-pip_en.html)。 + ### Demo +The following shows an example of how to use Graph component, and script can be found in [Graph Demo](https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/components/graph_test.py) There are two methods to launch this component: -- By the front end: +```python +import paddle +import paddle.nn as nn +import paddle.nn.functional as F - - If developers only need to use Graph, developers can launch VisualDL (Graph) by executing `visualdl`on the command line. - - If developers need to use Graph and other functions at the same time, they need to specify the log file path (using `./log` as an example): +from visualdl import LogWriter - ```shell - visualdl --logdir ./log --port 8080 - ``` +class MyNet(nn.Layer): + def __init__(self): + super(MyNet, self).__init__() + self.conv1 = nn.Conv2D( + in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2) + self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2) + self.conv2 = nn.Conv2D( + in_channels=20, + out_channels=20, + kernel_size=5, + stride=1, + padding=2) + self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2) + self.fc = nn.Linear(in_features=980, out_features=10) + + def forward(self, inputs): + x = self.conv1(inputs) + x = F.relu(x) + x = self.max_pool1(x) + x = self.conv2(x) + x = F.relu(x) + x = self.max_pool2(x) + x = paddle.reshape(x, [x.shape[0], -1]) + x = self.fc(x) + return x + + +net = MyNet() +with LogWriter(logdir="./log/graph_test/") as writer: + writer.add_graph( + model=net, + input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')], + verbose=True) +``` -- By the backend: - - Add the parameter `--model` and specify the **model file** path (not the folder path) to launch the panel: - ```shell - visualdl --model ./log/model --port 8080 - ``` +After running the above program, developers can launch the panel by: +```shell +visualdl --logdir ./log/graph_test/ --port 8080 +``` -After the launch, developers can view the network structure: +Then, open the browser and enter the address`http://127.0.0.1:8080` to view:

- +

+**Note** + +We provide option --model to specify model structure file in previous versions, and this option is still supported now. You can specify model exported by `add_graph` interface ("vdlgraph" contained in filename), which will be shown in dynamic graph page, and we use string "manual_input_model" in the page to denote the model you specify by this option. Other supported file formats are presented in static graph page. + +For example +```shell +visualdl --model ./log/model.pdmodel --port 8080 +``` +which will be shown in static graph page. And +```shell +visualdl --model ./log/vdlgraph.1655783158.log --port 8080 +``` +shown in dynamic graph page. + ### Functional Instructions -- Upload the model file by one-click - - Supported model:PaddlePaddle、ONNX、Keras、Core ML、Caffe、Caffe2、Darknet、MXNet、ncnn、TensorFlow Lite - - Experimental supported model:TorchScript、PyTorch、Torch、 ArmNN、BigDL、Chainer、CNTK、Deeplearning4j、MediaPipe、ML.NET、MNN、OpenVINO、Scikit-learn、Tengine、TensorFlow.js、TensorFlow +Graph page is divided into dynamic and static version currently. Dynamic version is used to visualize dynamic model of paddle, which is exported by add_graph interface. +The other is used to visualize static model of paddle, which is exported by [paddle.jit.save](https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/jit/save_en.html) interface and other supported formats. +

- +

+**Common functions** + + - Developers are allowed to drag the model up and down,left and right,zoom in and zoom out.

@@ -534,6 +604,44 @@ After the launch, developers can view the network structure:

+**Specific feature in dynamic version** + +- Fold and unfold one node +

+ +

+

+ +

+ +- Fold and unfold all nodes +

+ +

+

+ +

+ +- Link api specification of paddle + + If you use paddle.nn components to construct your network model, you can use alt+click mouse to direct to corresponding api specification. +

+ +

+

+ +

+ +**Specific feature in static version** + +- Upload the model file by one-click + - Supported model:PaddlePaddle、ONNX、Keras、Core ML、Caffe、Caffe2、Darknet、MXNet、ncnn、TensorFlow Lite + - Experimental supported model:TorchScript、PyTorch、Torch、 ArmNN、BigDL、Chainer、CNTK、Deeplearning4j、MediaPipe、ML.NET、MNN、OpenVINO、Scikit-learn、Tengine、TensorFlow.js、TensorFlow + +

+ +

+ ## Histogram--Distribution of Tensors ### Introduction diff --git a/docs/components/README_CN.md b/docs/components/README_CN.md index ac1f095734a82d3fd98d6a6c0f4b6f91f2eba9a5..3093d67e19e90eb1ebcb7b0c03ef716ff3dde0e3 100644 --- a/docs/components/README_CN.md +++ b/docs/components/README_CN.md @@ -511,49 +511,119 @@ visualdl --logdir ./log --port 8080 ## Graph--网络结构组件 + ### 介绍 -Graph组件一键可视化模型的网络结构。用于查看模型属性、节点信息、节点输入输出等,并进行节点搜索,协助开发者们快速分析模型结构与了解数据流向。 +Graph组件一键可视化模型的网络结构。用于查看模型属性、节点信息、节点输入输出等,并进行节点搜索,协助开发者们快速分析模型结构与了解数据流向,覆盖动态图与静态图两种格式。 + +### 记录接口 + +Graph组件的记录接口如下: + +```python +add_graph(model, input_spec, verbose=False): +``` + +接口参数说明如下: + +| 参数 | 格式 | 含义 | +| -------------- | --------------------- | ------------------------------------------- | +| model | paddle.nn.Layer | Paddle的动态图模型 | +| input_spec | list\[paddle.static.InputSpec\|Tensor\] | 用于描述模型[输入的参数](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/static/InputSpec_cn.html) | +| verbose | bool | 是否在终端打印模型的节点统计信息 | + +**注意** + +使用add_graph接口需要安装飞桨paddlepaddle, 安装步骤请参考[飞桨官方网站](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)。 ### Demo -共有两种启动方式: +下面展示了使用 Graph 组件记录飞桨动态图模型的示例,代码见[Graph组件](https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/components/graph_test.py) -- 前端启动Graph: +```python +import paddle +import paddle.nn as nn +import paddle.nn.functional as F - - 如只需使用Graph,无需添加任何参数,在命令行执行`visualdl`后即可启动。 - - 如果同时需使用其他功能,在命令行指定日志文件路径(以`./log`为例),即可启动: +from visualdl import LogWriter - ```shell - visualdl --logdir ./log --port 8080 - ``` +class MyNet(nn.Layer): + def __init__(self): + super(MyNet, self).__init__() + self.conv1 = nn.Conv2D( + in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2) + self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2) + self.conv2 = nn.Conv2D( + in_channels=20, + out_channels=20, + kernel_size=5, + stride=1, + padding=2) + self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2) + self.fc = nn.Linear(in_features=980, out_features=10) + + def forward(self, inputs): + x = self.conv1(inputs) + x = F.relu(x) + x = self.max_pool1(x) + x = self.conv2(x) + x = F.relu(x) + x = self.max_pool2(x) + x = paddle.reshape(x, [x.shape[0], -1]) + x = self.fc(x) + return x + + +net = MyNet() +with LogWriter(logdir="./log/graph_test/") as writer: + writer.add_graph( + model=net, + input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')], + verbose=True) +``` -- 后端启动Graph: +运行上述程序后,在命令行执行 - - 在命令行加入参数`--model`并指定**模型文件**路径(非文件夹路径),即可启动: +```shell +visualdl --logdir ./log/graph_test/ --port 8080 +``` - ```shell - visualdl --model ./log/model --port 8080 - ``` -*Graph目前只支持可视化网络结构格式的模型文件(如__model__(注意此处为两个下划线'_')) +接着在浏览器打开`http://127.0.0.1:8080`,即可查看Graph -启动后即可查看网络结构可视化: +启动后即可查看飞桨动态图网络结构可视化:

- +

-### 功能操作说明 +**注意** + +VisualDL之前的版本支持通过--model参数直接指定模型结构文件,现在仍然保持这一选项, +通过`add_graph`接口导出的动态图模型文件(文件名包含"vdlgraph"), 在动态图页面展示, +并在页面中以'manual_input_model'来表示通过该参数指定的模型。其余所支持的文件格式在静态图页面中展示。 + +例如 +```shell +visualdl --model ./log/model.pdmodel --port 8080 +``` +将展示在静态图页面。 + +```shell +visualdl --model ./log/vdlgraph.1655783158.log --port 8080 +``` +将展示在动态图页面。 -- 一键上传模型 - - 支持模型格式:PaddlePaddle、ONNX、Keras、Core ML、Caffe、Caffe2、Darknet、MXNet、ncnn、TensorFlow Lite - - 实验性支持模型格式:TorchScript、PyTorch、Torch、 ArmNN、BigDL、Chainer、CNTK、Deeplearning4j、MediaPipe、ML.NET、MNN、OpenVINO、Scikit-learn、Tengine、TensorFlow.js、TensorFlow +### 功能操作说明 + +当前Graph页面分为动态图和静态图两个页面。其中动态图页面用来展示通过add_graph接口导出的飞桨动态图模型结构,静态图页面用来展示飞桨静态图模型结构(通过飞桨的[paddle.jit.save](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html)导出的后缀名为pdmodel的文件)及其它可支持框架的模型。

- +

+**通用功能**: + - 支持上下左右任意拖拽模型、放大和缩小模型

@@ -600,6 +670,44 @@ Graph组件一键可视化模型的网络结构。用于查看模型属性、节

+**动态图页面特有功能**: + +- 展开和折叠指定节点 +

+ +

+

+ +

+ +- 一键全展开和全折叠 +

+ +

+

+ +

+ +- 飞桨API链接功能 + + 对于使用paddle.nn中的组件搭建的节点,可以使用alt+鼠标点击的方式跳转到官网的API说明文档。 +

+ +

+

+ +

+ +**静态图页面特有功能**: + +- 一键上传模型 + - 支持模型格式:PaddlePaddle、ONNX、Keras、Core ML、Caffe、Caffe2、Darknet、MXNet、ncnn、TensorFlow Lite + - 实验性支持模型格式:TorchScript、PyTorch、Torch、 ArmNN、BigDL、Chainer、CNTK、Deeplearning4j、MediaPipe、ML.NET、MNN、OpenVINO、Scikit-learn、Tengine、TensorFlow.js、TensorFlow + +

+ +

+ ## Histogram--直方图组件 ### 介绍 diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py index 6a59b3285264f86be57baaef9c7b025544e60817..90621d0d2a12338f1c526f911ce6799e42efe931 100644 --- a/visualdl/reader/graph_reader.py +++ b/visualdl/reader/graph_reader.py @@ -14,7 +14,6 @@ # ======================================================================= import json import os -import tempfile from visualdl.component.graph import analyse_model from visualdl.component.graph import Model @@ -56,7 +55,6 @@ class GraphReader(object): self.runs2displayname = {} self.graph_buffer = {} self.walks_buffer = {} - self.tempfile = None @property def logdir(self): @@ -64,11 +62,8 @@ class GraphReader(object): def get_all_walk(self): flush_walks = {} - if 'manual_input_model' in self.walks: - flush_walks['manual_input_model'] = [ - self.walks['manual_input_model'] - ] for dir in self.dir: + dir = os.path.realpath(dir) for root, dirs, files in bfile.walk(dir): flush_walks.update({root: files}) return flush_walks @@ -99,7 +94,10 @@ class GraphReader(object): def runs(self, update=True): self.graphs(update=update) - return list(self.walks.keys()) + graph_runs = list(self.walks.keys( + )) if 'manual_input_model' not in self.graph_buffer else list( + self.walks.keys()) + ['manual_input_model'] + return sorted(graph_runs) def get_graph(self, run, @@ -108,6 +106,12 @@ class GraphReader(object): keep_state=False, expand_all=False, refresh=False): + if run == 'manual_input_model' and run in self.graph_buffer: + graph_model = self.graph_buffer[run] + if nodeid is not None: + graph_model.adjust_visible(nodeid, expand, keep_state) + return graph_model.make_graph( + refresh=refresh, expand_all=expand_all) if run in self.walks: if run in self.walks_buffer: if self.walks[run] == self.walks_buffer[run]: @@ -118,7 +122,10 @@ class GraphReader(object): refresh=refresh, expand_all=expand_all) data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() - graph_model = Model(json.loads(data.decode())) + if 'pdmodel' in self.walks[run]: + graph_model = Model(analyse_model(data)) + else: + graph_model = Model(json.loads(data.decode())) self.graph_buffer[run] = graph_model self.walks_buffer[run] = self.walks[run] if nodeid is not None: @@ -127,6 +134,11 @@ class GraphReader(object): refresh=refresh, expand_all=expand_all) def search_graph_node(self, run, nodeid, keep_state=False, is_node=True): + if run == 'manual_input_model' and run in self.graph_buffer: + graph_model = self.graph_buffer[run] + graph_model.adjust_search_node_visible( + nodeid, keep_state=keep_state, is_node=is_node) + return graph_model.make_graph(refresh=False, expand_all=False) if run in self.walks: if run in self.walks_buffer: if self.walks[run] == self.walks_buffer[run]: @@ -137,7 +149,10 @@ class GraphReader(object): refresh=False, expand_all=False) data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() - graph_model = Model(json.loads(data.decode())) + if 'pdmodel' in self.walks[run]: + graph_model = Model(analyse_model(data)) + else: + graph_model = Model(json.loads(data.decode())) self.graph_buffer[run] = graph_model self.walks_buffer[run] = self.walks[run] graph_model.adjust_search_node_visible( @@ -145,6 +160,9 @@ class GraphReader(object): return graph_model.make_graph(refresh=False, expand_all=False) def get_all_nodes(self, run): + if run == 'manual_input_model' and run in self.graph_buffer: + graph_model = self.graph_buffer[run] + return graph_model.get_all_leaf_nodes() if run in self.walks: if run in self.walks_buffer: if self.walks[run] == self.walks_buffer[run]: @@ -152,7 +170,10 @@ class GraphReader(object): return graph_model.get_all_leaf_nodes() data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() - graph_model = Model(json.loads(data.decode())) + if 'pdmodel' in self.walks[run]: + graph_model = Model(analyse_model(data)) + else: + graph_model = Model(json.loads(data.decode())) self.graph_buffer[run] = graph_model self.walks_buffer[run] = self.walks[run] return graph_model.get_all_leaf_nodes() @@ -167,10 +188,6 @@ class GraphReader(object): def __exit__(self, exc_type, exc_val, exc_tb): pass - def __del__(self): - if self.tempfile: - os.unlink(self.tempfile.name) - def set_input_graph(self, content, file_type='pdmodel'): if isinstance(content, str): if not is_VDLGraph_file(content): @@ -184,23 +201,10 @@ class GraphReader(object): if file_type == 'pdmodel': data = analyse_model(content) self.graph_buffer['manual_input_model'] = Model(data) - temp = tempfile.NamedTemporaryFile(suffix='.pdmodel', delete=False) - temp.write(json.dumps(data).encode()) - temp.close() elif file_type == 'vdlgraph': self.graph_buffer['manual_input_model'] = Model( json.loads(content.decode())) - temp = tempfile.NamedTemporaryFile( - suffix='.log', prefix='vdlgraph.', delete=False) - temp.write(content) - temp.close() else: return - - if self.tempfile: - os.unlink(self.tempfile.name) - self.tempfile = temp - self.walks['manual_input_model'] = temp.name - self.walks_buffer['manual_input_model'] = temp.name diff --git a/visualdl/server/api.py b/visualdl/server/api.py index c03d68b2b281f5f4289e226912cb8d12ed315432..94d99d6accce262e3abbf8ab7eab7dafd05845ce 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -73,8 +73,9 @@ class Api(object): self._graph_reader = GraphReader(logdir) self._graph_reader.set_displayname(self._reader) if model: + if 'vdlgraph' in model: + self._graph_reader.set_input_graph(model) self._reader.model = model - self._graph_reader.set_input_graph(model) self.model_name = os.path.basename(model) else: self.model_name = '' @@ -104,8 +105,7 @@ class Api(object): def graph_runs(self): client_ip = request.remote_addr graph_reader = self.graph_reader_client_manager.get_data(client_ip) - return self._get_with_reader('data/graph_runs', lib.get_graph_runs, - graph_reader) + return lib.get_graph_runs(graph_reader) @result() def tags(self): @@ -280,6 +280,13 @@ class Api(object): key = os.path.join('data/plugin/roc_curves/steps', run) return self._get_with_retry(key, lib.get_roc_curve_step, run) + @result('application/octet-stream', lambda s: { + "Content-Disposition": 'attachment; filename="%s"' % s.model_name + } if len(s.model_name) else None) + def graph_static_graph(self): + key = os.path.join('data/plugin/graphs/static_graph') + return self._get_with_retry(key, lib.get_static_graph) + @result() def graph_graph(self, run, expand_all, refresh): client_ip = request.remote_addr @@ -395,6 +402,7 @@ def create_api_call(logdir, model, cache_timeout): 'embedding/metadata': (api.embedding_metadata, ['name']), 'histogram/list': (api.histogram_list, ['run', 'tag']), 'graph/graph': (api.graph_graph, ['run', 'expand_all', 'refresh']), + 'graph/static_graph': (api.graph_static_graph, []), 'graph/upload': (api.graph_upload, []), 'graph/search': (api.graph_search, ['run', 'nodeid', 'keep_state', 'is_node']), diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index fe6e6c887d874b6acb5fb704fe9812aa87abe431..b01f8a4cc511b450ae0d415082dcaf6953f83065 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -563,6 +563,14 @@ def get_histogram(log_reader, run, tag): return results +def get_static_graph(log_reader): + result = b"" + if log_reader.model: + with bfile.BFile(log_reader.model, 'rb') as bfp: + result = bfp.read_file(log_reader.model) + return result + + def get_graph(graph_reader, run, nodeid=None, diff --git a/visualdl/version.py b/visualdl/version.py index af9c267e9e25f3e5a5b355aa332d90235d418f2a..50f56ec6caa9398de7c9640f96175890aeb4ade3 100644 --- a/visualdl/version.py +++ b/visualdl/version.py @@ -13,4 +13,4 @@ # limitations under the License. # ======================================================================= -vdl_version = '2.2.3' +vdl_version = '2.3.0'