未验证 提交 5b9dacfc 编写于 作者: R RotPublic 提交者: GitHub

Add dynamic graph frontend

上级 96fb036b
...@@ -27,9 +27,9 @@ module.exports = { ...@@ -27,9 +27,9 @@ module.exports = {
}, },
plugins: ['license-header'], plugins: ['license-header'],
rules: { rules: {
'no-console': 'warn', 'sort-imports': 'warn',
'sort-imports': 'error', 'no-console': 'warn'
'license-header/header': ['error', './license-header.js'] // 'license-header/header': ['error', './license-header.js']
}, },
overrides: [ overrides: [
{ {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
"dev": "snowpack dev", "dev": "snowpack dev",
"dev:reload": "yarn dev --reload", "dev:reload": "yarn dev --reload",
"build": "snowpack build && node builder/post-build.js", "build": "snowpack build && node builder/post-build.js",
"lint": "eslint --ext .tsx,.jsx.ts,.js,.mjs",
"snowpack": "snowpack", "snowpack": "snowpack",
"test": "web-test-runner \"test/**/*.tsx\"" "test": "web-test-runner \"test/**/*.tsx\""
}, },
...@@ -48,6 +49,7 @@ ...@@ -48,6 +49,7 @@
"d3-format": "3.0.1", "d3-format": "3.0.1",
"echarts": "4.9.0", "echarts": "4.9.0",
"echarts-gl": "1.1.2", "echarts-gl": "1.1.2",
"eslint-plugin-simple-import-sort": "^7.0.0",
"eventemitter3": "4.0.7", "eventemitter3": "4.0.7",
"file-saver": "2.0.5", "file-saver": "2.0.5",
"i18next": "20.6.0", "i18next": "20.6.0",
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
"empty": "Nothing to display", "empty": "Nothing to display",
"error": "Error occurred", "error": "Error occurred",
"graph": "Graphs", "graph": "Graphs",
"graphDynamic": "dynamic",
"graphStatic": "static",
"high-dimensional": "High Dimensional", "high-dimensional": "High Dimensional",
"histogram": "Histogram", "histogram": "Histogram",
"hyper-parameter": "Hyper Parameters", "hyper-parameter": "Hyper Parameters",
......
...@@ -43,12 +43,15 @@ ...@@ -43,12 +43,15 @@
"type": "Type", "type": "Type",
"version": "Version" "version": "Version"
}, },
"restore-size": "Restore Size", "restore-size": "Fully Shrinked data model",
"expend-size": "Fully expanded Data Model",
"show-attributes": "Show Attributes", "show-attributes": "Show Attributes",
"show-initializers": "Show Initializers", "show-initializers": "Show Initializers",
"show-node-names": "Show Node Names", "show-node-names": "Show Node Names",
"keep-expanded": "keep expanded",
"subgraph": "Select Subgraph", "subgraph": "Select Subgraph",
"supported-model": "Supported models: ", "supported-model": "Supported models: ",
"Choose-model": "Choose a model",
"supported-model-list": "PaddlePaddle, ONNX, Keras, Core ML, Caffe, Caffe2, Darknet, MXNet, ncnn, TensorFlow Lite", "supported-model-list": "PaddlePaddle, ONNX, Keras, Core ML, Caffe, Caffe2, Darknet, MXNet, ncnn, TensorFlow Lite",
"upload-model": "Upload Model", "upload-model": "Upload Model",
"upload-tip": "Click or Drop file here to view neural network models", "upload-tip": "Click or Drop file here to view neural network models",
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
"empty": "暂无数据", "empty": "暂无数据",
"error": "发生错误", "error": "发生错误",
"graph": "网络结构", "graph": "网络结构",
"graphDynamic": "动态",
"graphStatic": "静态",
"high-dimensional": "数据降维", "high-dimensional": "数据降维",
"histogram": "直方图", "histogram": "直方图",
"hyper-parameter": "超参可视化", "hyper-parameter": "超参可视化",
......
...@@ -43,16 +43,19 @@ ...@@ -43,16 +43,19 @@
"type": "类型", "type": "类型",
"version": "版本" "version": "版本"
}, },
"restore-size": "重置大小", "restore-size": "全收缩数据模型",
"expend-size": "全展开数据模型",
"show-attributes": "显示参数", "show-attributes": "显示参数",
"show-initializers": "显示初始化参数", "show-initializers": "显示初始化参数",
"show-node-names": "显示节点名称", "show-node-names": "显示节点名称",
"subgraph": "选择子图", "subgraph": "选择子图",
"keep-expanded": "保持展开",
"supported-model": "VisualDL支持:", "supported-model": "VisualDL支持:",
"supported-model-list": "PaddlePaddle、ONNX、Keras、Core ML、Caffe、Caffe2、Darknet、MXNet、ncnn、TensorFlow Lite", "supported-model-list": "PaddlePaddle、ONNX、Keras、Core ML、Caffe、Caffe2、Darknet、MXNet、ncnn、TensorFlow Lite",
"upload-model": "上传模型", "upload-model": "上传模型",
"upload-tip": "点击或拖拽文件到页面上传模型,进行结构展示", "upload-tip": "点击或拖拽文件到页面上传模型,进行结构展示",
"vertical": "垂直", "vertical": "垂直",
"Choose-model": "选择模型",
"zoom-in": "放大", "zoom-in": "放大",
"zoom-out": "缩小" "zoom-out": "缩小"
} }
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import React, {FunctionComponent, useCallback} from 'react';
import {ellipsis, em, half, math, position, sameBorder, size, transitionProps} from '~/utils/style';
import styled from 'styled-components';
const height = em(20);
const checkSize = em(16);
const checkMark =
// eslint-disable-next-line
'data:image/svg+xml;base64,PHN2ZyBoZWlnaHQ9IjgiIHZpZXdCb3g9IjAgMCAxMSA4IiB3aWR0aD0iMTEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+PHBhdGggZD0ibTkuNDc5NDI3MDggMTAuMTg3NWgtNS4yNXYtMS4zMTI1aDMuOTM3bC4wMDA1LTcuODc1aDEuMzEyNXoiIGZpbGw9IiNmNWY1ZjUiIGZpbGwtcnVsZT0iZXZlbm9kZCIgdHJhbnNmb3JtPSJtYXRyaXgoLjcwNzEwNjc4IC43MDcxMDY3OCAtLjcwNzEwNjc4IC43MDcxMDY3OCA0Ljk2Mjk5NCAtNi4yMDg0NCkiLz48L3N2Zz4=';
const Wrapper = styled.label<{disabled?: boolean}>`
position: relative;
display: inline-flex;
align-items: flex-start;
cursor: ${props => (props.disabled ? 'not-allowed' : 'pointer')};
`;
const Input = styled.input.attrs<{disabled?: boolean}>(props => ({
type: 'checkbox',
disabled: !!props.disabled
}))`
${size(0)}
${position('absolute', 0, null, null, 0)}
opacity: 0;
pointer-events: none;
`;
const Inner = styled.div<{checked?: boolean; size?: string; disabled?: boolean}>`
color: ${props => (props.checked ? 'var(--text-invert-color)' : 'transparent')};
flex-shrink: 0;
${props => size(math(`${checkSize} * ${props.size === 'small' ? 0.875 : 1}`))}
margin: ${half(`${height} - ${checkSize}`)} 0;
margin-right: ${em(10)};
${props =>
sameBorder({color: props.disabled || !props.checked ? 'var(--text-lighter-color)' : 'var(--primary-color)'})};
background-color: ${props =>
props.disabled
? props.checked
? 'var(--text-lighter-color)'
: 'transparent'
: props.checked
? 'var(--primary-color)'
: 'var(--background-color)'};
background-image: ${props => (props.checked ? `url("${checkMark}")` : 'none')};
background-repeat: no-repeat;
background-position: center center;
background-size: ${em(10)} ${em(8)};
position: relative;
${transitionProps(['border-color', 'background-color', 'color'])}
${Wrapper}:hover > & {
border-color: ${props =>
props.disabled
? 'var(--text-lighter-color)'
: props.checked
? 'var(--primary-color)'
: 'var(--text-lighter-color)'};
}
`;
const Content = styled.div<{disabled?: boolean}>`
line-height: ${height};
flex-grow: 1;
${props => (props.disabled ? 'color: var(--text-lighter-color);' : '')}
${transitionProps('color')}
${ellipsis()}
`;
type CheckboxProps = {
value: string;
checked?: boolean;
className?: string;
onChange?: (checked: string) => unknown;
size?: 'small';
title?: string;
disabled?: boolean;
};
const Checkbox: FunctionComponent<CheckboxProps> = ({
value,
checked,
children,
size,
disabled,
className,
title,
onChange
}) => {
const onChangeInput = useCallback(() => {
if (disabled) {
return;
}
if (onChange) {
onChange(value);
}
}, [disabled, onChange]);
return (
<Wrapper disabled={disabled} className={className} title={title}>
<Input onChange={onChangeInput} checked={checked} disabled={disabled} />
<Inner checked={checked} size={size} disabled={disabled} />
<Content disabled={disabled}>{children}</Content>
</Wrapper>
);
};
export default Checkbox;
...@@ -61,7 +61,7 @@ const Wrapper = styled.div` ...@@ -61,7 +61,7 @@ const Wrapper = styled.div`
${transitionProps('color')} ${transitionProps('color')}
&:hover, &:hover,
&:active { &:active {
color: var(--text-light-color); color: var(--text-light-color);
} }
} }
...@@ -96,7 +96,7 @@ const Argument: FunctionComponent<ArgumentProps> = ({value, expand, showNodeDocu ...@@ -96,7 +96,7 @@ const Argument: FunctionComponent<ArgumentProps> = ({value, expand, showNodeDocu
{value.name}: <b>{value.value}</b> {value.name}: <b>{value.value}</b>
</> </>
) : ( ) : (
value.value.split('\n').map((line, index) => ( new String(value.value).split('\n').map((line, index) => (
<React.Fragment key={index}> <React.Fragment key={index}>
{index !== 0 && <br />} {index !== 0 && <br />}
{line} {line}
......
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import type {Documentation, OpenedResult, Properties, SearchItem, SearchResult} from '~/resource/graph/types';
import React, {useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState} from 'react';
import {contentHeight, position, primaryColor, rem, size, transitionProps} from '~/utils/style';
import ChartToolbox from '~/components/ChartToolbox';
import HashLoader from 'react-spinners/HashLoader';
import logo from '~/assets/images/netron.png';
import netron from '@visualdl/netron';
import styled from 'styled-components';
import {toast} from 'react-toastify';
import {fetcher} from '~/utils/fetch';
import useTheme from '~/hooks/useTheme';
import {useTranslation} from 'react-i18next';
const PUBLIC_PATH: string = import.meta.env.SNOWPACK_PUBLIC_PATH;
let IFRAME_HOST = `${window.location.protocol}//${window.location.host}`;
if (PUBLIC_PATH.startsWith('http')) {
const url = new URL(PUBLIC_PATH);
IFRAME_HOST = `${url.protocol}//${url.host}`;
}
const toolboxHeight = rem(40);
const Wrapper = styled.div`
position: relative;
height: ${contentHeight};
background-color: var(--background-color);
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
${transitionProps('background-color')}
`;
const RenderContent = styled.div<{show: boolean}>`
position: absolute;
top: 0;
left: 0;
${size('100%', '100%')}
opacity: ${props => (props.show ? 1 : 0)};
z-index: ${props => (props.show ? 0 : -1)};
pointer-events: ${props => (props.show ? 'auto' : 'none')};
`;
const Toolbox = styled(ChartToolbox)`
height: ${toolboxHeight};
border-bottom: 1px solid var(--border-color);
padding: 0 ${rem(20)};
${transitionProps('border-color')}
`;
const Content = styled.div`
position: relative;
height: calc(100% - ${toolboxHeight});
> iframe {
${size('100%', '100%')}
border: none;
}
> .powered-by {
display: block;
${position('absolute', null, null, rem(20), rem(30))}
color: var(--graph-copyright-color);
font-size: ${rem(14)};
user-select: none;
img {
height: 1em;
filter: var(--graph-copyright-logo-filter);
vertical-align: middle;
}
}
`;
const Loading = styled.div`
${size('100%', '100%')}
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
overscroll-behavior: none;
cursor: progress;
font-size: ${rem(16)};
line-height: ${rem(60)};
`;
export type GraphRef = {
export(type: 'svg' | 'png'): void;
changeGraph(name: string): void;
search(value: string): void;
select(item: SearchItem): void;
setSelectItems(data: Theobj): void;
setLoadings(data: boolean): void;
showModelProperties(): void;
showNodeDocumentation(data: Theobj): void;
};
interface Theobj {
[propname: string]: unknown;
}
type GraphProps = {
files: FileList | File[] | null;
uploader: JSX.Element;
showAttributes: boolean;
showInitializers: boolean;
showNames: boolean;
horizontal: boolean;
isKeepData: boolean;
runs: string[] | undefined;
selectedRuns: string;
onRendered?: () => unknown;
onOpened?: (data: OpenedResult) => unknown;
onSearch?: (data: SearchResult) => unknown;
onShowModelProperties?: (data: Properties) => unknown;
onShowNodeProperties?: (data: Properties) => unknown;
onShowNodeDocumentation?: (data: Documentation) => unknown;
};
const Graph = React.forwardRef<GraphRef, GraphProps>(
(
{
uploader,
showAttributes,
runs,
selectedRuns,
isKeepData,
showInitializers,
showNames,
horizontal,
onRendered,
onOpened,
onSearch,
onShowModelProperties,
onShowNodeProperties,
onShowNodeDocumentation
},
ref
) => {
const {t, i18n} = useTranslation('graph');
const language: string = i18n.language;
const theme = useTheme();
const [ready, setReady] = useState(false);
const [rendered, setRendered] = useState(false);
const [loading, setLoading] = useState(true);
const [item, setSelectItem] = useState<Theobj | null>();
const [isExpend, setIsExpend] = useState(0);
const [isRetract, setIsretract] = useState(0);
const [modelDatas, setModelDatas] = useState<Theobj>();
const [allModelDatas, setAllModelDatas] = useState<Theobj>();
const [selectNodeId, setSelectNodeId] = useState();
const [searchNodeId, setSearchNodeId] = useState<Theobj>();
const iframe = useRef<HTMLIFrameElement>(null);
const handler = useCallback(
(event: MessageEvent) => {
if (event.data) {
const {type, data} = event.data;
switch (type) {
case 'status':
switch (data) {
case 'ready':
return setReady(true);
case 'loading':
// return setLoading(true);
return 1;
case 'rendered':
setLoading(false);
setRendered(true);
// changeSvg()
onRendered?.();
return;
}
return;
case 'opened':
return onOpened?.(data);
case 'search':
return onSearch?.(data);
case 'cancel':
return setLoading(false);
case 'error':
toast.error(data);
setLoading(false);
return;
case 'show-model-properties':
return onShowModelProperties?.(data);
case 'show-node-properties':
return onShowNodeProperties?.(data);
case 'show-node-documentation':
return onShowNodeDocumentation?.(data);
case 'nodeId':
return setSelectNodeId?.(data);
case 'selectItem':
return setSelectItem?.(data);
}
}
},
[onRendered, onOpened, onSearch, onShowModelProperties, onShowNodeProperties, onShowNodeDocumentation]
);
const dispatch = useCallback((type: string, data?: unknown) => {
iframe.current?.contentWindow?.postMessage(
{
type,
data
},
IFRAME_HOST
);
}, []);
useEffect(() => {
keydown();
}, []);
useEffect(() => {
window.addEventListener('message', handler);
dispatch('ready');
return () => {
window.removeEventListener('message', handler);
};
}, [handler, dispatch]);
useEffect(() => {
if (selectedRuns) {
setLoading(true);
getGraph();
// getAllGraph()
}
}, [selectedRuns]);
useEffect(() => {
if (isExpend) {
// debugger
setLoading(true);
const refresh = false;
const expand_all = true;
fetcher(
'/graph/graph' + `?run=${selectedRuns}` + `&refresh=${refresh}` + `&expand_all=${expand_all}`
).then((res: Theobj) => {
setSelectItem(null);
setModelDatas(res);
});
}
}, [isExpend]);
useEffect(() => {
if (isRetract) {
// debugger
setLoading(true);
const refresh = true;
const expand_all = false;
fetcher(
'/graph/graph' + `?run=${selectedRuns}` + `&refresh=${refresh}` + `&expand_all=${expand_all}`
).then((res: Theobj) => {
setSelectItem(null);
setModelDatas(res);
});
}
}, [isRetract]);
useEffect(() => {
if (ready) {
dispatch('change-select', item);
}
}, [dispatch, item, ready]);
useEffect(() => {
if (!allModelDatas) {
return;
}
if (ready) {
dispatch('change-allGraph', allModelDatas);
}
}, [dispatch, allModelDatas, ready]);
useEffect(() => {
if (!modelDatas) {
return;
}
if (ready) {
dispatch('change-graph', modelDatas);
}
}, [dispatch, modelDatas, ready]);
useEffect(() => {
if (!selectNodeId) {
return;
}
// debugger;
setLoading(true);
const selectNodeIds: Theobj = selectNodeId;
fetcher(
'/graph/manipulate' +
`?run=${selectedRuns}` +
`&nodeid=${selectNodeIds.nodeId}` +
`&expand=${selectNodeIds.expand}` +
`&keep_state=${isKeepData}`
).then((res: Theobj) => {
setModelDatas(res);
});
}, [selectNodeId]);
useEffect(() => {
if (!searchNodeId) {
return;
}
// debugger
setLoading(true);
const searchNodeIds: Theobj = searchNodeId;
const is_node = searchNodeIds.type === 'node' ? true : false;
fetcher(
'/graph/search' +
`?run=${selectedRuns}` +
`&nodeid=${searchNodeIds.name}` +
`&keep_state=${isKeepData}` +
`&is_node=${is_node}`
).then((res: Theobj) => {
setModelDatas(res);
});
}, [searchNodeId]);
useEffect(
() => (ready && dispatch('toggle-attributes', showAttributes)) || undefined,
[dispatch, showAttributes, ready]
);
useEffect(
() => (ready && dispatch('toggle-initializers', showInitializers)) || undefined,
[dispatch, showInitializers, ready]
);
useEffect(() => (ready && dispatch('toggle-names', showNames)) || undefined, [dispatch, showNames, ready]);
useEffect(
() => (ready && dispatch('toggle-direction', horizontal)) || undefined,
[dispatch, horizontal, ready]
);
useEffect(() => (ready && dispatch('toggle-theme', theme)) || undefined, [dispatch, theme, ready]);
useEffect(() => (ready && dispatch('toggle-Language', language)) || undefined, [dispatch, language, ready]);
useImperativeHandle(ref, () => ({
export(type) {
dispatch('export', type);
},
changeGraph(name) {
dispatch('change-graph', name);
},
search(value) {
dispatch('search', value);
},
setSelectItems(data: Theobj) {
setSelectItem(data);
},
setLoadings(data: boolean) {
setLoading(data);
},
select(item) {
const a = document.querySelector('iframe') as HTMLIFrameElement;
const documents = a.contentWindow?.document as Document;
if (item.type === 'node') {
for (const node of documents.getElementsByClassName('cluster')) {
if (node.getAttribute('id') === `node-${item.name}`) {
dispatch('select', item);
return;
}
}
for (const node of documents.getElementsByClassName('node')) {
if (node.getAttribute('id') === `node-${item.name}`) {
dispatch('select', item);
return;
}
}
} else if (item.type === 'input') {
for (const node of documents.getElementsByClassName('edge-path')) {
if (node.getAttribute('id') === `edge-${item.name}`) {
dispatch('select', item);
return;
}
}
}
setSelectItem(item);
setSearchNodeId(item);
},
showModelProperties() {
dispatch('show-model-properties');
},
showNodeDocumentation(data) {
dispatch('show-node-documentation', data);
}
}));
const keydown = () => {
document.addEventListener('keydown', e => {
if (
e.code === 'MetaLeft' ||
e.code === 'MetaRight' ||
e.code === 'ControlLeft' ||
e.code === 'AltLeft' ||
e.code === 'AltRight'
) {
dispatch('isAlt', true);
}
});
document.addEventListener('keyup', e => {
if (
e.code === 'MetaLeft' ||
e.code === 'MetaRight' ||
e.code === 'ControlLeft' ||
e.code === 'AltLeft' ||
e.code === 'AltRight'
) {
dispatch('isAlt', false);
}
});
};
const getGraph = async () => {
const refresh = true;
const expand_all = false;
const result = await fetcher(
'/graph/graph' + `?run=${selectedRuns}` + `&refresh=${refresh}` + `&expand_all=${expand_all}`
);
const allResult = await fetcher('/graph/get_all_nodes' + `?run=${selectedRuns}`);
// const allResult = await fetcher('/graph/graph' + `?run=${selectedRuns}`);
setSelectItem(null);
if (result) setModelDatas(result);
if (allResult) setAllModelDatas(allResult);
};
const content = useMemo(() => {
if (loading) {
return (
<Loading>
<HashLoader size="60px" color={primaryColor} />
</Loading>
);
}
return null;
}, [loading]);
const uploaderContent = useMemo(() => {
if (!runs && !loading) {
return uploader;
}
}, [runs, loading, uploader]);
const svgContent = useMemo(() => {
return (
<Content>
<iframe
ref={iframe}
src={PUBLIC_PATH + netron}
frameBorder={0}
scrolling="yes"
marginWidth={0}
marginHeight={0}
></iframe>
<a
className="powered-by"
href="https://github.com/lutzroeder/netron"
target="_blank"
rel="noreferrer"
>
Powered by <img src={PUBLIC_PATH + logo} alt="netron" />
</a>
</Content>
);
}, [rendered]);
return (
<Wrapper>
{content}
{uploaderContent}
<RenderContent show={!loading && rendered}>
<Toolbox
items={[
{
icon: 'zoom-in',
tooltip: t('graph:zoom-in'),
onClick: () => dispatch('zoom-in')
},
{
icon: 'zoom-out',
tooltip: t('graph:zoom-out'),
onClick: () => dispatch('zoom-out')
},
{
icon: 'restore-size',
tooltip: t('expend-size'),
onClick: () => {
const id = isExpend + 1;
setIsExpend(id);
}
},
{
icon: 'shrink',
tooltip: t('restore-size'),
onClick: () => {
const id = isRetract + 1;
setIsretract(id);
}
}
]}
reversed
tooltipPlacement="bottom"
/>
{svgContent}
</RenderContent>
</Wrapper>
);
}
);
Graph.displayName = 'Graph';
export default Graph;
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import Aside, {AsideSection} from '~/components/Aside';
import type {Documentation, OpenedResult, Properties, SearchItem, SearchResult} from '~/resource/graph/types';
import GraphComponent, {GraphRef} from '~/components/GraphPage/GraphDynamic';
import React, {FunctionComponent, useCallback, useEffect, useMemo, useRef, useState} from 'react';
import Select, {SelectProps} from '~/components/Select';
import {actions, selectors} from '~/store';
import {primaryColor, rem, size} from '~/utils/style';
import {useDispatch, useSelector} from 'react-redux';
import Check from '~/components/Check';
import Button from '~/components/Button';
import Checkbox from '~/components/Checkbox';
import Content from '~/components/Content';
import Field from '~/components/Field';
import ModelPropertiesDialog from '~/components/GraphPage/ModelPropertiesDialog';
import NodeDocumentationSidebar from '~/components/GraphPage/NodeDocumentationSidebar';
import NodePropertiesSidebar from '~/components/GraphPage/NodePropertiesSidebar';
import RadioButton from '~/components/RadioButton';
import RadioGroup from '~/components/RadioGroup';
import Search from '~/components/GraphPage/Search';
import Title from '~/components/Title';
import Uploader from '~/components/GraphPage/Uploader';
import styled from 'styled-components';
import {fetcher} from '~/utils/fetch';
import {useTranslation} from 'react-i18next';
const FullWidthButton = styled(Button)`
width: 100%;
`;
const FullWidthSelect = styled<React.FunctionComponent<SelectProps<NonNullable<OpenedResult['selected']>>>>(Select)`
width: 100%;
`;
const ExportButtonWrapper = styled.div`
display: flex;
justify-content: space-between;
> * {
flex: 1 1 auto;
&:not(:last-child) {
margin-right: ${rem(20)};
}
}
`;
// TODO: better way to auto fit height
const SearchSection = styled(AsideSection)`
max-height: calc(100% - ${rem(40)});
display: flex;
flex-direction: column;
&:not(:last-child) {
padding-bottom: 0;
}
`;
const Graph: FunctionComponent = () => {
const {t} = useTranslation(['graph', 'common']);
const storeDispatch = useDispatch();
const storeModel = useSelector(selectors.graph.model);
const graph = useRef<GraphRef>(null);
const file = useRef<HTMLInputElement>(null);
const [files, setFiles] = useState<FileList | File[] | null>(storeModel);
const [filesId, setFilesId] = useState(0);
const [runs, setRuns] = useState<string[]>();
const [selectedRuns, setSelectedRuns] = useState<string>('');
const [isKeepData, setIsKeepData] = useState(false);
const setModelFile = useCallback(
(f: FileList | File[]) => {
storeDispatch(actions.graph.setModel(f));
setFiles(f);
},
[storeDispatch]
);
const onClickFile = useCallback(() => {
if (file.current) {
file.current.value = '';
file.current.click();
}
}, []);
const onChangeFile = (e: React.ChangeEvent<HTMLInputElement>) => {
const target = e.target as EventTarget & HTMLInputElement;
const file: FileList | null = target.files as FileList;
if (file[0].name.split('.')[1] !== 'pdmodel') {
alert('该页面只能解析paddle的模型,如需解析请跳转网络结构静态图页面');
return;
}
if (target && target.files && target.files.length) {
fileUploader(target.files);
}
};
const fileUploader = (files: FileList) => {
const formData = new FormData();
// 将文件转二进制
formData.append('file', files[0]);
formData.append('filename', files[0].name);
fetcher('/graph/upload', {
method: 'POST',
body: formData
}).then(
res => {
// debugger
const newFilesId = filesId + 1;
setFilesId(newFilesId);
},
res => {
// debugger
const newFilesId = filesId + 1;
setFilesId(newFilesId);
}
);
};
const [modelGraphs, setModelGraphs] = useState<OpenedResult['graphs']>([]);
const [selectedGraph, setSelectedGraph] = useState<NonNullable<OpenedResult['selected']>>('');
const setOpenedModel = useCallback((data: OpenedResult) => {
setModelGraphs(data.graphs);
setSelectedGraph(data.selected || '');
}, []);
const changeGraph = useCallback((name: string) => {
setSelectedGraph(name);
graph.current?.changeGraph(name);
}, []);
const [search, setSearch] = useState('');
const [searching, setSearching] = useState(false);
const [searchResult, setSearchResult] = useState<SearchResult>({text: '', result: []});
const onSearch = useCallback((value: string) => {
setSearch(value);
graph.current?.search(value);
}, []);
const onSelect = useCallback((item: SearchItem) => {
setSearch(item.name);
graph.current?.select(item);
}, []);
const [showAttributes, setShowAttributes] = useState(false);
const [showInitializers, setShowInitializers] = useState(true);
const [showNames, setShowNames] = useState(false);
const [horizontal, setHorizontal] = useState(false);
const [modelData, setModelData] = useState<Properties | null>(null);
const [nodeData, setNodeData] = useState<Properties | null>(null);
const [nodeDocumentation, setNodeDocumentation] = useState<Documentation | null>(null);
useEffect(() => {
// debugger
fetcher('/graph_runs').then((res: unknown) => {
const result = res as string[];
setRuns(result);
setSelectedRuns(result[0]);
});
}, [filesId]);
useEffect(() => {
setSearch('');
setSearchResult({text: '', result: []});
}, [files, showAttributes, showInitializers, showNames]);
const bottom = useMemo(
() =>
searching ? null : (
<FullWidthButton type="primary" rounded onClick={onClickFile}>
{t('graph:change-model')}
</FullWidthButton>
),
[t, onClickFile, searching]
);
const [rendered, setRendered] = useState(false);
const content = (runs: string) => {
return <div>{runs}</div>;
// return (<p>Content</p>)
};
const aside = useMemo(() => {
if (!rendered) {
return null;
}
if (nodeDocumentation) {
return (
<Aside width={rem(360)}>
<NodeDocumentationSidebar data={nodeDocumentation} onClose={() => setNodeDocumentation(null)} />
</Aside>
);
}
if (nodeData) {
return (
<Aside width={rem(360)}>
<NodePropertiesSidebar
data={nodeData}
onClose={() => setNodeData(null)}
showNodeDocumentation={() => graph.current?.showNodeDocumentation(nodeData)}
/>
</Aside>
);
}
return (
<Aside bottom={bottom}>
<SearchSection>
<Search
text={search}
data={searchResult}
onChange={onSearch}
onSelect={onSelect}
onActive={() => setSearching(true)}
onDeactive={() => setSearching(false)}
/>
</SearchSection>
{!searching && (
<>
<AsideSection>
<FullWidthButton onClick={() => graph.current?.showModelProperties()}>
{t('graph:model-properties')}
</FullWidthButton>
</AsideSection>
{modelGraphs.length > 1 && (
<AsideSection>
<Field label={t('graph:subgraph')}>
<FullWidthSelect list={modelGraphs} value={selectedGraph} onChange={changeGraph} />
</Field>
</AsideSection>
)}
<AsideSection>
<Field label={t('graph:display-data')}>
<div>
<Checkbox checked={showAttributes} onChange={setShowAttributes}>
{t('graph:show-attributes')}
</Checkbox>
</div>
<div>
<Checkbox checked={showInitializers} onChange={setShowInitializers}>
{t('graph:show-initializers')}
</Checkbox>
</div>
<div>
<Checkbox checked={showNames} onChange={setShowNames}>
{t('graph:show-node-names')}
</Checkbox>
</div>
<div>
<Checkbox checked={isKeepData} onChange={setIsKeepData}>
{/* {'保持展开状态'} */}
{t('graph:keep-expanded')}
</Checkbox>
</div>
</Field>
</AsideSection>
<AsideSection>
<Field label={t('graph:direction')}>
<RadioGroup value={horizontal} onChange={setHorizontal}>
<RadioButton value={false}>{t('graph:vertical')}</RadioButton>
<RadioButton value={true}>{t('graph:horizontal')}</RadioButton>
</RadioGroup>
</Field>
</AsideSection>
<AsideSection>
<Field label={t('graph:export-file')}>
<ExportButtonWrapper>
<Button onClick={() => graph.current?.export('png')}>
{t('graph:export-png')}
</Button>
<Button onClick={() => graph.current?.export('svg')}>
{t('graph:export-svg')}
</Button>
</ExportButtonWrapper>
</Field>
</AsideSection>
<AsideSection>
<Field label={t('graph:Choose-model')}>
<div className="run-list">
{runs &&
runs.map((run: string, index: number) => (
<div key={index}>
<Check
checked={selectedRuns === run ? true : false}
value={run}
title={run}
onChange={(value: string) => {
setSelectedRuns(run);
}}
>
{/* <Popover content={content(run)}> */}
<span className="run-item">
{/* <i style={{backgroundColor: run.colors[0]}}></i> */}
{run.split('/')[run.split('/').length - 1]}
</span>
{/* </Popover> */}
</Check>
</div>
))}
</div>
</Field>
</AsideSection>
</>
)}
</Aside>
);
}, [
t,
bottom,
search,
searching,
searchResult,
selectedRuns,
modelGraphs,
selectedGraph,
changeGraph,
onSearch,
onSelect,
showAttributes,
showInitializers,
showNames,
horizontal,
rendered,
nodeData,
nodeDocumentation
]);
const uploader = useMemo(
() => <Uploader onClickUpload={onClickFile} onDropFiles={setModelFile} />,
[onClickFile, setModelFile]
);
return (
<>
<Title>{t('common:graph')}</Title>
<ModelPropertiesDialog data={modelData} onClose={() => setModelData(null)} />
<Content aside={aside}>
<GraphComponent
ref={graph}
files={files}
uploader={uploader}
showAttributes={showAttributes}
showInitializers={showInitializers}
showNames={showNames}
isKeepData={isKeepData}
horizontal={horizontal}
selectedRuns={selectedRuns}
onRendered={() => setRendered(true)}
onOpened={setOpenedModel}
onSearch={data => setSearchResult(data)}
onShowModelProperties={data => setModelData(data)}
runs={runs}
onShowNodeProperties={data => {
setNodeData(data);
setNodeDocumentation(null);
}}
onShowNodeDocumentation={data => setNodeDocumentation(data)}
/>
<input
ref={file}
type="file"
multiple={false}
onChange={onChangeFile}
style={{
display: 'none'
}}
/>
</Content>
</>
);
};
export default Graph;
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import Aside, {AsideSection} from '~/components/Aside'; import Aside, {AsideSection} from '~/components/Aside';
import type {Documentation, OpenedResult, Properties, SearchItem, SearchResult} from '~/resource/graph/types'; import type {Documentation, OpenedResult, Properties, SearchItem, SearchResult} from '~/resource/graph/types';
import GraphComponent, {GraphRef} from '~/components/GraphPage/Graph'; import GraphComponent, {GraphRef} from '~/components/GraphPage/GraphStatic';
import React, {FunctionComponent, useCallback, useEffect, useMemo, useRef, useState} from 'react'; import React, {FunctionComponent, useCallback, useEffect, useMemo, useRef, useState} from 'react';
import Select, {SelectProps} from '~/components/Select'; import Select, {SelectProps} from '~/components/Select';
import {actions, selectors} from '~/store'; import {actions, selectors} from '~/store';
...@@ -311,7 +311,9 @@ const Graph: FunctionComponent = () => { ...@@ -311,7 +311,9 @@ const Graph: FunctionComponent = () => {
horizontal={horizontal} horizontal={horizontal}
onRendered={() => setRendered(true)} onRendered={() => setRendered(true)}
onOpened={setOpenedModel} onOpened={setOpenedModel}
onSearch={data => setSearchResult(data)} onSearch={data => {
setSearchResult(data);
}}
onShowModelProperties={data => setModelData(data)} onShowModelProperties={data => setModelData(data)}
onShowNodeProperties={data => { onShowNodeProperties={data => {
setNodeData(data); setNodeData(data);
......
...@@ -75,8 +75,18 @@ const routes: Route[] = [ ...@@ -75,8 +75,18 @@ const routes: Route[] = [
}, },
{ {
id: Pages.Graph, id: Pages.Graph,
path: '/graph', children: [
component: React.lazy(() => import('~/pages/graph')) {
id: 'graphDynamic',
path: '/graphDynamic',
component: React.lazy(() => import('~/pages/graphDynamic'))
},
{
id: 'graphStatic',
path: '/graphStatic',
component: React.lazy(() => import('~/pages/graphStatic'))
}
]
}, },
{ {
id: Pages.Histogram, id: Pages.Histogram,
......
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>收起全部</title>
<defs>
<rect id="path-1" x="0" y="0" width="16" height="16"></rect>
</defs>
<g id="页面-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="1.2网格展开,右下角出现缩略地图" transform="translate(-698.000000, -91.000000)">
<g id="形状结合" transform="translate(698.000000, 91.000000)">
<mask id="mask-2" fill="white">
<use xlink:href="#path-1"></use>
</mask>
<use id="蒙版" fill="#D8D8D8" opacity="0" xlink:href="#path-1"></use>
<path d="M7.56666006,9.03333994 C7.80600238,8.79399762 8.20070348,8.80070348 8.44831908,9.04831908 L8.44831908,9.04831908 L13.2085352,13.8085352 C13.4561508,14.0561508 13.4628566,14.4508519 13.2235143,14.6901942 C12.984172,14.9295365 12.5894709,14.9228306 12.3418553,14.6752151 L12.3418553,14.6752151 L8.02995827,10.363318 L3.67312385,14.7201525 C3.43385086,14.9594254 3.03908042,14.9527889 2.79146483,14.7051733 C2.54384923,14.4575577 2.53721271,14.0627873 2.77648569,13.8235143 L2.77648569,13.8235143 Z M2.79146483,1.20853517 C3.03908042,0.96091958 3.433753,0.954185191 3.67306012,1.19349231 L3.67306012,1.19349231 L8.03008573,5.55051792 L12.3419827,1.2386209 C12.5895983,0.991005306 12.9842709,0.984270917 13.223578,1.22357804 C13.4628852,1.46288515 13.4561508,1.85755773 13.2085352,2.10517333 L13.2085352,2.10517333 L8.44831908,6.86538942 C8.20070348,7.11300501 7.8060309,7.1197394 7.56672379,6.88043229 L7.56672379,6.88043229 L2.77642196,2.09013046 C2.53711485,1.85082334 2.54384923,1.45615077 2.79146483,1.20853517 Z" fill="#999999" mask="url(#mask-2)"></path>
</g>
</g>
</g>
</svg>
\ No newline at end of file
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export default {status: 0, msg: '', data: ['test_add_graph/', 'test_add_graph/test1']};
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
// cSpell:words actived nextcode // cSpell:words actived nextcode
const view = require('./view'); const view = require('./view');
const view2 = require('./view2');
const host = {}; const host = {};
host.BrowserHost = class { host.BrowserHost = class {
constructor() { constructor() {
window.eval = () => { window.eval = () => {
...@@ -54,7 +53,6 @@ host.BrowserHost = class { ...@@ -54,7 +53,6 @@ host.BrowserHost = class {
this._view = view; this._view = view;
return Promise.resolve(); return Promise.resolve();
} }
start() { start() {
window.addEventListener( window.addEventListener(
'message', 'message',
...@@ -64,12 +62,19 @@ host.BrowserHost = class { ...@@ -64,12 +62,19 @@ host.BrowserHost = class {
const type = originalData.type; const type = originalData.type;
const data = originalData.data; const data = originalData.data;
switch (type) { switch (type) {
// 在此书添加一个this._view的事件传递Graph页面过来的数据
case 'change-files': case 'change-files':
return this._changeFiles(data); return this._changeFiles(data);
case 'zoom-in': case 'zoom-in':
return this._view.zoomIn(); return this._view.zoomIn();
case 'zoom-out': case 'zoom-out':
return this._view.zoomOut(); return this._view.zoomOut();
case 'select-item':
return this._view.selectItem(data);
case 'toggle-Language':
return this._view.toggleLanguage(data);
case 'isAlt':
return this._view.changeAlt(data);
case 'zoom-reset': case 'zoom-reset':
return this._view.resetZoom(); return this._view.resetZoom();
case 'toggle-attributes': case 'toggle-attributes':
...@@ -78,6 +83,8 @@ host.BrowserHost = class { ...@@ -78,6 +83,8 @@ host.BrowserHost = class {
return this._view.toggleInitializers(data); return this._view.toggleInitializers(data);
case 'toggle-names': case 'toggle-names':
return this._view.toggleNames(data); return this._view.toggleNames(data);
case 'toggle-KeepData':
return this._view.toggleKeepData(data);
case 'toggle-direction': case 'toggle-direction':
return this._view.toggleDirection(data); return this._view.toggleDirection(data);
case 'toggle-theme': case 'toggle-theme':
...@@ -86,6 +93,10 @@ host.BrowserHost = class { ...@@ -86,6 +93,10 @@ host.BrowserHost = class {
return this._view.export(`${document.title}.${data}`); return this._view.export(`${document.title}.${data}`);
case 'change-graph': case 'change-graph':
return this._view.changeGraph(data); return this._view.changeGraph(data);
case 'change-allGraph':
return this._view.changeAllGrap(data);
case 'change-select':
return this._view.changeSelect(data);
case 'search': case 'search':
return this._view.find(data); return this._view.find(data);
case 'select': case 'select':
...@@ -116,8 +127,19 @@ host.BrowserHost = class { ...@@ -116,8 +127,19 @@ host.BrowserHost = class {
} }
status(status) { status(status) {
// 反传回去
this.message('status', status); this.message('status', status);
} }
selectNodeId(nodeInfo) {
// 反传回去
console.log('节点点击事件触发了', nodeInfo);
this.message('nodeId', nodeInfo);
}
selectItems(item) {
// 反传回去
console.log('节点点击事件触发了', item);
this.message('selectItem', item);
}
error(message, detail) { error(message, detail) {
this.message('error', (message === 'Error' ? '' : message + ' ') + detail); this.message('error', (message === 'Error' ? '' : message + ' ') + detail);
...@@ -176,6 +198,7 @@ host.BrowserHost = class { ...@@ -176,6 +198,7 @@ host.BrowserHost = class {
} }
_changeFiles(files) { _changeFiles(files) {
console.log('files', files);
if (files && files.length) { if (files && files.length) {
files = Array.from(files); files = Array.from(files);
const file = files.find(file => this._view.accept(file.name)); const file = files.find(file => this._view.accept(file.name));
...@@ -498,4 +521,15 @@ class BrowserFileContext { ...@@ -498,4 +521,15 @@ class BrowserFileContext {
} }
} }
window.__view__ = new view.View(new host.BrowserHost()); function getCaption(obj) {
let index = obj.lastIndexOf('/'); //获取-后边的字符串
let newObj = obj.substring(index + 1, obj.length);
return newObj;
}
const hash = getCaption(document.referrer);
console.log('hash', hash);
if (hash === 'graphStatic') {
window.__view__ = new view2.View(new host.BrowserHost());
} else {
window.__view__ = new view.View(new host.BrowserHost());
}
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export default {
leaf_nodes: [
{
name: 'layer',
schema: {
category: 'container'
}
},
{
name: 'layerlist',
schema: {
category: 'container'
}
},
{
name: 'parameterlist',
schema: {
category: 'container'
}
},
{
name: 'layerdict',
schema: {
category: 'container'
}
},
{
name: 'conv1d',
schema: {
category: 'conv'
}
},
{
name: 'conv1dtranspose',
schema: {
category: 'conv'
}
},
{
name: 'conv2d',
schema: {
category: 'conv'
}
},
{
name: 'conv2dtranspose',
schema: {
category: 'conv'
}
},
{
name: 'conv3d',
schema: {
category: 'conv'
}
},
{
name: 'conv3dtranspose',
schema: {
category: 'conv'
}
},
{
name: 'adaptiveavgpool1d',
schema: {
category: 'pool'
}
},
{
name: 'adaptiveavgpool2d',
schema: {
category: 'pool'
}
},
{
name: 'adaptiveavgpool3d',
schema: {
category: 'pool'
}
},
{
name: 'adaptivemaxpool1d',
schema: {
category: 'pool'
}
},
{
name: 'adaptivemaxpool2d',
schema: {
category: 'pool'
}
},
{
name: 'adaptivemaxpool3d',
schema: {
category: 'pool'
}
},
{
name: 'avgpool1d',
schema: {
category: 'pool'
}
},
{
name: 'avgpool2d',
schema: {
category: 'pool'
}
},
{
name: 'avgpool3d',
schema: {
category: 'pool'
}
},
{
name: 'maxpool1d',
schema: {
category: 'pool'
}
},
{
name: 'maxpool2d',
schema: {
category: 'pool'
}
},
{
name: 'maxpool3d',
schema: {
category: 'pool'
}
},
{
name: 'maxunpool1d',
schema: {
category: 'pool'
}
},
{
name: 'maxunpool2d',
schema: {
category: 'pool'
}
},
{
name: 'maxunpool3d',
schema: {
category: 'pool'
}
},
{
name: 'pad1d',
schema: {
category: 'pad'
}
},
{
name: 'pad2d',
schema: {
category: 'pad'
}
},
{
name: 'pad3d',
schema: {
category: 'pad'
}
},
{
name: 'zeropad2d',
schema: {
category: 'pad'
}
},
{
name: 'celu',
schema: {
category: 'activation'
}
},
{
name: 'elu',
schema: {
category: 'activation'
}
},
{
name: 'gelu',
schema: {
category: 'activation'
}
},
{
name: 'hardshrink',
schema: {
category: 'activation'
}
},
{
name: 'hardsigmoid',
schema: {
category: 'activation'
}
},
{
name: 'hardswish',
schema: {
category: 'activation'
}
},
{
name: 'hardtanh',
schema: {
category: 'activation'
}
},
{
name: 'leakyrelu',
schema: {
category: 'activation'
}
},
{
name: 'logsigmoid',
schema: {
category: 'activation'
}
},
{
name: 'logsoftmax',
schema: {
category: 'activation'
}
},
{
name: 'maxout',
schema: {
category: 'activation'
}
},
{
name: 'prelu',
schema: {
category: 'activation'
}
},
{
name: 'relu',
schema: {
category: 'activation'
}
},
{
name: 'relu6',
schema: {
category: 'activation'
}
},
{
name: 'selu',
schema: {
category: 'activation'
}
},
{
name: 'sigmoid',
schema: {
category: 'activation'
}
},
{
name: 'silu',
schema: {
category: 'activation'
}
},
{
name: 'softmax',
schema: {
category: 'activation'
}
},
{
name: 'softplus',
schema: {
category: 'activation'
}
},
{
name: 'softshrink',
schema: {
category: 'activation'
}
},
{
name: 'softsign',
schema: {
category: 'activation'
}
},
{
name: 'swish',
schema: {
category: 'activation'
}
},
{
name: 'mish',
schema: {
category: 'activation'
}
},
{
name: 'tanh',
schema: {
category: 'activation'
}
},
{
name: 'tanhshrink',
schema: {
category: 'activation'
}
},
{
name: 'thresholdedrelu',
schema: {
category: 'activation'
}
},
{
name: 'batchnorm',
schema: {
category: 'normalization'
}
},
{
name: 'batchnorm1d',
schema: {
category: 'normalization'
}
},
{
name: 'batchnorm2d',
schema: {
category: 'normalization'
}
},
{
name: 'batchnorm3d',
schema: {
category: 'normalization'
}
},
{
name: 'groupnorm',
schema: {
category: 'normalization'
}
},
{
name: 'instancenorm1d',
schema: {
category: 'normalization'
}
},
{
name: 'instancenorm2d',
schema: {
category: 'normalization'
}
},
{
name: 'instancenorm3d',
schema: {
category: 'normalization'
}
},
{
name: 'layernorm',
schema: {
category: 'normalization'
}
},
{
name: 'localresponsenorm',
schema: {
category: 'normalization'
}
},
{
name: 'spectralnorm',
schema: {
category: 'normalization'
}
},
{
name: 'syncbatchnorm',
schema: {
category: 'normalization'
}
},
{
name: 'alphadropout',
schema: {
category: 'normalization'
}
},
{
name: 'dropout',
schema: {
category: 'normalization'
}
},
{
name: 'dropout2d',
schema: {
category: 'normalization'
}
},
{
name: 'dropout3d',
schema: {
category: 'normalization'
}
},
{
name: 'birnn',
schema: {
category: 'sequence'
}
},
{
name: 'gru',
schema: {
category: 'sequence'
}
},
{
name: 'grucell',
schema: {
category: 'sequence'
}
},
{
name: 'lstm',
schema: {
category: 'sequence'
}
},
{
name: 'lstmcell',
schema: {
category: 'sequence'
}
},
{
name: 'rnn',
schema: {
category: 'sequence'
}
},
{
name: 'rnncellbase',
schema: {
category: 'sequence'
}
},
{
name: 'simplernn',
schema: {
category: 'sequence'
}
},
{
name: 'simplernncell',
schema: {
category: 'sequence'
}
},
{
name: 'multiheadattention',
schema: {
category: 'sequence'
}
},
{
name: 'transformer',
schema: {
category: 'sequence'
}
},
{
name: 'transformerdecoder',
schema: {
category: 'sequence'
}
},
{
name: 'transformerdecoderlayer',
schema: {
category: 'sequence'
}
},
{
name: 'transformerencoder',
schema: {
category: 'sequence'
}
},
{
name: 'transformerencoderlayer',
schema: {
category: 'sequence'
}
},
{
name: 'linear',
schema: {
category: 'sequence'
}
},
{
name: 'embedding',
schema: {
category: 'sequence'
}
},
{
name: 'bceloss',
schema: {
category: 'tensor'
}
},
{
name: 'bcewithlogitsloss',
schema: {
category: 'tensor'
}
},
{
name: 'crossentropyloss',
schema: {
category: 'tensor'
}
},
{
name: 'ctcloss',
schema: {
category: 'tensor'
}
},
{
name: 'hsigmoidloss',
schema: {
category: 'tensor'
}
},
{
name: 'kldivloss',
schema: {
category: 'tensor'
}
},
{
name: 'l1loss',
schema: {
category: 'tensor'
}
},
{
name: 'marginrankingloss',
schema: {
category: 'tensor'
}
},
{
name: 'mseloss',
schema: {
category: 'tensor'
}
},
{
name: 'nllloss',
schema: {
category: 'tensor'
}
},
{
name: 'smoothl1loss',
schema: {
category: 'tensor'
}
},
{
name: 'pixelshuffle',
schema: {
category: 'shape'
}
},
{
name: 'upsample',
schema: {
category: 'shape'
}
},
{
name: 'upsamplingbilinear2d',
schema: {
category: 'shape'
}
},
{
name: 'upsamplingnearest2d',
schema: {
category: 'shape'
}
},
{
name: 'clipgradbyglobalnorm',
schema: {
category: 'shape'
}
},
{
name: 'clipgradbynorm',
schema: {
category: 'shape'
}
},
{
name: 'clipgradbyvalue',
schema: {
category: 'shape'
}
},
{
name: 'beamsearchdecoder',
schema: {
category: 'shape'
}
},
{
name: 'cosinesimilarity',
schema: {
category: 'shape'
}
},
{
name: 'dynamic_decode',
schema: {
category: 'shape'
}
},
{
name: 'flatten',
schema: {
category: 'shape'
}
},
{
name: 'pairwisedistance',
schema: {
category: 'shape'
}
},
{
name: 'identity',
schema: {
category: 'shape'
}
},
{
name: 'unfold',
schema: {
category: 'shape'
}
},
{
name: 'fold',
schema: {
category: 'shape'
}
},
{
name: 'conv2d_grad',
schema: {
category: 'conv'
}
},
{
name: 'conv2d_transpose_grad',
schema: {
category: 'conv'
}
},
{
name: 'depthwise_conv2d_transpose',
schema: {
category: 'conv'
}
},
{
name: 'depthwise_conv2d_transpose_grad',
schema: {
category: 'conv'
}
},
{
name: 'deformable_conv_grad',
schema: {
category: 'conv'
}
},
{
name: 'depthwise_conv2d',
schema: {
category: 'conv'
}
},
{
name: 'deformable_conv',
schema: {
category: 'conv'
}
},
{
name: 'conv2d_transpose',
schema: {
category: 'conv'
}
},
{
name: 'depthwise_conv2d_grad',
schema: {
category: 'conv'
}
},
{
name: 'pool2d',
schema: {
category: 'pool'
}
},
{
name: 'pool2d_grad',
schema: {
category: 'pool'
}
},
{
name: 'max_pool2d_with_index',
schema: {
category: 'pool'
}
},
{
name: 'max_pool2d_with_index_grad',
schema: {
category: 'pool'
}
},
{
name: 'pad3d_grad',
schema: {
category: 'pad'
}
},
{
name: 'relu6_grad',
schema: {
category: 'activation'
}
},
{
name: 'leaky_relu',
schema: {
category: 'activation'
}
},
{
name: 'leaky_relu_grad',
schema: {
category: 'activation'
}
},
{
name: 'hard_sigmoid',
schema: {
category: 'activation'
}
},
{
name: 'hard_sigmoid_grad',
schema: {
category: 'activation'
}
},
{
name: 'sigmoid_grad',
schema: {
category: 'activation'
}
},
{
name: 'batch_norm',
schema: {
category: 'normalization'
}
},
{
name: 'batch_norm_grad',
schema: {
category: 'normalization'
}
},
{
name: 'sync_batch_norm',
schema: {
category: 'normalization'
}
},
{
name: 'sync_batch_norm_grad',
schema: {
category: 'normalization'
}
},
{
name: 'norm_grad',
schema: {
category: 'normalization'
}
},
{
name: 'p_norm',
schema: {
category: 'normalization'
}
},
{
name: 'p_norm_grad',
schema: {
category: 'normalization'
}
},
{
name: 'group_norm_grad',
schema: {
category: 'normalization'
}
},
{
name: 'squared_l2_norm',
schema: {
category: 'normalization'
}
},
{
name: 'squared_l2_norm_grad',
schema: {
category: 'normalization'
}
},
{
name: 'group_norm',
schema: {
category: 'normalization'
}
},
{
name: 'norm',
schema: {
category: 'normalization'
}
},
{
name: 'rnn_grad',
schema: {
category: 'sequence'
}
},
{
name: 'sequence_mask',
schema: {
category: 'sequence'
}
},
{
name: 'one_hot',
schema: {
category: 'sequence'
}
},
{
name: 'one_hot_v2',
schema: {
category: 'sequence'
}
},
{
name: 'bce_loss',
schema: {
category: 'tensor'
}
},
{
name: 'bce_loss_grad',
schema: {
category: 'tensor'
}
},
{
name: 'huber_loss',
schema: {
category: 'tensor'
}
},
{
name: 'huber_loss_grad',
schema: {
category: 'tensor'
}
},
{
name: 'log_loss',
schema: {
category: 'tensor'
}
},
{
name: 'log_loss_grad',
schema: {
category: 'tensor'
}
},
{
name: 'smooth_l1_loss',
schema: {
category: 'tensor'
}
},
{
name: 'smooth_l1_loss_grad',
schema: {
category: 'tensor'
}
},
{
name: 'elementwise_add',
schema: {
category: 'tensor'
}
},
{
name: 'elementwise_add_grad',
schema: {
category: 'tensor'
}
},
{
name: 'cumsum',
schema: {
category: 'tensor'
}
},
{
name: 'clip',
schema: {
category: 'tensor'
}
},
{
name: 'clip_grad',
schema: {
category: 'tensor'
}
},
{
name: 'greater_equal',
schema: {
category: 'tensor'
}
},
{
name: 'greater_than',
schema: {
category: 'tensor'
}
},
{
name: 'less_equal',
schema: {
category: 'tensor'
}
},
{
name: 'logical_and',
schema: {
category: 'tensor'
}
},
{
name: 'logical_or',
schema: {
category: 'tensor'
}
},
{
name: 'momentum',
schema: {
category: 'tensor'
}
},
{
name: 'reduce_max',
schema: {
category: 'tensor'
}
},
{
name: 'reduce_mean',
schema: {
category: 'tensor'
}
},
{
name: 'reduce_prod',
schema: {
category: 'tensor'
}
},
{
name: 'seed',
schema: {
category: 'tensor'
}
},
{
name: 'sigmoid_cross_entropy_with_logits',
schema: {
category: 'tensor'
}
},
{
name: 'label_smooth',
schema: {
category: 'tensor'
}
},
{
name: 'where_index',
schema: {
category: 'tensor'
}
},
{
name: 'is_empty',
schema: {
category: 'tensor'
}
},
{
name: 'sigmoid_cross_entropy_with_logits_grad',
schema: {
category: 'tensor'
}
},
{
name: 'target_assign',
schema: {
category: 'tensor'
}
},
{
name: 'gradient_accumulator',
schema: {
category: 'tensor'
}
},
{
name: 'size',
schema: {
category: 'tensor'
}
},
{
name: 'where',
schema: {
category: 'tensor'
}
},
{
name: 'elementwise_pow_grad',
schema: {
category: 'tensor'
}
},
{
name: 'argsort',
schema: {
category: 'tensor'
}
},
{
name: 'argsort_grad',
schema: {
category: 'tensor'
}
},
{
name: 'rmsprop',
schema: {
category: 'tensor'
}
},
{
name: 'atan',
schema: {
category: 'tensor'
}
},
{
name: 'atan_grad',
schema: {
category: 'tensor'
}
},
{
name: 'flatten_contiguous_range',
schema: {
category: 'tensor'
}
},
{
name: 'crop',
schema: {
category: 'tensor'
}
},
{
name: 'eye',
schema: {
category: 'tensor'
}
},
{
name: 'matmul',
schema: {
category: 'tensor'
}
},
{
name: 'set_value',
schema: {
category: 'tensor'
}
},
{
name: 'exp',
schema: {
category: 'tensor'
}
},
{
name: 'exp_grad',
schema: {
category: 'tensor'
}
},
{
name: 'square_grad',
schema: {
category: 'tensor'
}
},
{
name: 'log_softmax',
schema: {
category: 'tensor'
}
},
{
name: 'log_softmax_grad',
schema: {
category: 'tensor'
}
},
{
name: 'matmul_grad',
schema: {
category: 'tensor'
}
},
{
name: 'assign_value',
schema: {
category: 'tensor'
}
},
{
name: 'top_k_v2',
schema: {
category: 'tensor'
}
},
{
name: 'arg_max',
schema: {
category: 'tensor'
}
},
{
name: 'cos',
schema: {
category: 'tensor'
}
},
{
name: 'sin',
schema: {
category: 'tensor'
}
},
{
name: 'index_sample',
schema: {
category: 'shape'
}
},
{
name: 'squeeze_grad',
schema: {
category: 'shape'
}
},
{
name: 'squeeze2_grad',
schema: {
category: 'shape'
}
},
{
name: 'stack_grad',
schema: {
category: 'shape'
}
},
{
name: 'tril_triu',
schema: {
category: 'shape'
}
},
{
name: 'unstack',
schema: {
category: 'shape'
}
},
{
name: 'unstack_grad',
schema: {
category: 'shape'
}
},
{
name: 'bilinear_interp_v2',
schema: {
category: 'shape'
}
},
{
name: 'bilinear_interp_v2_grad',
schema: {
category: 'shape'
}
},
{
name: 'nearest_interp_v2',
schema: {
category: 'shape'
}
},
{
name: 'nearest_interp_v2_grad',
schema: {
category: 'shape'
}
},
{
name: 'randperm',
schema: {
category: 'shape'
}
},
{
name: 'sampling_id',
schema: {
category: 'shape'
}
},
{
name: 'bipartite_match',
schema: {
category: 'shape'
}
},
{
name: 'box_coder',
schema: {
category: 'shape'
}
},
{
name: 'density_prior_box',
schema: {
category: 'shape'
}
},
{
name: 'distribute_fpn_proposals',
schema: {
category: 'shape'
}
},
{
name: 'generate_proposals_v2',
schema: {
category: 'shape'
}
},
{
name: 'meshgrid',
schema: {
category: 'shape'
}
},
{
name: 'mine_hard_examples',
schema: {
category: 'shape'
}
},
{
name: 'yolo_box',
schema: {
category: 'shape'
}
},
{
name: 'warpctc',
schema: {
category: 'shape'
}
},
{
name: 'warpctc_grad',
schema: {
category: 'shape'
}
},
{
name: 'iou_similarity',
schema: {
category: 'shape'
}
},
{
name: 'split',
schema: {
category: 'shape'
}
},
{
name: 'flatten2',
schema: {
category: 'shape'
}
},
{
name: 'flatten2_grad',
schema: {
category: 'shape'
}
},
{
name: 'masked_select_grad',
schema: {
category: 'shape'
}
},
{
name: 'strided_slice',
schema: {
category: 'shape'
}
},
{
name: 'prior_box',
schema: {
category: 'shape'
}
},
{
name: 'elementwise_max_grad',
schema: {
category: 'shape'
}
},
{
name: 'not_equal',
schema: {
category: 'shape'
}
},
{
name: 'strided_slice_grad',
schema: {
category: 'shape'
}
},
{
name: 'fill_any_like',
schema: {
category: 'shape'
}
},
{
name: 'hard_swish',
schema: {
category: 'shape'
}
},
{
name: 'hard_swish_grad',
schema: {
category: 'shape'
}
},
{
name: 'expand_v2',
schema: {
category: 'shape'
}
},
{
name: 'expand_v2_grad',
schema: {
category: 'shape'
}
},
{
name: 'flatten_contiguous_range_grad',
schema: {
category: 'shape'
}
},
{
name: 'gather_nd',
schema: {
category: 'shape'
}
},
{
name: 'gather_nd_grad',
schema: {
category: 'shape'
}
},
{
name: 'reciprocal',
schema: {
category: 'shape'
}
},
{
name: 'reciprocal_grad',
schema: {
category: 'shape'
}
},
{
name: 'index_select',
schema: {
category: 'shape'
}
},
{
name: 'roi_align',
schema: {
category: 'shape'
}
},
{
name: 'roi_align_grad',
schema: {
category: 'shape'
}
},
{
name: 'reduce_mean_grad',
schema: {
category: 'shape'
}
},
{
name: 'masked_select',
schema: {
category: 'shape'
}
},
{
name: 'index_select_grad',
schema: {
category: 'shape'
}
},
{
name: 'elementwise_min_grad',
schema: {
category: 'shape'
}
},
{
name: 'fill_constant_batch_size_like',
schema: {
category: 'shape'
}
},
{
name: 'unsqueeze2_grad',
schema: {
category: 'shape'
}
},
{
name: 'unique',
schema: {
category: 'shape'
}
},
{
name: 'expand_as_v2',
schema: {
category: 'shape'
}
},
{
name: 'tile',
schema: {
category: 'shape'
}
},
{
name: 'nearest_interp_grad',
schema: {
category: 'shape'
}
}
],
non_leaf_nodes: [
{
name: 'Layer',
schema: {
category: 'container'
}
},
{
name: 'LayerList',
schema: {
category: 'container'
}
},
{
name: 'ParameterList',
schema: {
category: 'container'
}
},
{
name: 'LayerDict',
schema: {
category: 'container'
}
},
{
name: 'Conv1D',
schema: {
category: 'conv'
}
},
{
name: 'Conv1DTranspose',
schema: {
category: 'conv'
}
},
{
name: 'Conv2D',
schema: {
category: 'conv'
}
},
{
name: 'Conv2DTranspose',
schema: {
category: 'conv'
}
},
{
name: 'Conv3D',
schema: {
category: 'conv'
}
},
{
name: 'Conv3DTranspose',
schema: {
category: 'conv'
}
},
{
name: 'AdaptiveAvgPool1D',
schema: {
category: 'pool'
}
},
{
name: 'AdaptiveAvgPool2D',
schema: {
category: 'pool'
}
},
{
name: 'AdaptiveAvgPool3D',
schema: {
category: 'pool'
}
},
{
name: 'AdaptiveMaxPool1D',
schema: {
category: 'pool'
}
},
{
name: 'AdaptiveMaxPool2D',
schema: {
category: 'pool'
}
},
{
name: 'AdaptiveMaxPool3D',
schema: {
category: 'pool'
}
},
{
name: 'AvgPool1D',
schema: {
category: 'pool'
}
},
{
name: 'AvgPool2D',
schema: {
category: 'pool'
}
},
{
name: 'AvgPool3D',
schema: {
category: 'pool'
}
},
{
name: 'MaxPool1D',
schema: {
category: 'pool'
}
},
{
name: 'MaxPool2D',
schema: {
category: 'pool'
}
},
{
name: 'MaxPool3D',
schema: {
category: 'pool'
}
},
{
name: 'MaxUnPool1D',
schema: {
category: 'pool'
}
},
{
name: 'MaxUnPool2D',
schema: {
category: 'pool'
}
},
{
name: 'MaxUnPool3D',
schema: {
category: 'pool'
}
},
{
name: 'Pad1D',
schema: {
category: 'pad'
}
},
{
name: 'Pad2D',
schema: {
category: 'pad'
}
},
{
name: 'Pad3D',
schema: {
category: 'pad'
}
},
{
name: 'ZeroPad2D',
schema: {
category: 'pad'
}
},
{
name: 'CELU',
schema: {
category: 'activation'
}
},
{
name: 'ELU',
schema: {
category: 'activation'
}
},
{
name: 'GELU',
schema: {
category: 'activation'
}
},
{
name: 'Hardshrink',
schema: {
category: 'activation'
}
},
{
name: 'Hardsigmoid',
schema: {
category: 'activation'
}
},
{
name: 'Hardswish',
schema: {
category: 'activation'
}
},
{
name: 'Hardtanh',
schema: {
category: 'activation'
}
},
{
name: 'LeakyReLU',
schema: {
category: 'activation'
}
},
{
name: 'LogSigmoid',
schema: {
category: 'activation'
}
},
{
name: 'LogSoftmax',
schema: {
category: 'activation'
}
},
{
name: 'Maxout',
schema: {
category: 'activation'
}
},
{
name: 'PReLU',
schema: {
category: 'activation'
}
},
{
name: 'ReLU',
schema: {
category: 'activation'
}
},
{
name: 'ReLU6',
schema: {
category: 'activation'
}
},
{
name: 'SELU',
schema: {
category: 'activation'
}
},
{
name: 'Sigmoid',
schema: {
category: 'activation'
}
},
{
name: 'Silu',
schema: {
category: 'activation'
}
},
{
name: 'Softmax',
schema: {
category: 'activation'
}
},
{
name: 'Softplus',
schema: {
category: 'activation'
}
},
{
name: 'Softshrink',
schema: {
category: 'activation'
}
},
{
name: 'Softsign',
schema: {
category: 'activation'
}
},
{
name: 'Swish',
schema: {
category: 'activation'
}
},
{
name: 'Mish',
schema: {
category: 'activation'
}
},
{
name: 'Tanh',
schema: {
category: 'activation'
}
},
{
name: 'Tanhshrink',
schema: {
category: 'activation'
}
},
{
name: 'ThresholdedReLU',
schema: {
category: 'activation'
}
},
{
name: 'BatchNorm',
schema: {
category: 'normalization'
}
},
{
name: 'BatchNorm1D',
schema: {
category: 'normalization'
}
},
{
name: 'BatchNorm2D',
schema: {
category: 'normalization'
}
},
{
name: 'BatchNorm3D',
schema: {
category: 'normalization'
}
},
{
name: 'GroupNorm',
schema: {
category: 'normalization'
}
},
{
name: 'InstanceNorm1D',
schema: {
category: 'normalization'
}
},
{
name: 'InstanceNorm2D',
schema: {
category: 'normalization'
}
},
{
name: 'InstanceNorm3D',
schema: {
category: 'normalization'
}
},
{
name: 'LayerNorm',
schema: {
category: 'normalization'
}
},
{
name: 'LocalResponseNorm',
schema: {
category: 'normalization'
}
},
{
name: 'SpectralNorm',
schema: {
category: 'normalization'
}
},
{
name: 'SyncBatchNorm',
schema: {
category: 'normalization'
}
},
{
name: 'AlphaDropout',
schema: {
category: 'normalization'
}
},
{
name: 'Dropout',
schema: {
category: 'normalization'
}
},
{
name: 'Dropout2D',
schema: {
category: 'normalization'
}
},
{
name: 'Dropout3D',
schema: {
category: 'normalization'
}
},
{
name: 'BiRNN',
schema: {
category: 'sequence'
}
},
{
name: 'GRU',
schema: {
category: 'sequence'
}
},
{
name: 'GRUCell',
schema: {
category: 'sequence'
}
},
{
name: 'LSTM',
schema: {
category: 'sequence'
}
},
{
name: 'LSTMCell',
schema: {
category: 'sequence'
}
},
{
name: 'RNN',
schema: {
category: 'sequence'
}
},
{
name: 'RNNCellBase',
schema: {
category: 'sequence'
}
},
{
name: 'SimpleRNN',
schema: {
category: 'sequence'
}
},
{
name: 'SimpleRNNCell',
schema: {
category: 'sequence'
}
},
{
name: 'MultiHeadAttention',
schema: {
category: 'sequence'
}
},
{
name: 'Transformer',
schema: {
category: 'sequence'
}
},
{
name: 'TransformerDecoder',
schema: {
category: 'sequence'
}
},
{
name: 'TransformerDecoderLayer',
schema: {
category: 'sequence'
}
},
{
name: 'TransformerEncoder',
schema: {
category: 'sequence'
}
},
{
name: 'TransformerEncoderLayer',
schema: {
category: 'sequence'
}
},
{
name: 'Linear',
schema: {
category: 'sequence'
}
},
{
name: 'Embedding',
schema: {
category: 'sequence'
}
},
{
name: 'BCELoss',
schema: {
category: 'tensor'
}
},
{
name: 'BCEWithLogitsLoss',
schema: {
category: 'tensor'
}
},
{
name: 'CrossEntropyLoss',
schema: {
category: 'tensor'
}
},
{
name: 'CTCLoss',
schema: {
category: 'tensor'
}
},
{
name: 'HSigmoidLoss',
schema: {
category: 'tensor'
}
},
{
name: 'KLDivLoss',
schema: {
category: 'tensor'
}
},
{
name: 'L1Loss',
schema: {
category: 'tensor'
}
},
{
name: 'MarginRankingLoss',
schema: {
category: 'tensor'
}
},
{
name: 'MSELoss',
schema: {
category: 'tensor'
}
},
{
name: 'NLLLoss',
schema: {
category: 'tensor'
}
},
{
name: 'SmoothL1Loss',
schema: {
category: 'tensor'
}
},
{
name: 'PixelShuffle',
schema: {
category: 'shape'
}
},
{
name: 'Upsample',
schema: {
category: 'shape'
}
},
{
name: 'UpsamplingBilinear2D',
schema: {
category: 'shape'
}
},
{
name: 'UpsamplingNearest2D',
schema: {
category: 'shape'
}
},
{
name: 'ClipGradByGlobalNorm',
schema: {
category: 'shape'
}
},
{
name: 'ClipGradByNorm',
schema: {
category: 'shape'
}
},
{
name: 'ClipGradByValue',
schema: {
category: 'shape'
}
},
{
name: 'BeamSearchDecoder',
schema: {
category: 'shape'
}
},
{
name: 'CosineSimilarity',
schema: {
category: 'shape'
}
},
{
name: 'dynamic_decode',
schema: {
category: 'shape'
}
},
{
name: 'Flatten',
schema: {
category: 'shape'
}
},
{
name: 'PairwiseDistance',
schema: {
category: 'shape'
}
},
{
name: 'Identity',
schema: {
category: 'shape'
}
},
{
name: 'Unfold',
schema: {
category: 'shape'
}
},
{
name: 'Fold',
schema: {
category: 'shape'
}
}
]
};
...@@ -547,7 +547,7 @@ sidebar.ModelSidebar = class { ...@@ -547,7 +547,7 @@ sidebar.ModelSidebar = class {
} }
} }
if (this._model._graphs.length > 1) { if (this._model) {
// let graphSelector = new sidebar.SelectView( // let graphSelector = new sidebar.SelectView(
// this._host, // this._host,
// this._model.graphs.map(g => g.name), // this._model.graphs.map(g => g.name),
...@@ -683,23 +683,98 @@ sidebar.FindSidebar = class { ...@@ -683,23 +683,98 @@ sidebar.FindSidebar = class {
const id = item.id; const id = item.id;
const nodesElement = graphElement.getElementById('nodes'); const nodesElement = graphElement.getElementById('nodes');
let nodeElement = nodesElement.firstChild; if (nodesElement) {
while (nodeElement) { let nodeElement = nodesElement.firstChild;
if (nodeElement.id == id) { while (nodeElement) {
selection.push(nodeElement); if (nodeElement.id == id) {
selection.push(nodeElement);
}
nodeElement = nodeElement.nextSibling;
}
}
const clustersElement = graphElement.getElementById('clusters');
if (clustersElement) {
let clusterElement = clustersElement.firstChild;
while (clusterElement) {
if (clusterElement.id == id) {
selection.push(clusterElement);
}
clusterElement = clusterElement.nextSibling;
} }
nodeElement = nodeElement.nextSibling;
} }
const edgePathsElement = graphElement.getElementById('edge-paths'); const edgePathsElement = graphElement.getElementById('edge-paths');
let edgePathElement = edgePathsElement.firstChild; if (edgePathsElement) {
while (edgePathElement) { let edgePathElement = edgePathsElement.firstChild;
if (edgePathElement.id == id) { while (edgePathElement) {
selection.push(edgePathElement); if (edgePathElement.id === id) {
// console.log('edgePathElement',edgePathElement.getAttribute("fromnode"),item);
// if (item.fromnode && edgePathElement.getAttribute("fromnode") === item.fromnode) {
// selection.push(edgePathElement);
// }
// if (item.tonode && edgePathElement.getAttribute("tonode") === item.tonode) {
// selection.push(edgePathElement);
// }
selection.push(edgePathElement);
}
edgePathElement = edgePathElement.nextSibling;
} }
edgePathElement = edgePathElement.nextSibling;
} }
let initializerElement = graphElement.getElementById(id);
if (initializerElement) {
while (initializerElement.parentElement) {
initializerElement = initializerElement.parentElement;
if (initializerElement.id && initializerElement.id.startsWith('node-')) {
selection.push(initializerElement);
break;
}
}
}
if (selection.length > 0) {
return selection;
}
return null;
}
static selection2(item, graphElement) {
const selection = [];
const id = item.id;
const nodesElement = graphElement.getElementById('nodes');
if (nodesElement) {
let nodeElement = nodesElement.firstChild;
while (nodeElement) {
if (nodeElement.id == id) {
selection.push(nodeElement);
}
nodeElement = nodeElement.nextSibling;
}
}
const clustersElement = graphElement.getElementById('clusters');
if (clustersElement) {
let clusterElement = clustersElement.firstChild;
while (clusterElement) {
if (clusterElement.id == id) {
selection.push(clusterElement);
}
clusterElement = clusterElement.nextSibling;
}
}
const edgePathsElement = graphElement.getElementById('edge-paths');
if (edgePathsElement) {
let edgePathElement = edgePathsElement.firstChild;
while (edgePathElement) {
if (edgePathElement.id === id) {
if (item.fromnode && edgePathElement.getAttribute('fromnode') === item.fromnode) {
selection.push(edgePathElement);
}
if (item.tonode && edgePathElement.getAttribute('tonode') === item.tonode) {
selection.push(edgePathElement);
}
}
edgePathElement = edgePathElement.nextSibling;
}
}
let initializerElement = graphElement.getElementById(id); let initializerElement = graphElement.getElementById(id);
if (initializerElement) { if (initializerElement) {
while (initializerElement.parentElement) { while (initializerElement.parentElement) {
...@@ -725,7 +800,6 @@ sidebar.FindSidebar = class { ...@@ -725,7 +800,6 @@ sidebar.FindSidebar = class {
const edgeMatches = new Set(); const edgeMatches = new Set();
const result = []; const result = [];
for (const node of this._graph.nodes) { for (const node of this._graph.nodes) {
const initializers = []; const initializers = [];
...@@ -744,13 +818,14 @@ sidebar.FindSidebar = class { ...@@ -744,13 +818,14 @@ sidebar.FindSidebar = class {
}); });
edgeMatches.add(argument.name); edgeMatches.add(argument.name);
} else { } else {
initializers.push(argument.initializer); // initializers.push(argument.initializer);
} }
} }
} }
} }
const name = node.name; const name = node.name;
console.log('name', node);
const operator = node.type; const operator = node.type;
if ( if (
!nodeMatches.has(name) && !nodeMatches.has(name) &&
...@@ -759,12 +834,55 @@ sidebar.FindSidebar = class { ...@@ -759,12 +834,55 @@ sidebar.FindSidebar = class {
) { ) {
result.push({ result.push({
type: 'node', type: 'node',
name: node.name, name: name,
id: 'node-' + node.name id: 'node-' + name
}); });
nodeMatches.add(node.name); nodeMatches.add(name);
} }
// let path = node.name.split('/');
// path.pop();
// let groupName = path.join('/');
// console.log('groupName', groupName);
// const clusterNode = name => {
// if (
// !nodeMatches.has(name) &&
// name &&
// (name.toLowerCase().indexOf(text) != -1 || (operator && operator.toLowerCase().indexOf(text) != -1))
// ) {
// result.push({
// type: 'node',
// name: name,
// id: 'node-' + name
// });
// nodeMatches.add(name);
// let path = name.split('/');
// while (path.length > 0) {
// const name = path.join('/');
// path.pop();
// if (name) {
// clusterNode(name);
// }
// }
// }
// };
// if (groupName) {
// clusterNode(groupName);
// // g.setParent(nodeId, groupName);
// }
// clusterNode(node.show_name);
// if (
// !nodeMatches.has(name) &&
// name &&
// (name.toLowerCase().indexOf(text) != -1 || (operator && operator.toLowerCase().indexOf(text) != -1))
// ) {
// result.push({
// type: 'node',
// name: node.name,
// id: 'node-' + node.name
// });
// //
// nodeMatches.add(node.name);
// }
for (const initializer of initializers) { for (const initializer of initializers) {
result.push({ result.push({
type: 'initializer', type: 'initializer',
...@@ -792,7 +910,6 @@ sidebar.FindSidebar = class { ...@@ -792,7 +910,6 @@ sidebar.FindSidebar = class {
} }
} }
} }
return { return {
text: searchText, text: searchText,
result: result result: result
......
...@@ -8,8 +8,8 @@ body { ...@@ -8,8 +8,8 @@ body {
margin: 0; margin: 0;
width: 100vw; width: 100vw;
height: 100vh; height: 100vh;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe WPC', 'Segoe UI', 'Ubuntu', 'Droid Sans', font-family: -apple-system, BlinkMacSystemFont, 'Segoe WPC', 'Segoe UI', 'Ubuntu', 'Droid Sans', sans-serif,
sans-serif, 'PingFang SC'; 'PingFang SC';
font-size: 12px; font-size: 12px;
text-rendering: geometricPrecision; text-rendering: geometricPrecision;
background-color: #fff; background-color: #fff;
...@@ -27,7 +27,7 @@ body { ...@@ -27,7 +27,7 @@ body {
.canvas { .canvas {
display: block; display: block;
position: absolute; // position: absolute;
text-rendering: geometricPrecision; text-rendering: geometricPrecision;
user-select: none; user-select: none;
cursor: grab; cursor: grab;
...@@ -46,13 +46,14 @@ line { ...@@ -46,13 +46,14 @@ line {
} }
text { text {
font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", "Ubuntu", "Droid Sans", sans-serif, "PingFang SC"; font-family: -apple-system, BlinkMacSystemFont, 'Segoe WPC', 'Segoe UI', 'Ubuntu', 'Droid Sans', sans-serif,
'PingFang SC';
font-size: 11px; font-size: 11px;
text-rendering: geometricPrecision; text-rendering: geometricPrecision;
fill: #000; fill: #000;
.dark & { .dark & {
fill: #CFCFD1; fill: #cfcfd1;
} }
} }
...@@ -70,7 +71,7 @@ text { ...@@ -70,7 +71,7 @@ text {
&:hover { &:hover {
path { path {
fill: #2932E1; fill: #2932e1;
fill-opacity: 1; fill-opacity: 1;
} }
...@@ -81,7 +82,7 @@ text { ...@@ -81,7 +82,7 @@ text {
} }
.node-item-function path { .node-item-function path {
fill: #9BB9E8; fill: #9bb9e8;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
...@@ -89,70 +90,88 @@ text { ...@@ -89,70 +90,88 @@ text {
cursor: pointer; cursor: pointer;
path { path {
fill: #8BB8FF; fill: #8bb8ff;
fill-opacity: 0.9; fill-opacity: 0.9;
} }
} }
.node-item-type-constant path { .node-item-type-constant path {
fill: #B4CCB7; fill: #b4ccb7;
} }
.node-item-type-control path { .node-item-type-control path {
fill: #A8E9B8; fill: #a8e9b8;
} }
.node-item-type-layer path { .node-item-type-layer path {
fill: #DB989A; fill: #db989a;
fill-opacity: 0.7;
}
.node-item-type-container path {
fill: #db989a;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-wrapper path { .node-item-type-wrapper path {
fill: #6DCDE4; fill: #6dcde4;
fill-opacity: 0.7;
}
.node-item-type-conv path {
fill: #6dcde4;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-activation path { .node-item-type-activation path {
fill: #93C2CA; fill: #93c2ca;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-pool path { .node-item-type-pool path {
fill: #DE7CCE; fill: #de7cce;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-normalization path { .node-item-type-normalization path {
fill: #DA96BC; fill: #da96bc;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-dropout path { .node-item-type-dropout path {
fill: #309E51; fill: #309e51;
fill-opacity: 0.7;
}
.node-item-type-pad path {
fill: #309e51;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-shape path { .node-item-type-shape path {
fill: #D6C482; fill: #d6c482;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-tensor path { .node-item-type-tensor path {
fill: #6D7CE4; fill: #6d7ce4;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-transform path { .node-item-type-transform path {
fill: #CDCB74; fill: #cdcb74;
}
.node-item-type-sequence path {
fill: #cdcb74;
} }
.node-item-type-data path { .node-item-type-data path {
fill: #2576AD; fill: #2576ad;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
.node-item-type-custom path { .node-item-type-custom path {
fill: #E46D6D; fill: #e46d6d;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
...@@ -176,7 +195,7 @@ text { ...@@ -176,7 +195,7 @@ text {
cursor: pointer; cursor: pointer;
path { path {
fill: #CA5353; fill: #ca5353;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
} }
...@@ -203,7 +222,7 @@ text { ...@@ -203,7 +222,7 @@ text {
cursor: pointer; cursor: pointer;
path { path {
fill: #E49D6D; fill: #e49d6d;
fill-opacity: 0.7; fill-opacity: 0.7;
} }
} }
...@@ -212,7 +231,7 @@ text { ...@@ -212,7 +231,7 @@ text {
cursor: pointer; cursor: pointer;
path { path {
fill: #E4E06D; fill: #e4e06d;
fill-opacity: 0.9; fill-opacity: 0.9;
} }
} }
...@@ -235,18 +254,97 @@ text { ...@@ -235,18 +254,97 @@ text {
stroke-dasharray: 3, 2; stroke-dasharray: 3, 2;
} }
.cluster rect { .cluster .clusterGroup {
stroke: #000; fill: #dce9ff;
fill: #000; stroke: #666;
fill-opacity: 0.02;
stroke-opacity: 0.06;
stroke-width: 1px; stroke-width: 1px;
} }
.node-item-function path {
fill: #9bb9e8;
fill-opacity: 0.7;
}
.cluster .clusterGroup-constant {
fill: #e8efe9;
}
.select { .cluster .clusterGroup-control {
fill: #e4f8e9;
}
.cluster .clusterGroup-layer {
fill: #f4e0e0;
}
.cluster .clusterGroup-conv {
fill: #d3f0f6;
}
.cluster .clusterGroup-container {
fill: #f4e0e0;
}
.cluster .clusterGroup-wrapper {
fill: #d3f0f6;
}
.cluster .clusterGroup-activation {
fill: #deecef;
}
.cluster .clusterGroup-pool {
fill: #f5d7f0;
}
.cluster .clusterGroup-normalization {
fill: #f3dfea;
}
.cluster .clusterGroup-dropout {
fill: #c0e1ca;
}
.cluster .clusterGroup-pad {
fill: #c0e1ca;
}
.cluster .clusterGroup-shape {
fill: #f2edd9;
}
.cluster .clusterGroup-tensor {
fill: #d3d7f6;
}
.cluster .clusterGroup-transform {
fill: #f0efd5;
}
.cluster .clusterGroup-sequence {
fill: #f0efd5;
}
.cluster .clusterGroup-data {
fill: #bdd5e6;
}
.cluster .clusterGroup-custom {
fill: #f6d3d3;
}
.cluster .clusterButton {
fill-opacity: 0.3;
fill: #db989a;
stroke: #999;
cursor: pointer;
}
.cluster .button-text {
fill: #999;
}
.cluster.border {
display: none;
}
.select {
&.edge-path { &.edge-path {
stroke: #1527C2; stroke: #1527c2;
stroke-width: 2px; stroke-width: 2px;
stroke-dasharray: 6px 3px; stroke-dasharray: 6px 3px;
stroke-dashoffset: 0; stroke-dashoffset: 0;
...@@ -254,7 +352,15 @@ text { ...@@ -254,7 +352,15 @@ text {
} }
.node.border { .node.border {
stroke: #1527C2; stroke: #1527c2;
stroke-width: 2px;
stroke-dasharray: 6px 3px;
stroke-dashoffset: 0;
animation: pulse 4s infinite linear;
}
.cluster.border {
display: block;
stroke: #1527c2;
stroke-width: 2px; stroke-width: 2px;
stroke-dasharray: 6px 3px; stroke-dasharray: 6px 3px;
stroke-dashoffset: 0; stroke-dashoffset: 0;
......
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
var grapher = grapher || {};
var dagre = dagre || require('dagre');
grapher.Renderer = class {
constructor(host, svgElement, view) {
this._document = host.document;
this._svgElement = svgElement;
this._host = host;
this._view = view;
}
render(graph) {
let svgClusterGroup = null;
let svgEdgePathGroup = null;
let svgEdgeLabelGroup = null;
let svgNodeGroup = null;
svgClusterGroup = this.createElement('g');
svgClusterGroup.setAttribute('id', 'clusters');
svgClusterGroup.setAttribute('class', 'clusters');
this._svgElement.appendChild(svgClusterGroup);
svgEdgePathGroup = this.createElement('g');
svgEdgePathGroup.setAttribute('id', 'edge-paths');
svgEdgePathGroup.setAttribute('class', 'edge-paths');
this._svgElement.appendChild(svgEdgePathGroup);
svgEdgeLabelGroup = this.createElement('g');
svgEdgeLabelGroup.setAttribute('id', 'edge-labels');
svgEdgeLabelGroup.setAttribute('class', 'edge-labels');
this._svgElement.appendChild(svgEdgeLabelGroup);
svgNodeGroup = this.createElement('g');
svgNodeGroup.setAttribute('id', 'nodes');
svgNodeGroup.setAttribute('class', 'nodes');
this._svgElement.appendChild(svgNodeGroup);
// } else {
// svgClusterGroup = this._document.getElementById('clusters')
// svgEdgePathGroup = this._document.getElementById('edge-paths')
// svgEdgeLabelGroup = this._document.getElementById('edge-labels')
// svgNodeGroup = this._document.getElementById('nodes')
// }
for (const nodeId of graph.nodes()) {
if (graph.children(nodeId).length == 0) {
const node = graph.node(nodeId);
// 在这里进行缓存的判断
// console.log('this._document', this._document);
// const nodeDom = this._document.getElementById(node.id);
// console.log('nodeDom', nodeDom);
if (this._view._nodes.hasOwnProperty(node.id)) {
// 这个节点存在过
svgNodeGroup.appendChild(this._view._nodes[node.id]);
const nodeBox = this._view._nodes[node.id].getBBox();
node.width = nodeBox.width;
node.height = nodeBox.height;
node.element = this._view._nodes[node.id];
} else {
const element = this.createElement('g');
if (node.id) {
element.setAttribute('id', node.id);
}
element.setAttribute(
'class',
Object.prototype.hasOwnProperty.call(node, 'class') ? 'node ' + node.class : 'node'
);
element.style.opacity = 0;
const container = this.createElement('g');
container.appendChild(node.label);
// node.label 就是fromat 之后的节点
element.appendChild(container);
svgNodeGroup.appendChild(element);
const nodeBox = node.label.getBBox();
const nodeX = -nodeBox.width / 2;
const nodeY = -nodeBox.height / 2;
container.setAttribute('transform', 'translate(' + nodeX + ',' + nodeY + ')');
node.width = nodeBox.width;
node.height = nodeBox.height;
node.element = element;
}
}
}
for (const edgeId of graph.edges()) {
const edge = graph.edge(edgeId);
if (edge.label) {
const tspan = this.createElement('tspan');
tspan.setAttribute('xml:space', 'preserve');
tspan.setAttribute('dy', '1em');
tspan.setAttribute('x', '1');
tspan.appendChild(this._document.createTextNode(edge.label));
const text = this.createElement('text');
text.appendChild(tspan);
const textContainer = this.createElement('g');
textContainer.appendChild(text);
const labelElement = this.createElement('g');
labelElement.style.opacity = 0;
labelElement.setAttribute('class', 'edge-label');
labelElement.appendChild(textContainer);
svgEdgeLabelGroup.appendChild(labelElement);
const edgeBox = textContainer.getBBox();
const edgeX = -edgeBox.width / 2;
const edgeY = -edgeBox.height / 2;
textContainer.setAttribute('transform', 'translate(' + edgeX + ',' + edgeY + ')');
edge.width = edgeBox.width;
edge.height = edgeBox.height;
edge.labelElement = labelElement;
}
}
dagre.layout(graph);
// 实际要变的就是这一块
for (const nodeId of graph.nodes()) {
if (graph.children(nodeId).length == 0) {
const node = graph.node(nodeId);
node.element.setAttribute('transform', 'translate(' + node.x + ',' + node.y + ')');
node.element.style.opacity = 1;
delete node.element;
}
}
for (const edgeId of graph.edges()) {
const edge = graph.edge(edgeId);
if (edge.labelElement) {
edge.labelElement.setAttribute('transform', 'translate(' + edge.x + ',' + edge.y + ')');
edge.labelElement.style.opacity = 1;
delete edge.labelElement;
}
}
const edgePathGroupDefs = this.createElement('defs');
svgEdgePathGroup.appendChild(edgePathGroupDefs);
const marker = this.createElement('marker');
marker.setAttribute('id', 'arrowhead-vee');
marker.setAttribute('viewBox', '0 0 10 10');
marker.setAttribute('refX', 9);
marker.setAttribute('refY', 5);
marker.setAttribute('markerUnits', 'strokeWidth');
marker.setAttribute('markerWidth', 8);
marker.setAttribute('markerHeight', 6);
marker.setAttribute('orient', 'auto');
edgePathGroupDefs.appendChild(marker);
const markerPath = this.createElement('path');
markerPath.setAttribute('d', 'M 0 0 L 10 5 L 0 10 L 4 5 z');
markerPath.style.setProperty('stroke-width', 1);
markerPath.style.setProperty('stroke-dasharray', '1,0');
marker.appendChild(markerPath);
for (const edgeId of graph.edges()) {
const edge = graph.edge(edgeId);
const edgePath = grapher.Renderer._computeCurvePath(edge, graph.node(edgeId.v), graph.node(edgeId.w));
const edgeElement = this.createElement('path');
edgeElement.setAttribute(
'class',
Object.prototype.hasOwnProperty.call(edge, 'class') ? 'edge-path ' + edge.class : 'edge-path'
);
edgeElement.setAttribute('d', edgePath);
edgeElement.setAttribute('marker-end', 'url(#arrowhead-vee)');
if (edge.id) {
edgeElement.setAttribute('id', edge.id);
}
if (edge.fromnode) {
edgeElement.setAttribute('fromnode', edge.fromnode);
}
if (edge.tonode) {
edgeElement.setAttribute('tonode', edge.tonode);
}
svgEdgePathGroup.appendChild(edgeElement);
}
const groupArray = [];
for (const nodeId of graph.nodes()) {
if (!Number(nodeId) && Number(nodeId) !== 0) {
groupArray.push(nodeId);
}
}
const newGroupArray = groupArray.sort((a, b) => {
let level1 = a.split('/').length;
let level2 = b.split('/').length;
return level1 - level2;
});
for (const nodeId of newGroupArray) {
if (graph.children(nodeId).length > 0) {
const node = graph.node(nodeId);
// const nodeDom = this._document.getElementById(`node-${nodeId}`)
// if (this._view._nodes.hasOwnProperty(node.id)) {
// // 这个节点存在过
// svgNodeGroup.appendChild(this._view._nodes[node.id]);
// const nodeBox = this._view._nodes[node.id].getBBox();
// node.width = nodeBox.width;
// node.height = nodeBox.height;
// node.element = this._view._nodes[node.id]
if (this._view._clusters.hasOwnProperty(node.id)) {
const nodeDom = this._view._clusters.hasOwnProperty(node.id);
nodeDom.setAttribute('transform', 'translate(' + node.x + ',' + node.y + ')');
nodeDom.firstChild.setAttribute('x', -node.width / 2);
nodeDom.firstChild.setAttribute('y', -node.height / 2);
nodeDom.firstChild.setAttribute('width', node.width + 10);
nodeDom.firstChild.setAttribute('height', node.height + 10);
} else {
const nodeElement = this.createElement('g');
nodeElement.setAttribute('class', 'cluster');
nodeElement.setAttribute('id', `node-${nodeId}`);
nodeElement.setAttribute('transform', 'translate(' + node.x + ',' + node.y + ')');
const rect = this.createElement('rect');
const tspan = this.createElement('tspan');
const button = this.createElement('circle');
const buttonSign = this.createElement('tspan');
button.setAttribute('r', '6.5');
button.setAttribute('cx', node.width / 2 - 20 + 7.5 + 10);
button.setAttribute('cy', -(node.height / 2) + 5 + 7.5);
buttonSign.setAttribute('x', node.width / 2 - 15 + 9);
buttonSign.setAttribute('y', -(node.height / 2) + 1.3);
buttonSign.setAttribute('xml:space', 'preserve');
buttonSign.setAttribute('dy', '1em');
buttonSign.setAttribute('font-size', '16px');
buttonSign.setAttribute('class', 'button-text');
button.setAttribute('class', 'clusterButton');
tspan.setAttribute('xml:space', 'preserve');
tspan.setAttribute('dy', '1em');
tspan.setAttribute('x', 0);
tspan.setAttribute('y', -(node.height / 2) + 5);
tspan.setAttribute('text-anchor', 'middle');
let name = '';
for (const nodes of this._host._view._allGraph.nodes) {
if (nodes.name === node.nodeId) {
name = nodes.show_name.split('/')[nodes.show_name.split('/').length - 1];
}
}
tspan.appendChild(this._document.createTextNode(name));
buttonSign.appendChild(this._document.createTextNode('-'));
const text = this.createElement('text');
text.appendChild(tspan);
const text2 = this.createElement('text');
text2.appendChild(buttonSign);
rect.setAttribute('class', node.classList.join(' '));
rect.setAttribute('x', -node.width / 2);
rect.setAttribute('y', -node.height / 2);
rect.setAttribute('width', node.width + 10);
rect.setAttribute('height', node.height + 10);
const borderElement = this.createElement('path');
borderElement.setAttribute('class', ['cluster', 'border'].join(' '));
borderElement.setAttribute(
'd',
grapher.NodeElement.roundedRect(
-node.width / 2,
-node.height / 2,
node.width + 10,
node.height + 10,
true,
true,
true,
true
)
);
nodeElement.addEventListener('click', () => {
this._view.select({
id: `node-${nodeId}`,
name: nodeId,
type: 'node'
});
});
text2.addEventListener('click', () => {
this._host.selectNodeId({
nodeId: node.nodeId,
expand: node.expand,
isKeepData: node.isKeepData
});
this._host.selectItems({
id: `node-${node.nodeId}`,
name: node.nodeId,
type: 'node'
});
});
rect.addEventListener('click', () => {
if (this._view.isCtrl) {
for (const nodes of this._view._allGraph.nodes) {
if (nodes.name === node.nodeId) {
for (const type of this._view.non_graphMetadatas) {
if (type.name === nodes.type) {
if (this._view.Language === 'zh') {
window.open(
`https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/${type.name}_cn.html`
);
} else {
window.open(
`https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/${type.name}_en.html`
);
}
}
}
this._view.showNodeProperties(nodes);
return;
}
}
} else {
for (const nodes of this._view._allGraph.nodes) {
if (nodes.name === node.nodeId) {
this._view.showNodeProperties(nodes);
return;
}
}
}
});
if (node.rx) {
rect.setAttribute('rx', node.rx);
}
if (node.ry) {
rect.setAttribute('ry', node.ry);
}
nodeElement.appendChild(rect);
nodeElement.appendChild(text);
nodeElement.appendChild(button);
nodeElement.appendChild(text2);
nodeElement.appendChild(borderElement);
svgClusterGroup.appendChild(nodeElement);
}
}
}
}
createElement(name) {
return this._document.createElementNS('http://www.w3.org/2000/svg', name);
}
static _computeCurvePath(edge, tail, head) {
const points = edge.points.slice(1, edge.points.length - 1);
points.unshift(grapher.Renderer.intersectRect(tail, points[0]));
points.push(grapher.Renderer.intersectRect(head, points[points.length - 1]));
const path = new Path();
const curve = new Curve(path);
for (let i = 0; i < points.length; i++) {
const point = points[i];
if (i == 0) {
curve.lineStart();
}
curve.point(point.x, point.y);
if (i == points.length - 1) {
curve.lineEnd();
}
}
return path.data;
}
static intersectRect(node, point) {
const x = node.x;
const y = node.y;
const dx = point.x - x;
const dy = point.y - y;
let w = node.width / 2;
let h = node.height / 2;
let sx;
let sy;
if (Math.abs(dy) * w > Math.abs(dx) * h) {
if (dy < 0) {
h = -h;
}
sx = dy === 0 ? 0 : (h * dx) / dy;
sy = h;
} else {
if (dx < 0) {
w = -w;
}
sx = w;
sy = dx === 0 ? 0 : (w * dy) / dx;
}
return {x: x + sx, y: y + sy};
}
};
grapher.NodeElement = class {
constructor(document) {
this._document = document;
this._blocks = [];
}
block(type) {
this._block = null;
switch (type) {
case 'header':
this._block = new grapher.NodeElement.Header(this._document);
break;
case 'list':
this._block = new grapher.NodeElement.List(this._document);
break;
}
this._blocks.push(this._block);
return this._block;
}
format(contextElement) {
const rootElement = this.createElement('g');
contextElement.appendChild(rootElement);
let width = 0;
let height = 0;
const tops = [];
for (const block of this._blocks) {
tops.push(height);
block.layout(rootElement);
if (width < block.width) {
width = block.width;
}
height = height + block.height;
}
for (let i = 0; i < this._blocks.length; i++) {
// push 进来的header 或者 list
const top = tops.shift();
this._blocks[i].update(rootElement, top, width, i == 0, i == this._blocks.length - 1);
}
const borderElement = this.createElement('path');
borderElement.setAttribute('class', ['node', 'border'].join(' '));
borderElement.setAttribute('d', grapher.NodeElement.roundedRect(0, 0, width, height, true, true, true, true));
rootElement.appendChild(borderElement);
contextElement.innerHTML = '';
return rootElement;
}
static roundedRect(x, y, width, height, r1, r2, r3, r4) {
const radius = 5;
r1 = r1 ? radius : 0;
r2 = r2 ? radius : 0;
r3 = r3 ? radius : 0;
r4 = r4 ? radius : 0;
return (
'M' +
(x + r1) +
',' +
y +
'h' +
(width - r1 - r2) +
'a' +
r2 +
',' +
r2 +
' 0 0 1 ' +
r2 +
',' +
r2 +
'v' +
(height - r2 - r3) +
'a' +
r3 +
',' +
r3 +
' 0 0 1 ' +
-r3 +
',' +
r3 +
'h' +
(r3 + r4 - width) +
'a' +
r4 +
',' +
r4 +
' 0 0 1 ' +
-r4 +
',' +
-r4 +
'v' +
(-height + r4 + r1) +
'a' +
r1 +
',' +
r1 +
' 0 0 1 ' +
r1 +
',' +
-r1 +
'z'
);
}
createElement(name) {
return this._document.createElementNS('http://www.w3.org/2000/svg', name);
}
};
grapher.NodeElement.Header = class {
constructor(document) {
this._document = document;
this._items = [];
}
add(id, classList, content, tooltip, handler) {
this._items.push({
id: id,
classList: classList,
content: content,
tooltip: tooltip,
handler: handler
});
}
layout(parentElement) {
this._width = 0;
this._height = 0;
this._elements = [];
let x = 0;
const y = 0;
for (const item of this._items) {
const yPadding = 4;
const xPadding = 7;
const element = this.createElement('g');
let classList = ['node-item'];
parentElement.appendChild(element);
const pathElement = this.createElement('path');
const textElement = this.createElement('text');
element.appendChild(pathElement);
element.appendChild(textElement);
if (item.classList) {
classList = classList.concat(item.classList);
}
element.setAttribute('class', classList.join(' '));
if (item.id) {
element.setAttribute('id', item.id);
}
if (item.handler) {
element.addEventListener('click', item.handler);
}
if (item.tooltip) {
const titleElement = this.createElement('title');
titleElement.textContent = item.tooltip;
element.appendChild(titleElement);
}
if (item.content) {
textElement.textContent = item.content;
}
const boundingBox = textElement.getBBox();
const width = boundingBox.width + xPadding + xPadding;
const height = boundingBox.height + yPadding + yPadding;
this._elements.push({
group: element,
text: textElement,
path: pathElement,
x: x,
y: y,
width: width,
height: height,
tx: xPadding,
ty: yPadding - boundingBox.y
});
x += width;
if (this._height < height) {
this._height = height;
}
if (x > this._width) {
this._width = x;
}
}
}
get width() {
return this._width;
}
get height() {
return this._height;
}
update(parentElement, top, width, first, last) {
const dx = width - this._width;
let i;
let element;
for (i = 0; i < this._elements.length; i++) {
element = this._elements[i];
if (i == 0) {
element.width = element.width + dx;
} else {
element.x = element.x + dx;
element.tx = element.tx + dx;
}
element.y = element.y + top;
}
for (i = 0; i < this._elements.length; i++) {
element = this._elements[i];
element.group.setAttribute('transform', 'translate(' + element.x + ',' + element.y + ')');
const r1 = i == 0 && first;
const r2 = i == this._elements.length - 1 && first;
const r3 = i == this._elements.length - 1 && last;
const r4 = i == 0 && last;
element.path.setAttribute(
'd',
grapher.NodeElement.roundedRect(0, 0, element.width, element.height, r1, r2, r3, r4)
);
element.text.setAttribute('x', 6);
element.text.setAttribute('y', element.ty);
}
let lineElement;
for (i = 0; i < this._elements.length; i++) {
element = this._elements[i];
if (i != 0) {
lineElement = this.createElement('line');
lineElement.setAttribute('class', 'node');
lineElement.setAttribute('x1', element.x);
lineElement.setAttribute('x2', element.x);
lineElement.setAttribute('y1', top);
lineElement.setAttribute('y2', top + this._height);
parentElement.appendChild(lineElement);
}
}
if (!first) {
lineElement = this.createElement('line');
lineElement.setAttribute('class', 'node');
lineElement.setAttribute('x1', 0);
lineElement.setAttribute('x2', width);
lineElement.setAttribute('y1', top);
lineElement.setAttribute('y2', top);
parentElement.appendChild(lineElement);
}
}
createElement(name) {
return this._document.createElementNS('http://www.w3.org/2000/svg', name);
}
};
grapher.NodeElement.List = class {
constructor(document) {
this._document = document;
this._items = [];
}
add(id, name, value, tooltip, separator) {
this._items.push({id: id, name: name, value: value, tooltip: tooltip, separator: separator});
}
get handler() {
return this._handler;
}
set handler(handler) {
this._handler = handler;
}
layout(parentElement) {
this._width = 0;
this._height = 0;
const x = 0;
const y = 0;
this._element = this.createElement('g');
this._element.setAttribute('class', 'node-attribute');
parentElement.appendChild(this._element);
if (this._handler) {
this._element.addEventListener('click', this._handler);
}
this._backgroundElement = this.createElement('path');
this._element.appendChild(this._backgroundElement);
this._element.setAttribute('transform', 'translate(' + x + ',' + y + ')');
this._height += 3;
for (const item of this._items) {
const yPadding = 1;
const xPadding = 6;
const textElement = this.createElement('text');
if (item.id) {
textElement.setAttribute('id', item.id);
}
textElement.setAttribute('xml:space', 'preserve');
this._element.appendChild(textElement);
if (item.tooltip) {
const titleElement = this.createElement('title');
titleElement.textContent = item.tooltip;
textElement.appendChild(titleElement);
}
const textNameElement = this.createElement('tspan');
textNameElement.textContent = item.name;
if (item.separator.trim() != '=') {
textNameElement.style.fontWeight = 'bold';
}
textElement.appendChild(textNameElement);
const textValueElement = this.createElement('tspan');
textValueElement.textContent = item.separator + item.value;
textElement.appendChild(textValueElement);
const size = textElement.getBBox();
const width = xPadding + size.width + xPadding;
if (this._width < width) {
this._width = width;
}
textElement.setAttribute('x', x + xPadding);
textElement.setAttribute('y', this._height + yPadding - size.y);
this._height += yPadding + size.height + yPadding;
}
this._height += 3;
if (this._width < 100) {
this._width = 100;
}
}
get width() {
return this._width;
}
get height() {
return this._height;
}
update(parentElement, top, width, first, last) {
this._element.setAttribute('transform', 'translate(0,' + top + ')');
const r1 = first;
const r2 = first;
const r3 = last;
const r4 = last;
this._backgroundElement.setAttribute(
'd',
grapher.NodeElement.roundedRect(0, 0, width, this._height, r1, r2, r3, r4)
);
if (!first) {
const lineElement = this.createElement('line');
lineElement.setAttribute('class', 'node');
lineElement.setAttribute('x1', 0);
lineElement.setAttribute('x2', width);
lineElement.setAttribute('y1', 0);
lineElement.setAttribute('y2', 0);
this._element.appendChild(lineElement);
}
}
createElement(name) {
return this._document.createElementNS('http://www.w3.org/2000/svg', name);
}
};
class Path {
constructor() {
this._x0 = null;
this._y0 = null;
this._x1 = null;
this._y1 = null;
this._data = '';
}
moveTo(x, y) {
this._data += 'M' + (this._x0 = this._x1 = +x) + ',' + (this._y0 = this._y1 = +y);
}
lineTo(x, y) {
this._data += 'L' + (this._x1 = +x) + ',' + (this._y1 = +y);
}
bezierCurveTo(x1, y1, x2, y2, x, y) {
this._data += 'C' + +x1 + ',' + +y1 + ',' + +x2 + ',' + +y2 + ',' + (this._x1 = +x) + ',' + (this._y1 = +y);
}
closePath() {
if (this._x1 !== null) {
this._x1 = this._x0;
this._y1 = this._y0;
this._data += 'Z';
}
}
get data() {
return this._data;
}
}
class Curve {
constructor(context) {
this._context = context;
}
lineStart() {
this._x0 = NaN;
this._x1 = NaN;
this._y0 = NaN;
this._y1 = NaN;
this._point = 0;
}
lineEnd() {
switch (this._point) {
case 3:
this.curve(this._x1, this._y1);
this._context.lineTo(this._x1, this._y1);
break;
case 2:
this._context.lineTo(this._x1, this._y1);
break;
}
if (this._line || (this._line !== 0 && this._point === 1)) {
this._context.closePath();
}
this._line = 1 - this._line;
}
point(x, y) {
x = +x;
y = +y;
switch (this._point) {
case 0:
this._point = 1;
if (this._line) {
this._context.lineTo(x, y);
} else {
this._context.moveTo(x, y);
}
break;
case 1:
this._point = 2;
break;
case 2:
this._point = 3;
this._context.lineTo((5 * this._x0 + this._x1) / 6, (5 * this._y0 + this._y1) / 6);
this.curve(x, y);
break;
default:
this.curve(x, y);
break;
}
this._x0 = this._x1;
this._x1 = x;
this._y0 = this._y1;
this._y1 = y;
}
curve(x, y) {
this._context.bezierCurveTo(
(2 * this._x0 + this._x1) / 3,
(2 * this._y0 + this._y1) / 3,
(this._x0 + 2 * this._x1) / 3,
(this._y0 + 2 * this._y1) / 3,
(this._x0 + 4 * this._x1 + x) / 6,
(this._y0 + 4 * this._y1 + y) / 6
);
}
}
if (typeof module !== 'undefined' && typeof module.exports === 'object') {
module.exports.Renderer = grapher.Renderer;
module.exports.NodeElement = grapher.NodeElement;
}
...@@ -20,30 +20,38 @@ const zip = require('netron/src/zip'); ...@@ -20,30 +20,38 @@ const zip = require('netron/src/zip');
const gzip = require('netron/src/gzip'); const gzip = require('netron/src/gzip');
const tar = require('netron/src/tar'); const tar = require('netron/src/tar');
const protobuf = require('netron/src/protobuf'); const protobuf = require('netron/src/protobuf');
const d3 = require('d3'); const d3 = require('d3');
const dagre = require('dagre'); const dagre = require('dagre');
const grapher = require('./view-grapher');
const grapher = require('netron/src/view-grapher');
const sidebar = require('./sidebar'); const sidebar = require('./sidebar');
const view = {}; const view = {};
const graphMetadata = require('./paddle-metadata');
const {style} = require('d3');
view.View = class { view.View = class {
constructor(host) { constructor(host) {
this._host = host; this._host = host;
this._host this._host
.initialize(this) .initialize(this)
.then(() => { .then(() => {
this.typeLayer = {};
this.Language = 'zh';
this.graphMetadatas = [];
this.non_graphMetadatas = [];
this.isCtrl = false;
this._clusters = {};
this._nodeName = {};
this._nodes = {};
this._model = null; this._model = null;
this._selection = []; this._selection = [];
this._selectItem = null;
this._host.start(); this._host.start();
this._showAttributes = false; this._showAttributes = false;
this._showInitializers = true; this._showInitializers = true;
this._showNames = false; this._showNames = false;
this._KeepData = false;
this._showHorizontal = false; this._showHorizontal = false;
this._modelFactoryService = new view.ModelFactoryService(this._host); this._modelFactoryService = new view.ModelFactoryService(this._host);
this._graphNodes = {};
}) })
.catch(err => { .catch(err => {
this.error(err.message, err); this.error(err.message, err);
...@@ -70,16 +78,22 @@ view.View = class { ...@@ -70,16 +78,22 @@ view.View = class {
if (this._activeGraph) { if (this._activeGraph) {
this.clearSelection(); this.clearSelection();
const graphElement = document.getElementById('canvas'); const graphElement = document.getElementById('canvas');
const view = new sidebar.FindSidebar(this._host, graphElement, this._activeGraph); if (this._allGraph) {
this._host.message('search', view.update(value)); const view = new sidebar.FindSidebar(this._host, graphElement, this._allGraph);
this._host.message('search', view.update(value));
} else {
const view = new sidebar.FindSidebar(this._host, graphElement, this._activeGraph);
this._host.message('search', view.update(value));
}
} }
} }
toggleAttributes(toggle) { toggleAttributes(toggle) {
if (toggle != null && !(toggle ^ this._showAttributes)) { if (toggle != null && !(toggle ^ this._showAttributes)) {
return; return;
} }
this._showAttributes = toggle == null ? !this._showAttributes : toggle; this._showAttributes = toggle == null ? !this._showAttributes : toggle;
this._nodes = {};
this._clusters = {};
this._reload(); this._reload();
} }
...@@ -92,6 +106,8 @@ view.View = class { ...@@ -92,6 +106,8 @@ view.View = class {
return; return;
} }
this._showInitializers = toggle == null ? !this._showInitializers : toggle; this._showInitializers = toggle == null ? !this._showInitializers : toggle;
this._nodes = {};
this._clusters = {};
this._reload(); this._reload();
} }
...@@ -104,12 +120,26 @@ view.View = class { ...@@ -104,12 +120,26 @@ view.View = class {
return; return;
} }
this._showNames = toggle == null ? !this._showNames : toggle; this._showNames = toggle == null ? !this._showNames : toggle;
this._nodes = {};
this._clusters = {};
this._reload(); this._reload();
} }
toggleKeepData(toggle) {
if (toggle != null && !(toggle ^ this._KeepData)) {
return;
}
this._KeepData = toggle == null ? !this._KeepData : toggle;
// this._reload();
}
toggleLanguage(data) {
this.Language = data;
}
get showNames() { get showNames() {
return this._showNames; return this._showNames;
} }
get KeepData() {
return this._KeepData;
}
toggleDirection(toggle) { toggleDirection(toggle) {
if (toggle != null && !(toggle ^ this._showHorizontal)) { if (toggle != null && !(toggle ^ this._showHorizontal)) {
...@@ -130,7 +160,7 @@ view.View = class { ...@@ -130,7 +160,7 @@ view.View = class {
_reload() { _reload() {
this._host.status('loading'); this._host.status('loading');
if (this._model && this._activeGraph) { if (this._model && this._activeGraph) {
this._updateGraph(this._model, this._activeGraph).catch(error => { this._updateGraph2(this._model, this._activeGraph).catch(error => {
if (error) { if (error) {
this.error('Graph update failed.', error); this.error('Graph update failed.', error);
} }
...@@ -157,7 +187,9 @@ view.View = class { ...@@ -157,7 +187,9 @@ view.View = class {
this._zoom.scaleBy(d3.select(this._host.document.getElementById('canvas')), 0.8); this._zoom.scaleBy(d3.select(this._host.document.getElementById('canvas')), 0.8);
} }
} }
selectItem(data) {
this._selectItem = data;
}
resetZoom() { resetZoom() {
if (this._zoom) { if (this._zoom) {
this._zoom.scaleTo(d3.select(this._host.document.getElementById('canvas')), 1); this._zoom.scaleTo(d3.select(this._host.document.getElementById('canvas')), 1);
...@@ -166,6 +198,14 @@ view.View = class { ...@@ -166,6 +198,14 @@ view.View = class {
select(item) { select(item) {
this.clearSelection(); this.clearSelection();
if (item.type === 'node') {
for (const nodes of this._allGraph.nodes) {
if (nodes.name === item.name) {
this.showNodeProperties(nodes);
break;
}
}
}
const graphElement = document.getElementById('canvas'); const graphElement = document.getElementById('canvas');
const selection = sidebar.FindSidebar.selection(item, graphElement); const selection = sidebar.FindSidebar.selection(item, graphElement);
if (selection && selection.length > 0) { if (selection && selection.length > 0) {
...@@ -191,6 +231,16 @@ view.View = class { ...@@ -191,6 +231,16 @@ view.View = class {
); );
} }
} }
select2(item) {
const graphElement = document.getElementById('canvas');
const selection = sidebar.FindSidebar.selection2(item, graphElement);
if (selection && selection.length > 0) {
for (const element of selection) {
this._selection.push(element);
element.classList.add('select');
}
}
}
clearSelection() { clearSelection() {
while (this._selection.length > 0) { while (this._selection.length > 0) {
...@@ -216,36 +266,105 @@ view.View = class { ...@@ -216,36 +266,105 @@ view.View = class {
graphs: model.graphs.map(g => g.name || ''), graphs: model.graphs.map(g => g.name || ''),
selected: graph && (graph.name || '') selected: graph && (graph.name || '')
}); });
return this._updateGraph(model, graph); return this._updateGraph2(model, graph);
}); });
}); });
}); });
} }
changeAlt(data) {
this.isCtrl = data;
console.log('isCtrl', this.isCtrl);
}
keydown() {
// 用户按下ctrl后变量isCtrl为true
this._host.document.onkeydown = e => {
if (
e.code === 'MetaLeft' ||
e.code === 'MetaRight' ||
e.code === 'ControlLeft' ||
e.code === 'AltLeft' ||
e.code === 'AltRight'
) {
this.isCtrl = true;
}
};
// 用户松开ctrl后变量isCtrl为false
this._host.document.onkeyup = e => {
if (
e.code === 'MetaLeft' ||
e.code === 'MetaRight' ||
e.code === 'ControlLeft' ||
e.code === 'AltLeft' ||
e.code === 'AltRight'
) {
this.isCtrl = false;
}
};
}
changeGraph(name) { changeGraph(name) {
this._updateActiveGraph(name); this._updateActiveGraph(name);
} }
changeAllGrap(graph) {
_updateActiveGraph(name) { this._allGraph = graph;
if (this._model) { for (const node of graph.nodes) {
const model = this._model; this._nodeName[node.name] = node;
const graph = model.graphs.filter(graph => name == graph.name).shift(); }
if (graph) { // console.log('this._nodeName',this._nodeName);
this._host.status('loading'); }
this._timeout(200).then(() => { changeSelect(selectItem) {
return this._updateGraph(model, graph).catch(error => { this._selectItem = selectItem;
if (error) { }
this.error('Graph update failed.', error); _updateActiveGraph(data) {
} if (data) {
}); this._host.status('loading');
this._timeout(200).then(() => {
return this._updateGraph(data).catch(error => {
if (error) {
this.error('Graph update failed.', error);
}
}); });
} });
} }
} }
_updateGraph(data) {
_updateGraph(model, graph) { return this._timeout(100).then(() => {
// 直接在此处传入模型数据的数据
if (data && data != this._activeGraph) {
const nodes = data.nodes;
if (nodes.length > 1400) {
if (
!this._host.confirm(
'Large model detected.',
'This graph contains a large number of nodes and might take a long time to render. Do you want to continue?'
)
) {
return null;
}
}
}
return this.renderGraph(data, data)
.then(() => {
this._model = data;
this._activeGraph = data;
this._host.status('rendered');
return this._model;
})
.catch(error => {
return this.renderGraph(this._model, this._activeGraph)
.then(() => {
this._host.status('rendered');
})
.finally(() => {
throw error;
});
});
});
}
_updateGraph2(graph) {
return this._timeout(100).then(() => { return this._timeout(100).then(() => {
// 直接在此处传入模型数据的数据
if (graph && graph != this._activeGraph) { if (graph && graph != this._activeGraph) {
this._selectItem = null;
const nodes = graph.nodes; const nodes = graph.nodes;
if (nodes.length > 1400) { if (nodes.length > 1400) {
if ( if (
...@@ -258,9 +377,9 @@ view.View = class { ...@@ -258,9 +377,9 @@ view.View = class {
} }
} }
} }
return this.renderGraph(model, graph) return this.renderGraph(graph, graph)
.then(() => { .then(() => {
this._model = model; this._model = graph;
this._activeGraph = graph; this._activeGraph = graph;
this._host.status('rendered'); this._host.status('rendered');
return this._model; return this._model;
...@@ -269,19 +388,66 @@ view.View = class { ...@@ -269,19 +388,66 @@ view.View = class {
return this.renderGraph(this._model, this._activeGraph) return this.renderGraph(this._model, this._activeGraph)
.then(() => { .then(() => {
this._host.status('rendered'); this._host.status('rendered');
throw error;
}) })
.catch(() => { .finally(() => {
throw error; throw error;
}); });
}); });
}); });
} }
jumpRoute(node) {
console.log('node', node);
if (node.is_leaf) {
console.log('isCtrl', this.isCtrl);
if (this.isCtrl) {
for (const nodes of this._allGraph.nodes) {
if (nodes.name === node.name) {
for (const type of this.non_graphMetadatas) {
console.log('type', type.name.toLowerCase(), node.type);
if (type.name.toLowerCase() === node.type) {
if (this.Language === 'zh') {
window.open(
`https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/${type.name}_cn.html`
);
} else {
window.open(
`https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/${type.name}_en.html`
);
}
}
}
}
}
}
} else {
if (this.isCtrl) {
for (const nodes of this._allGraph.nodes) {
if (nodes.name === node.name) {
for (const type of this.non_graphMetadatas) {
if (type.name === node.type) {
window.open(
`https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/${type.name}_cn.html`
);
}
}
}
}
}
}
}
renderGraph(model, graph) { renderGraph(model, graph) {
try { try {
this.keydown();
this.graphMetadatas = graphMetadata.default.leaf_nodes;
this.non_graphMetadatas = graphMetadata.default.non_leaf_nodes;
const typeLayer = {};
for (const type of graphMetadata.default.non_leaf_nodes) {
typeLayer[type.name] = true;
}
this.typeLayer = typeLayer;
const graphElement = this._host.document.getElementById('canvas'); const graphElement = this._host.document.getElementById('canvas');
while (graphElement.lastChild) { while (graphElement.lastChild) {
// 做上一次渲染的的清理动作
graphElement.removeChild(graphElement.lastChild); graphElement.removeChild(graphElement.lastChild);
} }
if (!graph) { if (!graph) {
...@@ -290,163 +456,262 @@ view.View = class { ...@@ -290,163 +456,262 @@ view.View = class {
this._zoom = null; this._zoom = null;
graphElement.style.position = 'absolute'; graphElement.style.position = 'absolute';
graphElement.style.margin = '0'; graphElement.style.margin = '0';
const groups = true;
const groups = graph.groups;
const graphOptions = {}; const graphOptions = {};
graphOptions.nodesep = 25; graphOptions.nodesep = 65;
graphOptions.ranksep = 20; graphOptions.ranksep = 60;
if (this._showHorizontal) { if (this._showHorizontal) {
graphOptions.rankdir = 'LR'; graphOptions.rankdir = 'LR';
} }
const g = new dagre.graphlib.Graph({compound: groups}); const g = new dagre.graphlib.Graph({compound: groups});
g.setGraph(graphOptions); g.setGraph(graphOptions);
g.setDefaultEdgeLabel(() => { g.setDefaultEdgeLabel(() => {
return {}; return {};
}); });
let nodeId = 0; let nodeId = 0;
const edgeMap = {}; const edgeMap = {};
const clusterMap = {}; const clusterMap = {};
const clusterParentMap = {}; const clusterParentMap = {};
let id = new Date().getTime(); let id = new Date().getTime();
const nodes = graph.nodes; let nodes = graph.nodes;
if (nodes.length > 1500) { if (nodes.length > 1500) {
graphOptions.ranker = 'longest-path'; graphOptions.ranker = 'longest-path';
} }
if (groups) { if (groups) {
for (const node of nodes) { for (const node of nodes) {
if (node.group) { let path = node.name.split('/');
const path = node.group.split('/'); path.pop();
while (path.length > 0) { while (path.length > 0) {
const name = path.join('/'); const name = path.join('/');
path.pop(); path.pop();
if (name) {
clusterParentMap[name] = path.join('/'); clusterParentMap[name] = path.join('/');
} }
} }
} }
} }
for (const node of nodes) { for (const node of nodes) {
const element = new grapher.NodeElement(this._host.document); let element = null;
if (!document.getElementById(`node-${node.name}`)) {
element = new grapher.NodeElement(this._host.document);
}
const addNode = (element, node, edges) => { const addNode = (element, node, edges) => {
const header = element.block('header'); if (!document.getElementById(`node-${node.name}`)) {
const styles = ['node-item-type']; const header = element.block('header');
const metadata = node.metadata; const styles = ['node-item-type'];
const category = metadata && metadata.category ? metadata.category : ''; const type = node.type;
if (category) { if (node.is_leaf) {
styles.push('node-item-type-' + category.toLowerCase()); for (const metadatas of this.graphMetadatas) {
} if (node.type === metadatas.name) {
const type = node.type; if (metadatas.schema.category) {
if (typeof type !== 'string' || !type.split) { styles.push('node-item-type-' + metadatas.schema.category.toLowerCase());
// #416 }
throw new ModelError( }
"Unknown node type '" + JSON.stringify(type) + "' in '" + model.format + "'."
);
}
const content = this.showNames && node.name ? node.name : type.split('.').pop();
const tooltip = this.showNames && node.name ? type : node.name;
header.add(null, styles, content, tooltip, () => {
this.showNodeProperties(node);
});
if (node.function) {
header.add(null, ['node-item-function'], '+', null, () => {
// debugger;
});
}
const initializers = [];
let hiddenInitializers = false;
if (this._showInitializers) {
for (const input of node.inputs) {
if (
input.visible &&
input.arguments.length == 1 &&
input.arguments[0].initializer != null
) {
initializers.push(input);
} }
if ( } else {
(!input.visible || input.arguments.length > 1) && for (const metadatas of this.non_graphMetadatas) {
input.arguments.some(argument => argument.initializer != null) if (node.type === metadatas.name) {
) { if (metadatas.schema.category) {
hiddenInitializers = true; styles.push('node-item-type-' + metadatas.schema.category.toLowerCase());
}
}
} }
} }
} if (typeof type !== 'string' || !type.split) {
let sortedAttributes = []; // #416
const attributes = node.attributes; throw new ModelError(
if (this.showAttributes && attributes) { "Unknown node type '" + JSON.stringify(type) + "' in '" + model.format + "'."
sortedAttributes = attributes.filter(attribute => attribute.visible).slice(); );
sortedAttributes.sort((a, b) => { }
const au = a.name.toUpperCase(); const nodeName = node.name.split('/')[node.name.split('/').length - 1];
const bu = b.name.toUpperCase(); const content = this.showNames && node.name ? nodeName : type.split('.').pop();
return au < bu ? -1 : au > bu ? 1 : 0; const tooltip = this.showNames && node.name ? type : nodeName;
}); header.add(null, styles, content, tooltip, () => {
} this.jumpRoute(node);
if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
const block = element.block('list');
block.handler = () => {
this.showNodeProperties(node); this.showNodeProperties(node);
}; this.select({
for (const initializer of initializers) { id: `node-${node.name}`,
const argument = initializer.arguments[0]; name: node.name,
const type = argument.type; type: 'node'
let shape = ''; });
let separator = ''; const inputs = node.inputs;
if ( for (const input of inputs) {
type && for (const argument of input.arguments) {
type.shape && if (argument.name != '' && !argument.initializer) {
type.shape.dimensions && this.select2({
Object.prototype.hasOwnProperty.call(type.shape.dimensions, 'length') id: `edge-${argument.name}`,
) { name: argument.name,
shape = type: 'input',
'\u3008' + tonode: node.name
type.shape.dimensions.map(d => (d ? d : '?')).join('\u00D7') + });
'\u3009'; }
}
}
let outputs = node.outputs;
if (node.chain && node.chain.length > 0) {
const chainOutputs = node.chain[node.chain.length - 1].outputs;
if (chainOutputs.length > 0) {
outputs = chainOutputs;
}
}
for (const output of outputs) {
for (const argument of output.arguments) {
if (argument.name != '') {
this.select2({
id: `edge-${argument.name}`,
name: argument.name,
type: 'input',
fromnode: node.name
});
}
}
}
});
const buttons = ['node-item-buttons'];
if (!node.is_leaf) {
header.add(null, buttons, '+', null, () => {
// debugger;
this._host.selectNodeId({
nodeId: node.name,
expand: true
});
this._host.selectItems({
id: `node-${node.name}`,
name: node.name,
type: 'node'
});
});
}
const initializers = [];
let hiddenInitializers = false;
if (this._showInitializers) {
// 是否显示初始化参数
for (const input of node.inputs) {
if ( if (
type.shape.dimensions.length == 0 && input.visible &&
argument.initializer && input.arguments.length == 1 &&
!argument.initializer.state input.arguments[0].initializer != null
) { ) {
shape = argument.initializer.toString(); initializers.push(input);
if (shape && shape.length > 10) { }
shape = shape.substring(0, 10) + '\u2026'; if (
} (!input.visible || input.arguments.length > 1) &&
separator = ' = '; input.arguments.some(argument => argument.initializer != null)
) {
hiddenInitializers = true;
} }
} }
block.add(
'initializer-' + argument.name,
initializer.name,
shape,
type ? type.toString() : '',
separator
);
} }
if (hiddenInitializers) { let sortedAttributes = [];
block.add(null, '\u3008' + '\u2026' + '\u3009', '', null, ''); const attributes = node.attributes;
if (this.showAttributes && attributes) {
// 是否显示属性参数
sortedAttributes = attributes.filter(attribute => attribute.visible).slice();
sortedAttributes.sort((a, b) => {
const au = a.name.toUpperCase();
const bu = b.name.toUpperCase();
return au < bu ? -1 : au > bu ? 1 : 0;
});
} }
if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
const block = element.block('list');
block.handler = () => {
// 侧边栏点击事件
this.jumpRoute(node);
this.showNodeProperties(node);
this.select({
id: `node-${node.name}`,
name: node.name,
type: 'node'
});
const inputs = node.inputs;
for (const input of inputs) {
for (const argument of input.arguments) {
if (argument.name != '' && !argument.initializer) {
this.select2({
id: `edge-${argument.name}`,
name: argument.name,
type: 'input',
tonode: node.name
});
}
}
}
for (const attribute of sortedAttributes) { let outputs = node.outputs;
if (attribute.visible) { if (node.chain && node.chain.length > 0) {
let attributeValue = sidebar.NodeSidebar.formatAttributeValue( const chainOutputs = node.chain[node.chain.length - 1].outputs;
attribute.value, if (chainOutputs.length > 0) {
attribute.type outputs = chainOutputs;
}
}
for (const output of outputs) {
for (const argument of output.arguments) {
if (argument.name != '') {
this.select2({
id: `edge-${argument.name}`,
name: argument.name,
type: 'input',
fromnode: node.name
});
}
}
}
};
for (const initializer of initializers) {
const argument = initializer.arguments[0];
const type = argument.type;
let shape = '';
let separator = '';
if (
type &&
type.shape &&
type.shape.dimensions &&
Object.prototype.hasOwnProperty.call(type.shape.dimensions, 'length')
) {
shape =
'\u3008' +
type.shape.dimensions.map(d => (d ? d : '?')).join('\u00D7') +
'\u3009';
if (
type.shape.dimensions.length == 0 &&
argument.initializer &&
!argument.initializer.state
) {
shape = argument.initializer.toString();
if (shape && shape.length > 10) {
shape = shape.substring(0, 10) + '\u2026';
}
separator = ' = ';
}
}
block.add(
'initializer-' + argument.name,
initializer.name,
shape,
type ? type.toString() : '',
separator
); );
if (attributeValue && attributeValue.length > 25) { }
attributeValue = attributeValue.substring(0, 25) + '\u2026'; if (hiddenInitializers) {
block.add(null, '\u3008' + '\u2026' + '\u3009', '', null, '');
}
for (const attribute of sortedAttributes) {
if (attribute.visible) {
let attributeValue = sidebar.NodeSidebar.formatAttributeValue(
attribute.value,
attribute.type
);
if (attributeValue && attributeValue.length > 25) {
attributeValue = attributeValue.substring(0, 25) + '\u2026';
}
block.add(null, attribute.name, attributeValue, attribute.type, ' = ');
} }
block.add(null, attribute.name, attributeValue, attribute.type, ' = ');
} }
} }
} }
if (edges) { if (edges) {
const inputs = node.inputs; const inputs = node.inputs;
for (const input of inputs) { for (const input of inputs) {
...@@ -458,12 +723,15 @@ view.View = class { ...@@ -458,12 +723,15 @@ view.View = class {
edgeMap[argument.name] = tuple; edgeMap[argument.name] = tuple;
} }
tuple.to.push({ tuple.to.push({
// 这个节点的id
node: nodeId, node: nodeId,
name: input.name name: input.name,
nodename: node.name
}); });
} }
} }
} }
let outputs = node.outputs; let outputs = node.outputs;
if (node.chain && node.chain.length > 0) { if (node.chain && node.chain.length > 0) {
const chainOutputs = node.chain[node.chain.length - 1].outputs; const chainOutputs = node.chain[node.chain.length - 1].outputs;
...@@ -482,13 +750,13 @@ view.View = class { ...@@ -482,13 +750,13 @@ view.View = class {
tuple.from = { tuple.from = {
node: nodeId, node: nodeId,
name: output.name, name: output.name,
type: argument.type type: argument.type,
nodename: node.name
}; };
} }
} }
} }
} }
if (node.chain && node.chain.length > 0) { if (node.chain && node.chain.length > 0) {
for (const innerNode of node.chain) { for (const innerNode of node.chain) {
addNode(element, innerNode, false); addNode(element, innerNode, false);
...@@ -512,33 +780,63 @@ view.View = class { ...@@ -512,33 +780,63 @@ view.View = class {
tuple.to.push({ tuple.to.push({
node: nodeId, node: nodeId,
name: controlDependency, name: controlDependency,
controlDependency: true controlDependency: true,
nodename: node.name
}); });
} }
} }
const nodeName = node.name; const nodeName = node.name;
if (nodeName) { if (!document.getElementById(`node-${node.name}`)) {
g.setNode(nodeId, {label: element.format(graphElement), id: 'node-' + nodeName}); // 此时图上没有
if (nodeName) {
g.setNode(nodeId, {label: element.format(graphElement), id: 'node-' + nodeName});
} else {
g.setNode(nodeId, {label: element.format(graphElement), id: 'node-' + id.toString()});
id++;
}
} else { } else {
g.setNode(nodeId, {label: element.format(graphElement), id: 'node-' + id.toString()}); g.setNode(nodeId, {label: 'node-' + nodeName, id: 'node-' + nodeName});
id++;
} }
const isKeepData = this._KeepData;
const createCluster = function (name) { const createCluster = (name, node) => {
const non_leaf_nodes = graphMetadata.default.non_leaf_nodes;
const styles = ['clusterGroup'];
const showName = node.show_name.split('/')[node.show_name.split('/').length - 1];
if (this._nodeName[name]) {
for (const non_leaf_node of non_leaf_nodes) {
if (this._nodeName[name].type === non_leaf_node.name) {
styles.push(`clusterGroup-${non_leaf_node.schema.category.toLowerCase()}`);
break;
}
}
}
let newStyle = ['clusterGroup'];
if (styles.length > 1) {
newStyle = ['clusterGroup', styles[1]];
}
if (!clusterMap[name]) { if (!clusterMap[name]) {
g.setNode(name, {rx: 5, ry: 5}); g.setNode(name, {
rx: 10,
ry: 10,
nodeId: name,
showName: showName,
expand: false,
classList: newStyle,
isKeepData: isKeepData
});
clusterMap[name] = true; clusterMap[name] = true;
const parent = clusterParentMap[name]; const parent = clusterParentMap[name]; // 父节点的父节点
if (parent) { if (parent) {
createCluster(parent); createCluster(parent, node);
g.setParent(name, parent); g.setParent(name, parent);
} }
} }
}; };
if (groups) { if (groups) {
let groupName = node.group; let path = node.name.split('/');
path.pop();
let groupName = path.join('/');
if (groupName && groupName.length > 0) { if (groupName && groupName.length > 0) {
if (!Object.prototype.hasOwnProperty.call(clusterParentMap, groupName)) { if (!Object.prototype.hasOwnProperty.call(clusterParentMap, groupName)) {
const lastIndex = groupName.lastIndexOf('/'); const lastIndex = groupName.lastIndexOf('/');
...@@ -552,15 +850,14 @@ view.View = class { ...@@ -552,15 +850,14 @@ view.View = class {
} }
} }
if (groupName) { if (groupName) {
createCluster(groupName); createCluster(groupName, node);
g.setParent(nodeId, groupName); g.setParent(nodeId, groupName);
} }
} }
} }
this._graphNodes[node.name] = element;
nodeId++; nodeId++;
} }
for (const input of graph.inputs) { for (const input of graph.inputs) {
for (const argument of input.arguments) { for (const argument of input.arguments) {
let tuple = edgeMap[argument.name]; let tuple = edgeMap[argument.name];
...@@ -586,7 +883,6 @@ view.View = class { ...@@ -586,7 +883,6 @@ view.View = class {
}); });
g.setNode(nodeId++, {label: inputElement.format(graphElement), class: 'graph-input'}); g.setNode(nodeId++, {label: inputElement.format(graphElement), class: 'graph-input'});
} }
for (const output of graph.outputs) { for (const output of graph.outputs) {
for (const argument of output.arguments) { for (const argument of output.arguments) {
let tuple = edgeMap[argument.name]; let tuple = edgeMap[argument.name];
...@@ -609,7 +905,6 @@ view.View = class { ...@@ -609,7 +905,6 @@ view.View = class {
}); });
g.setNode(nodeId++, {label: outputElement.format(graphElement)}); g.setNode(nodeId++, {label: outputElement.format(graphElement)});
} }
for (const edge of Object.keys(edgeMap)) { for (const edge of Object.keys(edgeMap)) {
const tuple = edgeMap[edge]; const tuple = edgeMap[edge];
if (tuple.from != null) { if (tuple.from != null) {
...@@ -629,21 +924,25 @@ view.View = class { ...@@ -629,21 +924,25 @@ view.View = class {
label: text, label: text,
id: 'edge-' + edge, id: 'edge-' + edge,
arrowhead: 'vee', arrowhead: 'vee',
class: 'edge-path-control-dependency' class: 'edge-path-control-dependency',
fromnode: tuple.from.nodename,
tonode: to.nodename
}); });
} else { } else {
g.setEdge(tuple.from.node, to.node, { g.setEdge(tuple.from.node, to.node, {
label: text, label: text,
id: 'edge-' + edge, id: 'edge-' + edge,
arrowhead: 'vee' arrowhead: 'vee',
fromnode: tuple.from.nodename,
tonode: to.nodename
}); });
} }
} }
} }
} }
// Workaround for Safari background drag/zoom issue: // Workaround for Safari background drag/zoom issue:
// https://stackoverflow.com/questions/40887193/d3-js-zoom-is-not-working-with-mousewheel-in-safari // https://stackoverflow.com/questions/40887193/d3-js-zoom-is-not-working-with-mousewheel-in-safari
// if (!this.secondChange) {
const backgroundElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'rect'); const backgroundElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'rect');
backgroundElement.setAttribute('id', 'background'); backgroundElement.setAttribute('id', 'background');
backgroundElement.setAttribute('width', '100%'); backgroundElement.setAttribute('width', '100%');
...@@ -651,52 +950,69 @@ view.View = class { ...@@ -651,52 +950,69 @@ view.View = class {
backgroundElement.setAttribute('fill', 'none'); backgroundElement.setAttribute('fill', 'none');
backgroundElement.setAttribute('pointer-events', 'all'); backgroundElement.setAttribute('pointer-events', 'all');
graphElement.appendChild(backgroundElement); graphElement.appendChild(backgroundElement);
const originElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'g'); const originElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'g');
originElement.setAttribute('id', 'origin'); originElement.setAttribute('id', 'origin');
graphElement.appendChild(originElement); graphElement.appendChild(originElement);
let svg = null; let svg = null;
svg = d3.select(graphElement); svg = d3.select(graphElement);
backgroundElement.addEventListener('click', () => {
this.clearSelection();
});
this._zoom = d3.zoom(); this._zoom = d3.zoom();
this._zoom(svg); this._zoom(svg);
this._zoom.scaleExtent([0.1, 2]); this._zoom.scaleExtent([0.01, 2]); // 缩放的范围
this._zoom.on('zoom', () => { this._zoom.on('zoom', () => {
originElement.setAttribute('transform', d3.event.transform.toString()); originElement.setAttribute('transform', d3.event.transform.toString());
}); });
this._zoom.transform(svg, d3.zoomIdentity); this._zoom.transform(svg, d3.zoomIdentity);
return this._timeout(20).then(() => { return this._timeout(200).then(() => {
const graphRenderer = new grapher.Renderer(this._host.document, originElement); const graphRenderer = new grapher.Renderer(this._host, originElement, this);
graphRenderer.render(g); graphRenderer.render(g);
for (const cluster of document.getElementById('clusters').children) {
this._clusters[cluster.getAttribute('id')] = cluster;
}
for (const node of document.getElementById('nodes').children) {
this._nodes[node.getAttribute('id')] = node;
}
const inputElements = graphElement.getElementsByClassName('graph-input'); const inputElements = graphElement.getElementsByClassName('graph-input');
const svgSize = graphElement.getBoundingClientRect(); const svgSize = graphElement.getBoundingClientRect();
if (inputElements && inputElements.length > 0) { if (inputElements && inputElements.length > 0) {
// Center view based on input elements // Center view based on input elements
const xs = []; if (this._selectItem) {
const ys = []; this.select(this._selectItem);
for (let i = 0; i < inputElements.length; i++) { } else {
const inputTransform = inputElements[i].transform.baseVal.consolidate().matrix; const xs = [];
xs.push(inputTransform.e); const ys = [];
ys.push(inputTransform.f); for (let i = 0; i < inputElements.length; i++) {
} const inputTransform = inputElements[i].transform.baseVal.consolidate().matrix;
let x = xs[0]; xs.push(inputTransform.e);
const y = ys[0]; ys.push(inputTransform.f);
if (ys.every(y => y == ys[0])) { }
x = xs.reduce((a, b) => a + b) / xs.length; let x = xs[0];
const y = ys[0];
if (ys.every(y => y == ys[0])) {
x = xs.reduce((a, b) => a + b) / xs.length;
}
const sx = svgSize.width / (this._showHorizontal ? 4 : 2) - x;
const sy = svgSize.height / (this._showHorizontal ? 2 : 4) - y;
this._zoom.transform(svg, d3.zoomIdentity.translate(sx, sy));
} }
const sx = svgSize.width / (this._showHorizontal ? 4 : 2) - x; // 这里应该触发一次小地图重定位
const sy = svgSize.height / (this._showHorizontal ? 2 : 4) - y;
this._zoom.transform(svg, d3.zoomIdentity.translate(sx, sy));
} else { } else {
this._zoom.transform( if (this._selectItem) {
svg, this.select(this._selectItem);
d3.zoomIdentity.translate( } else {
(svgSize.width - g.graph().width) / 2, this._zoom.transform(
(svgSize.height - g.graph().height) / 2 svg,
) d3.zoomIdentity.translate(
); (svgSize.width - g.graph().width) / 2,
(svgSize.height - g.graph().height) / 2
)
);
}
} }
return; return;
}); });
...@@ -705,7 +1021,6 @@ view.View = class { ...@@ -705,7 +1021,6 @@ view.View = class {
return Promise.reject(error); return Promise.reject(error);
} }
} }
applyStyleSheet(element, name) { applyStyleSheet(element, name) {
let rules = []; let rules = [];
for (let i = 0; i < this._host.document.styleSheets.length; i++) { for (let i = 0; i < this._host.document.styleSheets.length; i++) {
...@@ -807,6 +1122,7 @@ view.View = class { ...@@ -807,6 +1122,7 @@ view.View = class {
if (this._model) { if (this._model) {
const modelSidebar = new sidebar.ModelSidebar(this._host, this._model, this._activeGraph); const modelSidebar = new sidebar.ModelSidebar(this._host, this._model, this._activeGraph);
this._host.message('show-model-properties', modelSidebar.render()); this._host.message('show-model-properties', modelSidebar.render());
// 通信函数
} }
} }
......
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// cSpell:words grapher selectall nodesep ranksep rankdir pbtxt
const zip = require('netron/src/zip');
const gzip = require('netron/src/gzip');
const tar = require('netron/src/tar');
const protobuf = require('netron/src/protobuf');
const d3 = require('d3');
const dagre = require('dagre');
const grapher = require('netron/src/view-grapher');
const sidebar = require('./sidebar');
const view = {};
view.View = class {
constructor(host) {
this._host = host;
this._host
.initialize(this)
.then(() => {
this._model = null;
this._selection = [];
this._host.start();
this._showAttributes = false;
this._showInitializers = true;
this._showNames = false;
this._showHorizontal = false;
this._modelFactoryService = new view.ModelFactoryService(this._host);
})
.catch(err => {
this.error(err.message, err);
});
}
cut() {
this._host.document.execCommand('cut');
}
copy() {
this._host.document.execCommand('copy');
}
paste() {
this._host.document.execCommand('paste');
}
selectAll() {
this._host.document.execCommand('selectall');
}
find(value) {
if (this._activeGraph) {
this.clearSelection();
const graphElement = document.getElementById('canvas');
const view = new sidebar.FindSidebar(this._host, graphElement, this._activeGraph);
this._host.message('search', view.update(value));
}
}
toggleAttributes(toggle) {
if (toggle != null && !(toggle ^ this._showAttributes)) {
return;
}
this._showAttributes = toggle == null ? !this._showAttributes : toggle;
this._reload();
}
get showAttributes() {
return this._showAttributes;
}
toggleInitializers(toggle) {
if (toggle != null && !(toggle ^ this._showInitializers)) {
return;
}
this._showInitializers = toggle == null ? !this._showInitializers : toggle;
this._reload();
}
get showInitializers() {
return this._showInitializers;
}
toggleNames(toggle) {
if (toggle != null && !(toggle ^ this._showNames)) {
return;
}
this._showNames = toggle == null ? !this._showNames : toggle;
this._reload();
}
get showNames() {
return this._showNames;
}
toggleDirection(toggle) {
if (toggle != null && !(toggle ^ this._showHorizontal)) {
return;
}
this._showHorizontal = toggle == null ? !this._showHorizontal : toggle;
this._reload();
}
get showHorizontal() {
return this._showHorizontal;
}
toggleTheme(theme) {
this._host.document.body.className = theme;
}
_reload() {
this._host.status('loading');
if (this._model && this._activeGraph) {
this._updateGraph(this._model, this._activeGraph).catch(error => {
if (error) {
this.error('Graph update failed.', error);
}
});
}
}
_timeout(time) {
return new Promise(resolve => {
setTimeout(() => {
resolve();
}, time);
});
}
zoomIn() {
if (this._zoom) {
this._zoom.scaleBy(d3.select(this._host.document.getElementById('canvas')), 1.2);
}
}
zoomOut() {
if (this._zoom) {
this._zoom.scaleBy(d3.select(this._host.document.getElementById('canvas')), 0.8);
}
}
resetZoom() {
if (this._zoom) {
this._zoom.scaleTo(d3.select(this._host.document.getElementById('canvas')), 1);
}
}
select(item) {
this.clearSelection();
const graphElement = document.getElementById('canvas');
const selection = sidebar.FindSidebar.selection(item, graphElement);
if (selection && selection.length > 0) {
const graphElement = this._host.document.getElementById('canvas');
const graphRect = graphElement.getBoundingClientRect();
let x = 0;
let y = 0;
for (const element of selection) {
element.classList.add('select');
this._selection.push(element);
const transform = element.transform.baseVal.consolidate();
const box = element.getBBox();
const ex = transform ? transform.matrix.e : box.x + box.width / 2;
const ey = transform ? transform.matrix.f : box.y + box.height / 2;
x += ex;
y += ey;
}
x = x / selection.length;
y = y / selection.length;
this._zoom.transform(
d3.select(graphElement),
d3.zoomIdentity.translate(graphRect.width / 2 - x, graphRect.height / 2 - y)
);
}
}
clearSelection() {
while (this._selection.length > 0) {
const element = this._selection.pop();
element.classList.remove('select');
}
}
error(message, err) {
this._host.error(message, err.toString());
}
accept(file) {
return this._modelFactoryService.accept(file);
}
open(context) {
return this._timeout(2).then(() => {
return this._modelFactoryService.open(context).then(model => {
return this._timeout(20).then(() => {
const graph = model.graphs.length > 0 ? model.graphs[0] : null;
this._host.message('opened', {
graphs: model.graphs.map(g => g.name || ''),
selected: graph && (graph.name || '')
});
return this._updateGraph(model, graph);
});
});
});
}
changeGraph(name) {
this._updateActiveGraph(name);
}
_updateActiveGraph(name) {
if (this._model) {
const model = this._model;
const graph = model.graphs.filter(graph => name == graph.name).shift();
if (graph) {
this._host.status('loading');
this._timeout(200).then(() => {
return this._updateGraph(model, graph).catch(error => {
if (error) {
this.error('Graph update failed.', error);
}
});
});
}
}
}
_updateGraph(model, graph) {
return this._timeout(100).then(() => {
if (graph && graph != this._activeGraph) {
const nodes = graph.nodes;
if (nodes.length > 1400) {
if (
!this._host.confirm(
'Large model detected.',
'This graph contains a large number of nodes and might take a long time to render. Do you want to continue?'
)
) {
return null;
}
}
}
return this.renderGraph(model, graph)
.then(() => {
this._model = model;
this._activeGraph = graph;
this._host.status('rendered');
return this._model;
})
.catch(error => {
return this.renderGraph(this._model, this._activeGraph)
.then(() => {
this._host.status('rendered');
throw error;
})
.catch(() => {
throw error;
});
});
});
}
renderGraph(model, graph) {
try {
const graphElement = this._host.document.getElementById('canvas');
while (graphElement.lastChild) {
graphElement.removeChild(graphElement.lastChild);
}
if (!graph) {
return Promise.resolve();
} else {
this._zoom = null;
graphElement.style.position = 'absolute';
graphElement.style.margin = '0';
const groups = graph.groups;
const graphOptions = {};
graphOptions.nodesep = 25;
graphOptions.ranksep = 20;
if (this._showHorizontal) {
graphOptions.rankdir = 'LR';
}
const g = new dagre.graphlib.Graph({compound: groups});
g.setGraph(graphOptions);
g.setDefaultEdgeLabel(() => {
return {};
});
let nodeId = 0;
const edgeMap = {};
const clusterMap = {};
const clusterParentMap = {};
let id = new Date().getTime();
const nodes = graph.nodes;
if (nodes.length > 1500) {
graphOptions.ranker = 'longest-path';
}
if (groups) {
for (const node of nodes) {
if (node.group) {
const path = node.group.split('/');
while (path.length > 0) {
const name = path.join('/');
path.pop();
clusterParentMap[name] = path.join('/');
}
}
}
}
for (const node of nodes) {
const element = new grapher.NodeElement(this._host.document);
const addNode = (element, node, edges) => {
const header = element.block('header');
const styles = ['node-item-type'];
const metadata = node.metadata;
const category = metadata && metadata.category ? metadata.category : '';
if (category) {
styles.push('node-item-type-' + category.toLowerCase());
}
const type = node.type;
if (typeof type !== 'string' || !type.split) {
// #416
throw new ModelError(
"Unknown node type '" + JSON.stringify(type) + "' in '" + model.format + "'."
);
}
const content = this.showNames && node.name ? node.name : type.split('.').pop();
const tooltip = this.showNames && node.name ? type : node.name;
header.add(null, styles, content, tooltip, () => {
this.showNodeProperties(node);
});
if (node.function) {
header.add(null, ['node-item-function'], '+', null, () => {
// debugger;
});
}
const initializers = [];
let hiddenInitializers = false;
if (this._showInitializers) {
for (const input of node.inputs) {
if (
input.visible &&
input.arguments.length == 1 &&
input.arguments[0].initializer != null
) {
initializers.push(input);
}
if (
(!input.visible || input.arguments.length > 1) &&
input.arguments.some(argument => argument.initializer != null)
) {
hiddenInitializers = true;
}
}
}
let sortedAttributes = [];
const attributes = node.attributes;
if (this.showAttributes && attributes) {
sortedAttributes = attributes.filter(attribute => attribute.visible).slice();
sortedAttributes.sort((a, b) => {
const au = a.name.toUpperCase();
const bu = b.name.toUpperCase();
return au < bu ? -1 : au > bu ? 1 : 0;
});
}
if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
const block = element.block('list');
block.handler = () => {
this.showNodeProperties(node);
};
for (const initializer of initializers) {
const argument = initializer.arguments[0];
const type = argument.type;
let shape = '';
let separator = '';
if (
type &&
type.shape &&
type.shape.dimensions &&
Object.prototype.hasOwnProperty.call(type.shape.dimensions, 'length')
) {
shape =
'\u3008' +
type.shape.dimensions.map(d => (d ? d : '?')).join('\u00D7') +
'\u3009';
if (
type.shape.dimensions.length == 0 &&
argument.initializer &&
!argument.initializer.state
) {
shape = argument.initializer.toString();
if (shape && shape.length > 10) {
shape = shape.substring(0, 10) + '\u2026';
}
separator = ' = ';
}
}
block.add(
'initializer-' + argument.name,
initializer.name,
shape,
type ? type.toString() : '',
separator
);
}
if (hiddenInitializers) {
block.add(null, '\u3008' + '\u2026' + '\u3009', '', null, '');
}
for (const attribute of sortedAttributes) {
if (attribute.visible) {
let attributeValue = sidebar.NodeSidebar.formatAttributeValue(
attribute.value,
attribute.type
);
if (attributeValue && attributeValue.length > 25) {
attributeValue = attributeValue.substring(0, 25) + '\u2026';
}
block.add(null, attribute.name, attributeValue, attribute.type, ' = ');
}
}
}
if (edges) {
const inputs = node.inputs;
for (const input of inputs) {
for (const argument of input.arguments) {
if (argument.name != '' && !argument.initializer) {
let tuple = edgeMap[argument.name];
if (!tuple) {
tuple = {from: null, to: []};
edgeMap[argument.name] = tuple;
}
tuple.to.push({
node: nodeId,
name: input.name
});
}
}
}
let outputs = node.outputs;
if (node.chain && node.chain.length > 0) {
const chainOutputs = node.chain[node.chain.length - 1].outputs;
if (chainOutputs.length > 0) {
outputs = chainOutputs;
}
}
for (const output of outputs) {
for (const argument of output.arguments) {
if (argument.name != '') {
let tuple = edgeMap[argument.name];
if (!tuple) {
tuple = {from: null, to: []};
edgeMap[argument.name] = tuple;
}
tuple.from = {
node: nodeId,
name: output.name,
type: argument.type
};
}
}
}
}
if (node.chain && node.chain.length > 0) {
for (const innerNode of node.chain) {
addNode(element, innerNode, false);
}
}
if (node.inner) {
addNode(element, node.inner, false);
}
};
addNode(element, node, true);
if (node.controlDependencies && node.controlDependencies.length > 0) {
for (const controlDependency of node.controlDependencies) {
let tuple = edgeMap[controlDependency];
if (!tuple) {
tuple = {from: null, to: []};
edgeMap[controlDependency] = tuple;
}
tuple.to.push({
node: nodeId,
name: controlDependency,
controlDependency: true
});
}
}
const nodeName = node.name;
if (nodeName) {
g.setNode(nodeId, {label: element.format(graphElement), id: 'node-' + nodeName});
} else {
g.setNode(nodeId, {label: element.format(graphElement), id: 'node-' + id.toString()});
id++;
}
const createCluster = function (name) {
if (!clusterMap[name]) {
g.setNode(name, {rx: 5, ry: 5});
clusterMap[name] = true;
const parent = clusterParentMap[name];
if (parent) {
createCluster(parent);
g.setParent(name, parent);
}
}
};
if (groups) {
let groupName = node.group;
if (groupName && groupName.length > 0) {
if (!Object.prototype.hasOwnProperty.call(clusterParentMap, groupName)) {
const lastIndex = groupName.lastIndexOf('/');
if (lastIndex != -1) {
groupName = groupName.substring(0, lastIndex);
if (!Object.prototype.hasOwnProperty.call(clusterParentMap, groupName)) {
groupName = null;
}
} else {
groupName = null;
}
}
if (groupName) {
createCluster(groupName);
g.setParent(nodeId, groupName);
}
}
}
nodeId++;
}
for (const input of graph.inputs) {
for (const argument of input.arguments) {
let tuple = edgeMap[argument.name];
if (!tuple) {
tuple = {from: null, to: []};
edgeMap[argument.name] = tuple;
}
tuple.from = {
node: nodeId,
type: argument.type
};
}
const types = input.arguments.map(argument => argument.type || '').join('\n');
let inputName = input.name || '';
if (inputName.length > 16) {
inputName = inputName.split('/').pop();
}
const inputElement = new grapher.NodeElement(this._host.document);
const inputHeader = inputElement.block('header');
inputHeader.add(null, ['graph-item-input'], inputName, types, () => {
this.showModelProperties();
});
g.setNode(nodeId++, {label: inputElement.format(graphElement), class: 'graph-input'});
}
for (const output of graph.outputs) {
for (const argument of output.arguments) {
let tuple = edgeMap[argument.name];
if (!tuple) {
tuple = {from: null, to: []};
edgeMap[argument.name] = tuple;
}
tuple.to.push({node: nodeId});
}
const outputTypes = output.arguments.map(argument => argument.type || '').join('\n');
let outputName = output.name || '';
if (outputName.length > 16) {
outputName = outputName.split('/').pop();
}
const outputElement = new grapher.NodeElement(this._host.document);
const outputHeader = outputElement.block('header');
outputHeader.add(null, ['graph-item-output'], outputName, outputTypes, () => {
this.showModelProperties();
});
g.setNode(nodeId++, {label: outputElement.format(graphElement)});
}
for (const edge of Object.keys(edgeMap)) {
const tuple = edgeMap[edge];
if (tuple.from != null) {
for (const to of tuple.to) {
let text = '';
const type = tuple.from.type;
if (type && type.shape && type.shape.dimensions && type.shape.dimensions.length > 0) {
text = type.shape.dimensions.join('\u00D7');
}
if (this._showNames) {
text = edge.split('\n').shift(); // custom argument id
}
if (to.controlDependency) {
g.setEdge(tuple.from.node, to.node, {
label: text,
id: 'edge-' + edge,
arrowhead: 'vee',
class: 'edge-path-control-dependency'
});
} else {
g.setEdge(tuple.from.node, to.node, {
label: text,
id: 'edge-' + edge,
arrowhead: 'vee'
});
}
}
}
}
// Workaround for Safari background drag/zoom issue:
// https://stackoverflow.com/questions/40887193/d3-js-zoom-is-not-working-with-mousewheel-in-safari
const backgroundElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'rect');
backgroundElement.setAttribute('id', 'background');
backgroundElement.setAttribute('width', '100%');
backgroundElement.setAttribute('height', '100%');
backgroundElement.setAttribute('fill', 'none');
backgroundElement.setAttribute('pointer-events', 'all');
graphElement.appendChild(backgroundElement);
const originElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'g');
originElement.setAttribute('id', 'origin');
graphElement.appendChild(originElement);
let svg = null;
svg = d3.select(graphElement);
this._zoom = d3.zoom();
this._zoom(svg);
this._zoom.scaleExtent([0.1, 2]);
this._zoom.on('zoom', () => {
originElement.setAttribute('transform', d3.event.transform.toString());
});
this._zoom.transform(svg, d3.zoomIdentity);
return this._timeout(20).then(() => {
const graphRenderer = new grapher.Renderer(this._host.document, originElement);
graphRenderer.render(g);
const inputElements = graphElement.getElementsByClassName('graph-input');
const svgSize = graphElement.getBoundingClientRect();
if (inputElements && inputElements.length > 0) {
// Center view based on input elements
const xs = [];
const ys = [];
for (let i = 0; i < inputElements.length; i++) {
const inputTransform = inputElements[i].transform.baseVal.consolidate().matrix;
xs.push(inputTransform.e);
ys.push(inputTransform.f);
}
let x = xs[0];
const y = ys[0];
if (ys.every(y => y == ys[0])) {
x = xs.reduce((a, b) => a + b) / xs.length;
}
const sx = svgSize.width / (this._showHorizontal ? 4 : 2) - x;
const sy = svgSize.height / (this._showHorizontal ? 2 : 4) - y;
this._zoom.transform(svg, d3.zoomIdentity.translate(sx, sy));
} else {
this._zoom.transform(
svg,
d3.zoomIdentity.translate(
(svgSize.width - g.graph().width) / 2,
(svgSize.height - g.graph().height) / 2
)
);
}
return;
});
}
} catch (error) {
return Promise.reject(error);
}
}
applyStyleSheet(element, name) {
let rules = [];
for (let i = 0; i < this._host.document.styleSheets.length; i++) {
const styleSheet = this._host.document.styleSheets[i];
if (styleSheet && styleSheet.href && styleSheet.href.endsWith('/' + name)) {
rules = styleSheet.cssRules;
break;
}
}
const nodes = element.getElementsByTagName('*');
for (let j = 0; j < nodes.length; j++) {
const node = nodes[j];
for (let k = 0; k < rules.length; k++) {
const rule = rules[k];
if (node.matches(rule.selectorText)) {
for (let l = 0; l < rule.style.length; l++) {
const item = rule.style.item(l);
node.style[item] = rule.style[item];
}
}
}
}
}
export(file) {
const lastIndex = file.lastIndexOf('.');
const extension = lastIndex != -1 ? file.substring(lastIndex + 1) : '';
if (this._activeGraph && (extension == 'png' || extension == 'svg')) {
const graphElement = this._host.document.getElementById('canvas');
const exportElement = graphElement.cloneNode(true);
this.applyStyleSheet(exportElement, 'style.css');
exportElement.setAttribute('id', 'export');
exportElement.removeAttribute('width');
exportElement.removeAttribute('height');
exportElement.style.removeProperty('opacity');
exportElement.style.removeProperty('display');
const backgroundElement = exportElement.querySelector('#background');
const originElement = exportElement.querySelector('#origin');
originElement.setAttribute('transform', 'translate(0,0) scale(1)');
backgroundElement.removeAttribute('width');
backgroundElement.removeAttribute('height');
const parentElement = graphElement.parentElement;
parentElement.insertBefore(exportElement, graphElement);
const size = exportElement.getBBox();
parentElement.removeChild(exportElement);
parentElement.removeChild(graphElement);
parentElement.appendChild(graphElement);
const delta = (Math.min(size.width, size.height) / 2.0) * 0.1;
const width = Math.ceil(delta + size.width + delta);
const height = Math.ceil(delta + size.height + delta);
originElement.setAttribute(
'transform',
'translate(' + delta.toString() + ', ' + delta.toString() + ') scale(1)'
);
exportElement.setAttribute('width', width);
exportElement.setAttribute('height', height);
backgroundElement.setAttribute('width', width);
backgroundElement.setAttribute('height', height);
backgroundElement.setAttribute('fill', '#fff');
const data = new XMLSerializer().serializeToString(exportElement);
if (extension === 'svg') {
const blob = new Blob([data], {type: 'image/svg'});
this._host.export(file, blob);
} else if (extension === 'png') {
const imageElement = new Image();
imageElement.onload = () => {
const max = Math.max(width, height);
const scale = max * 2.0 > 24000 ? 24000.0 / max : 2.0;
const canvas = this._host.document.createElement('canvas');
canvas.width = Math.ceil(width * scale);
canvas.height = Math.ceil(height * scale);
const context = canvas.getContext('2d');
context.scale(scale, scale);
context.drawImage(imageElement, 0, 0);
this._host.document.body.removeChild(imageElement);
canvas.toBlob(blob => {
if (blob) {
this._host.export(file, blob);
} else {
const err = new Error();
err.name = 'Error exporting image.';
err.message = 'Image may be too large to render as PNG.';
this._host.exception(err, false);
this._host.error(err.name, err.message);
}
}, 'image/png');
};
imageElement.src = 'data:image/svg+xml;base64,' + window.btoa(unescape(encodeURIComponent(data)));
this._host.document.body.insertBefore(imageElement, this._host.document.body.firstChild);
}
}
}
showModelProperties() {
if (this._model) {
const modelSidebar = new sidebar.ModelSidebar(this._host, this._model, this._activeGraph);
this._host.message('show-model-properties', modelSidebar.render());
}
}
showNodeProperties(node) {
if (node) {
const nodeSidebar = new sidebar.NodeSidebar(this._host, node);
// TODO: export
// nodeSidebar.on('export-tensor', (sender, tensor) => {
// this._host
// .require('./numpy')
// .then(numpy => {
// const defaultPath = tensor.name
// ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_')
// : 'tensor';
// this._host.save('NumPy Array', 'npy', defaultPath, file => {
// try {
// const dataTypeMap = new Map([
// ['float16', 'f2'],
// ['float32', 'f4'],
// ['float64', 'f8'],
// ['int8', 'i1'],
// ['int16', 'i2'],
// ['int32', 'i4'],
// ['int64', 'i8'],
// ['uint8', 'u1'],
// ['uint16', 'u2'],
// ['uint32', 'u4'],
// ['uint64', 'u8'],
// ['qint8', 'i1'],
// ['qint16', 'i2'],
// ['quint8', 'u1'],
// ['quint16', 'u2']
// ]);
// const array = new numpy.Array();
// array.shape = tensor.type.shape.dimensions;
// array.data = tensor.value;
// array.dataType = dataTypeMap.has(tensor.type.dataType)
// ? dataTypeMap.get(tensor.type.dataType)
// : tensor.type.dataType;
// const blob = new Blob([array.toBuffer()], {type: 'application/octet-stream'});
// this._host.export(file, blob);
// } catch (error) {
// this.error('Error saving NumPy tensor.', error);
// }
// });
// })
// .catch(() => {});
// });
this._host.message('show-node-properties', {...nodeSidebar.render(), metadata: node.metadata});
}
}
showNodeDocumentation(node) {
const metadata = node.metadata;
if (metadata) {
const documentationSidebar = new sidebar.DocumentationSidebar(this._host, metadata);
this._host.message('show-node-documentation', documentationSidebar.render());
}
}
};
class ModelError extends Error {
constructor(message, telemetry) {
super(message);
this.name = 'Error loading model.';
this.telemetry = telemetry;
}
}
class ModelContext {
constructor(context) {
this._context = context;
this._tags = new Map();
this._entries = new Map();
}
request(file, encoding) {
return this._context.request(file, encoding);
}
get identifier() {
return this._context.identifier;
}
get buffer() {
return this._context.buffer;
}
get text() {
if (!this._text) {
this._text = new TextDecoder('utf-8').decode(this.buffer);
}
return this._text;
}
entries(extension) {
let entries = this._entries.get(extension);
if (!entries) {
entries = [];
try {
const buffer = this.buffer;
switch (extension) {
case 'zip': {
if (buffer && buffer.length > 2 && buffer[0] == 0x50 && buffer[1] == 0x4b) {
entries = new zip.Archive(buffer).entries;
}
break;
}
case 'tar': {
if (buffer.length >= 512) {
let sum = 0;
for (let i = 0; i < 512; i++) {
sum += i >= 148 && i < 156 ? 32 : buffer[i];
}
let checksum = '';
for (let i = 148; i < 156 && buffer[i] !== 0x00; i++) {
checksum += String.fromCharCode(buffer[i]);
}
checksum = parseInt(checksum, 8);
if (!isNaN(checksum) && sum == checksum) {
entries = new tar.Archive(buffer).entries;
}
}
break;
}
}
} catch (error) {
entries = [];
}
this._entries.set(extension, entries);
}
return entries;
}
tags(extension) {
let tags = this._tags.get(extension);
if (!tags) {
tags = new Map();
try {
switch (extension) {
case 'pbtxt': {
const b = this.buffer;
const length = b.length;
const signature =
(length >= 3 && b[0] === 0xef && b[1] === 0xbb && b[2] === 0xbf) ||
(length >= 4 && b[0] === 0x00 && b[1] === 0x00 && b[2] === 0xfe && b[3] === 0xff) ||
(length >= 4 && b[0] === 0xff && b[1] === 0xfe && b[2] === 0x00 && b[3] === 0x00) ||
(length >= 4 && b[0] === 0x84 && b[1] === 0x31 && b[2] === 0x95 && b[3] === 0x33) ||
(length >= 2 && b[0] === 0xfe && b[1] === 0xff) ||
(length >= 2 && b[0] === 0xff && b[1] === 0xfe);
if (
!signature &&
b.subarray(0, Math.min(1024, length)).some(c => c < 7 || (c > 14 && c < 32))
) {
break;
}
const reader = protobuf.TextReader.create(this.text);
reader.start(false);
while (!reader.end(false)) {
const tag = reader.tag();
tags.set(tag, true);
reader.skip();
}
break;
}
case 'pb': {
const tagTypes = new Set([0, 1, 2, 3, 5]);
const reader = protobuf.Reader.create(this.buffer);
const end = reader.next();
while (reader.pos < end) {
const tagType = reader.uint32();
tags.set(tagType >>> 3, tagType & 7);
if (!tagTypes.has(tagType & 7)) {
tags = new Map();
break;
}
try {
reader.skipType(tagType & 7);
} catch (err) {
tags = new Map();
break;
}
}
break;
}
}
} catch (error) {
tags = new Map();
}
this._tags.set(extension, tags);
}
return tags;
}
}
class ArchiveContext {
constructor(entries, rootFolder, identifier, buffer) {
this._entries = {};
if (entries) {
for (const entry of entries) {
if (entry.name.startsWith(rootFolder)) {
const name = entry.name.substring(rootFolder.length);
if (identifier.length > 0 && identifier.indexOf('/') < 0) {
this._entries[name] = entry;
}
}
}
}
this._identifier = identifier.substring(rootFolder.length);
this._buffer = buffer;
}
request(file, encoding) {
const entry = this._entries[file];
if (!entry) {
return Promise.reject(new Error('File not found.'));
}
const data = encoding ? new TextDecoder(encoding).decode(entry.data) : entry.data;
return Promise.resolve(data);
}
get identifier() {
return this._identifier;
}
get buffer() {
return this._buffer;
}
}
class ArchiveError extends Error {
constructor(message) {
super(message);
this.name = 'Error loading archive.';
}
}
view.ModelFactoryService = class {
constructor(host) {
this._host = host;
this._extensions = [];
this.register('./onnx', ['.onnx', '.pb', '.pbtxt', '.prototxt']);
this.register('./mxnet', ['.mar', '.model', '.json', '.params']);
this.register('./keras', ['.h5', '.hd5', '.hdf5', '.keras', '.json', '.model', '.pb', '.pth']);
this.register('./coreml', ['.mlmodel']);
this.register('./caffe', ['.caffemodel', '.pbtxt', '.prototxt', '.pt']);
this.register('./caffe2', ['.pb', '.pbtxt', '.prototxt']);
this.register('./pytorch', [
'.pt',
'.pth',
'.pt1',
'.pkl',
'.h5',
'.t7',
'.model',
'.dms',
'.tar',
'.ckpt',
'.bin',
'.pb',
'.zip'
]);
this.register('./torch', ['.t7']);
this.register('./tflite', ['.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json']);
this.register('./tf', ['.pb', '.meta', '.pbtxt', '.prototxt', '.json', '.index', '.ckpt']);
this.register('./mediapipe', ['.pbtxt']);
this.register('./uff', ['.uff', '.pb', '.trt', '.pbtxt', '.uff.txt']);
this.register('./sklearn', ['.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5']);
this.register('./cntk', ['.model', '.cntk', '.cmf', '.dnn']);
this.register('./paddle', ['.paddle', '.pdmodel', '__model__']);
this.register('./armnn', ['.armnn']);
this.register('./bigdl', ['.model', '.bigdl']);
this.register('./darknet', ['.cfg', '.model']);
this.register('./mnn', ['.mnn']);
this.register('./ncnn', ['.param', '.bin', '.cfg.ncnn', '.weights.ncnn']);
this.register('./tnn', ['.tnnproto', '.tnnmodel']);
this.register('./tengine', ['.tmfile']);
this.register('./barracuda', ['.nn']);
this.register('./openvino', ['.xml', '.bin']);
this.register('./flux', ['.bson']);
this.register('./npz', ['.npz', '.h5', '.hd5', '.hdf5']);
this.register('./dl4j', ['.zip']);
this.register('./mlnet', ['.zip']);
}
register(id, extensions) {
for (const extension of extensions) {
this._extensions.push({extension: extension, id: id});
}
}
open(context) {
return this._openSignature(context).then(context => {
return this._openArchive(context).then(context => {
context = new ModelContext(context);
const identifier = context.identifier;
const extension = identifier.split('.').pop().toLowerCase();
const modules = this._filter(context);
if (modules.length == 0) {
throw new ModelError("Unsupported file extension '." + extension + "'.");
}
const errors = [];
let match = false;
const nextModule = () => {
if (modules.length > 0) {
const id = modules.shift();
return this._host.require(id).then(module => {
if (!module.ModelFactory) {
throw new ModelError("Failed to load module '" + id + "'.");
}
const modelFactory = new module.ModelFactory();
if (!modelFactory.match(context)) {
return nextModule();
}
match++;
return modelFactory
.open(context, this._host)
.then(model => {
return model;
})
.catch(error => {
errors.push(error);
return nextModule();
});
});
} else {
if (match) {
if (errors.length == 1) {
throw errors[0];
}
throw new ModelError(errors.map(err => err.message).join('\n'));
}
const knownUnsupportedIdentifiers = new Set([
'natives_blob.bin',
'v8_context_snapshot.bin',
'snapshot_blob.bin',
'image_net_labels.json',
'package.json',
'models.json',
'LICENSE.meta',
'input_0.pb',
'output_0.pb'
]);
const skip = knownUnsupportedIdentifiers.has(identifier);
const buffer = context.buffer;
const content = Array.from(buffer.subarray(0, Math.min(16, buffer.length)))
.map(c => (c < 16 ? '0' : '') + c.toString(16))
.join('');
throw new ModelError(
'Unsupported file content (' +
content +
") for extension '." +
extension +
"' in '" +
identifier +
"'.",
!skip
);
}
};
return nextModule();
});
});
}
_openArchive(context) {
let archive = null;
let extension;
let identifier = context.identifier;
let buffer = context.buffer;
try {
extension = identifier.split('.').pop().toLowerCase();
if (extension == 'gz' || extension == 'tgz') {
archive = new gzip.Archive(buffer);
if (archive.entries.length == 1) {
const entry = archive.entries[0];
if (entry.name) {
identifier = entry.name;
} else {
identifier = identifier.substring(0, identifier.lastIndexOf('.'));
if (extension == 'tgz') {
identifier += '.tar';
}
}
buffer = entry.data;
}
}
} catch (error) {
const message = error && error.message ? error.message : error.toString();
return Promise.reject(new ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."));
}
try {
extension = identifier.split('.').pop().toLowerCase();
switch (extension) {
case 'tar': {
// handle .pth.tar
const torch = [0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
if (
!buffer ||
buffer.length < 14 ||
buffer[0] != 0x80 ||
!torch.every((v, i) => v == buffer[i + 2])
) {
archive = new tar.Archive(buffer);
}
break;
}
case 'zip': {
archive = new zip.Archive(buffer);
// PyTorch Zip archive
if (
archive.entries.some(e => e.name.split('/').pop().split('\\').pop() === 'version') &&
archive.entries.some(e => e.name.split('/').pop().split('\\').pop() === 'data.pkl')
) {
return Promise.resolve(context);
}
// dl4j
if (
archive.entries.some(e => e.name.split('/').pop().split('\\').pop() === 'coefficients.bin') &&
archive.entries.some(e => e.name.split('/').pop().split('\\').pop() === 'configuration.json')
) {
return Promise.resolve(context);
}
break;
}
}
} catch (error) {
const message = error && error.message ? error.message : error.toString();
return Promise.reject(new ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."));
}
if (!archive) {
return Promise.resolve(context);
}
try {
const folders = {};
for (const entry of archive.entries) {
if (entry.name.indexOf('/') != -1) {
folders[entry.name.split('/').shift() + '/'] = true;
} else {
folders['/'] = true;
}
}
if (extension == 'tar') {
delete folders['PaxHeader/'];
}
let rootFolder = Object.keys(folders).length == 1 ? Object.keys(folders)[0] : '';
rootFolder = rootFolder == '/' ? '' : rootFolder;
let matches = [];
const entries = archive.entries.slice();
const nextEntry = () => {
if (entries.length > 0) {
const entry = entries.shift();
if (entry.name.startsWith(rootFolder)) {
const identifier = entry.name.substring(rootFolder.length);
if (identifier.length > 0 && identifier.indexOf('/') < 0 && !identifier.startsWith('.')) {
const context = new ModelContext(
new ArchiveContext(null, rootFolder, entry.name, entry.data)
);
let modules = this._filter(context);
const nextModule = () => {
if (modules.length > 0) {
const id = modules.shift();
return this._host.require(id).then(module => {
if (!module.ModelFactory) {
throw new ArchiveError("Failed to load module '" + id + "'.", null);
}
const factory = new module.ModelFactory();
if (factory.match(context)) {
matches.push(entry);
modules = [];
}
return nextModule();
});
} else {
return nextEntry();
}
};
return nextModule();
}
}
return nextEntry();
} else {
if (matches.length == 0) {
return Promise.resolve(context);
}
// MXNet
if (
matches.length == 2 &&
matches.some(e => e.name.endsWith('.params')) &&
matches.some(e => e.name.endsWith('-symbol.json'))
) {
matches = matches.filter(e => e.name.endsWith('.params'));
}
if (matches.length > 1) {
return Promise.reject(new ArchiveError('Archive contains multiple model files.'));
}
const match = matches[0];
return Promise.resolve(
new ModelContext(new ArchiveContext(archive.entries, rootFolder, match.name, match.data))
);
}
};
return nextEntry();
} catch (error) {
return Promise.reject(new ArchiveError(error.message));
}
}
accept(identifier) {
identifier = identifier.toLowerCase();
for (const extension of this._extensions) {
if (identifier.endsWith(extension.extension)) {
return true;
}
}
if (
identifier.endsWith('.zip') ||
identifier.endsWith('.tar') ||
identifier.endsWith('.tar.gz') ||
identifier.endsWith('.tgz')
) {
return true;
}
return false;
}
_filter(context) {
const identifier = context.identifier.toLowerCase();
const list = this._extensions.filter(entry => identifier.endsWith(entry.extension)).map(entry => entry.id);
return Array.from(new Set(list));
}
_openSignature(context) {
const buffer = context.buffer;
if (context.buffer.length === 0) {
return Promise.reject(new ModelError('File has no content.', true));
}
const list = [
// cSpell:disable
{name: 'ELF executable', value: /^\x7FELF/},
{name: 'Git LFS header', value: /^version https:\/\/git-lfs.github.com\/spec\/v1\n/},
{name: 'Git LFS header', value: /^oid sha256:/},
{name: 'HTML markup', value: /^\s*<html>/},
{name: 'HTML markup', value: /^\s*<!DOCTYPE html>/},
{name: 'HTML markup', value: /^\s*<!DOCTYPE HTML>/},
{name: 'Unity metadata', value: /^fileFormatVersion:/},
{name: 'Vulkan SwiftShader ICD manifest', value: /^{\s*"file_format_version":\s*"1.0.0"\s*,\s*"ICD":/},
{name: 'StringIntLabelMapProto data', value: /^item\s*{\r?\n\s*id:/},
{name: 'StringIntLabelMapProto data', value: /^item\s*{\r?\n\s*name:/},
{name: 'Python source code', value: /^\s*import sys, types, os;/}
// cSpell:enable
];
const text = new TextDecoder().decode(buffer.subarray(0, Math.min(1024, buffer.length)));
for (const item of list) {
if (text.match(item.value)) {
return Promise.reject(new ModelError('Invalid file content. File contains ' + item.name + '.', true));
}
}
return Promise.resolve(context);
}
};
if (typeof module !== 'undefined' && typeof module.exports === 'object') {
module.exports.View = view.View;
module.exports.ModelFactoryService = view.ModelFactoryService;
}
...@@ -6146,6 +6146,11 @@ eslint-plugin-react@7.25.1: ...@@ -6146,6 +6146,11 @@ eslint-plugin-react@7.25.1:
resolve "^2.0.0-next.3" resolve "^2.0.0-next.3"
string.prototype.matchall "^4.0.5" string.prototype.matchall "^4.0.5"
eslint-plugin-simple-import-sort@^7.0.0:
version "7.0.0"
resolved "https://registry.npmmirror.com/eslint-plugin-simple-import-sort/-/eslint-plugin-simple-import-sort-7.0.0.tgz#a1dad262f46d2184a90095a60c66fef74727f0f8"
integrity sha512-U3vEDB5zhYPNfxT5TYR7u01dboFZp+HNpnGhkDB2g/2E4wZ/g1Q9Ton8UwCLfRV9yAKyYqDh62oHOamvkFxsvw==
eslint-scope@5.1.1, eslint-scope@^5.1.1: eslint-scope@5.1.1, eslint-scope@^5.1.1:
version "5.1.1" version "5.1.1"
resolved "https://registry.yarnpkg.com/eslint-scope/-/eslint-scope-5.1.1.tgz#e786e59a66cb92b3f6c1fb0d508aab174848f48c" resolved "https://registry.yarnpkg.com/eslint-scope/-/eslint-scope-5.1.1.tgz#e786e59a66cb92b3f6c1fb0d508aab174848f48c"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册