未验证 提交 19e5e52e 编写于 作者: P Peter Pan 提交者: GitHub

bump 2.1.0 (#891)

* style: add python lint & pre-commit

* style: use pre-commit to lint fe codes

* fix: show raw shape of tensor in high-dimensional chart

* feat: implement random sampling in high-dimensional chart

* feat: add fog in scatter chart

* feat: focus hovered point in high-dimensional chart

* chore: hide unfulfilled features

* chore: update dependencies

* fix: fix lost import

* style: fix style lint

* chore: fix typo

* bump 2.1.0
上级 8cff5856
#!/bin/bash
set -e
(cd frontend && npm run lint)
RESULT=$?
[ $RESULT -ne 0 ] && exit 1
exit 0
[flake8] [flake8]
max-line-length = 120 extend-exclude = venv,build,dist,frontend,visualdl/proto/record_pb2.py,visualdl/server/dist
\ No newline at end of file max-line-length = 120
max-doc-length = 120
max-complexity = 15
...@@ -32,7 +32,7 @@ jobs: ...@@ -32,7 +32,7 @@ jobs:
python-version: '3.8' python-version: '3.8'
- name: Install requirements - name: Install requirements
run: | run: |
echo "TODO" pip install --disable-pip-version-check flake8
- name: Lint - name: Lint
run: | run: |
echo "TODO" flake8
- repo: https://github.com/pre-commit/mirrors-yapf.git repos:
sha: v0.16.0 - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
hooks: hooks:
- id: yapf - id: check-added-large-files
files: \.py$ - id: check-case-conflict
- id: check-executables-have-shebangs
- repo: https://github.com/pre-commit/pre-commit-hooks - id: check-merge-conflict
sha: a11d9314b22d8f8c7556443875b731ef05965464 - id: check-json
- id: check-toml
- id: check-yaml
- id: check-symlinks
- id: destroyed-symlinks
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
hooks: hooks:
- id: check-merge-conflict - id: flake8
- id: check-symlinks alias: be
- id: end-of-file-fixer exclude: ^frontend/
- id: trailing-whitespace - repo: local
- id: detect-private-key
- id: check-symlinks
- id: check-added-large-files
- repo: local
hooks: hooks:
- id: python-format-checker - id: fe
name: python-format-checker name: lint-staged
description: Format python files using PEP8 standard entry: cd frontend && lint-staged
entry: flake8 language: node
language: system files: ^frontend/
files: \.(py)$ pass_filenames: false
description: lint frontend code by lint-staged
- repo: local - repo: meta
hooks: hooks:
- id: eslint-format-checker - id: check-hooks-apply
name: eslint-format-checker - id: check-useless-excludes
description: Format files with ESLint.
entry: bash ./.eslint_format.hook
language: system
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
from scipy.io import wavfile from scipy.io import wavfile
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
import numpy as np import numpy as np
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from visualdl import LogWriter from visualdl import LogWriter
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
import numpy as np import numpy as np
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
import numpy as np import numpy as np
...@@ -21,7 +23,7 @@ with LogWriter("./log/roc_curve_test/train") as writer: ...@@ -21,7 +23,7 @@ with LogWriter("./log/roc_curve_test/train") as writer:
labels = np.random.randint(2, size=100) labels = np.random.randint(2, size=100)
predictions = np.random.rand(100) predictions = np.random.rand(100)
writer.add_roc_curve(tag='roc_curve', writer.add_roc_curve(tag='roc_curve',
labels=labels, labels=labels,
predictions=predictions, predictions=predictions,
step=step, step=step,
num_thresholds=5) num_thresholds=5)
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
# coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
if __name__ == '__main__': if __name__ == '__main__':
...@@ -20,4 +23,3 @@ if __name__ == '__main__': ...@@ -20,4 +23,3 @@ if __name__ == '__main__':
for step in range(1000): for step in range(1000):
writer.add_scalar(tag="acc", step=step, value=value[step]) writer.add_scalar(tag="acc", step=step, value=value[step])
writer.add_scalar(tag="loss", step=step, value=1/(value[step] + 1)) writer.add_scalar(tag="loss", step=step, value=1/(value[step] + 1))
...@@ -2,3 +2,4 @@ dist ...@@ -2,3 +2,4 @@ dist
out out
wasm wasm
output output
license-header.js
...@@ -39,7 +39,7 @@ To stop VisualDL server, just type ...@@ -39,7 +39,7 @@ To stop VisualDL server, just type
visualdl stop visualdl stop
``` ```
For more usage infomation, please type For more usage information, please type
```bash ```bash
visualdl -h visualdl -h
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"packages": [ "packages": [
"packages/*" "packages/*"
], ],
"version": "2.0.9", "version": "2.1.0-1",
"npmClient": "yarn", "npmClient": "yarn",
"useWorkspaces": true, "useWorkspaces": true,
"command": { "command": {
......
...@@ -38,15 +38,14 @@ ...@@ -38,15 +38,14 @@
"version": "yarn format && git add -A" "version": "yarn format && git add -A"
}, },
"devDependencies": { "devDependencies": {
"@typescript-eslint/eslint-plugin": "4.10.0", "@typescript-eslint/eslint-plugin": "4.11.0",
"@typescript-eslint/parser": "4.10.0", "@typescript-eslint/parser": "4.11.0",
"eslint": "7.15.0", "eslint": "7.16.0",
"eslint-config-prettier": "7.0.0", "eslint-config-prettier": "7.1.0",
"eslint-plugin-license-header": "0.2.0", "eslint-plugin-license-header": "0.2.0",
"eslint-plugin-prettier": "3.3.0", "eslint-plugin-prettier": "3.3.0",
"eslint-plugin-react": "7.21.5", "eslint-plugin-react": "7.21.5",
"eslint-plugin-react-hooks": "4.2.0", "eslint-plugin-react-hooks": "4.2.0",
"husky": "4.3.6",
"lerna": "3.22.1", "lerna": "3.22.1",
"lint-staged": "10.5.3", "lint-staged": "10.5.3",
"prettier": "2.2.1", "prettier": "2.2.1",
...@@ -57,10 +56,5 @@ ...@@ -57,10 +56,5 @@
"engines": { "engines": {
"node": ">=10", "node": ">=10",
"npm": ">=6" "npm": ">=6"
},
"husky": {
"hooks": {
"pre-commit": "lint-staged"
}
} }
} }
{ {
"name": "@visualdl/cli", "name": "@visualdl/cli",
"version": "2.0.9", "version": "2.1.0-1",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
"visualdl", "visualdl",
...@@ -34,14 +34,14 @@ ...@@ -34,14 +34,14 @@
"dist" "dist"
], ],
"dependencies": { "dependencies": {
"@visualdl/server": "2.0.9", "@visualdl/server": "2.1.0-1",
"open": "7.3.0", "open": "7.3.0",
"ora": "5.1.0", "ora": "5.1.0",
"pm2": "4.5.0", "pm2": "4.5.1",
"yargs": "16.2.0" "yargs": "16.2.0"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "14.14.14", "@types/node": "14.14.16",
"@types/yargs": "15.0.12", "@types/yargs": "15.0.12",
"cross-env": "7.0.3", "cross-env": "7.0.3",
"ts-node": "9.1.1", "ts-node": "9.1.1",
......
{ {
"name": "@visualdl/core", "name": "@visualdl/core",
"version": "2.0.9", "version": "2.1.0-1",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
"visualdl", "visualdl",
...@@ -35,13 +35,13 @@ ...@@ -35,13 +35,13 @@
], ],
"dependencies": { "dependencies": {
"@tippyjs/react": "4.2.0", "@tippyjs/react": "4.2.0",
"@visualdl/netron": "2.0.9", "@visualdl/netron": "2.1.0-1",
"@visualdl/wasm": "2.0.9", "@visualdl/wasm": "2.1.0-1",
"bignumber.js": "9.0.1", "bignumber.js": "9.0.1",
"d3": "6.3.1", "d3": "6.3.1",
"d3-format": "2.0.0", "d3-format": "2.0.0",
"echarts": "4.9.0", "echarts": "4.9.0",
"echarts-gl": "1.1.1", "echarts-gl": "1.1.2",
"eventemitter3": "4.0.7", "eventemitter3": "4.0.7",
"file-saver": "2.0.5", "file-saver": "2.0.5",
"i18next": "19.8.4", "i18next": "19.8.4",
...@@ -68,8 +68,8 @@ ...@@ -68,8 +68,8 @@
"react-toastify": "6.2.0", "react-toastify": "6.2.0",
"redux": "4.0.5", "redux": "4.0.5",
"styled-components": "5.2.1", "styled-components": "5.2.1",
"swr": "0.3.9", "swr": "0.3.11",
"three": "0.123.0", "three": "0.124.0",
"tippy.js": "6.2.7", "tippy.js": "6.2.7",
"umap-js": "1.3.3" "umap-js": "1.3.3"
}, },
...@@ -93,7 +93,7 @@ ...@@ -93,7 +93,7 @@
"@types/file-saver": "2.0.1", "@types/file-saver": "2.0.1",
"@types/jest": "26.0.19", "@types/jest": "26.0.19",
"@types/loadable__component": "5.13.1", "@types/loadable__component": "5.13.1",
"@types/lodash": "4.14.165", "@types/lodash": "4.14.166",
"@types/mime-types": "2.1.0", "@types/mime-types": "2.1.0",
"@types/nprogress": "0.2.0", "@types/nprogress": "0.2.0",
"@types/numeric": "1.2.1", "@types/numeric": "1.2.1",
...@@ -101,20 +101,20 @@ ...@@ -101,20 +101,20 @@
"@types/react-dom": "17.0.0", "@types/react-dom": "17.0.0",
"@types/react-helmet": "6.1.0", "@types/react-helmet": "6.1.0",
"@types/react-rangeslider": "2.2.3", "@types/react-rangeslider": "2.2.3",
"@types/react-redux": "7.1.12", "@types/react-redux": "7.1.14",
"@types/react-router-dom": "5.1.6", "@types/react-router-dom": "5.1.6",
"@types/snowpack-env": "2.3.2", "@types/snowpack-env": "2.3.3",
"@types/styled-components": "5.1.5", "@types/styled-components": "5.1.7",
"@visualdl/mock": "2.0.9", "@visualdl/mock": "2.1.0-1",
"babel-plugin-styled-components": "1.12.0", "babel-plugin-styled-components": "1.12.0",
"dotenv": "8.2.0", "dotenv": "8.2.0",
"enhanced-resolve": "5.4.0", "enhanced-resolve": "5.4.1",
"express": "4.17.1", "express": "4.17.1",
"fs-extra": "9.0.1", "fs-extra": "9.0.1",
"html-minifier": "4.0.0", "html-minifier": "4.0.0",
"http-proxy-middleware": "1.0.6", "http-proxy-middleware": "1.0.6",
"jest": "26.6.3", "jest": "26.6.3",
"snowpack": "2.18.4", "snowpack": "2.18.5",
"typescript": "4.0.5", "typescript": "4.0.5",
"yargs": "16.2.0" "yargs": "16.2.0"
}, },
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
...@@ -77,7 +77,7 @@ const ChartOperations: FunctionComponent<ChartOperationsProps> = ({onReset}) => ...@@ -77,7 +77,7 @@ const ChartOperations: FunctionComponent<ChartOperationsProps> = ({onReset}) =>
return ( return (
<Operations> <Operations>
<Tippy content={t('high-dimensional:selection')} placement="bottom" theme="tooltip"> {/* <Tippy content={t('high-dimensional:selection')} placement="bottom" theme="tooltip">
<a> <a>
<span> <span>
<Icon type="selection" /> <Icon type="selection" />
...@@ -90,7 +90,7 @@ const ChartOperations: FunctionComponent<ChartOperationsProps> = ({onReset}) => ...@@ -90,7 +90,7 @@ const ChartOperations: FunctionComponent<ChartOperationsProps> = ({onReset}) =>
<Icon type="three-d" /> <Icon type="three-d" />
</span> </span>
</a> </a>
</Tippy> </Tippy> */}
<Tippy content={t('high-dimensional:reset-zoom')} placement="bottom" theme="tooltip"> <Tippy content={t('high-dimensional:reset-zoom')} placement="bottom" theme="tooltip">
<a onClick={() => onReset?.()}> <a onClick={() => onReset?.()}>
<span> <span>
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
import type {CalculateParams, CalculateResult, Reduction} from '~/resource/high-dimensional'; import type {CalculateParams, CalculateResult, Reduction, Shape} from '~/resource/high-dimensional';
import React, {useCallback, useEffect, useImperativeHandle, useLayoutEffect, useMemo, useRef, useState} from 'react'; import React, {useCallback, useEffect, useImperativeHandle, useLayoutEffect, useMemo, useRef, useState} from 'react';
import ScatterChart, {ScatterChartRef} from '~/components/ScatterChart'; import ScatterChart, {ScatterChartRef} from '~/components/ScatterChart';
import ChartOperations from '~/components/HighDimensionalPage/ChartOperations'; import ChartOperations from '~/components/HighDimensionalPage/ChartOperations';
import type {InfoData} from '~/worker/high-dimensional/tsne'; import type {InfoData} from '~/worker/high-dimensional/calculate';
import type {WithStyled} from '~/utils/style'; import type {WithStyled} from '~/utils/style';
import {rem} from '~/utils/style'; import {rem} from '~/utils/style';
import styled from 'styled-components'; import styled from 'styled-components';
...@@ -67,12 +67,14 @@ const Chart = styled.div` ...@@ -67,12 +67,14 @@ const Chart = styled.div`
type HighDimensionalChartProps = { type HighDimensionalChartProps = {
vectors: Float32Array; vectors: Float32Array;
labels: string[]; labels: string[];
shape: Shape;
dim: number; dim: number;
is3D: boolean; is3D: boolean;
reduction: Reduction; reduction: Reduction;
perplexity: number; perplexity: number;
learningRate: number; learningRate: number;
neighbors: number; neighbors: number;
focusedIndices?: number[];
highlightIndices?: number[]; highlightIndices?: number[];
onCalculate?: () => unknown; onCalculate?: () => unknown;
onCalculated?: (data: CalculateResult) => unknown; onCalculated?: (data: CalculateResult) => unknown;
...@@ -91,12 +93,14 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen ...@@ -91,12 +93,14 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
{ {
vectors, vectors,
labels, labels,
shape,
dim, dim,
is3D, is3D,
reduction, reduction,
perplexity, perplexity,
learningRate, learningRate,
neighbors, neighbors,
focusedIndices,
highlightIndices, highlightIndices,
onCalculate, onCalculate,
onCalculated, onCalculated,
...@@ -113,8 +117,6 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen ...@@ -113,8 +117,6 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
const [width, setWidth] = useState(0); const [width, setWidth] = useState(0);
const [height, setHeight] = useState(0); const [height, setHeight] = useState(0);
const points = useMemo(() => Math.floor(vectors.length / dim), [vectors, dim]);
useLayoutEffect(() => { useLayoutEffect(() => {
const c = chartElement.current; const c = chartElement.current;
if (c) { if (c) {
...@@ -251,11 +253,11 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen ...@@ -251,11 +253,11 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
<div className="info"> <div className="info">
{t('high-dimensional:points')} {t('high-dimensional:points')}
{t('common:colon')} {t('common:colon')}
{points} {shape[0]}
<span className="sep"></span> <span className="sep"></span>
{t('high-dimensional:data-dimension')} {t('high-dimensional:data-dimension')}
{t('common:colon')} {t('common:colon')}
{dim} {shape[1]}
</div> </div>
<ChartOperations onReset={() => chart.current?.reset()} /> <ChartOperations onReset={() => chart.current?.reset()} />
</Toolbar> </Toolbar>
...@@ -268,7 +270,8 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen ...@@ -268,7 +270,8 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
labels={labels} labels={labels}
is3D={is3D} is3D={is3D}
rotate={reduction !== 'tsne'} rotate={reduction !== 'tsne'}
highlightIndices={highlightIndices ?? []} focusedIndices={focusedIndices}
highlightIndices={highlightIndices}
/> />
</Chart> </Chart>
</Wrapper> </Wrapper>
......
...@@ -63,19 +63,20 @@ const List = styled.ul` ...@@ -63,19 +63,20 @@ const List = styled.ul`
type LabelSearchResultProps = { type LabelSearchResultProps = {
list: string[]; list: string[];
onHovered?: (index?: number) => unknown;
}; };
const LabelSearchResult: FunctionComponent<LabelSearchResultProps> = ({list}) => { const LabelSearchResult: FunctionComponent<LabelSearchResultProps> = ({list, onHovered}) => {
const {t} = useTranslation('high-dimensional'); const {t} = useTranslation('high-dimensional');
if (!list.length) { if (!list.length) {
return <Empty>{t('high-dimensional:search-empty')}</Empty>; return <Empty>{t('high-dimensional:search-empty')}</Empty>;
} }
return ( return (
<List> <List onMouseLeave={() => onHovered?.()}>
{list.map((label, index) => ( {list.map((label, index) => (
<li key={index}> <li key={index}>
<a>{label}</a> <a onMouseEnter={() => onHovered?.(index)}>{label}</a>
</li> </li>
))} ))}
</List> </List>
......
...@@ -198,7 +198,7 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({ ...@@ -198,7 +198,7 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({
// clear cache if tag or run changed // clear cache if tag or run changed
useEffect(() => { useEffect(() => {
Object.values(cached.current).forEach(({timer}) => clearTimeout(timer)); Object.values(cached.current).forEach(({timer}) => window.clearTimeout(timer));
cached.current = {}; cached.current = {};
}, [tag, run]); }, [tag, run]);
...@@ -211,7 +211,7 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({ ...@@ -211,7 +211,7 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({
const url = getUrl(type, step, run.label, tag, wallTime); const url = getUrl(type, step, run.label, tag, wallTime);
cached.current[step] = { cached.current[step] = {
src: url, src: url,
timer: setTimeout(() => { timer: window.setTimeout(() => {
((s: number) => delete cached.current[s])(step); ((s: number) => delete cached.current[s])(step);
}, cache) }, cache)
}; };
...@@ -227,10 +227,10 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({ ...@@ -227,10 +227,10 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({
// first load, return immediately // first load, return immediately
cacheSrc(); cacheSrc();
} else { } else {
timer.current = setTimeout(cacheSrc, 500); timer.current = window.setTimeout(cacheSrc, 500);
return () => { return () => {
if (timer.current != null) { if (timer.current != null) {
clearTimeout(timer.current); window.clearTimeout(timer.current);
timer.current = null; timer.current = null;
} }
}; };
......
...@@ -36,7 +36,8 @@ export type ScatterChartProps = { ...@@ -36,7 +36,8 @@ export type ScatterChartProps = {
labels: string[]; labels: string[];
is3D: boolean; is3D: boolean;
rotate?: boolean; rotate?: boolean;
highlightIndices: number[]; focusedIndices?: number[];
highlightIndices?: number[];
}; };
export type ScatterChartRef = { export type ScatterChartRef = {
...@@ -44,7 +45,7 @@ export type ScatterChartRef = { ...@@ -44,7 +45,7 @@ export type ScatterChartRef = {
}; };
const ScatterChart = React.forwardRef<ScatterChartRef, ScatterChartProps & WithStyled>( const ScatterChart = React.forwardRef<ScatterChartRef, ScatterChartProps & WithStyled>(
({width, height, data, labels, is3D, rotate, highlightIndices, className}, ref) => { ({width, height, data, labels, is3D, rotate, focusedIndices, highlightIndices, className}, ref) => {
const theme = useTheme(); const theme = useTheme();
const element = useRef<HTMLDivElement>(null); const element = useRef<HTMLDivElement>(null);
...@@ -77,7 +78,11 @@ const ScatterChart = React.forwardRef<ScatterChartRef, ScatterChartProps & WithS ...@@ -77,7 +78,11 @@ const ScatterChart = React.forwardRef<ScatterChartRef, ScatterChartProps & WithS
}, [data, labels]); }, [data, labels]);
useEffect(() => { useEffect(() => {
chart.current?.setHighLightIndices(highlightIndices); chart.current?.setFocusedPointIndices(focusedIndices ?? []);
}, [focusedIndices]);
useEffect(() => {
chart.current?.setHighLightIndices(highlightIndices ?? []);
}, [highlightIndices]); }, [highlightIndices]);
useEffect(() => { useEffect(() => {
......
...@@ -103,11 +103,15 @@ export default class ScatterChart { ...@@ -103,11 +103,15 @@ export default class ScatterChart {
static ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 2; static ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 2;
static NUM_POINTS_FOG_THRESHOLD = 5000; static NUM_POINTS_FOG_THRESHOLD = 5000;
static POINT_COLOR_NO_SELECTION = 0x7e7e7e; static POINT_COLOR_DEFAULT = new THREE.Color(0x7e7e7e);
static POINT_COLOR_HOVER = 0x2932e1; static POINT_COLOR_HOVER = new THREE.Color(0x2932e1);
static POINT_COLOR_HIGHLIGHT = new THREE.Color(0x2932e1);
static POINT_COLOR_FOCUS = new THREE.Color(0x2932e1);
static POINT_SCALE_DEFAULT = 1.0; static POINT_SCALE_DEFAULT = 1.0;
static POINT_SCALE_HOVER = 1.2; static POINT_SCALE_HOVER = 1.2;
static POINT_SCALE_HIGHLIGHT = 1.0;
static POINT_SCALE_FOCUS = 1.2;
static PERSP_CAMERA_INIT_POSITION: Point3D = [0.45, 0.9, 1.6]; static PERSP_CAMERA_INIT_POSITION: Point3D = [0.45, 0.9, 1.6];
static ORTHO_CAMERA_INIT_POSITION: Point3D = [0, 0, 4]; static ORTHO_CAMERA_INIT_POSITION: Point3D = [0, 0, 4];
...@@ -124,6 +128,7 @@ export default class ScatterChart { ...@@ -124,6 +128,7 @@ export default class ScatterChart {
private renderer: THREE.WebGLRenderer; private renderer: THREE.WebGLRenderer;
private camera: THREE.OrthographicCamera | THREE.PerspectiveCamera; private camera: THREE.OrthographicCamera | THREE.PerspectiveCamera;
private controls: OrbitControls; private controls: OrbitControls;
private fog: THREE.Fog;
private pickingTexture: THREE.WebGLRenderTarget; private pickingTexture: THREE.WebGLRenderTarget;
private geometry: THREE.BufferGeometry; private geometry: THREE.BufferGeometry;
private renderMaterial: THREE.ShaderMaterial | null = null; private renderMaterial: THREE.ShaderMaterial | null = null;
...@@ -134,6 +139,7 @@ export default class ScatterChart { ...@@ -134,6 +139,7 @@ export default class ScatterChart {
private scaleFactors: Float32Array | null = null; private scaleFactors: Float32Array | null = null;
private axes: THREE.AxesHelper | null = null; private axes: THREE.AxesHelper | null = null;
private points: THREE.Points | null = null; private points: THREE.Points | null = null;
private focusedPointIndices: number[] = [];
private hoveredPointIndices: number[] = []; private hoveredPointIndices: number[] = [];
private label: ScatterChartLabel; private label: ScatterChartLabel;
private highLightPointIndices: number[] = []; private highLightPointIndices: number[] = [];
...@@ -163,6 +169,7 @@ export default class ScatterChart { ...@@ -163,6 +169,7 @@ export default class ScatterChart {
this.camera = this.initCamera(); this.camera = this.initCamera();
this.renderer = this.initRenderer(); this.renderer = this.initRenderer();
this.controls = this.initControls(); this.controls = this.initControls();
this.fog = this.initFog();
this.pickingTexture = this.initRenderTarget(); this.pickingTexture = this.initRenderTarget();
this.geometry = this.createGeometry(); this.geometry = this.createGeometry();
...@@ -190,7 +197,6 @@ export default class ScatterChart { ...@@ -190,7 +197,6 @@ export default class ScatterChart {
const light = new THREE.PointLight(16772287, 1, 0); const light = new THREE.PointLight(16772287, 1, 0);
light.name = 'light'; light.name = 'light';
scene.add(light); scene.add(light);
scene.fog = new THREE.Fog(0xffffff);
return scene; return scene;
} }
...@@ -255,6 +261,10 @@ export default class ScatterChart { ...@@ -255,6 +261,10 @@ export default class ScatterChart {
return controls; return controls;
} }
private initFog() {
return new THREE.Fog(this.background);
}
private initRenderTarget() { private initRenderTarget() {
const renderCanvasSize = new THREE.Vector2(); const renderCanvasSize = new THREE.Vector2();
this.renderer.getSize(renderCanvasSize); this.renderer.getSize(renderCanvasSize);
...@@ -355,17 +365,23 @@ export default class ScatterChart { ...@@ -355,17 +365,23 @@ export default class ScatterChart {
const count = this.data.length; const count = this.data.length;
const colors = new Float32Array(count * 3); const colors = new Float32Array(count * 3);
let dst = 0; let dst = 0;
const color = new THREE.Color(ScatterChart.POINT_COLOR_NO_SELECTION);
const hoveredColor = new THREE.Color(ScatterChart.POINT_COLOR_HOVER);
for (let i = 0; i < count; i++) { for (let i = 0; i < count; i++) {
if (i === this.hoveredPointIndices[0] || this.highLightPointIndices.includes(i)) { if (this.hoveredPointIndices.includes(i)) {
colors[dst++] = hoveredColor.r; colors[dst++] = ScatterChart.POINT_COLOR_HOVER.r;
colors[dst++] = hoveredColor.g; colors[dst++] = ScatterChart.POINT_COLOR_HOVER.g;
colors[dst++] = hoveredColor.b; colors[dst++] = ScatterChart.POINT_COLOR_HOVER.b;
} else if (this.focusedPointIndices.includes(i)) {
colors[dst++] = ScatterChart.POINT_COLOR_FOCUS.r;
colors[dst++] = ScatterChart.POINT_COLOR_FOCUS.g;
colors[dst++] = ScatterChart.POINT_COLOR_FOCUS.b;
} else if (this.highLightPointIndices.includes(i)) {
colors[dst++] = ScatterChart.POINT_COLOR_HIGHLIGHT.r;
colors[dst++] = ScatterChart.POINT_COLOR_HIGHLIGHT.g;
colors[dst++] = ScatterChart.POINT_COLOR_HIGHLIGHT.b;
} else { } else {
colors[dst++] = color.r; colors[dst++] = ScatterChart.POINT_COLOR_DEFAULT.r;
colors[dst++] = color.g; colors[dst++] = ScatterChart.POINT_COLOR_DEFAULT.g;
colors[dst++] = color.b; colors[dst++] = ScatterChart.POINT_COLOR_DEFAULT.b;
} }
} }
return colors; return colors;
...@@ -375,8 +391,12 @@ export default class ScatterChart { ...@@ -375,8 +391,12 @@ export default class ScatterChart {
const count = this.data.length; const count = this.data.length;
const scaleFactor = new Float32Array(count); const scaleFactor = new Float32Array(count);
for (let i = 0; i < count; i++) { for (let i = 0; i < count; i++) {
if (i === this.hoveredPointIndices[0]) { if (this.hoveredPointIndices.includes(i)) {
scaleFactor[i] = ScatterChart.POINT_SCALE_HOVER; scaleFactor[i] = ScatterChart.POINT_SCALE_HOVER;
} else if (this.focusedPointIndices.includes(i)) {
scaleFactor[i] = ScatterChart.POINT_SCALE_FOCUS;
} else if (this.highLightPointIndices.includes(i)) {
scaleFactor[i] = ScatterChart.POINT_SCALE_HIGHLIGHT;
} else { } else {
scaleFactor[i] = ScatterChart.POINT_SCALE_DEFAULT; scaleFactor[i] = ScatterChart.POINT_SCALE_DEFAULT;
} }
...@@ -441,21 +461,26 @@ export default class ScatterChart { ...@@ -441,21 +461,26 @@ export default class ScatterChart {
} }
private updateHoveredLabels() { private updateHoveredLabels() {
if (this.hoveredPointIndices[0] == null || !this.camera || !this.positions) { if (!this.camera || !this.positions) {
return;
}
const indices = this.focusedPointIndices.length ? this.focusedPointIndices : this.hoveredPointIndices;
if (!indices.length) {
this.label.clear(); this.label.clear();
return; return;
} }
const dpr = window.devicePixelRatio || 1; const dpr = window.devicePixelRatio || 1;
const w = this.width; const w = this.width;
const h = this.height; const h = this.height;
const index = this.hoveredPointIndices[0]; const labels = indices.map(index => {
const pi = index * 3; const pi = index * 3;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const point = new THREE.Vector3(this.positions![pi], this.positions![pi + 1], this.positions![pi + 2]); const point = new THREE.Vector3(this.positions![pi], this.positions![pi + 1], this.positions![pi + 2]);
const pv = new THREE.Vector3().copy(point).project(this.camera); const pv = new THREE.Vector3().copy(point).project(this.camera);
const coord: Point2D = [((pv.x + 1) / 2) * w * dpr, -(((pv.y - 1) / 2) * h) * dpr]; const coord: Point2D = [((pv.x + 1) / 2) * w * dpr, -(((pv.y - 1) / 2) * h) * dpr];
const labels = [ return {
{
text: this.labels[index] ?? '', text: this.labels[index] ?? '',
fontSize: 40, fontSize: 40,
fillColor: '#000', fillColor: '#000',
...@@ -463,12 +488,75 @@ export default class ScatterChart { ...@@ -463,12 +488,75 @@ export default class ScatterChart {
opacity: 1, opacity: 1,
x: coord[0] + 4, x: coord[0] + 4,
y: coord[1] y: coord[1]
} };
]; });
this.label.render(labels); this.label.render(labels);
} }
private updateLight() {
const light = this.scene.getObjectByName('light') as THREE.PointLight;
const cameraPos = this.camera.position;
const lightPos = cameraPos.clone();
lightPos.x += 1;
lightPos.y += 1;
light.position.set(lightPos.x, lightPos.y, lightPos.z);
}
private updateFog() {
const fog = this.fog;
fog.color = new THREE.Color(this.background);
if (this.is3D && this.positions) {
const cameraPos = this.camera.position;
const cameraTarget = this.controls.target;
let shortestDist = Number.POSITIVE_INFINITY;
let furthestDist = 0;
const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos);
const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize();
const n = this.positions.length / 3;
let src = 0;
const p = new THREE.Vector3();
const camToPoint = new THREE.Vector3();
for (let i = 0; i < n; i++) {
p.x = this.positions[src++];
p.y = this.positions[src++];
p.z = this.positions[src++];
camToPoint.copy(p).sub(cameraPos);
const dist = camPlaneNormal.dot(camToPoint);
if (dist < 0) {
continue;
}
furthestDist = dist > furthestDist ? dist : furthestDist;
shortestDist = dist < shortestDist ? dist : shortestDist;
}
const multiplier =
2 - Math.min(n, ScatterChart.NUM_POINTS_FOG_THRESHOLD) / ScatterChart.NUM_POINTS_FOG_THRESHOLD;
fog.near = shortestDist;
fog.far = furthestDist * multiplier;
} else {
fog.near = Number.POSITIVE_INFINITY;
fog.far = Number.POSITIVE_INFINITY;
}
if (this.points) {
const material = this.points.material as THREE.ShaderMaterial;
material.uniforms.fogColor.value = fog.color;
material.uniforms.fogNear.value = fog.near;
material.uniforms.fogFar.value = fog.far;
}
this.scene.fog = fog;
}
private bindEventListeners() { private bindEventListeners() {
this.canvas?.addEventListener('mousemove', this.onMouseMoveBindThis); this.canvas?.addEventListener('mousemove', this.onMouseMoveBindThis);
} }
...@@ -498,55 +586,13 @@ export default class ScatterChart { ...@@ -498,55 +586,13 @@ export default class ScatterChart {
} }
private render() { private render() {
const light = this.scene.getObjectByName('light') as THREE.PointLight; this.updateLight();
const cameraPos = this.camera.position; this.updateFog();
const lightPos = cameraPos.clone();
lightPos.x += 1;
lightPos.y += 1;
light.position.set(lightPos.x, lightPos.y, lightPos.z);
// TODO: remake fog
// const fog = this.scene.fog as THREE.Fog | null;
// if (fog) {
// if (this.is3D) {
// const cameraTarget = this.controls.target ?? new THREE.Vector3();
// let shortestDist = Number.POSITIVE_INFINITY;
// let furthestDist = 0;
// const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos);
// const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize();
// const camToPoint = new THREE.Vector3();
// this.data.forEach(d => {
// camToPoint.copy(new THREE.Vector3(...d)).sub(cameraPos);
// const dist = camPlaneNormal.dot(camToPoint);
// if (dist < 0) {
// return;
// }
// furthestDist = dist > furthestDist ? dist : furthestDist;
// shortestDist = dist < shortestDist ? dist : shortestDist;
// });
// const multiplier =
// 2 -
// Math.min(this.data.length, ScatterChart.NUM_POINTS_FOG_THRESHOLD) /
// ScatterChart.NUM_POINTS_FOG_THRESHOLD;
// fog.near = shortestDist;
// fog.far = furthestDist * multiplier;
// } else {
// fog.near = Number.POSITIVE_INFINITY;
// fog.far = Number.POSITIVE_INFINITY;
// }
// console.log(fog.near, fog.far, fog.color);
// if (this.points) {
// const material = this.points.material as THREE.ShaderMaterial;
// material.uniforms.fogColor.value = fog.color;
// material.uniforms.fogNear.value = fog.near;
// material.uniforms.fogFar.value = fog.far;
// }
// }
this.updateHoveredPoints(); this.updateHoveredPoints();
this.updateHoveredLabels(); this.updateHoveredLabels();
this.updatePointsAttribute(); this.updatePointsAttribute();
if (this.pickingMaterial) { if (this.pickingMaterial) {
if (this.axes) { if (this.axes) {
this.scene.remove(this.axes); this.scene.remove(this.axes);
...@@ -659,6 +705,7 @@ export default class ScatterChart { ...@@ -659,6 +705,7 @@ export default class ScatterChart {
setLabels(labels: string[]) { setLabels(labels: string[]) {
this.labels = labels; this.labels = labels;
this.render();
} }
setHighLightIndices(highLightIndices: number[]) { setHighLightIndices(highLightIndices: number[]) {
...@@ -666,6 +713,11 @@ export default class ScatterChart { ...@@ -666,6 +713,11 @@ export default class ScatterChart {
this.render(); this.render();
} }
setFocusedPointIndices(focusedPointIndices: number[]) {
this.focusedPointIndices = focusedPointIndices;
this.render();
}
dispose() { dispose() {
this.removeEventListeners(); this.removeEventListeners();
this.label.dispose(); this.label.dispose();
......
...@@ -21,6 +21,7 @@ import type { ...@@ -21,6 +21,7 @@ import type {
ParseParams, ParseParams,
ParseResult, ParseResult,
Reduction, Reduction,
Shape,
TSNEResult, TSNEResult,
UMAPResult UMAPResult
} from '~/resource/high-dimensional'; } from '~/resource/high-dimensional';
...@@ -180,6 +181,7 @@ const HighDimensional: FunctionComponent = () => { ...@@ -180,6 +181,7 @@ const HighDimensional: FunctionComponent = () => {
const [metadata, setMetadata] = useState<string[][]>([]); const [metadata, setMetadata] = useState<string[][]>([]);
// dimension of data // dimension of data
const [dim, setDim] = useState<number>(0); const [dim, setDim] = useState<number>(0);
const [rawShape, setRawShape] = useState<Shape>([0, 0]);
const getLabelByLabels = useCallback( const getLabelByLabels = useCallback(
(value: string | undefined) => { (value: string | undefined) => {
if (value != null) { if (value != null) {
...@@ -272,6 +274,7 @@ const HighDimensional: FunctionComponent = () => { ...@@ -272,6 +274,7 @@ const HighDimensional: FunctionComponent = () => {
if (error) { if (error) {
showError(error); showError(error);
} else if (data) { } else if (data) {
setRawShape(data.rawShape);
setDim(data.dimension); setDim(data.dimension);
setVectors(data.vectors); setVectors(data.vectors);
setLabels(data.labels); setLabels(data.labels);
...@@ -339,6 +342,12 @@ const HighDimensional: FunctionComponent = () => { ...@@ -339,6 +342,12 @@ const HighDimensional: FunctionComponent = () => {
}; };
}, [getLabelByLabels, searchResult.labelBy, searchResult.value]); }, [getLabelByLabels, searchResult.labelBy, searchResult.value]);
const [hoveredIndices, setHoveredIndices] = useState<number[]>([]);
const hoverSearchResult = useCallback(
(index?: number) => setHoveredIndices(index == null ? [] : [searchedResult.indices[index]]),
[searchedResult.indices]
);
const detail = useMemo(() => { const detail = useMemo(() => {
switch (reduction) { switch (reduction) {
case 'pca': case 'pca':
...@@ -385,11 +394,9 @@ const HighDimensional: FunctionComponent = () => { ...@@ -385,11 +394,9 @@ const HighDimensional: FunctionComponent = () => {
<Field label={t('high-dimensional:select-label')}> <Field label={t('high-dimensional:select-label')}>
<FullWidthSelect list={labels} value={labelBy} onChange={setLabelBy} /> <FullWidthSelect list={labels} value={labelBy} onChange={setLabelBy} />
</Field> </Field>
{metadataFile && ( {/* <Field label={t('high-dimensional:select-color')}>
<Field label={t('high-dimensional:select-color')}> <FullWidthSelect />
<FullWidthSelect /> </Field> */}
</Field>
)}
<Field> <Field>
<FullWidthButton rounded outline type="primary" onClick={() => setUploadModal(true)}> <FullWidthButton rounded outline type="primary" onClick={() => setUploadModal(true)}>
{t('high-dimensional:upload-data')} {t('high-dimensional:upload-data')}
...@@ -416,7 +423,7 @@ const HighDimensional: FunctionComponent = () => { ...@@ -416,7 +423,7 @@ const HighDimensional: FunctionComponent = () => {
</AsideSection> </AsideSection>
</RightAside> </RightAside>
), ),
[t, dataPath, reduction, dimension, metadataFile, labels, labelBy, embeddingList, selectedEmbeddingName, detail] [t, dataPath, reduction, dimension, labels, labelBy, embeddingList, selectedEmbeddingName, detail]
); );
const leftAside = useMemo( const leftAside = useMemo(
...@@ -438,12 +445,12 @@ const HighDimensional: FunctionComponent = () => { ...@@ -438,12 +445,12 @@ const HighDimensional: FunctionComponent = () => {
</AsideSection> </AsideSection>
<AsideSection className="search-result"> <AsideSection className="search-result">
<Field> <Field>
<LabelSearchResult list={searchedResult.metadata} /> <LabelSearchResult list={searchedResult.metadata} onHovered={hoverSearchResult} />
</Field> </Field>
</AsideSection> </AsideSection>
</LeftAside> </LeftAside>
), ),
[labels, searchResult.value, searchedResult.metadata, t] [hoverSearchResult, labels, searchResult.value, searchedResult.metadata, t]
); );
return ( return (
...@@ -456,12 +463,14 @@ const HighDimensional: FunctionComponent = () => { ...@@ -456,12 +463,14 @@ const HighDimensional: FunctionComponent = () => {
ref={chart} ref={chart}
vectors={vectors} vectors={vectors}
labels={labelByLabels} labels={labelByLabels}
shape={rawShape}
dim={dim} dim={dim}
is3D={is3D} is3D={is3D}
reduction={reduction} reduction={reduction}
perplexity={perplexity} perplexity={perplexity}
learningRate={learningRate} learningRate={learningRate}
neighbors={neighbors} neighbors={neighbors}
focusedIndices={hoveredIndices}
highlightIndices={searchedResult.indices} highlightIndices={searchedResult.indices}
onCalculate={calculate} onCalculate={calculate}
onCalculated={calculated} onCalculated={calculated}
......
...@@ -23,6 +23,7 @@ export type { ...@@ -23,6 +23,7 @@ export type {
PCAParams, PCAParams,
PCAResult, PCAResult,
Reduction, Reduction,
Shape,
TSNEParams, TSNEParams,
TSNEResult, TSNEResult,
UMAPParams, UMAPParams,
......
...@@ -14,7 +14,14 @@ ...@@ -14,7 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */
import type {MetadataResult, ParseFromBlobParams, ParseFromStringParams, ParseResult, VectorResult} from './types'; import type {
MetadataResult,
ParseFromBlobParams,
ParseFromStringParams,
ParseResult,
Shape,
VectorResult
} from './types';
import {safeSplit} from '~/utils'; import {safeSplit} from '~/utils';
...@@ -64,21 +71,65 @@ function alignItems<T>(data: T[][], dimension: number, defaultValue: T): T[][] { ...@@ -64,21 +71,65 @@ function alignItems<T>(data: T[][], dimension: number, defaultValue: T): T[][] {
}); });
} }
function parseVectors(str: string, maxCount?: number, maxDimension?: number): VectorResult { // Fisher-Yates (aka Knuth) Shuffle
function shuffleIndices(length: number) {
const indices = Array.from<number>({length}).map((_, i) => i);
let currentIndex = length;
while (currentIndex) {
const randomIndex = Math.floor(Math.random() * currentIndex);
currentIndex--;
const temporaryValue = indices[currentIndex];
indices[currentIndex] = indices[randomIndex];
indices[randomIndex] = temporaryValue;
}
return indices;
}
function shuffle(data: ParseResult, maxCount?: number, maxDimension?: number): ParseResult {
const {
count: originalCount,
dimension: originalDimension,
vectors: originalVectors,
metadata: originalMetadata
} = data;
const count = maxCount ? Math.min(originalCount, maxCount) : originalCount;
const dimension = maxDimension ? Math.min(originalDimension, maxDimension) : originalDimension;
const shuffledIndices = shuffleIndices(originalCount);
const shuffledDimensionIndices = shuffleIndices(originalDimension);
const vectors = new Float32Array(count * dimension);
const metadata: string[][] = [];
for (let c = 0; c < count; c++) {
const offset = shuffledIndices[c] * originalDimension;
if (dimension < originalDimension) {
for (let i = 0; i < dimension; i++) {
vectors[c + i] = originalVectors[offset + shuffledDimensionIndices[i]];
}
} else {
vectors.set(originalVectors.subarray(offset, offset + dimension), c * dimension);
}
metadata.push(originalMetadata[shuffledIndices[c]]);
}
return {
...data,
count,
dimension,
vectors,
metadata
};
}
function parseVectors(str: string): VectorResult {
if (!str) { if (!str) {
throw new ParserError('Tenser file is empty', ParserError.CODES.TENSER_EMPTY); throw new ParserError('Tenser file is empty', ParserError.CODES.TENSER_EMPTY);
} }
const rawShape: Shape = [0, 0];
let vectors = split(str, Number.parseFloat); let vectors = split(str, Number.parseFloat);
// TODO: random sampling rawShape[0] = vectors.length;
if (maxCount) { const dimension = Math.min(...vectors.map(vector => vector.length));
vectors = vectors.slice(0, maxCount); rawShape[1] = dimension;
}
let dimension = Math.min(...vectors.map(vector => vector.length));
if (maxDimension) {
dimension = Math.min(dimension, maxDimension);
}
vectors = alignItems(vectors, dimension, 0); vectors = alignItems(vectors, dimension, 0);
return { return {
rawShape,
dimension, dimension,
count: vectors.length, count: vectors.length,
vectors: new Float32Array(vectors.flat()) vectors: new Float32Array(vectors.flat())
...@@ -132,6 +183,7 @@ function genMetadataAndLabels(metadata: string, count: number) { ...@@ -132,6 +183,7 @@ function genMetadataAndLabels(metadata: string, count: number) {
export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension}: ParseFromStringParams): ParseResult { export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension}: ParseFromStringParams): ParseResult {
const result: ParseResult = { const result: ParseResult = {
rawShape: [0, 0],
count: 0, count: 0,
dimension: 0, dimension: 0,
vectors: new Float32Array(), vectors: new Float32Array(),
...@@ -139,14 +191,10 @@ export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension ...@@ -139,14 +191,10 @@ export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension
metadata: [] metadata: []
}; };
if (v) { if (v) {
const {dimension, vectors, count} = parseVectors(v, maxCount, maxDimension); Object.assign(result, parseVectors(v));
result.dimension = dimension; Object.assign(result, genMetadataAndLabels(m, result.count));
result.vectors = vectors;
const metadataAndLabels = genMetadataAndLabels(m, count);
metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count);
Object.assign(result, metadataAndLabels);
} }
return result; return shuffle(result, maxCount, maxDimension);
} }
export async function parseFromBlob({ export async function parseFromBlob({
...@@ -156,25 +204,20 @@ export async function parseFromBlob({ ...@@ -156,25 +204,20 @@ export async function parseFromBlob({
maxCount, maxCount,
maxDimension maxDimension
}: ParseFromBlobParams): Promise<ParseResult> { }: ParseFromBlobParams): Promise<ParseResult> {
// TODO: random sampling const [count, dimension] = shape;
const [originalCount, originalDimension] = shape; const vectors = new Float32Array(await v.arrayBuffer());
const originalVectors = new Float32Array(await v.arrayBuffer()); if (count * dimension !== vectors.length) {
if (originalCount * originalDimension !== originalVectors.length) {
throw new ParserError('Size of tensor does not match.', ParserError.CODES.SHAPE_MISMATCH); throw new ParserError('Size of tensor does not match.', ParserError.CODES.SHAPE_MISMATCH);
} }
const count = maxCount ? Math.min(originalCount, maxCount) : originalCount; return shuffle(
const dimension = maxDimension ? Math.min(originalDimension, maxDimension) : originalDimension; {
const vectors = new Float32Array(count * dimension); rawShape: shape,
for (let c = 0; c < count; c++) { count,
const offset = c * originalDimension; dimension,
vectors.set(originalVectors.subarray(offset, offset + dimension), c * dimension); vectors,
} ...genMetadataAndLabels(m, count)
const metadataAndLabels = genMetadataAndLabels(m, originalCount); },
metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count); maxCount,
return { maxDimension
count, );
dimension,
vectors,
...metadataAndLabels
};
} }
...@@ -50,7 +50,7 @@ ...@@ -50,7 +50,7 @@
* You can find the PDF [here](http://jmlr.csail.mit.edu/papers/volume9/vandermaaten08a/vandermaaten08a.pdf). * You can find the PDF [here](http://jmlr.csail.mit.edu/papers/volume9/vandermaaten08a/vandermaaten08a.pdf).
*/ */
// cSpell:words Maaten Andrej Karpathy Karpathy's Guass randn xtod premult dists gainid // cSpell:words Maaten Andrej Karpathy Karpathy's randn xtod premult dists gainid
export type tSNEOptions = { export type tSNEOptions = {
perplexity?: number; perplexity?: number;
...@@ -58,7 +58,7 @@ export type tSNEOptions = { ...@@ -58,7 +58,7 @@ export type tSNEOptions = {
dimension: number; dimension: number;
}; };
class GuassRandom { class GaussRandom {
private returnV = false; private returnV = false;
private vVal = 0.0; private vVal = 0.0;
...@@ -85,7 +85,7 @@ export default class tSNE { ...@@ -85,7 +85,7 @@ export default class tSNE {
epsilon = 10; epsilon = 10;
perplexity = 30; perplexity = 30;
private Random = new GuassRandom(); private Random = new GaussRandom();
private data = new Float32Array(); private data = new Float32Array();
private P: Float32Array = new Float32Array(); private P: Float32Array = new Float32Array();
......
...@@ -19,7 +19,10 @@ export type Reduction = 'pca' | 'tsne' | 'umap'; ...@@ -19,7 +19,10 @@ export type Reduction = 'pca' | 'tsne' | 'umap';
export type Vectors = [number, number, number][]; export type Vectors = [number, number, number][];
export type Shape = [number, number];
export type VectorResult = { export type VectorResult = {
rawShape: Shape;
dimension: number; dimension: number;
count: number; count: number;
vectors: Float32Array; vectors: Float32Array;
...@@ -42,7 +45,7 @@ export interface ParseFromStringParams extends BaseParseParams { ...@@ -42,7 +45,7 @@ export interface ParseFromStringParams extends BaseParseParams {
} }
export interface ParseFromBlobParams extends BaseParseParams { export interface ParseFromBlobParams extends BaseParseParams {
shape: [number, number]; shape: Shape;
vectors: Blob; vectors: Blob;
} }
...@@ -58,6 +61,7 @@ export type ParseParams = ...@@ -58,6 +61,7 @@ export type ParseParams =
| null; | null;
export type ParseResult = { export type ParseResult = {
rawShape: Shape;
count: number; count: number;
dimension: number; dimension: number;
vectors: Float32Array; vectors: Float32Array;
......
...@@ -9,12 +9,7 @@ ...@@ -9,12 +9,7 @@
], ],
"extends": "@snowpack/app-scripts-react/tsconfig.base.json", "extends": "@snowpack/app-scripts-react/tsconfig.base.json",
"compilerOptions": { "compilerOptions": {
// You can't currently define paths in your 'extends' config,
// so we have to set 'baseUrl' & 'paths' here.
// Don't change these unless you know what you're doing.
// See: https://github.com/microsoft/TypeScript/issues/25430
"baseUrl": "./", "baseUrl": "./",
/* more strict checking for errors that per-file transpilers like `esbuild` would crash */
"isolatedModules": true, "isolatedModules": true,
"paths": { "paths": {
"*": [ "*": [
...@@ -24,8 +19,6 @@ ...@@ -24,8 +19,6 @@
"./src/*" "./src/*"
] ]
}, },
// Feel free to add/edit new config options below:
// ...
"lib": [ "lib": [
"DOM", "DOM",
"DOM.Iterable", "DOM.Iterable",
......
{ {
"name": "@visualdl/demo", "name": "@visualdl/demo",
"version": "2.0.9", "version": "2.1.0-1",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
"visualdl", "visualdl",
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
"devDependencies": { "devDependencies": {
"@types/express": "4.17.9", "@types/express": "4.17.9",
"@types/mkdirp": "1.0.1", "@types/mkdirp": "1.0.1",
"@types/node": "14.14.14", "@types/node": "14.14.16",
"@types/node-fetch": "2.5.7", "@types/node-fetch": "2.5.7",
"@types/rimraf": "3.0.0", "@types/rimraf": "3.0.0",
"cpy-cli": "3.1.1", "cpy-cli": "3.1.1",
......
{ {
"name": "@visualdl/mock", "name": "@visualdl/mock",
"version": "2.0.9", "version": "2.1.0-1",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
"visualdl", "visualdl",
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
"devDependencies": { "devDependencies": {
"@types/express": "4.17.9", "@types/express": "4.17.9",
"@types/faker": "5.1.5", "@types/faker": "5.1.5",
"@types/node": "14.14.14", "@types/node": "14.14.16",
"cpy-cli": "3.1.1", "cpy-cli": "3.1.1",
"rimraf": "3.0.2", "rimraf": "3.0.2",
"ts-node": "9.1.1", "ts-node": "9.1.1",
......
{ {
"name": "@visualdl/netron", "name": "@visualdl/netron",
"version": "2.0.9", "version": "2.1.0-1",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
"visualdl", "visualdl",
...@@ -51,8 +51,8 @@ ...@@ -51,8 +51,8 @@
"sass": "1.30.0", "sass": "1.30.0",
"sass-loader": "10.1.0", "sass-loader": "10.1.0",
"terser": "5.5.1", "terser": "5.5.1",
"webpack": "5.10.3", "webpack": "5.11.0",
"webpack-cli": "4.2.0" "webpack-cli": "4.3.0"
}, },
"engines": { "engines": {
"node": ">=12", "node": ">=12",
......
{ {
"name": "@visualdl/server", "name": "@visualdl/server",
"version": "2.0.9", "version": "2.1.0-1",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
"visualdl", "visualdl",
...@@ -36,25 +36,25 @@ ...@@ -36,25 +36,25 @@
"ecosystem.config.d.ts" "ecosystem.config.d.ts"
], ],
"dependencies": { "dependencies": {
"@visualdl/core": "2.0.9", "@visualdl/core": "2.1.0-1",
"dotenv": "8.2.0", "dotenv": "8.2.0",
"enhanced-resolve": "5.4.0", "enhanced-resolve": "5.4.1",
"express": "4.17.1", "express": "4.17.1",
"http-proxy-middleware": "1.0.6", "http-proxy-middleware": "1.0.6",
"pm2": "4.5.0" "pm2": "4.5.1"
}, },
"devDependencies": { "devDependencies": {
"@types/enhanced-resolve": "3.0.6", "@types/enhanced-resolve": "3.0.6",
"@types/express": "4.17.9", "@types/express": "4.17.9",
"@types/node": "14.14.14", "@types/node": "14.14.16",
"@visualdl/mock": "2.0.9", "@visualdl/mock": "2.1.0-1",
"cross-env": "7.0.3", "cross-env": "7.0.3",
"nodemon": "2.0.6", "nodemon": "2.0.6",
"ts-node": "9.1.1", "ts-node": "9.1.1",
"typescript": "4.0.5" "typescript": "4.0.5"
}, },
"optionalDependencies": { "optionalDependencies": {
"@visualdl/demo": "2.0.9" "@visualdl/demo": "2.1.0-1"
}, },
"engines": { "engines": {
"node": ">=12", "node": ">=12",
......
{ {
"name": "@visualdl/wasm", "name": "@visualdl/wasm",
"version": "2.0.9", "version": "2.1.0-1",
"title": "VisualDL", "title": "VisualDL",
"description": "A platform to visualize the deep learning process and result.", "description": "A platform to visualize the deep learning process and result.",
"keywords": [ "keywords": [
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
set -e set -e
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd $SCRIPT_DIR/..
# rust toolchain # rust toolchain
# https://rustup.rs/ # https://rustup.rs/
if ! hash rustup 2>/dev/null; then if ! hash rustup 2>/dev/null; then
......
此差异已折叠。
# Build script for Windows
# Requires: PowerShell, Python 3, Visual Studio 15 2017
$TOP_DIR=$pwd.PATH
$FRONTEND_DIR="$TOP_DIR/frontend"
$BACKEND_DIR="$TOP_DIR/visualdl"
$BUILD_DIR="$TOP_DIR/build"
mkdir $BUILD_DIR -ErrorAction Ignore
function check_duplicated($filename_format) {
$files = ls dist/$filename_format -ErrorAction Ignore
if (Get-TypeName($files.Length) -ne "FileInfo") {
Write-Error "dist have duplicate file for $filename_format, please clean and rerun"
exit(1)
}
}
function build_frontend_from_source() {
cd $FRONTEND_DIR
# TODO:
# ./scripts/build.sh
}
function build_frontend() {
$PACKAGE_NAME="@visualdl/serverless"
$SRC=npm view $PACKAGE_NAME dist.tarball
Invoke-WebRequest -Uri "$SRC" -OutFile "$BUILD_DIR/$PACKAGE_NAME.tar.gz"
# Need Windows 10 Insider Build 17063 and later
tar -zxf "$BUILD_DIR/$PACKAGE_NAME.tar.gz" -C "$BUILD_DIR"
}
function build_frontend_fake() {
mkdir -p "$BUILD_DIR/package/dist"
}
function build_backend() {
cd $BUILD_DIR
cmake -G "Visual Studio 15 2017 Win64" `
-DCMAKE_BUILD_TYPE=Release `
-DON_RELEASE=ON `
-DWITH_PYTHON3=ON `
-DWITH_TESTING=OFF `
-DBUILD_FOR_HOST:BOOL=YES ..
cmake --build . --config Release
}
function build_onnx_graph() {
$env:PATH = "$BUILD_DIR/third_party/protobuf/src/extern_protobuf-build/Release;" + $env:PATH
cd $TOP_DIR/visualdl/server/model/onnx
protoc onnx.proto --python_out .
cd $TOP_DIR/visualdl/server/model/paddle
protoc framework.proto --python_out .
}
function clean_env() {
rm -Recurse -Force -ErrorAction Ignore $TOP_DIR/visualdl/server/dist
rm -Recurse -Force -ErrorAction Ignore $BUILD_DIR/bdist*
rm -Recurse -Force -ErrorAction Ignore $BUILD_DIR/lib*
rm -Recurse -Force -ErrorAction Ignore $BUILD_DIR/temp*
rm -Recurse -Force -ErrorAction Ignore $BUILD_DIR/scripts*
rm -Recurse -Force -ErrorAction Ignore $BUILD_DIR/*.tar.gz
rm -Recurse -Force -ErrorAction Ignore $BUILD_DIR/package
}
function package() {
cp -Recurse $BUILD_DIR/package/dist $TOP_DIR/visualdl/server/
cp $BUILD_DIR/visualdl/logic/Release/core.pyd $TOP_DIR/visualdl
cp $BUILD_DIR/visualdl/logic/Release/core.pyd $TOP_DIR/visualdl/python/
}
clean_env
build_frontend
build_backend
build_onnx_graph
package
#!/bin/bash
function abort(){
echo "Your change doesn't follow VisualDL's code style." 1>&2
echo "Please use pre-commit to check what is wrong." 1>&2
exit 1
}
trap 'abort' 0
set -e
cd $TRAVIS_BUILD_DIR
export PATH=/usr/bin:$PATH
pre-commit install
clang-format --version
flake8 --version
if ! pre-commit run -a ; then
git diff
exit 1
fi
trap : 0
#!/usr/bin/env bash
exit_code=0
if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit $exit_code; fi;
if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then
PPO_SCRIPT_BRANCH=develop
elif [[ "$TRAVIS_BRANCH" == "develop" ]]; then
PPO_SCRIPT_BRANCH=master
else
exit $exit_code;
fi
export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh
docker run -it \
-e CONTENT_DEC_PASSWD=$CONTENT_DEC_PASSWD \
-e TRAVIS_BRANCH=$TRAVIS_BRANCH \
-e DEPLOY_DOCS_SH=$DEPLOY_DOCS_SH \
-e TRAVIS_PULL_REQUEST=$TRAVIS_PULL_REQUEST \
-e PPO_SCRIPT_BRANCH=$PPO_SCRIPT_BRANCH \
-e PADDLE_ROOT=/VisualDL \
-v "$PWD:/VisualDL" \
-w /VisualDL \
paddlepaddle/paddle:latest-dev \
/bin/bash -c 'curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH /VisualDL /VisualDL/build/doc/ $PPO_SCRIPT_BRANCH' || exit_code=$(( exit_code | $? ))
#!/bin/bash
set -e
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd $SCRIPT_DIR/..
which python > /dev/null 2>&1
if [ $? -ne 0 ]; then
echo "You need to install Python 3.5+"
exit 1
fi
python --version
which node > /dev/null 2>&1
if [ $? -ne 0 ]; then
echo "You need to install nodejs 14+"
exit 1
fi
node --version
pip install pre-commit && pre-commit install
pip install --disable-pip-version-check -r requirements.txt
(cd frontend && scripts/install.sh)
#!/bin/bash
set -e
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd $SCRIPT_DIR/..
mkdir -p build
cd build
which node >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "You need to install nodejs"
exit 1
fi
echo "Setting up nodejs dependencies"
npm install visualdl@latest --no-package-lock
processors=1
if [ "$(uname)" == "Darwin" ]; then
processors=`sysctl -n hw.ncpu`
elif [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
processors=`nproc`
fi
echo "Building VisualDL SDK"
cmake ..
make -j $processors
export PYTHONPATH=$PYTHONPATH:"$SCRIPT_DIR/.."
#!/bin/bash
set -e
ORIGINAL_ARGS="$@"
ARGS=`getopt --long host:,port: -n 'start_dev_server.sh' -- "$@"`
if [ $? != 0 ]; then
echo "Get arguments failed!" >&2
cd $CWD
exit 1
fi
PORT=8040
HOST="localhost"
eval set -- "$ARGS"
while true
do
case "$1" in
--host)
HOST="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--)
shift
break
;;
*)
break
;;
esac
done
CURRENT_DIR=`pwd`
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export PYTHONPATH=$PYTHONPATH:"$SCRIPT_DIR/.."
FRONTEND_PORT=8999
VDL_BIN="./build/node_modules/.bin/visualdl"
$VDL_BIN start --port=$FRONTEND_PORT --host=$HOST --backend="http://$HOST:$PORT"
function finish {
$VDL_BIN stop
}
trap finish EXIT HUP INT QUIT PIPE TERM
cd $CURRENT_DIR
echo "Development server ready on http://$HOST:$FRONTEND_PORT"
# Run the visualDL with local PATH
python ${SCRIPT_DIR}/../visualdl/server/app.py "$ORIGINAL_ARGS"
#!/bin/bash
set -ex
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd $SCRIPT_DIR/..
mode=$1
readonly TOP_DIR=$(pwd)
readonly core_path=$TOP_DIR/build/visualdl/logic
readonly python_path=$TOP_DIR/visualdl/python
readonly max_file_size=1000000 # 1MB
# version number follow the rule of https://semver.org/
readonly version_number=`cat VERSION_NUMBER | sed 's/\([0-9]*.[0-9]*.[0-9]*\).*/\1/g'`
sudo="sudo"
pip="pip"
python="python2"
if [[ "$WITH_PYTHON3" == "ON" ]]; then
pip="pip3"
python="python3"
$sudo python3 -m pip install --upgrade pip
fi
if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then sudo=""; fi
if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then
curl -O http://python-distribute.org/distribute_setup.py
$python distribute_setup.py
curl -O https://raw.github.com/pypa/pip/master/contrib/get-pip.py
$python get-pip.py
fi
$sudo $pip install numpy
$sudo $pip install Flask
$sudo $pip install Pillow
$sudo $pip install --ignore-installed six
$sudo $pip install protobuf
export PYTHONPATH="${core_path}:${python_path}"
# install the visualdl wheel first
package() {
cd $TOP_DIR/visualdl/server
# manully install protobuf3
curl -OL https://github.com/google/protobuf/releases/download/v3.5.0/protoc-3.5.0-linux-x86_64.zip
unzip protoc-3.5.0-linux-x86_64.zip -d protoc3
export PATH="$PATH:$(pwd)/protoc3/bin"
chmod +x protoc3/bin/*
cd $TOP_DIR
$python setup.py bdist_wheel
$sudo $pip install dist/visualdl-${version_number}*.whl
}
backend_test() {
cd $TOP_DIR
mkdir -p build
cd build
if [[ "$WITH_PYTHON3" == "ON" ]]; then
cmake -DWITH_PYTHON3=ON ..
else
cmake ..
fi
make
make test
}
server_test() {
$sudo $pip install google
$sudo $pip install protobuf==3.5.1
cd $TOP_DIR/visualdl/server
bash graph_test.sh
cd $TOP_DIR/
$python -m visualdl.server.lib_test
}
# check the size of files in the repo.
# reject PR that has some big data included.
bigfile_reject() {
cd $TOP_DIR
# it failed to exclude .git, remove it first.
rm -rf .git
local largest_file=$(find . -path .git -prune -o -printf '%s %p\n' | sort -nr | grep -v "CycleGAN"| head -n1)
local size=$(echo "$largest_file" | awk '{print $1}')
if [ "$size" -ge "$max_file_size" ]; then
echo $largest_file
echo "file size exceed $max_file_size"
echo "Should not add large data or binary file."
exit -1
fi
}
clean_env() {
rm -rf $TOP_DIR/build
rm -rf $TOP_DIR/dist
rm -rf $TOP_DIR/visualdl/core.so
rm -rf $TOP_DIR/visualdl/python/core.so
rm -rf $TOP_DIR/visualdl/server/protoc3
rm -rf $TOP_DIR/visualdl/server/protoc*.zip
}
echo "mode" $mode
if [ $mode = "backend" ]; then
backend_test
elif [ $mode = "all" ]; then
# Clean before test
clean_env
# bigfile_reject should be tested first, or some files downloaded may fail this test.
bigfile_reject
package
backend_test
server_test
elif [ $mode = "local" ]; then
backend_test
server_test
fi
# This file is used by clang-format to autoformat paddle source code
#
# The clang-format is part of llvm toolchain.
# It need to install llvm and clang to format source code style.
#
# The basic usage is,
# clang-format -i -style=file PATH/TO/SOURCE/CODE
#
# The -style=file implicit use ".clang-format" file located in one of
# parent directory.
# The -i means inplace change.
#
# The document of clang-format is
# http://clang.llvm.org/docs/ClangFormat.html
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
---
Language: Cpp
BasedOnStyle: Google
IndentWidth: 2
TabWidth: 2
ContinuationIndentWidth: 4
AccessModifierOffset: -2 # The private/protected/public has no indent in class
Standard: Cpp11
AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false
BinPackArguments: false
...
...@@ -17,9 +17,9 @@ from __future__ import absolute_import ...@@ -17,9 +17,9 @@ from __future__ import absolute_import
import os import os
from visualdl.writer.writer import LogWriter # noqa from visualdl.writer.writer import LogWriter # noqa: F401
from visualdl.reader.reader import LogReader from visualdl.reader.reader import LogReader # noqa: F401
from visualdl.version import vdl_version as __version__ from visualdl.version import vdl_version as __version__ # noqa: F401
from visualdl.utils.dir import init_vdl_config from visualdl.utils.dir import init_vdl_config
init_vdl_config() init_vdl_config()
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
components = { components = {
"scalar": { "scalar": {
"enabled": False "enabled": False
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
from visualdl.proto.record_pb2 import Record from visualdl.proto.record_pb2 import Record
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -80,11 +81,11 @@ def imgarray2bytes(np_array): ...@@ -80,11 +81,11 @@ def imgarray2bytes(np_array):
return img_bin return img_bin
def make_grid(I, ncols=8): def make_grid(I, ncols=8): # noqa: E741
assert isinstance( assert isinstance(
I, np.ndarray), 'plugin error, should pass numpy array here' I, np.ndarray), 'plugin error, should pass numpy array here'
if I.shape[1] == 1: if I.shape[1] == 1:
I = np.concatenate([I, I, I], 1) I = np.concatenate([I, I, I], 1) # noqa: E741
assert I.ndim == 4 and I.shape[1] == 3 or I.shape[1] == 4 assert I.ndim == 4 and I.shape[1] == 3 or I.shape[1] == 4
nimg = I.shape[0] nimg = I.shape[0]
H = I.shape[2] H = I.shape[2]
...@@ -404,6 +405,8 @@ def pr_curve_raw(tag, tp, fp, tn, fn, precision, recall, step, walltime): ...@@ -404,6 +405,8 @@ def pr_curve_raw(tag, tp, fp, tn, fn, precision, recall, step, walltime):
Record.Value( Record.Value(
id=step, tag=tag, timestamp=walltime, pr_curve=prcurve) id=step, tag=tag, timestamp=walltime, pr_curve=prcurve)
]) ])
def compute_roc_curve(labels, predictions, num_thresholds=None, weights=None): def compute_roc_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute ROC curve data by labels and predictions. """ Compute ROC curve data by labels and predictions.
Args: Args:
...@@ -454,8 +457,7 @@ def compute_roc_curve(labels, predictions, num_thresholds=None, weights=None): ...@@ -454,8 +457,7 @@ def compute_roc_curve(labels, predictions, num_thresholds=None, weights=None):
return data return data
def roc_curve(tag, labels, predictions, step, walltime, num_thresholds=127, def roc_curve(tag, labels, predictions, step, walltime, num_thresholds=127, weights=None):
weights=None):
"""Package data to one roc_curve. """Package data to one roc_curve.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
...@@ -473,14 +475,14 @@ def roc_curve(tag, labels, predictions, step, walltime, num_thresholds=127, ...@@ -473,14 +475,14 @@ def roc_curve(tag, labels, predictions, step, walltime, num_thresholds=127,
roc_curve_map = compute_roc_curve(labels, predictions, num_thresholds, weights) roc_curve_map = compute_roc_curve(labels, predictions, num_thresholds, weights)
return roc_curve_raw(tag=tag, return roc_curve_raw(tag=tag,
tp=roc_curve_map['tp'], tp=roc_curve_map['tp'],
fp=roc_curve_map['fp'], fp=roc_curve_map['fp'],
tn=roc_curve_map['tn'], tn=roc_curve_map['tn'],
fn=roc_curve_map['fn'], fn=roc_curve_map['fn'],
tpr=roc_curve_map['tpr'], tpr=roc_curve_map['tpr'],
fpr=roc_curve_map['fpr'], fpr=roc_curve_map['fpr'],
step=step, step=step,
walltime=walltime) walltime=walltime)
def roc_curve_raw(tag, tp, fp, tn, fn, tpr, fpr, step, walltime): def roc_curve_raw(tag, tp, fp, tn, fn, tpr, fpr, step, walltime):
...@@ -516,13 +518,12 @@ def roc_curve_raw(tag, tp, fp, tn, fn, tpr, fpr, step, walltime): ...@@ -516,13 +518,12 @@ def roc_curve_raw(tag, tp, fp, tn, fn, tpr, fpr, step, walltime):
fpr = fpr.astype(int).tolist() fpr = fpr.astype(int).tolist()
""" """
roc_curve = Record.ROC_Curve(TP=tp, roc_curve = Record.ROC_Curve(TP=tp,
FP=fp, FP=fp,
TN=tn, TN=tn,
FN=fn, FN=fn,
tpr=tpr, tpr=tpr,
fpr=fpr) fpr=fpr)
return Record(values=[ return Record(values=[
Record.Value( Record.Value(
id=step, tag=tag, timestamp=walltime, roc_curve=roc_curve) id=step, tag=tag, timestamp=walltime, roc_curve=roc_curve)
]) ])
...@@ -13,4 +13,4 @@ ...@@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
from . import bfile # noqa from . import bfile # noqa: F401
...@@ -387,7 +387,7 @@ class BosFileSystem(object): ...@@ -387,7 +387,7 @@ class BosFileSystem(object):
content_md5=content_md5(file_content), content_md5=content_md5(file_content),
content_length=content_length, content_length=content_length,
offset=offset) offset=offset)
except (exception.BceServerError, exception.BceHttpClientError) as e: except (exception.BceServerError, exception.BceHttpClientError):
init_data = b'' init_data = b''
self.bos_client.append_object(bucket_name=bucket_name, self.bos_client.append_object(bucket_name=bucket_name,
key=object_key, key=object_key,
...@@ -634,7 +634,7 @@ class BFile(object): ...@@ -634,7 +634,7 @@ class BFile(object):
if isinstance(self.fs, BosFileSystem): if isinstance(self.fs, BosFileSystem):
try: try:
self.fs.append(self._filename, b'', self.binary_mode, force=True) self.fs.append(self._filename, b'', self.binary_mode, force=True)
except: except Exception:
pass pass
self.flush() self.flush()
if self.write_temp is not None: if self.write_temp is not None:
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import collections import collections
from functools import partial
from visualdl.io import bfile from visualdl.io import bfile
from visualdl.component import components from visualdl.component import components
from visualdl.reader.record_reader import RecordReader from visualdl.reader.record_reader import RecordReader
...@@ -34,7 +34,7 @@ def is_VDLRecord_file(path, check=False): ...@@ -34,7 +34,7 @@ def is_VDLRecord_file(path, check=False):
Returns: Returns:
True if the file is a VDL log file, otherwise false. True if the file is a VDL log file, otherwise false.
""" """
if not "vdlrecords" in path: if "vdlrecords" not in path:
return False return False
if check: if check:
_reader = RecordReader(filepath=path) _reader = RecordReader(filepath=path)
...@@ -72,7 +72,8 @@ class LogReader(object): ...@@ -72,7 +72,8 @@ class LogReader(object):
self.file_readers = {} self.file_readers = {}
# {'run': {'scalar': {'tag1': data, 'tag2': data}}} # {'run': {'scalar': {'tag1': data, 'tag2': data}}}
self._log_datas = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(list))) self._log_datas = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list)))
if file_path: if file_path:
self._log_data = collections.defaultdict(lambda: collections.defaultdict(list)) self._log_data = collections.defaultdict(lambda: collections.defaultdict(list))
...@@ -128,7 +129,9 @@ class LogReader(object): ...@@ -128,7 +129,9 @@ class LogReader(object):
return log_tags return log_tags
def get_log_data(self, component, run, tag): def get_log_data(self, component, run, tag):
if run in self._log_datas.keys() and component in self._log_datas[run].keys() and tag in self._log_datas[run][component].keys(): if (run in self._log_datas.keys() and
component in self._log_datas[run].keys() and
tag in self._log_datas[run][component].keys()):
return self._log_datas[run][component][tag] return self._log_datas[run][component][tag]
else: else:
file_path = bfile.join(run, self.walks[run]) file_path = bfile.join(run, self.walks[run])
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
from visualdl.io import bfile from visualdl.io import bfile
import struct import struct
......
...@@ -112,6 +112,7 @@ class Api(object): ...@@ -112,6 +112,7 @@ class Api(object):
@result() @result()
def pr_curve_tags(self): def pr_curve_tags(self):
return self._get_with_retry('data/plugin/pr_curves/tags', lib.get_pr_curve_tags) return self._get_with_retry('data/plugin/pr_curves/tags', lib.get_pr_curve_tags)
@result() @result()
def roc_curve_tags(self): def roc_curve_tags(self):
return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags) return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags)
...@@ -181,20 +182,26 @@ class Api(object): ...@@ -181,20 +182,26 @@ class Api(object):
def pr_curves_pr_curve(self, run, tag): def pr_curves_pr_curve(self, run, tag):
key = os.path.join('data/plugin/pr_curves/pr_curve', run, tag) key = os.path.join('data/plugin/pr_curves/pr_curve', run, tag)
return self._get_with_retry(key, lib.get_pr_curve, run, tag) return self._get_with_retry(key, lib.get_pr_curve, run, tag)
@result() @result()
def roc_curves_roc_curve(self, run, tag): def roc_curves_roc_curve(self, run, tag):
key = os.path.join('data/plugin/roc_curves/roc_curve', run, tag) key = os.path.join('data/plugin/roc_curves/roc_curve', run, tag)
return self._get_with_retry(key, lib.get_roc_curve, run, tag) return self._get_with_retry(key, lib.get_roc_curve, run, tag)
@result() @result()
def pr_curves_steps(self, run): def pr_curves_steps(self, run):
key = os.path.join('data/plugin/pr_curves/steps', run) key = os.path.join('data/plugin/pr_curves/steps', run)
return self._get_with_retry(key, lib.get_pr_curve_step, run) return self._get_with_retry(key, lib.get_pr_curve_step, run)
@result() @result()
def roc_curves_steps(self, run): def roc_curves_steps(self, run):
key = os.path.join('data/plugin/roc_curves/steps', run) key = os.path.join('data/plugin/roc_curves/steps', run)
return self._get_with_retry(key, lib.get_roc_curve_step, run) return self._get_with_retry(key, lib.get_roc_curve_step, run)
@result('application/octet-stream', lambda s: {"Content-Disposition": 'attachment; filename="%s"' % s.model_name} if len(s.model_name) else None) @result(
'application/octet-stream',
lambda s: {"Content-Disposition": 'attachment; filename="%s"' % s.model_name} if len(s.model_name) else None
)
def graph_graph(self): def graph_graph(self):
key = os.path.join('data/plugin/graphs/graph') key = os.path.join('data/plugin/graphs/graph')
return self._get_with_retry(key, lib.get_graph) return self._get_with_retry(key, lib.get_graph)
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
# ======================================================================= # =======================================================================
from __future__ import absolute_import from __future__ import absolute_import
from functools import partial # noqa: F401
import sys import sys
import time import time
import os import os
import io import io
import csv import csv
from functools import partial
import numpy as np import numpy as np
from visualdl.server.log import logger from visualdl.server.log import logger
from visualdl.io import bfile from visualdl.io import bfile
...@@ -198,8 +198,8 @@ def get_pr_curve(log_reader, run, tag): ...@@ -198,8 +198,8 @@ def get_pr_curve(log_reader, run, tag):
list(pr_curve.FN), list(pr_curve.FN),
num_thresholds]) num_thresholds])
return results return results
def get_roc_curve(log_reader, run, tag): def get_roc_curve(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
...@@ -225,7 +225,7 @@ def get_roc_curve(log_reader, run, tag): ...@@ -225,7 +225,7 @@ def get_roc_curve(log_reader, run, tag):
def get_pr_curve_step(log_reader, run, tag=None): def get_pr_curve_step(log_reader, run, tag=None):
fake_run = run fake_run = run
run = log_reader.name2tags[run] if run in log_reader.name2tags else run run = log_reader.name2tags[run] if run in log_reader.name2tags else run
run2tag = get_pr_curve_tags(log_reader) run2tag = get_pr_curve_tags(log_reader) # noqa: F821
tag = run2tag['tags'][run2tag['runs'].index(fake_run)][0] tag = run2tag['tags'][run2tag['runs'].index(fake_run)][0]
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("pr_curve").get_items( records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
...@@ -237,7 +237,7 @@ def get_pr_curve_step(log_reader, run, tag=None): ...@@ -237,7 +237,7 @@ def get_pr_curve_step(log_reader, run, tag=None):
def get_roc_curve_step(log_reader, run, tag=None): def get_roc_curve_step(log_reader, run, tag=None):
fake_run = run fake_run = run
run = log_reader.name2tags[run] if run in log_reader.name2tags else run run = log_reader.name2tags[run] if run in log_reader.name2tags else run
run2tag = get_roc_curve_tags(log_reader) run2tag = get_roc_curve_tags(log_reader) # noqa: F821
tag = run2tag['tags'][run2tag['runs'].index(fake_run)][0] tag = run2tag['tags'][run2tag['runs'].index(fake_run)][0]
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("roc_curve").get_items( records = log_reader.data_manager.get_reservoir("roc_curve").get_items(
...@@ -245,7 +245,7 @@ def get_roc_curve_step(log_reader, run, tag=None): ...@@ -245,7 +245,7 @@ def get_roc_curve_step(log_reader, run, tag=None):
results = [[s2ms(item.timestamp), item.id] for item in records] results = [[s2ms(item.timestamp), item.id] for item in records]
return results return results
def get_embeddings_list(log_reader): def get_embeddings_list(log_reader):
run2tag = get_logs(log_reader, 'embeddings') run2tag = get_logs(log_reader, 'embeddings')
...@@ -356,26 +356,3 @@ def cache_get(cache): ...@@ -356,26 +356,3 @@ def cache_get(cache):
return data return data
return _handler return _handler
def simple_pca(x, dimension):
"""
A simple PCA implementation to do the dimension reduction.
"""
# Center the data.
x -= np.mean(x, axis=0)
# Computing the Covariance Matrix
cov = np.cov(x, rowvar=False)
# Get eigenvectors and eigenvalues from the covariance matrix
eigvals, eigvecs = np.linalg.eig(cov)
# Sort the eigvals from high to low
order = np.argsort(eigvals)[::-1]
# Drop the eigenvectors with low eigenvalues
eigvecs = eigvecs[:, order[:dimension]]
return np.dot(x, eigvecs)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import requests import requests
import json import json
import os import os
......
#
# tsne.py
#
# Implementation of t-SNE in Python. The implementation was tested on Python
# 2.7.10, and it requires a working installation of NumPy. The implementation
# comes with an example on the MNIST dataset. In order to plot the
# results of this example, a working installation of matplotlib is required.
#
# The example can be run by executing: `ipython tsne.py`
#
#
# Created by Laurens van der Maaten on 20-12-08.
# Copyright (c) 2008 Tilburg University. All rights reserved.
# Editor's note: Thanks https://lvdmaaten.github.io/tsne/
# For allowing it to be free to use, modify, or redistribute for non-commercial purposes
# This file is modified to remove some print out.
import numpy as np
def Hbeta(D=np.array([]), beta=1.0):
"""
Compute the perplexity and the P-row for a specific value of the
precision of a Gaussian distribution.
"""
# Compute P-row and corresponding perplexity
P = np.exp(-D.copy() * beta)
sumP = sum(P)
H = np.log(sumP) + beta * np.sum(D * P) / sumP
P = P / sumP
return H, P
def x2p(X=np.array([]), tol=1e-5, perplexity=30.0):
"""
Performs a binary search to get P-values in such a way that each
conditional Gaussian has the same perplexity.
"""
# Initialize some variables
print("Computing pairwise distances...")
(n, d) = X.shape
sum_X = np.sum(np.square(X), 1)
D = np.add(np.add(-2 * np.dot(X, X.T), sum_X).T, sum_X)
P = np.zeros((n, n))
beta = np.ones((n, 1))
logU = np.log(perplexity)
# Loop over all datapoints
for i in range(n):
# Print progress
# if i % 500 == 0:
# print("Computing P-values for point %d of %d..." % (i, n))
# Compute the Gaussian kernel and entropy for the current precision
betamin = -np.inf
betamax = np.inf
Di = D[i, np.concatenate((np.r_[0:i], np.r_[i + 1:n]))]
(H, thisP) = Hbeta(Di, beta[i])
# Evaluate whether the perplexity is within tolerance
Hdiff = H - logU
tries = 0
while np.abs(Hdiff) > tol and tries < 50:
# If not, increase or decrease precision
if Hdiff > 0:
betamin = beta[i].copy()
if betamax == np.inf or betamax == -np.inf:
beta[i] = beta[i] * 2.
else:
beta[i] = (beta[i] + betamax) / 2.
else:
betamax = beta[i].copy()
if betamin == np.inf or betamin == -np.inf:
beta[i] = beta[i] / 2.
else:
beta[i] = (beta[i] + betamin) / 2.
# Recompute the values
(H, thisP) = Hbeta(Di, beta[i])
Hdiff = H - logU
tries += 1
# Set the final row of P
P[i, np.concatenate((np.r_[0:i], np.r_[i + 1:n]))] = thisP
# Return final P-matrix
print("Mean value of sigma: %f" % np.mean(np.sqrt(1 / beta)))
return P
def pca(X=np.array([]), no_dims=50):
"""
Runs PCA on the NxD array X in order to reduce its dimensionality to
no_dims dimensions.
"""
(n, d) = X.shape
X = X - np.tile(np.mean(X, 0), (n, 1))
(l, M) = np.linalg.eig(np.dot(X.T, X))
Y = np.dot(X, M[:, 0:no_dims])
return Y
def tsne(X=np.array([]), no_dims=2, initial_dims=50, perplexity=30.0):
"""
Runs t-SNE on the dataset in the NxD array X to reduce its
dimensionality to no_dims dimensions. The syntaxis of the function is
`Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.
"""
# Check inputs
if isinstance(no_dims, float):
print("Error: array X should have type float.")
return -1
if round(no_dims) != no_dims:
print("Error: number of dimensions should be an integer.")
return -1
# Initialize variables
X = pca(X, initial_dims).real
(n, d) = X.shape
max_iter = 1000
initial_momentum = 0.5
final_momentum = 0.8
eta = 500
min_gain = 0.01
Y = np.random.randn(n, no_dims)
dY = np.zeros((n, no_dims))
iY = np.zeros((n, no_dims))
gains = np.ones((n, no_dims))
# Compute P-values
P = x2p(X, 1e-5, perplexity)
P = P + np.transpose(P)
P = P / np.sum(P)
P = P * 4. # early exaggeration
P = np.maximum(P, 1e-12)
# Run iterations
for iter in range(max_iter):
# Compute pairwise affinities
sum_Y = np.sum(np.square(Y), 1)
num = -2. * np.dot(Y, Y.T)
num = 1. / (1. + np.add(np.add(num, sum_Y).T, sum_Y))
num[range(n), range(n)] = 0.
Q = num / np.sum(num)
Q = np.maximum(Q, 1e-12)
# Compute gradient
PQ = P - Q
for i in range(n):
dY[i, :] = np.sum(
np.tile(PQ[:, i] * num[:, i], (no_dims, 1)).T * (Y[i, :] - Y),
0)
# Perform the update
if iter < 20:
momentum = initial_momentum
else:
momentum = final_momentum
gains = (gains + 0.2) * ((dY > 0.) != (iY > 0.)) + \
(gains * 0.8) * ((dY > 0.) == (iY > 0.))
gains[gains < min_gain] = min_gain
iY = momentum * iY - eta * (gains * dY)
Y = Y + iY
Y = Y - np.tile(np.mean(Y, 0), (n, 1))
# Compute current value of cost function
# if (iter + 1) % 10 == 0:
# C = np.sum(P * np.log(P / Q))
# print("Iteration %d: error is %f" % (iter + 1, C))
# Stop lying about P-values
if iter == 100:
P = P / 4.
# Return solution
return Y
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import array import array
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
def encode_tag(tag): def encode_tag(tag):
return tag.replace("%", "/") return tag.replace("%", "/")
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import threading import threading
import hashlib import hashlib
import requests import requests
import json
from visualdl import __version__ from visualdl import __version__
from visualdl.proto.record_pb2 import DESCRIPTOR from visualdl.proto.record_pb2 import DESCRIPTOR
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
from visualdl.io import bfile from visualdl.io import bfile
from visualdl.utils.crc32 import masked_crc32c from visualdl.utils.crc32 import masked_crc32c
from visualdl.proto import record_pb2 from visualdl.proto import record_pb2
...@@ -95,7 +94,7 @@ class RecordFileWriter(object): ...@@ -95,7 +94,7 @@ class RecordFileWriter(object):
fn = "vdlrecords.%010d.log%s" % (time.time(), filename_suffix) fn = "vdlrecords.%010d.log%s" % (time.time(), filename_suffix)
self._file_name = bfile.join(logdir, fn) self._file_name = bfile.join(logdir, fn)
print( print(
'Since the log filename should contain `vdlrecords`, the filename is invalid and `{}` will replace `{}`'.format( 'Since the log filename should contain `vdlrecords`, the filename is invalid and `{}` will replace `{}`'.format( # noqa: E501
fn, filename)) fn, filename))
else: else:
self._file_name = bfile.join(logdir, "vdlrecords.%010d.log%s" % ( self._file_name = bfile.join(logdir, "vdlrecords.%010d.log%s" % (
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import os import os
import time import time
import numpy as np import numpy as np
...@@ -380,14 +381,15 @@ class LogWriter(object): ...@@ -380,14 +381,15 @@ class LogWriter(object):
num_thresholds=num_thresholds, num_thresholds=num_thresholds,
weights=weights weights=weights
)) ))
def add_roc_curve(self, def add_roc_curve(self,
tag, tag,
labels, labels,
predictions, predictions,
step, step,
num_thresholds=10, num_thresholds=10,
weights=None, weights=None,
walltime=None): walltime=None):
"""Add an ROC curve to vdl record file. """Add an ROC curve to vdl record file.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
...@@ -422,7 +424,6 @@ class LogWriter(object): ...@@ -422,7 +424,6 @@ class LogWriter(object):
weights=weights weights=weights
)) ))
def flush(self): def flush(self):
"""Flush all data in cache to disk. """Flush all data in cache to disk.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册