diff --git a/frontend/components/LineChart.tsx b/frontend/components/LineChart.tsx index 0540c0235c24af79f7b2679d03a47704b8d87431..3f4436dc0c4cffedaab7b341162a76b7dbcc8fd7 100644 --- a/frontend/components/LineChart.tsx +++ b/frontend/components/LineChart.tsx @@ -29,7 +29,9 @@ const LineChart: FunctionComponent = ({ loading, className }) => { - const [ref, echart] = useECharts(!!loading); + const [ref, echart] = useECharts({ + loading: !!loading + }); const xAxisFormatter = useCallback( (value: number) => (type === 'time' ? new Date(value).toLocaleTimeString() : value), diff --git a/frontend/components/ScatterChart.tsx b/frontend/components/ScatterChart.tsx index 94f802a5271cb4d78f0904a39c7ee01d1a84af92..3cf4c0603a2756256ba3bca95f89748e73834985 100644 --- a/frontend/components/ScatterChart.tsx +++ b/frontend/components/ScatterChart.tsx @@ -1,77 +1,144 @@ -import React, {FunctionComponent, useEffect} from 'react'; +import React, {FunctionComponent, useEffect, useMemo} from 'react'; import {EChartOption} from 'echarts'; import {WithStyled, primaryColor} from '~/utils/style'; import useECharts from '~/hooks/useECharts'; import {Dimension} from '~/types'; +const SYMBOL_SIZE = 12; + +type Point = { + name: string; + value: [number, number] | [number, number, number]; +}; + type ScatterChartProps = { - data?: ([number, number] | [number, number, number])[]; - labels?: string[]; - loading?: boolean; - dimension?: Dimension; + keyword: string; + loading: boolean; + points: Point[]; + dimension: Dimension; +}; + +const assemble2D = (points: {highlighted: Point[]; others: Point[]}, label: EChartOption.SeriesBar['label']) => { + // eslint-disable-next-line + const createSeries = (name: string, data: Point[], patch?: {[k: string]: any}) => ({ + name, + symbolSize: SYMBOL_SIZE, + data, + type: 'scatter', + label, + ...(patch || {}) + }); + + return { + xAxis: {}, + yAxis: {}, + toolbox: { + show: true, + showTitle: false, + itemSize: 0, + + feature: { + dataZoom: {}, + restore: {}, + saveAsImage: {} + } + }, + series: [ + createSeries('highlighted', points.highlighted), + createSeries('others', points.others, {color: primaryColor}) + ] + }; +}; + +const assemble3D = (points: {highlighted: Point[]; others: Point[]}, label: EChartOption.SeriesBar['label']) => { + // eslint-disable-next-line + const createSeries = (name: string, data: Point[], patch?: {[k: string]: any}) => ({ + name, + symbolSize: SYMBOL_SIZE, + data, + type: 'scatter3D', + label, + ...(patch || {}) + }); + + return { + grid3D: {}, + xAxis3D: {}, + yAxis3D: {}, + zAxis3D: {}, + series: [ + createSeries('highlighted', points.highlighted), + createSeries('others', points.others, {color: primaryColor}) + ] + }; +}; + +const getChartOptions = ( + settings: Pick & {points: {highlighted: Point[]; others: Point[]}} +) => { + const {dimension, points} = settings; + const label = { + show: true, + position: 'top', + formatter: (params: {data: {name: string; showing: boolean}}) => (params.data.showing ? params.data.name : '') + }; + + const assemble = dimension === '2d' ? assemble2D : assemble3D; + return assemble(points, label); +}; + +const dividePoints = (points: Point[], keyword: string): [Point[], Point[]] => { + if (!keyword) { + return [[], points]; + } + + const matched: Point[] = []; + const missing: Point[] = []; + points.forEach(point => { + if (point.name.includes(keyword)) { + matched.push(point); + return; + } + missing.push(point); + }); + + return [matched, missing]; }; const ScatterChart: FunctionComponent = ({ - data, - labels, + points, + keyword, loading, dimension, className }) => { - const [ref, echart] = useECharts(!!loading); + const [ref, echart] = useECharts({ + loading, + gl: true + }); + const [highlighted, others] = useMemo(() => dividePoints(points, keyword), [points, keyword]); + const chartOptions = useMemo( + () => + getChartOptions({ + dimension, + points: { + highlighted, + others + } + }), + [dimension, highlighted, others] + ); useEffect(() => { - if (process.browser) { - (async () => { - const is3D = dimension === '3d'; - if (is3D) { - await import('echarts-gl'); - } - echart.current?.setOption( - { - ...(is3D - ? { - yAxis3D: {}, - xAxis3D: {}, - zAxis3D: {}, - grid3D: {} - } - : { - xAxis: {}, - yAxis: {} - }), - series: [ - { - data, - label: { - show: true, - position: 'top', - formatter: ( - params: EChartOption.Tooltip.Format | EChartOption.Tooltip.Format[] - ) => { - if (!labels) { - return ''; - } - const {dataIndex: index} = Array.isArray(params) ? params[0] : params; - if (index == null) { - return ''; - } - return labels[index] ?? ''; - } - }, - symbolSize: 12, - itemStyle: { - color: primaryColor - }, - type: is3D ? 'scatter3D' : 'scatter' - } - ] - }, - {notMerge: true} - ); - })(); + if (!process.browser) { + return; } - }, [data, labels, dimension, echart]); + + echart.current?.setOption( + chartOptions, + true // not merged + ); + }, [chartOptions, echart]); return
; }; diff --git a/frontend/hooks/useECharts.ts b/frontend/hooks/useECharts.ts index 62ff29b8c30aa2ae9b73394ea17b1a5f99aee4b0..b68b3080d1a3fe4e2bbdd88c34766280a8d9bad3 100644 --- a/frontend/hooks/useECharts.ts +++ b/frontend/hooks/useECharts.ts @@ -2,16 +2,20 @@ import {useRef, useEffect, useCallback, MutableRefObject} from 'react'; import echarts, {ECharts} from 'echarts'; import {useTranslation} from 'react-i18next'; -const useECharts = ( - loading: boolean -): [MutableRefObject, MutableRefObject] => { +const useECharts = (options: { + loading: boolean; + gl?: boolean; +}): [MutableRefObject, MutableRefObject] => { const {t} = useTranslation('common'); const ref = useRef(null); const echart = useRef(null as ECharts | null); const createChart = useCallback(() => { - echart.current = echarts.init((ref.current as unknown) as HTMLDivElement); - }, []); + const loadExtension = options.gl ? import('echarts-gl') : Promise.resolve(); + loadExtension.then(() => { + echart.current = echarts.init((ref.current as unknown) as HTMLDivElement); + }); + }, [options.gl]); const destroyChart = useCallback(() => { echart.current?.dispose(); @@ -26,7 +30,7 @@ const useECharts = ( useEffect(() => { if (process.browser) { - if (loading) { + if (options.loading) { echart.current?.showLoading('default', { text: t('loading'), color: '#c23531', @@ -38,7 +42,7 @@ const useECharts = ( echart.current?.hideLoading(); } } - }, [t, loading]); + }, [t, options.loading]); return [ref, echart]; }; diff --git a/frontend/pages/high-dimensional.tsx b/frontend/pages/high-dimensional.tsx index 7179a5be15b9cd846e96b1bcd01e20efdda572b9..3c0b5f10003f058c2413f6c1fa530bbf91344317 100644 --- a/frontend/pages/high-dimensional.tsx +++ b/frontend/pages/high-dimensional.tsx @@ -1,4 +1,4 @@ -import React, {useState, useEffect} from 'react'; +import React, {useState, useEffect, useMemo} from 'react'; import styled from 'styled-components'; import useSWR from 'swr'; import {NextPage} from 'next'; @@ -58,6 +58,7 @@ const HighDimensional: NextPage = () => { const [dimension, setDimension] = useState(dimensions[0] as Dimension); const [reduction, setReduction] = useState(reductions[0]); const [running, setRunning] = useState(true); + const [labelVisibility, setLabelVisibility] = useState(true); const {data, error} = useSWR( `/embeddings/embeddings?run=${encodeURIComponent(run ?? '')}&dimension=${Number.parseInt( @@ -67,6 +68,21 @@ const HighDimensional: NextPage = () => { refreshInterval: running ? 15 * 1000 : 0 } ); + const points = useMemo(() => { + if (!data) { + return []; + } + + const {embedding, labels} = data; + return embedding.map((value, i) => { + const name = labels[i] || ''; + return { + name, + showing: labelVisibility, + value + }; + }); + }, [data, labelVisibility]); const aside = (
@@ -81,7 +97,9 @@ const HighDimensional: NextPage = () => { - {t('display-all-label')} + + {t('display-all-label')} + @@ -119,12 +137,7 @@ const HighDimensional: NextPage = () => { <> {t('common:high-dimensional')} - + );