未验证 提交 ae2608f1 编写于 作者: P Peter Pan 提交者: GitHub

rewrite roc curve (#881)

* 增加 roc 曲线

所有更改参照pr曲线更改,在公式计算的地方更改为了 roc 的计算公式

* Update index.ts

语法改正

* Update reducers.ts

修改语法错误

* 生成record_pb2.py

编译生成record_pb2.py
修改 record.proto 序号问题

* Update base_component.py

更正参数

* Update api.py

修正语法错误

* add data/roc-curve

前端页面

* add resource

add resource

* Update lib.py

* docs: remove luckydraw

* feat: download raw data

* fix conflict

* chore: rewrite pr-curve & roc-curve

* fix: unexpected background color in dark mode
Co-authored-by: Niceriver97 <1105107356@qq.com>
上级 e7284deb
# Copyright (c) 2020 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
# coding=utf-8
from visualdl import LogWriter
import numpy as np
with LogWriter("./log/roc_curve_test/train") as writer:
for step in range(3):
labels = np.random.randint(2, size=100)
predictions = np.random.rand(100)
writer.add_roc_curve(tag='roc_curve',
labels=labels,
predictions=predictions,
step=step,
num_thresholds=5)
......@@ -47,6 +47,7 @@
"i18next": "19.8.4",
"i18next-browser-languagedetector": "6.0.1",
"i18next-fetch-backend": "3.0.0",
"jszip": "3.5.0",
"lodash": "4.17.20",
"mime-types": "2.1.27",
"moment": "2.29.1",
......
......@@ -12,8 +12,9 @@
"image": "Image",
"loading": "Please wait while loading data",
"next-page": "Next Page",
"pr-curve": "PR-Curve",
"pr-curve": "PR Curve",
"previous-page": "Prev Page",
"roc-curve": "ROC Curve",
"run": "Run",
"running": "Running",
"runs": "Runs",
......
{
"download-data": "Download data",
"download-image": "Download image",
"ignore-outliers": "Ignore outliers in chart scaling",
"max": "Max.",
......
......@@ -13,6 +13,7 @@
"loading": "数据载入中,请稍等",
"next-page": "下一页",
"pr-curve": "PR曲线",
"roc-curve": "ROC曲线",
"previous-page": "上一页",
"run": "运行",
"running": "运行中",
......
{
"download-data": "下载数据",
"download-image": "下载图片",
"ignore-outliers": "图表缩放时忽略极端值",
"max": "最大值",
......
......@@ -71,7 +71,6 @@ module.exports = {
clean: true
},
installOptions: {
polyfillNode: true,
namedExports: ['file-saver']
polyfillNode: true
}
};
......@@ -14,7 +14,7 @@
* limitations under the License.
*/
import React, {FunctionComponent, useCallback, useState} from 'react';
import React, {FunctionComponent, useCallback, useContext, useMemo, useState} from 'react';
import {WithStyled, em, rem, transitionProps} from '~/utils/style';
import Icon from '~/components/Icon';
......@@ -48,6 +48,28 @@ const ToolboxItem = styled.a<{active?: boolean}>`
}
`;
const ChartToolboxMenu = styled.div`
background-color: var(--background-color);
${transitionProps('background-color')};
> a {
cursor: pointer;
display: block;
padding: ${rem(10)};
background-color: var(--background-color);
${transitionProps(['color', 'background-color'])};
&:hover {
background-color: var(--background-focused-color);
}
}
`;
interface ChartToolboxItemChild {
label: string;
onClick?: () => unknown;
}
type BaseChartToolboxItem = {
icon: Icons;
tooltip?: string;
......@@ -65,12 +87,111 @@ type ToggleChartToolboxItem = {
onClick?: (value: boolean) => unknown;
} & BaseChartToolboxItem;
export type ChartToolboxItem = NormalChartToolboxItem | ToggleChartToolboxItem;
type MenuChartToolboxItem = {
toggle?: false;
tooltip: undefined;
menuList: ChartToolboxItemChild[];
} & BaseChartToolboxItem;
export type ChartToolboxItem = NormalChartToolboxItem | ToggleChartToolboxItem | MenuChartToolboxItem;
const ChartToolboxIcon = React.forwardRef<
HTMLAnchorElement,
{
toggle?: boolean;
icon: Icons;
activeIcon?: Icons;
activeStatus?: boolean;
onClick?: () => unknown;
}
>(({toggle, icon, activeIcon, activeStatus, onClick}, ref) => {
return (
<ToolboxItem ref={ref} active={toggle && !activeIcon && activeStatus} onClick={() => onClick?.()}>
<Icon type={toggle ? (activeStatus && activeIcon) || icon : icon} />
</ToolboxItem>
);
});
ChartToolboxIcon.displayName = 'ChartToolboxIcon';
type ChartToolboxItemProps = {
tooltipPlacement?: 'top' | 'bottom' | 'left' | 'right';
};
type ChartToolboxProps = {
items: ChartToolboxItem[];
reversed?: boolean;
tooltipPlacement?: 'top' | 'bottom' | 'left' | 'right';
} & ChartToolboxItemProps;
const ChartToolboxItemContext = React.createContext<ChartToolboxItemProps>({
tooltipPlacement: 'top'
});
const NormalChartToolbox: FunctionComponent<NormalChartToolboxItem> = ({icon, tooltip, onClick}) => {
const toolboxIcon = useMemo(() => <ChartToolboxIcon icon={icon} onClick={onClick} />, [icon, onClick]);
const {tooltipPlacement} = useContext(ChartToolboxItemContext);
return tooltip ? (
<Tippy content={tooltip} placement={tooltipPlacement || 'top'} theme="tooltip">
{toolboxIcon}
</Tippy>
) : (
toolboxIcon
);
};
const ToggleChartToolbox: FunctionComponent<ToggleChartToolboxItem> = ({
icon,
tooltip,
activeIcon,
activeTooltip,
onClick
}) => {
const [active, setActive] = useState(false);
const click = useCallback(() => {
setActive(a => {
onClick?.(!a);
return !a;
});
}, [onClick]);
const toolboxIcon = useMemo(
() => <ChartToolboxIcon icon={icon} activeIcon={activeIcon} activeStatus={active} toggle onClick={click} />,
[icon, activeIcon, active, click]
);
const {tooltipPlacement} = useContext(ChartToolboxItemContext);
return tooltip ? (
<Tippy content={(active && activeTooltip) || tooltip} placement={tooltipPlacement || 'top'} theme="tooltip">
{toolboxIcon}
</Tippy>
) : (
toolboxIcon
);
};
const MenuChartToolbox: FunctionComponent<MenuChartToolboxItem> = ({icon, menuList}) => {
return (
<Tippy
content={
<ChartToolboxMenu>
{menuList.map((item, index) => (
<a key={index} onClick={() => item.onClick?.()}>
{item.label}
</a>
))}
</ChartToolboxMenu>
}
placement="right-start"
animation="shift-away-subtle"
interactive
hideOnClick={false}
arrow={false}
role="menu"
theme="menu"
>
<ChartToolboxIcon icon={icon} />
</Tippy>
);
};
const ChartToolbox: FunctionComponent<ChartToolboxProps & WithStyled> = ({
......@@ -79,56 +200,25 @@ const ChartToolbox: FunctionComponent<ChartToolboxProps & WithStyled> = ({
reversed,
className
}) => {
const [activeStatus, setActiveStatus] = useState<boolean[]>(new Array(items.length).fill(false));
const onClick = useCallback(
(index: number) => {
const item = items[index];
if (item.toggle) {
item.onClick?.(!activeStatus[index]);
setActiveStatus(m => {
const n = [...m];
n.splice(index, 1, !m[index]);
return n;
});
} else {
item.onClick?.();
}
},
[items, activeStatus]
);
const getToolboxItem = useCallback(
(item: ChartToolboxItem, index: number) => (
<ToolboxItem
key={index}
active={item.toggle && !item.activeIcon && activeStatus[index]}
onClick={() => onClick(index)}
>
<Icon type={item.toggle ? (activeStatus[index] && item.activeIcon) || item.icon : item.icon} />
</ToolboxItem>
),
[activeStatus, onClick]
);
const contextValue = useMemo(() => ({tooltipPlacement}), [tooltipPlacement]);
return (
<>
<Toolbox className={className} size={items.length} reversed={reversed}>
{items.map((item, index) =>
item.tooltip ? (
<Tippy
content={
item.toggle ? (activeStatus[index] && item.activeTooltip) || item.tooltip : item.tooltip
}
placement={tooltipPlacement || 'top'}
theme="tooltip"
key={index}
>
{getToolboxItem(item, index)}
</Tippy>
) : (
getToolboxItem(item, index)
)
)}
<ChartToolboxItemContext.Provider value={contextValue}>
{items.map((item, index) => {
if ((item as MenuChartToolboxItem).menuList) {
const i = item as MenuChartToolboxItem;
return <MenuChartToolbox {...i} key={index} />;
}
if ((item as ToggleChartToolboxItem).toggle) {
const i = item as ToggleChartToolboxItem;
return <ToggleChartToolbox {...i} key={index} />;
}
const i = item as NormalChartToolboxItem;
return <NormalChartToolbox {...i} key={index} />;
})}
</ChartToolboxItemContext.Provider>
</Toolbox>
</>
);
......
......@@ -14,21 +14,16 @@
* limitations under the License.
*/
import ChartPage, {WithChart} from '~/components/ChartPage';
import type {Run as CurveRun, Tag as CurveTag, CurveType, StepInfo} from '~/resource/curves';
import React, {FunctionComponent, useCallback, useEffect, useMemo, useState} from 'react';
import type {Run, StepInfo, Tag} from '~/resource/pr-curve';
import {rem, transitionProps} from '~/utils/style';
import {AsideSection} from '~/components/Aside';
import Content from '~/components/Content';
import Error from '~/components/Error';
import Field from '~/components/Field';
import PRCurveChart from '~/components/PRCurvePage/PRCurveChart';
import RunAside from '~/components/RunAside';
import StepSlider from '~/components/PRCurvePage/StepSlider';
import StepSlider from '~/components/CurvesPage/StepSlider';
import TimeModeSelect from '~/components/TimeModeSelect';
import {TimeType} from '~/resource/pr-curve';
import Title from '~/components/Title';
import {TimeType} from '~/resource/curves';
import {cycleFetcher} from '~/utils/fetch';
import queryString from 'query-string';
import styled from 'styled-components';
......@@ -60,12 +55,29 @@ const StepSliderWrapper = styled.div`
}
`;
const PRCurve: FunctionComponent = () => {
const {t} = useTranslation(['pr-curve', 'common']);
type CurveAsideProps = {
type: CurveType;
onChangeLoading: (loading: boolean) => unknown;
onChangeSteps: (tags: CurveTag[]) => unknown;
onToggleRunning: (running: boolean) => unknown;
};
const CurveAside: FunctionComponent<CurveAsideProps> = ({type, onChangeLoading, onChangeSteps, onToggleRunning}) => {
const {t} = useTranslation('curves');
const [running, setRunning] = useState(true);
const {runs, tags, runsInTags, selectedRuns, onChangeRuns, loading} = useTagFilter('pr-curve', running);
// TODO: remove `as` after ts 4.1
const {runs, tags, runsInTags, selectedRuns, onChangeRuns, loading} = useTagFilter(
`${type}-curve` as 'pr-curve' | 'roc-curve',
running
);
const {data: stepInfo} = useRunningRequest<StepInfo[]>(
runsInTags.map(run => `/${type}-curve/steps?${queryString.stringify({run: run.label})}`),
!!running,
(...urls) => cycleFetcher(urls)
);
const [indexes, setIndexes] = useState<Record<string, number>>({});
const onChangeIndexes = useCallback(
......@@ -89,12 +101,7 @@ const PRCurve: FunctionComponent = () => {
[runsInTags]
);
const {data: stepInfo} = useRunningRequest<StepInfo[]>(
runsInTags.map(run => `/pr-curve/steps?${queryString.stringify({run: run.label})}`),
!!running,
(...urls) => cycleFetcher(urls)
);
const runWithInfo = useMemo<Run[]>(
const curveRun = useMemo<CurveRun[]>(
() =>
runsInTags.map((run, i) => ({
...run,
......@@ -108,70 +115,52 @@ const PRCurve: FunctionComponent = () => {
const [timeType, setTimeType] = useState<TimeType>(TimeType.Step);
const prCurveTags = useMemo<Tag[]>(
() =>
tags.map(tag => ({
useEffect(() => {
onChangeSteps(
tags.map<CurveTag>(tag => ({
...tag,
runs: tag.runs.map(run => ({
...run,
index: 0,
steps: [] as Run['steps'],
wallTimes: [] as Run['wallTimes'],
relatives: [] as Run['relatives'],
...runWithInfo.find(r => r.label === run.label)
steps: [],
wallTimes: [],
relatives: [],
...curveRun.find(r => r.label === run.label)
}))
})),
[tags, runWithInfo]
);
const aside = useMemo(
() =>
runs.length ? (
<RunAside
runs={runs}
selectedRuns={selectedRuns}
onChangeRuns={onChangeRuns}
running={running}
onToggleRunning={setRunning}
>
<AsideSection>
<Field label={t('pr-curve:time-display-type')}>
<TimeModeSelect value={timeType} onChange={setTimeType} />
</Field>
}))
);
}, [tags, curveRun, onChangeSteps]);
useEffect(() => {
onChangeLoading(loading);
}, [loading, onChangeLoading]);
useEffect(() => {
onToggleRunning(running);
}, [onToggleRunning, running]);
return runs.length ? (
<RunAside
runs={runs}
selectedRuns={selectedRuns}
onChangeRuns={onChangeRuns}
running={running}
onToggleRunning={setRunning}
>
<AsideSection>
<Field label={t('curves:time-display-type')}>
<TimeModeSelect value={timeType} onChange={setTimeType} />
</Field>
</AsideSection>
<StepSliderWrapper>
{curveRun.map(run => (
<AsideSection key={run.label}>
<StepSlider run={run} type={timeType} onChange={index => onChangeIndexes(run.label, index)} />
</AsideSection>
<StepSliderWrapper>
{runWithInfo.map(run => (
<AsideSection key={run.label}>
<StepSlider
run={run}
type={timeType}
onChange={index => onChangeIndexes(run.label, index)}
/>
</AsideSection>
))}
</StepSliderWrapper>
</RunAside>
) : null,
[t, onChangeRuns, running, runs, selectedRuns, timeType, runWithInfo, onChangeIndexes]
);
const withChart = useCallback<WithChart<Tag>>(
({label, runs, ...args}) => <PRCurveChart runs={runs} tag={label} {...args} running={running} />,
[running]
);
return (
<>
<Title>{t('common:pr-curve')}</Title>
<Content aside={aside} loading={loading}>
{!loading && !runs.length ? (
<Error />
) : (
<ChartPage items={prCurveTags} withChart={withChart} loading={loading} />
)}
</Content>
</>
);
))}
</StepSliderWrapper>
</RunAside>
) : null;
};
export default PRCurve;
export default CurveAside;
......@@ -14,10 +14,10 @@
* limitations under the License.
*/
import type {CurveType, PRCurveData, Run} from '~/resource/curves';
import LineChart, {LineChartRef} from '~/components/LineChart';
import type {PRCurveData, Run} from '~/resource/pr-curve';
import React, {FunctionComponent, useCallback, useMemo, useRef, useState} from 'react';
import {options as chartOptions, nearestPoint} from '~/resource/pr-curve';
import {options as chartOptions, nearestPoint} from '~/resource/curves';
import {rem, size} from '~/utils/style';
import ChartToolbox from '~/components/ChartToolbox';
......@@ -62,19 +62,20 @@ const Error = styled.div`
`;
type PRCurveChartProps = {
type: CurveType;
cid: symbol;
runs: Run[];
tag: string;
running?: boolean;
};
const PRCurveChart: FunctionComponent<PRCurveChartProps> = ({cid, runs, tag, running}) => {
const {t} = useTranslation(['pr-curve', 'common']);
const PRCurveChart: FunctionComponent<PRCurveChartProps> = ({type, cid, runs, tag, running}) => {
const {t} = useTranslation(['curves', 'common']);
const echart = useRef<LineChartRef>(null);
const {data: dataset, error, loading} = useRunningRequest<PRCurveData[]>(
runs.map(run => `/pr-curve/list?${queryString.stringify({run: run.label, tag})}`),
runs.map(run => `/${type}-curve/list?${queryString.stringify({run: run.label, tag})}`),
!!running,
(...urls) => cycleFetcher(urls)
);
......@@ -136,25 +137,25 @@ const PRCurveChart: FunctionComponent<PRCurveChartProps> = ({cid, runs, tag, run
);
const columns = [
{
label: t('pr-curve:threshold')
label: t('curves:threshold')
},
{
label: t('pr-curve:precision')
label: t('curves:precision')
},
{
label: t('pr-curve:recall')
label: t('curves:recall')
},
{
label: t('pr-curve:true-positives')
label: t('curves:true-positives')
},
{
label: t('pr-curve:false-positives')
label: t('curves:false-positives')
},
{
label: t('pr-curve:true-negatives')
label: t('curves:true-negatives')
},
{
label: t('pr-curve:false-negatives')
label: t('curves:false-negatives')
}
];
const runData = points.reduce<Run[]>((m, runPoints, index) => {
......@@ -208,19 +209,19 @@ const PRCurveChart: FunctionComponent<PRCurveChartProps> = ({cid, runs, tag, run
{
icon: 'maximize',
activeIcon: 'minimize',
tooltip: t('pr-curve:maximize'),
activeTooltip: t('pr-curve:minimize'),
tooltip: t('curves:maximize'),
activeTooltip: t('curves:minimize'),
toggle: true,
onClick: toggleMaximized
},
{
icon: 'restore-size',
tooltip: t('pr-curve:restore'),
tooltip: t('curves:restore'),
onClick: () => echart.current?.restore()
},
{
icon: 'download',
tooltip: t('pr-curve:download-image'),
tooltip: t('curves:download-image'),
onClick: () => echart.current?.saveAsImage()
}
]}
......
......@@ -19,8 +19,8 @@ import {ellipsis, size, transitionProps} from '~/utils/style';
import Field from '~/components/Field';
import RangeSlider from '~/components/RangeSlider';
import type {Run} from '~/resource/pr-curve';
import {TimeType} from '~/resource/pr-curve';
import type {Run} from '~/resource/curves';
import {TimeType} from '~/resource/curves';
import {format} from 'd3-format';
import {formatTime} from '~/utils';
import styled from 'styled-components';
......
......@@ -118,10 +118,17 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
useLayoutEffect(() => {
const c = chartElement.current;
if (c) {
let animationId: number | null = null;
const observer = new ResizeObserver(() => {
const rect = c.getBoundingClientRect();
setWidth(rect.width);
setHeight(rect.height);
if (animationId != null) {
cancelAnimationFrame(animationId);
animationId = null;
}
animationId = requestAnimationFrame(() => {
const rect = c.getBoundingClientRect();
setWidth(rect.width);
setHeight(rect.height);
});
});
observer.observe(c);
return () => observer.unobserve(c);
......
......@@ -83,7 +83,7 @@ const StyledAside = styled(Aside)`
}
`;
type RunAsideProps = {
export type RunAsideProps = {
runs?: Run[];
selectedRuns?: Run[];
onChangeRuns?: (runs: Run[]) => unknown;
......
......@@ -23,11 +23,11 @@ import GridLoader from 'react-spinners/GridLoader';
import type {Run} from '~/types';
import StepSlider from '~/components/SamplePage/StepSlider';
import {fetcher} from '~/utils/fetch';
import fileSaver from 'file-saver';
import {formatTime} from '~/utils';
import isEmpty from 'lodash/isEmpty';
import mime from 'mime-types';
import queryString from 'query-string';
import {saveFile} from '~/utils/saveFile';
import styled from 'styled-components';
import useRequest from '~/hooks/useRequest';
import {useRunningRequest} from '~/hooks/useRequest';
......@@ -268,13 +268,7 @@ const SampleChart: FunctionComponent<SampleChartProps> = ({
const download = useCallback(() => {
if (entityData) {
const ext = entityData.type ? mime.extension(entityData.type) : null;
fileSaver.saveAs(
entityData.data,
`${run.label}-${tag}-${steps[step]}-${wallTime.toString().replace(/\./, '_')}`.replace(
/[/\\?%*:|"<>]/g,
'_'
) + (ext ? `.${ext}` : '')
);
saveFile(entityData.data, `${run.label}-${tag}-${steps[step]}-${wallTime}` + (ext ? `.${ext}` : ''));
}
}, [entityData, run.label, tag, steps, step, wallTime]);
......
......@@ -39,6 +39,7 @@ import ee from '~/utils/event';
import {format} from 'd3-format';
import queryString from 'query-string';
import {renderToStaticMarkup} from 'react-dom/server';
import saveFile from '~/utils/saveFile';
import styled from 'styled-components';
import {useRunningRequest} from '~/hooks/useRequest';
import {useTranslation} from 'react-i18next';
......@@ -210,6 +211,49 @@ const ScalarChart: FunctionComponent<ScalarChartProps> = ({
[formatter, ranges, xAxisType, yAxisType]
);
const toolbox = useMemo(
() => [
{
icon: 'maximize',
activeIcon: 'minimize',
tooltip: t('scalar:maximize'),
activeTooltip: t('scalar:minimize'),
toggle: true,
onClick: toggleMaximized
},
{
icon: 'restore-size',
tooltip: t('scalar:restore'),
onClick: () => echart.current?.restore()
},
{
icon: 'log-axis',
tooltip: t('scalar:toggle-log-axis'),
toggle: true,
onClick: toggleYAxisType
},
{
icon: 'download',
menuList: [
{
label: t('scalar:download-image'),
onClick: () => echart.current?.saveAsImage()
},
{
label: t('scalar:download-data'),
onClick: () =>
saveFile(
runs.map(run => `/scalar/data?${queryString.stringify({run: run.label, tag})}`),
runs.map(run => `visualdl-scalar-${run.label}-${tag}.tsv`),
`visualdl-scalar-${tag}.zip`
)
}
]
}
],
[runs, t, tag, toggleMaximized, toggleYAxisType]
);
// display error only on first fetch
if (!data && error) {
return <Error>{t('common:error')}</Error>;
......@@ -218,34 +262,7 @@ const ScalarChart: FunctionComponent<ScalarChartProps> = ({
return (
<Wrapper>
<StyledLineChart ref={echart} title={tag} options={options} data={data} loading={loading} zoom />
<Toolbox
items={[
{
icon: 'maximize',
activeIcon: 'minimize',
tooltip: t('scalar:maximize'),
activeTooltip: t('scalar:minimize'),
toggle: true,
onClick: toggleMaximized
},
{
icon: 'restore-size',
tooltip: t('scalar:restore'),
onClick: () => echart.current?.restore()
},
{
icon: 'log-axis',
tooltip: t('scalar:toggle-log-axis'),
toggle: true,
onClick: toggleYAxisType
},
{
icon: 'download',
tooltip: t('scalar:download-image'),
onClick: () => echart.current?.saveAsImage()
}
]}
/>
<Toolbox items={toolbox} />
</Wrapper>
);
};
......
......@@ -21,7 +21,7 @@ import {position, primaryColor, size} from '~/utils/style';
import type {ECharts} from 'echarts';
import {dataURL2Blob} from '~/utils/image';
import fileSaver from 'file-saver';
import {saveFile} from '~/utils/saveFile';
import styled from 'styled-components';
import {themes} from '~/utils/theme';
import useTheme from '~/hooks/useTheme';
......@@ -118,8 +118,15 @@ const useECharts = <T extends HTMLElement, W extends HTMLElement = HTMLDivElemen
if (options.autoFit) {
const w = wrapper.current;
if (w) {
let animationId: number | null = null;
const observer = new ResizeObserver(() => {
echart?.resize();
if (animationId != null) {
cancelAnimationFrame(animationId);
animationId = null;
}
animationId = requestAnimationFrame(() => {
echart?.resize();
});
});
observer.observe(w);
return () => observer.unobserve(w);
......@@ -131,7 +138,7 @@ const useECharts = <T extends HTMLElement, W extends HTMLElement = HTMLDivElemen
(filename?: string) => {
if (echart) {
const blob = dataURL2Blob(echart.getDataURL({type: 'png', pixelRatio: 2, backgroundColor: '#FFF'}));
fileSaver.saveAs(blob, `${filename?.replace(/[/\\?%*:|"<>]/g, '_') || 'chart'}.png`);
saveFile(blob, `${filename || 'chart'}.png`);
}
},
[echart]
......
......@@ -28,7 +28,8 @@ export const navMap = {
audio: Pages.Audio,
graph: Pages.Graph,
embeddings: Pages.HighDimensional,
pr_curve: Pages.PRCurve
pr_curve: Pages.PRCurve,
roc_curve: Pages.ROCCurve
} as const;
const useNavItems = () => {
......
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import ChartPage, {WithChart} from '~/components/ChartPage';
import React, {FunctionComponent, useCallback, useState} from 'react';
import Content from '~/components/Content';
import CurveAside from '~/components/CurvesPage/CurveAside';
import CurveChart from '~/components/CurvesPage/CurveChart';
import Error from '~/components/Error';
import type {Tag} from '~/resource/curves';
import Title from '~/components/Title';
import {useTranslation} from 'react-i18next';
const PRCurve: FunctionComponent = () => {
const {t} = useTranslation('common');
const [running, setRunning] = useState(true);
const [loading, setLoading] = useState(false);
const [tags, setTags] = useState<Tag[]>([]);
const withChart = useCallback<WithChart<Tag>>(
({label, runs, ...args}) => <CurveChart type="pr" runs={runs} tag={label} {...args} running={running} />,
[running]
);
return (
<>
<Title>{t('common:pr-curve')}</Title>
<Content
aside={
<CurveAside
type="pr"
onChangeLoading={setLoading}
onChangeSteps={setTags}
onToggleRunning={setRunning}
/>
}
loading={loading}
>
{!loading && !tags.length ? (
<Error />
) : (
<ChartPage items={tags} withChart={withChart} loading={loading} />
)}
</Content>
</>
);
};
export default PRCurve;
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import ChartPage, {WithChart} from '~/components/ChartPage';
import React, {FunctionComponent, useCallback, useState} from 'react';
import Content from '~/components/Content';
import CurveAside from '~/components/CurvesPage/CurveAside';
import CurveChart from '~/components/CurvesPage/CurveChart';
import Error from '~/components/Error';
import type {Tag} from '~/resource/curves';
import Title from '~/components/Title';
import {useTranslation} from 'react-i18next';
const ROCCurve: FunctionComponent = () => {
const {t} = useTranslation('common');
const [running, setRunning] = useState(true);
const [loading, setLoading] = useState(false);
const [tags, setTags] = useState<Tag[]>([]);
const withChart = useCallback<WithChart<Tag>>(
({label, runs, ...args}) => <CurveChart type="roc" runs={runs} tag={label} {...args} running={running} />,
[running]
);
return (
<>
<Title>{t('common:roc-curve')}</Title>
<Content
aside={
<CurveAside
type="roc"
onChangeLoading={setLoading}
onChangeSteps={setTags}
onToggleRunning={setRunning}
/>
}
loading={loading}
>
{!loading && !tags.length ? (
<Error />
) : (
<ChartPage items={tags} withChart={withChart} loading={loading} />
)}
</Content>
</>
);
};
export default ROCCurve;
......@@ -32,7 +32,7 @@ const chartSize = {
height: rem(244)
};
const Audio: FunctionComponent = () => {
const AudioSample: FunctionComponent = () => {
const {t} = useTranslation(['sample', 'common']);
const audioContext = useRef<AudioContext>();
......@@ -90,4 +90,4 @@ const Audio: FunctionComponent = () => {
);
};
export default Audio;
export default AudioSample;
......@@ -36,7 +36,7 @@ const chartSize = {
height: rem(406)
};
const Image: FunctionComponent = () => {
const ImageSample: FunctionComponent = () => {
const {t} = useTranslation(['sample', 'common']);
const [running, setRunning] = useState(true);
......@@ -112,4 +112,4 @@ const Image: FunctionComponent = () => {
);
};
export default Image;
export default ImageSample;
......@@ -14,7 +14,7 @@
* limitations under the License.
*/
export type {PRCurveData, Run, StepInfo, Tag} from './types';
export type {CurveType, PRCurveData, Run, StepInfo, Tag} from './types';
export {TimeType} from './types';
export * from './chart';
export * from './data';
......@@ -18,6 +18,8 @@ import {Run as BaseRun, Tag as BaseTag, TimeMode} from '~/types';
export {TimeMode as TimeType};
export type CurveType = 'pr' | 'roc';
type Step = number;
type WallTime = number;
type Relative = number;
......
......@@ -23,7 +23,8 @@ export enum Pages {
Audio = 'audio',
Graph = 'graph',
HighDimensional = 'high-dimensional',
PRCurve = 'pr-curve'
PRCurve = 'pr-curve',
ROCCurve = 'roc-curve'
}
export interface Route {
......@@ -81,7 +82,12 @@ const routes: Route[] = [
{
id: Pages.PRCurve,
path: '/pr-curve',
component: React.lazy(() => import('~/pages/pr-curve'))
component: React.lazy(() => import('~/pages/curves/pr'))
},
{
id: Pages.ROCCurve,
path: '/roc-curve',
component: React.lazy(() => import('~/pages/curves/roc'))
}
];
......
......@@ -23,7 +23,8 @@ const initState: RunsState = {
histogram: [],
image: [],
audio: [],
'pr-curve': []
'pr-curve': [],
'roc-curve': []
};
function runsReducer(state = initState, action: RunsActionTypes): RunsState {
......
......@@ -26,6 +26,7 @@ export interface RunsState {
image: Runs;
audio: Runs;
'pr-curve': Runs;
'roc-curve': Runs;
}
export type Page = keyof RunsState;
......
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import FileSaver from 'file-saver';
import JSZip from 'jszip';
import {fetcher} from '~/utils/fetch';
import isString from 'lodash/isString';
import {toast} from 'react-toastify';
async function getFile(url: string): Promise<string | Blob> {
const result = await fetcher(url);
if (result.data instanceof Blob) {
return result.data;
} else if (!isString(result)) {
return JSON.stringify(result);
}
return result;
}
function normalizeFilename(name: string) {
return name.replace(/[/\\?%*:|"<>]/g, '_');
}
export function saveFile(file: string | Blob, filename: string) {
let blob: Blob;
if (file instanceof Blob) {
blob = file;
} else {
blob = new Blob([file], {type: 'text/plain;charset=utf-8'});
}
FileSaver.saveAs(blob, normalizeFilename(filename));
}
export default async function (url: string | string[], filename: string | string[], zipFilename = 'download.zip') {
if (!url) {
return;
}
let urls: string[] = url as string[];
let filenames: string[] = filename as string[];
if (isString(url)) {
urls = [url];
}
if (isString(filename)) {
filenames = [filename];
}
try {
const data = await Promise.all(urls.map(getFile));
if (data.length === 1) {
saveFile(data[0], filenames[0]);
} else {
const zip = new JSZip();
let basename = '';
let extname = '';
if (filenames.length === 1) {
const filenameArr = filenames[0].split('.');
if (filenameArr.length > 1) {
extname = filenameArr.pop() as string;
}
basename = filenameArr.join('.');
}
data.forEach((file, index) => {
zip.file(normalizeFilename(basename ? `${basename} - ${index}.${extname}` : filenames[index]), file);
});
const zipFile = await zip.generateAsync({type: 'blob'});
saveFile(zipFile, zipFilename);
}
} catch (e) {
toast(e.message, {
position: toast.POSITION.TOP_CENTER,
type: toast.TYPE.ERROR
});
}
}
......@@ -149,7 +149,7 @@ export const GlobalStyle = createGlobalStyle`
html,
body {
height: 100%;
min-height: 100%;
background-color: var(--body-background-color);
color: var(--text-color);
${transitionProps(['background-color', 'color'])}
......
......@@ -27,7 +27,7 @@ import {spawn} from 'child_process';
const host = '127.0.0.1';
const publicPath = '/visualdl';
const pages = ['common', 'scalar', 'histogram', 'image', 'audio', 'graph', 'pr-curve', 'high-dimensional'];
const pages = ['common', 'scalar', 'histogram', 'image', 'audio', 'graph', 'pr-curve', 'roc-curve', 'high-dimensional'];
const dataDir = path.resolve(__dirname, '../data');
async function start() {
......
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import type {Data, Worker} from './types';
const worker: Worker = async io => {
const components = await io.getData<string[]>('/components');
if (!components.includes('roc_curve')) {
return;
}
const {runs, tags} = await io.save<Data>('/roc-curve/tags');
for (const [index, run] of runs.entries()) {
await io.save('/roc-curve/steps', {run});
for (const tag of tags[index]) {
await io.save('/roc-curve/list', {run, tag});
}
}
};
export default worker;
......@@ -14,4 +14,4 @@
* limitations under the License.
*/
export default ['embeddings', 'scalar', 'image', 'audio', 'graph', 'histogram', 'pr_curve'];
export default ['embeddings', 'scalar', 'image', 'audio', 'graph', 'histogram', 'pr_curve', 'roc_curve'];
此差异已折叠。
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Request} from 'express';
export default (request: Request) => {
if (request.query.run === 'train') {
return [
[1593069993786.464, 0],
[1593069993787.353, 1],
[1593069993788.1448, 2],
[1593069993788.836, 3],
[1593069993789.4, 4],
[1593069993790.076, 5],
[1593069993790.763, 6],
[1593069993791.473, 7],
[1593069993792.149, 8],
[1593069993792.763, 9]
];
}
return [
[1593069993538.6739, 0],
[1593069993539.396, 1],
[1593069993540.066, 2],
[1593069993540.662, 3],
[1593069993541.333, 4],
[1593069993542.078, 5],
[1593069993543.1821, 6],
[1593069993543.998, 7],
[1593069993544.9128, 8],
[1593069993545.62, 9]
];
};
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export default {
runs: ['train', 'test'],
tags: [
['layer2/biases/summaries/mean', 'test/1234', 'another'],
[
'layer2/biases/summaries/mean',
'layer2/biases/summaries/accuracy',
'layer2/biases/summaries/cost',
'test/431',
'others'
]
]
};
/**
* Copyright 2020 Baidu Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Request, Response} from 'express';
export default (req: Request, res: Response) => {
const {run, tag} = req.query;
res.setHeader('Content-Type', 'text/tab-separated-values');
return `scalar\n${run}\n${tag}`;
};
......@@ -8231,6 +8231,11 @@ ignore@^5.1.4:
resolved "https://registry.yarnpkg.com/ignore/-/ignore-5.1.8.tgz#f150a8b50a34289b33e22f5889abd4d8016f0e57"
integrity sha512-BMpfD7PpiETpBl/A6S498BaIJ6Y/ABT93ETbby2fP00v4EbvPBXWEoaR1UBPKs3iR53pJY7EtZk5KACI57i1Uw==
immediate@~3.0.5:
version "3.0.6"
resolved "https://registry.yarnpkg.com/immediate/-/immediate-3.0.6.tgz#9db1dbd0faf8de6fbe0f5dd5e56bb606280de69b"
integrity sha1-nbHb0Pr43m++D13V5Wu2BigN5ps=
import-fresh@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/import-fresh/-/import-fresh-2.0.0.tgz#d81355c15612d386c61f9ddd3922d4304822a546"
......@@ -9337,6 +9342,16 @@ jsprim@^1.2.2:
array-includes "^3.1.1"
object.assign "^4.1.1"
jszip@3.5.0:
version "3.5.0"
resolved "https://registry.yarnpkg.com/jszip/-/jszip-3.5.0.tgz#b4fd1f368245346658e781fec9675802489e15f6"
integrity sha512-WRtu7TPCmYePR1nazfrtuF216cIVon/3GWOvHS9QR5bIwSbnxtdpma6un3jyGGNhHsKCSzn5Ypk+EkDRvTGiFA==
dependencies:
lie "~3.3.0"
pako "~1.0.2"
readable-stream "~2.3.6"
set-immediate-shim "~1.0.1"
junk@^3.1.0:
version "3.1.0"
resolved "https://registry.yarnpkg.com/junk/-/junk-3.1.0.tgz#31499098d902b7e98c5d9b9c80f43457a88abfa1"
......@@ -9464,6 +9479,13 @@ levn@~0.3.0:
prelude-ls "~1.1.2"
type-check "~0.3.2"
lie@~3.3.0:
version "3.3.0"
resolved "https://registry.yarnpkg.com/lie/-/lie-3.3.0.tgz#dcf82dee545f46074daf200c7c1c5a08e0f40f6a"
integrity sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==
dependencies:
immediate "~3.0.5"
lines-and-columns@^1.1.6:
version "1.1.6"
resolved "https://registry.yarnpkg.com/lines-and-columns/-/lines-and-columns-1.1.6.tgz#1c00c743b433cd0a4e80758f7b64a57440d9ff00"
......@@ -11018,7 +11040,7 @@ package-json@^6.3.0:
registry-url "^5.0.0"
semver "^6.2.0"
pako@1.0.11:
pako@1.0.11, pako@~1.0.2:
version "1.0.11"
resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.11.tgz#6c9599d340d54dfd3946380252a35705a6b992bf"
integrity sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==
......@@ -12683,6 +12705,11 @@ set-blocking@^2.0.0, set-blocking@~2.0.0:
resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7"
integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc=
set-immediate-shim@~1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/set-immediate-shim/-/set-immediate-shim-1.0.1.tgz#4b2b1b27eb808a9f8dcc481a58e5e56f599f3f61"
integrity sha1-SysbJ+uAip+NzEgaWOXlb1mfP2E=
set-value@^2.0.0, set-value@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/set-value/-/set-value-2.0.1.tgz#a18d40530e6f07de4228c7defe4227af8cad005b"
......
# **抽奖规则**
### **方式一:点Star抽奖(20名)**
**参与方式**
[VisualDL](https://github.com/PaddlePaddle/VisualDL)或者 [PaddleX](https://github.com/PaddlePaddle/PaddleX)任意Repo 点星(Star)参与抽奖(也可双重点击累加,提高中奖几率)
**奖项设置**
VDL Github抽取9名,PaddleX Github抽取9名(以下奖项均分)
**一等奖(2名)** :蓝牙键盘、飞桨充电宝
**二等奖(2名)** :价值50元的京东购物卡
**三等奖(4名)** :百度网盘超级会员
**鼓励奖(10名)** :飞桨鸭舌帽/飞桨精美帆布包 任选
**领奖方式**
根据开奖公告,获奖同学凭github账号截图到**VisualDL qq群(1045783368)@Yixin|Visualdl **小姐姐领奖
### **方式二:在AI Studio上创建项目(3名)**
**参与方式**
在AI Studio上应用PaddleX 以及 VisualDL创建「计数」项目,并将项目链接回复至AI Studio的[VDL论坛评论区](https://ai.baidu.com/forum/topic/show/960053):飞桨团队会对项目质量打分并选择得分最高的前三名送出奖品。
**奖项设置**:蓝牙键盘*3
**领奖方式** :依依小姐姐将单独联系获奖的项目同学送出奖品
### **参与时间**:
**9月21日--9月29日(下周二)**
### **开奖方式**:
**获奖名单将在9月30日公布于本页「0921直播抽奖活动」专栏**
![Star示意](./imgs/star.png)
![Star后ID呈现](./imgs/ID.jpg)
# 获奖名单
## 一等奖
### VisualDL Repo
恭喜用户**1084667371**获得**蓝牙键盘**一个!!
### PaddleX Repo
恭喜用户**Calvert97**获得**飞桨充电宝**一个!!
## 二等奖
### VisualDL Repo
恭喜**0-yy-0**获得价值50元的**京东购物卡**!!
### PaddleX Repo
恭喜**AnkerLeng**获得价值50元的**京东购物卡**!!
## 三等奖
### VisualDL Repo
恭喜**JetHong、gylidian**获得**百度网盘超级会员**!!
### PaddleX Repo
恭喜**loyasto、liu824**获得**百度网盘超级会员**!!
## 阳光普照奖
### VisualDL Repo
恭喜**Sqhttwl、ckalong、dongtianqi1125、ralph0813、wjiawei97**获得**飞桨鸭舌帽 OR 飞桨帆布袋**!!
### PaddleX Repo
恭喜**wangluohaima、Yaoxingtian、rango42z、qiceng、mumucai**获得**飞桨鸭舌帽 OR 飞桨帆布袋**!!
*注意:阳光普照奖获奖小伙伴们先领奖可以先选择心仪的奖品(飞桨鸭舌帽 OR 飞桨帆布袋)噢~~
## 领奖方式
**再次恭喜以上获奖小伙伴们,请加入VisualDL官方QQ群1045783368,在群里@依依小姐姐(Yixin|VisualDL)发送Github账号截图领奖~~**
<p align="center">
<img src="https://user-images.githubusercontent.com/48054808/82522691-c2758680-9b5c-11ea-9aee-fca994aba175.png" width="20%"/>
</p>
......@@ -34,6 +34,9 @@ components = {
"pr_curve": {
"enabled": False
},
"roc_curve": {
"enabled": False
},
"meta_data": {
"enabled": False
}
......
......@@ -404,3 +404,125 @@ def pr_curve_raw(tag, tp, fp, tn, fn, precision, recall, step, walltime):
Record.Value(
id=step, tag=tag, timestamp=walltime, pr_curve=prcurve)
])
def compute_roc_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute ROC curve data by labels and predictions.
Args:
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
classified as true.
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
"""
if isinstance(labels, list):
labels = np.array(labels)
if isinstance(predictions, list):
predictions = np.array(predictions)
_MINIMUM_COUNT = 1e-7
if weights is None:
weights = 1.0
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float)
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights)
fp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights)
# Obtain the reverse cumulative sum.
tp = np.cumsum(tp_buckets[::-1])[::-1]
fp = np.cumsum(fp_buckets[::-1])[::-1]
tn = fp[0] - fp
fn = tp[0] - tp
tpr = tp / np.maximum(_MINIMUM_COUNT, tn + fp)
fpr = fp / np.maximum(_MINIMUM_COUNT, tn + fp)
data = {
'tp': tp.astype(int).tolist(),
'fp': fp.astype(int).tolist(),
'tn': tn.astype(int).tolist(),
'fn': fn.astype(int).tolist(),
'tpr': tpr.astype(float).tolist(),
'fpr': fpr.astype(float).tolist()
}
return data
def roc_curve(tag, labels, predictions, step, walltime, num_thresholds=127,
weights=None):
"""Package data to one roc_curve.
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
classified as true.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Return:
Package with format of record_pb2.Record
"""
num_thresholds = min(num_thresholds, 127)
roc_curve_map = compute_roc_curve(labels, predictions, num_thresholds, weights)
return roc_curve_raw(tag=tag,
tp=roc_curve_map['tp'],
fp=roc_curve_map['fp'],
tn=roc_curve_map['tn'],
fn=roc_curve_map['fn'],
tpr=roc_curve_map['tpr'],
fpr=roc_curve_map['fpr'],
step=step,
walltime=walltime)
def roc_curve_raw(tag, tp, fp, tn, fn, tpr, fpr, step, walltime):
"""Package raw data to one roc_curve.
Args:
tag (string): Data identifier
tp (list): True Positive.
fp (list): False Positive.
tn (list): True Negative.
fn (list): False Negative.
tpr (list): true positive rate:
fpr (list): false positive rate.
step (int): Step of roc_curve
walltime (int): Wall time of roc_curve
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Return:
Package with format of record_pb2.Record
"""
"""
if isinstance(tp, np.ndarray):
tp = tp.astype(int).tolist()
if isinstance(fp, np.ndarray):
fp = fp.astype(int).tolist()
if isinstance(tn, np.ndarray):
tn = tn.astype(int).tolist()
if isinstance(fn, np.ndarray):
fn = fn.astype(int).tolist()
if isinstance(tpr, np.ndarray):
tpr = tpr.astype(int).tolist()
if isinstance(fpr, np.ndarray):
fpr = fpr.astype(int).tolist()
"""
roc_curve = Record.ROC_Curve(TP=tp,
FP=fp,
TN=tn,
FN=fn,
tpr=tpr,
fpr=fpr)
return Record(values=[
Record.Value(
id=step, tag=tag, timestamp=walltime, roc_curve=roc_curve)
])
......@@ -43,6 +43,15 @@ message Record {
repeated double precision = 5;
repeated double recall = 6;
}
message ROC_Curve {
repeated int64 TP = 1 [packed = true];
repeated int64 FP = 2 [packed = true];
repeated int64 TN = 3 [packed = true];
repeated int64 FN = 4 [packed = true];
repeated double tpr = 5;
repeated double fpr = 6;
}
message MetaData {
string display_name = 1;
......@@ -60,6 +69,7 @@ message Record {
Histogram histogram = 8;
PRCurve pr_curve = 9;
MetaData meta_data = 10;
ROC_Curve roc_curve = 11;
}
}
......
此差异已折叠。
......@@ -171,6 +171,8 @@ class LogReader(object):
component = "histogram"
elif "pr_curve" == value_type:
component = "pr_curve"
elif "roc_curve" == value_type:
component = "roc_curve"
elif "meta_data" == value_type:
self.update_meta_data(record)
component = "meta_data"
......
......@@ -112,6 +112,9 @@ class Api(object):
@result()
def pr_curve_tags(self):
return self._get_with_retry('data/plugin/pr_curves/tags', lib.get_pr_curve_tags)
@result()
def roc_curve_tags(self):
return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags)
@result()
def scalar_list(self, run, tag):
......@@ -178,11 +181,18 @@ class Api(object):
def pr_curves_pr_curve(self, run, tag):
key = os.path.join('data/plugin/pr_curves/pr_curve', run, tag)
return self._get_with_retry(key, lib.get_pr_curve, run, tag)
@result()
def roc_curves_roc_curve(self, run, tag):
key = os.path.join('data/plugin/roc_curves/roc_curve', run, tag)
return self._get_with_retry(key, lib.get_roc_curve, run, tag)
@result()
def pr_curves_steps(self, run):
key = os.path.join('data/plugin/pr_curves/steps', run)
return self._get_with_retry(key, lib.get_pr_curve_step, run)
@result()
def roc_curves_steps(self, run):
key = os.path.join('data/plugin/roc_curves/steps', run)
return self._get_with_retry(key, lib.get_roc_curve_step, run)
@result('application/octet-stream', lambda s: {"Content-Disposition": 'attachment; filename="%s"' % s.model_name} if len(s.model_name) else None)
def graph_graph(self):
......@@ -203,6 +213,7 @@ def create_api_call(logdir, model, cache_timeout):
'embedding/tags': (api.embedding_tags, []),
'histogram/tags': (api.histogram_tags, []),
'pr-curve/tags': (api.pr_curve_tags, []),
'roc-curve/tags': (api.roc_curve_tags, []),
'scalar/list': (api.scalar_list, ['run', 'tag']),
'scalar/data': (api.scalar_data, ['run', 'tag']),
'image/list': (api.image_list, ['run', 'tag']),
......@@ -216,7 +227,9 @@ def create_api_call(logdir, model, cache_timeout):
'histogram/list': (api.histogram_list, ['run', 'tag']),
'graph/graph': (api.graph_graph, []),
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
'pr-curve/steps': (api.pr_curves_steps, ['run'])
'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']),
'pr-curve/steps': (api.pr_curves_steps, ['run']),
'roc-curve/steps': (api.roc_curves_steps, ['run'])
}
def call(path: str, args):
......
......@@ -24,6 +24,7 @@ DEFAULT_PLUGIN_MAXSIZE = {
"embeddings": 50000000,
"audio": 10,
"pr_curve": 300,
"roc_curve": 300,
"meta_data": 100
}
......@@ -346,6 +347,8 @@ class DataManager(object):
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["audio"]),
"pr_curve":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["pr_curve"]),
"roc_curve":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["roc_curve"]),
"meta_data":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"])
}
......
......@@ -198,6 +198,28 @@ def get_pr_curve(log_reader, run, tag):
list(pr_curve.FN),
num_thresholds])
return results
def get_roc_curve(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("roc_curve").get_items(
run, decode_tag(tag))
results = []
for item in records:
roc_curve = item.roc_curve
length = len(roc_curve.tpr)
num_thresholds = [float(v) / length for v in range(1, length + 1)]
results.append([s2ms(item.timestamp),
item.id,
list(roc_curve.tpr),
list(roc_curve.fpr),
list(roc_curve.TP),
list(roc_curve.FP),
list(roc_curve.TN),
list(roc_curve.FN),
num_thresholds])
return results
def get_pr_curve_step(log_reader, run, tag=None):
......@@ -212,6 +234,18 @@ def get_pr_curve_step(log_reader, run, tag=None):
return results
def get_roc_curve_step(log_reader, run, tag=None):
fake_run = run
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
run2tag = get_roc_curve_tags(log_reader)
tag = run2tag['tags'][run2tag['runs'].index(fake_run)][0]
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("roc_curve").get_items(
run, decode_tag(tag))
results = [[s2ms(item.timestamp), item.id] for item in records]
return results
def get_embeddings_list(log_reader):
run2tag = get_logs(log_reader, 'embeddings')
......@@ -266,30 +300,6 @@ def get_embedding_tensors(log_reader, name):
return vectors
def get_embeddings(log_reader, run, tag, reduction, dimension=2):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
labels = []
vectors = []
for item in records[0].embeddings.embeddings:
labels.append(item.label)
vectors.append(item.vectors)
vectors = np.array(vectors)
if reduction == 'tsne':
import visualdl.server.tsne as tsne
low_dim_embs = tsne.tsne(
vectors, dimension, initial_dims=50, perplexity=30.0)
elif reduction == 'pca':
low_dim_embs = simple_pca(vectors, dimension)
return {"embedding": low_dim_embs.tolist(), "labels": labels}
def get_histogram(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
......
......@@ -16,7 +16,7 @@ import os
import time
import numpy as np
from visualdl.writer.record_writer import RecordFileWriter
from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve, meta_data
from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve, roc_curve, meta_data
class DummyFileWriter(object):
......@@ -380,6 +380,48 @@ class LogWriter(object):
num_thresholds=num_thresholds,
weights=weights
))
def add_roc_curve(self,
tag,
labels,
predictions,
step,
num_thresholds=10,
weights=None,
walltime=None):
"""Add an ROC curve to vdl record file.
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element
be classified as true.
step (int): Step of pr curve.
weights (float): Multiple of data to display on the curve.
num_thresholds (int): Number of thresholds used to draw the curve.
walltime (int): Wall time of pr curve.
Example:
with LogWriter(logdir="./log/roc_curve_test/train") as writer:
for index in range(3):
labels = np.random.randint(2, size=100)
predictions = np.random.rand(100)
writer.add_roc_curve(tag='default',
labels=labels,
predictions=predictions,
step=index)
"""
if '%' in tag:
raise RuntimeError("% can't appear in tag!")
walltime = round(time.time() * 1000) if walltime is None else walltime
self._get_file_writer().add_record(
roc_curve(
tag=tag,
labels=labels,
predictions=predictions,
step=step,
walltime=walltime,
num_thresholds=num_thresholds,
weights=weights
))
def flush(self):
"""Flush all data in cache to disk.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册