未验证 提交 845a74b8 编写于 作者: W wuzewu 提交者: GitHub

fix #512 (#513)

* add paddle framework.proto and support show paddle-fluid model in web page

* modify the graph drawing policy, only draw operator node in default

* Resolution of code style issues raised by Superjomn
上级 08cd0cb7
......@@ -44,8 +44,10 @@ function build_backend() {
function build_onnx_graph() {
$env:PATH = "$BUILD_DIR/third_party/protobuf/src/extern_protobuf-build/Release;" + $env:PATH
cd $TOP_DIR/visualdl/server/onnx
cd $TOP_DIR/visualdl/server/model/onnx
protoc onnx.proto --python_out .
cd $TOP_DIR/visualdl/server/model/paddle
protoc framework.proto --python_out .
}
function clean_env() {
......
......@@ -49,8 +49,10 @@ build_backend() {
build_onnx_graph() {
export PATH="$BUILD_DIR/third_party/protobuf/src/extern_protobuf-build/:$PATH"
cd $TOP_DIR/visualdl/server/onnx
cd $TOP_DIR/visualdl/server/model/onnx
protoc onnx.proto --python_out .
cd $TOP_DIR/visualdl/server/model/paddle
protoc framework.proto --python_out .
}
clean_env() {
......
......@@ -47,6 +47,12 @@ export default {
graphZoom: null,
svgSelection: null,
zoomScale: null,
graphConfig: {
drawInputNode: false,
drawOutputNode: false,
drawTempNode: false,
inputMap: {},
},
};
},
watch: {
......@@ -69,6 +75,144 @@ export default {
},
},
methods: {
// get the relate operator of a variable
getVarRelateDict(graphData) {
let dic = {};
for (let i = 0; i < graphData['node'].length; ++i) {
let curOperatorNode = graphData['node'][i];
let nodeKey = 'opNode_' + i;
// record all relate operator of a variable
if (has(curOperatorNode, 'output') && isArrayLike(curOperatorNode['output'])) {
for (let j = 0; j < curOperatorNode['output'].length; ++j) {
let outputData = curOperatorNode['output'][j];
if (!dic[outputData]) {
dic[outputData] = {'input': [], 'output': []};
}
let arr = dic[outputData]['input'];
arr[arr.length] = nodeKey;
}
}
if (has(curOperatorNode, 'input') && isArrayLike(curOperatorNode['input'])) {
for (let j = 0; j < curOperatorNode['input'].length; ++j) {
let inputData = curOperatorNode['input'][j];
if (!dic[inputData]) {
dic[inputData] = {'input': [], 'output': []};
}
let arr = dic[inputData]['output'];
arr[arr.length] = nodeKey;
}
}
}
return dic;
},
getInputNodeStyle() {
return 'opacity: 0.1; ' +
'stroke-width: 3px; ' +
'stroke: #333; ' +
'stroke-color: #41b3a3; ' +
'fill: #6c648b; ';
},
getOutputNodeStyle() {
return 'opacity: 0.1;' +
'stroke-width: 3px; ' +
'stroke-dasharray: 5, 5;' +
'stroke: #333;' +
'stroke-color: #41b3a3; ' +
'fill: #015249; ';
},
getTempNodeStyle() {
return getOutputNodeStyle();
},
getOpNodeStyle() {
return 'stroke-width: 3px; ' +
'opacity: 0.1; ' +
'rx: 10; ' +
'ry: 10; ' +
'stroke: #333; ' +
'stroke-color: #41b3a3; ' +
'fill: #008c99; ';
},
setGraphNode(graph, nodeKey, labelVal, shapeVal, className, styleVal) {
graph.setNode(
nodeKey,
{
label: labelVal,
shape: shapeVal,
class: className,
style: styleVal,
}
);
},
buildInputNodeLabel(inputNode) {
// TODO(daming-lu): need more complex compound node
let nodeLabel = 'id: ' + inputNode['name'] + '\n'
+ 'type: ' + inputNode['data_type'] + '\n'
+ 'dims: ' + inputNode['shape'].join(' x ');
return nodeLabel;
},
buildGraph(graph, graphData) {
// add operator node
for (let i = 0; i < graphData['node'].length; ++i) {
let curOperatorNode = graphData['node'][i];
let nodeKey = 'opNode_' + i;
// add operator node
let curOpLabel = curOperatorNode['opType'];
curOpLabel = curOpLabel + ' '.repeat(Math.floor(curOpLabel.length/5));
this.setGraphNode(graph, nodeKey, curOpLabel, 'rect', 'operator', this.getOpNodeStyle());
}
let dic = this.getVarRelateDict(graphData);
for (let obj in dic) {
if (!dic.hasOwnProperty(obj)) continue;
// add input node
if (dic[obj]['input'].length === 0 && this.graphConfig.drawInputNode === true) {
let temp = obj.indexOf('@');
let nodeKey = obj;
if (temp > 0) {
nodeKey = obj.substr(0, temp);
}
let index = this.graphConfig.inputMap[nodeKey];
let curInputNode = graphData['input'][index];
this.setGraphNode(graph, nodeKey,
this.buildInputNodeLabel(curInputNode), 'rect', 'input', this.getInputNodeStyle());
for (let output in dic[obj]['output']) {
if (!dic[obj]['output'].hasOwnProperty(output)) continue;
graph.setEdge(nodeKey, dic[obj]['output'][output]);
}
}
// add output node
if (dic[obj]['output'].length === 0 && this.graphConfig.drawOutputNode === true) {
let nodeKey = obj;
let outputPadding = ' '.repeat(Math.floor(nodeKey.length/2));
this.setGraphNode(graph, nodeKey, nodeKey + outputPadding, 'diamond', 'output', this.getOutputNodeStyle());
for (let input in dic[obj]['input']) {
if (!dic[obj]['input'].hasOwnProperty(input)) continue;
graph.setEdge(nodeKey, dic[obj]['input'][input]);
}
}
for (let input in dic[obj]['input']) {
if (!dic[obj]['input'].hasOwnProperty(input)) continue;
for (let output in dic[obj]['output']) {
if (!dic[obj]['output'].hasOwnProperty(output)) continue;
if (this.graphConfig.drawTempNode === true) {
let nodeKey = obj;
let outputPadding = ' '.repeat(Math.floor(nodeKey.length/2));
this.setGraphNode(graph,
nodeKey, nodeKey + outputPadding, 'diamond', 'output', this.getOutputNodeStyle());
graph.setEdge(dic[obj]['input'][input], nodeKey);
graph.setEdge(nodeKey, dic[obj]['output'][output]);
} else {
graph.setEdge(dic[obj]['input'][input], dic[obj]['output'][output]);
}
}
}
}
},
restoreImage(thenDownload) {
let chartScope = this;
let svg = d3.select('svg');
......@@ -99,6 +243,8 @@ export default {
this.zoomScale = linearScale;
},
mounted() {
// some model is too large to draw in dagred3, so don't draw input and output node in default
let chartScope = this;
getPluginGraphsGraph().then(({errno, data}) => {
if (has(data, 'data') === false) {
......@@ -106,6 +252,9 @@ export default {
}
let graphData = data.data;
if (has(graphData, 'node') === false) {
return;
}
// d3 svg drawing
let g = new dagreD3.graphlib.Graph()
......@@ -116,119 +265,29 @@ export default {
// eslint-disable-next-line
let render = new dagreD3.render();
let nodeKeys = [];
let buildInputNodeLabel = function(inputNode) {
// TODO(daming-lu): need more complex compound node
let nodeLabel = 'id: ' + inputNode['name'] + '\n'
+ 'type: ' + inputNode['data_type'] + '\n'
+ 'dims: ' + inputNode['shape'].join(' x ');
return nodeLabel;
};
// add input nodes
if (has(graphData, 'input') === false) {
return;
}
let inputIdToIndex = {};
for (let i=0; i<graphData['input'].length; ++i) {
let curInputNode = graphData['input'][i];
let nodeKey = curInputNode['name'];
inputIdToIndex[nodeKey] = i;
g.setNode(
nodeKey,
{
label: buildInputNodeLabel(curInputNode),
style: 'opacity: 0.1; ' +
'stroke-width: 3px; ' +
'stroke: #333; ' +
'stroke-color: #41b3a3; ' +
'fill: #6c648b; ',
class: 'input',
labelStyle: 'font-size: 0.8em;',
}
);
nodeKeys.push(nodeKey);
}
// add operator nodes then add edges from inputs to operator and from operator to output
if (has(graphData, 'node') === false) {
return;
}
for (let i=0; i<graphData['node'].length; ++i) {
let curOperatorNode = graphData['node'][i];
let nodeKey = 'opNode_' + i;
// add operator node
let curOpLabel = curOperatorNode['opType'];
g.setNode(
nodeKey,
{
label: curOpLabel + ' '.repeat(Math.floor(curOpLabel.length/5)),
shape: 'rect',
class: 'operator',
style: 'stroke-width: 3px; ' +
'opacity: 0.1; ' +
'rx: 10; ' +
'ry: 10; ' +
'stroke: #333; ' +
'stroke-color: #41b3a3; ' +
'fill: #008c99; ',
}
);
nodeKeys.push(nodeKey);
// add output node
if (has(graphData, 'output') === false) {
return;
}
let outputNodeKey = curOperatorNode['output'][0];
let outputPadding = ' '.repeat(Math.floor(outputNodeKey.length/2));
g.setNode(
outputNodeKey,
{
label: outputNodeKey + outputPadding,
class: 'output',
style: 'opacity: 0.1;' +
'stroke-width: 3px; ' +
'stroke-dasharray: 5, 5;' +
'stroke: #333;' +
'stroke-color: #41b3a3; ' +
'fill: #015249; ',
shape: 'diamond',
}
);
nodeKeys.push(outputNodeKey);
// add edges from inputs to node and from node to output
if (has(curOperatorNode, 'input') && isArrayLike(curOperatorNode['input'])) {
for (let e = 0; e < curOperatorNode['input'].length; ++e) {
g.setEdge(curOperatorNode['input'][e], nodeKey);
}
}
if (has(curOperatorNode, 'output') && isArrayLike(curOperatorNode['output'])) {
g.setEdge(nodeKey, curOperatorNode['output'][0], {
style: 'stroke: #333;stroke-width: 1.5px',
});
if (has(graphData, 'input') === true && this.graphConfig.drawInputNode === true) {
for (let i = 0; i < graphData['input'].length; ++i) {
let name = graphData['input'][i]['name'];
this.graphConfig.inputMap[name] = i;
}
}
this.buildGraph(g, graphData);
let svg = d3.select('svg')
.attr('font-family', 'sans-serif')
.attr('font-size', '28px');
render(d3.select('svg g'), g);
let graphSelection = d3.select('svg g');
let graphWidth = g.graph().width;
let graphHeight = g.graph().height;
svg.attr('viewBox', '0 0 ' + graphWidth + ' ' + graphHeight);
this.imageWidth = graphWidth;
this.imageWidth = graphWidth * 1.1;
this.imageHeight = graphHeight;
this.originImageWidth = graphWidth;
this.originImageWidth = graphWidth * 1.1;
this.originImageHeight = graphHeight;
// zooming
......@@ -255,7 +314,7 @@ export default {
let opIndex = d.slice(7); // remove prefix "opNode_"
nodeInfo = graphData.node[opIndex];
} else if (nodeType === 'input') {
nodeInfo = graphData.input[inputIdToIndex[d]];
nodeInfo = graphData.input[this.graphConfig.inputMap[d]];
} else {
nodeInfo = 'output';
}
......
......@@ -96,7 +96,9 @@ packages = [
'visualdl.python',
'visualdl.server',
'visualdl.server.mock',
'visualdl.server.onnx',
'visualdl.server.model',
'visualdl.server.model.onnx',
'visualdl.server.model.paddle',
]
libraries = ['core.so']
......
......@@ -20,7 +20,8 @@ import os
from google.protobuf.json_format import MessageToJson
from . import graphviz_graph as gg
from . import onnx
from .model import onnx
from .model import paddle
def debug_print(json_obj):
......@@ -486,11 +487,17 @@ class GraphPreviewGenerator(object):
return self.graph.edge(source, target, **kwargs)
def draw_graph(model_pb_path):
def draw_onnx_graph(model_pb_path):
json_str = load_model(model_pb_path)
return json_str
def draw_paddle_graph(model_pb_path):
pm = paddle.PaddleModel(model_pb_path)
json_str = pm.to_graph_data()
return json_str
if __name__ == '__main__':
import sys
current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
......
from __future__ import absolute_import
from .paddle.paddle2graph import is_paddle_model
from .onnx import is_onnx_model
__all__ = [is_paddle_model, is_onnx_model]
......@@ -3,11 +3,19 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import google.protobuf.message
import os
from .onnx_pb2 import ModelProto
def is_onnx_model(model_path):
if not os.path.isfile(model_path):
return False
res = load(model_path)
return res is not None and str(res) != ""
def load(obj):
'''
Loads a binary protobuf that stores onnx model
......@@ -37,7 +45,9 @@ def load_from_string(s):
decoded = model.ParseFromString(s)
# in python implementation ParseFromString returns None
if decoded is not None and decoded != len(s):
raise google.protobuf.message.DecodeError(
"Protobuf decoding consumed too few bytes: {} out of {}".format(
decoded, len(s)))
return None
# raise google.protobuf.message.DecodeError(
# "Protobuf decoding consumed too few bytes: {} out of {}".format(
# decoded, len(s)))
return model
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from .paddle2graph import PaddleModel
__all__ = [PaddleModel]
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//This file is borrowed from project PaddlePaddle to better support the models from it.
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should
// raise the version defined version.h.
//
// Serailization and Deserialization codes should be modified in a way
// that supports old versions following the version and compatibility policy.
message Version { optional int64 version = 1 [ default = 0 ]; }
enum AttrType {
INT = 0;
FLOAT = 1;
STRING = 2;
INTS = 3;
FLOATS = 4;
STRINGS = 5;
BOOLEAN = 6;
BOOLEANS = 7;
BLOCK = 8;
LONG = 9;
BLOCKS = 10;
LONGS = 11;
}
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
message Attr {
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional float f = 4;
optional string s = 5;
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
optional bool b = 10;
repeated bool bools = 11;
optional int32 block_idx = 12;
optional int64 l = 13;
repeated int32 blocks_idx = 14;
repeated int64 longs = 15;
};
message Var {
required string parameter = 1;
repeated string arguments = 2;
};
required string type = 3;
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
optional bool is_target = 5 [ default = false ];
};
// OpProto describes a C++ framework::OperatorBase derived class.
message OpProto {
// VarProto describes the C++ type framework::Variable.
message Var {
required string name = 1;
required string comment = 2;
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ];
}
// AttrProto describes the C++ type Attribute.
message Attr {
required string name = 1;
required AttrType type = 2;
required string comment = 3;
// If that attribute is generated, it means the Paddle third
// language binding has responsibility to fill that
// attribute. End-User should not set that attribute.
optional bool generated = 4 [ default = false ];
}
required string type = 1;
repeated Var inputs = 2;
repeated Var outputs = 3;
repeated Attr attrs = 4;
required string comment = 5;
}
message VarType {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW = 17;
TUPLE = 18;
}
required Type type = 1;
message TensorDesc {
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorDesc lod_tensor = 3;
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorArrayDesc tensor_array = 4;
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
message Tuple { repeated Type element_type = 1; }
optional Tuple tuple = 7;
}
message VarDesc {
required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
}
message BlockDesc {
required int32 idx = 1;
required int32 parent_idx = 2;
repeated VarDesc vars = 3;
repeated OpDesc ops = 4;
optional int32 forward_block_idx = 5 [ default = -1 ];
}
// Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md
// for more details.
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message ProgramDesc {
repeated BlockDesc blocks = 1;
optional Version version = 2;
}
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from .framework_pb2 import ProgramDesc
import os
def is_paddle_model(model_path):
model_file_path = os.path.join(model_path, "__model__")
if not os.path.exists(model_file_path):
return False
return PaddleModel.load_from_model_file(model_file_path) is not None
class PaddleOp:
def __init__(self, op, op_count_dict, var_count_dict):
self.type = op.type
self.index = op_count_dict.setdefault(self.type, 0)
op_count_dict[self.type] += 1
self.name = "%s_%d" % (self.type, self.index)
self.inputs = {}
self.outputs = {}
# rename var to avoid cycle
for input in op.inputs:
for argument in input.arguments:
cnt = var_count_dict.setdefault(argument, 0)
val = "%s@%d" % (argument, cnt)
self.inputs[argument] = val
for output in op.outputs:
for argument in output.arguments:
var_count_dict.setdefault(argument, 0)
var_count_dict[argument] += 1
cnt = var_count_dict[argument]
val = "%s@%d" % (argument, cnt)
self.outputs[argument] = val
def set_input_name(self, key, name):
self.inputs[key] = name
def set_output_name(self, key, name):
self.outputs[key] = name
def get_input_name(self, key):
return self.inputs[key]
def get_output_name(self, key):
return self.outputs[key]
def get_input_set(self):
return set([argument for argument in self.inputs])
def get_output_set(self):
return set([argument for argument in self.outputs])
def get_name(self):
return self.name
def is_compute_op(self):
return not self.is_fetch_op() and not self.is_feed_op()
def is_fetch_op(self):
return self.type in ["fetch"]
def is_feed_op(self):
return self.type in ["feed"]
def get_info(self):
return {
"opType": self.name,
"input": [value for key, value in self.inputs.items()],
"output": [value for key, value in self.outputs.items()]
}
class PaddleVar:
def __init__(self, var):
tensor = var.type.lod_tensor.tensor
self.shape = [i for i in tensor.dims] if len(
tensor.dims) != 0 else [-1]
self.name = var.name
self.data_type = tensor.data_type
self.input_ops = set()
self.output_ops = set()
self.info = {
"shape": self.shape,
"data_type": self.data_type,
"name": self.name
}
def add_input_op(self, input):
self.input_ops.add(input)
def add_output_op(self, output):
self.output_ops.add(output)
def get_op_intersection(self):
return self.input_ops.intersection(self.output_ops)
def get_op_difference(self):
return self.input_ops.difference(self.output_ops)
def get_info(self):
return self.info
class PaddleModel:
# load from a pb model
@staticmethod
def load_from_model_file(model_path):
with open(model_path, "rb") as file:
string = file.read()
program = ProgramDesc()
decoded = program.ParseFromString(string)
if decoded is not None and decoded != len(string):
return None
return program
def __init__(self, model_path, name=None):
model_path = os.path.join(model_path, "__model__")
self.name = name if name is not None else "PaddleGraph"
self.program = PaddleModel.load_from_model_file(model_path)
self.vars = {
var.name: PaddleVar(var)
for block in self.program.blocks for var in block.vars
}
self.ops = {}
self.input_set = set()
self.output_set = set()
self.op_count_dict = {}
self.var_count_dict = {}
for block in self.program.blocks:
for op in block.ops:
op_obj = PaddleOp(op, self.op_count_dict, self.var_count_dict)
self.ops[op_obj.get_name()] = op_obj
# record all input
for input in op.inputs:
for argument in input.arguments:
self.vars[argument].add_input_op(op_obj.get_name())
savename = op_obj.get_input_name(argument)
self.input_set.add(savename)
self.vars[savename] = self.vars[argument]
for output in op.outputs:
for argument in output.arguments:
self.vars[argument].add_output_op(op_obj.get_name())
savename = op_obj.get_output_name(argument)
self.output_set.add(savename)
self.vars[savename] = self.vars[argument]
# return data to front-end
def to_graph_data(self):
# record the input and output var in a graph
inputs = []
outputs = []
# remove vars bind with feed operator
for key, op in self.ops.items():
if op.is_feed_op():
self.input_set = self.input_set.union(op.get_output_set())
self.output_set = self.output_set.difference(
op.get_output_set())
self.input_set = self.input_set.difference(op.get_input_set())
# remove vars bind with fetch operator
for key, op in self.ops.items():
if op.is_fetch_op():
self.output_set = self.output_set.union(op.get_input_set())
self.input_set = self.input_set.difference(op.get_input_set())
self.output_set = self.output_set.difference(
op.get_output_set())
for input in self.input_set:
if input in self.output_set:
continue
inputs.append(self.vars[input].get_info())
for output in self.output_set:
if output in self.input_set:
continue
outputs.append(self.vars[output].get_info())
return {
"node": [
op.get_info() for _, op in self.ops.items()
if op.is_compute_op()
],
"input":
inputs,
"output":
outputs
}
if __name__ == "__main__":
import sys
import json
pm = PaddleModel(sys.argv[1])
print(json.dumps(pm.to_graph_data()))
......@@ -27,6 +27,7 @@ from flask import (Flask, Response, redirect, request, send_file,
import visualdl
import visualdl.server
import visualdl.server.graph as vdl_graph
import visualdl.server.model as model
from visualdl.server import lib
from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data
......@@ -86,7 +87,7 @@ def parse_args():
"--model_pb",
type=str,
action="store",
help="model proto in ONNX format")
help="model proto in ONNX format or in Paddle framework format")
parser.add_argument(
"--logdir",
required=True,
......@@ -319,7 +320,12 @@ def individual_audio():
@app.route('/data/plugin/graphs/graph')
def graph():
json_str = vdl_graph.draw_graph(args.model_pb)
if model.is_onnx_model(args.model_pb):
json_str = vdl_graph.draw_onnx_graph(args.model_pb)
elif model.is_paddle_model(args.model_pb):
json_str = vdl_graph.draw_paddle_graph(args.model_pb)
else:
json_str = {}
data = {'data': json_str}
result = gen_result(0, "", data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册