diff --git a/frontend/packages/core/package.json b/frontend/packages/core/package.json index e539a99d7ae5db8a57983a89abdd95e95489854f..80be45150bc5417e21fb0924f736a7812b72cc1f 100644 --- a/frontend/packages/core/package.json +++ b/frontend/packages/core/package.json @@ -51,6 +51,7 @@ "mime-types": "2.1.27", "moment": "2.29.1", "nprogress": "0.2.0", + "numeric": "1.2.6", "polished": "4.0.5", "query-string": "6.13.7", "react": "17.0.1", @@ -94,6 +95,7 @@ "@types/lodash": "4.14.165", "@types/mime-types": "2.1.0", "@types/nprogress": "0.2.0", + "@types/numeric": "1.2.1", "@types/react": "17.0.0", "@types/react-dom": "17.0.0", "@types/react-helmet": "6.1.0", diff --git a/frontend/packages/core/src/components/HighDimensionalPage/HighDimensionalChart.tsx b/frontend/packages/core/src/components/HighDimensionalPage/HighDimensionalChart.tsx index b48a6cdc5627f94ff2378dc103e654bed6d3c3a4..13801336283349e9d3325842f62f4697b7655b35 100644 --- a/frontend/packages/core/src/components/HighDimensionalPage/HighDimensionalChart.tsx +++ b/frontend/packages/core/src/components/HighDimensionalPage/HighDimensionalChart.tsx @@ -14,7 +14,7 @@ * limitations under the License. */ -import type {PCAResult, Reduction, TSNEResult, UMAPResult} from '~/resource/high-dimensional'; +import type {CalculateParams, CalculateResult, Reduction} from '~/resource/high-dimensional'; import React, {useCallback, useEffect, useImperativeHandle, useLayoutEffect, useMemo, useRef, useState} from 'react'; import ScatterChart, {ScatterChartRef} from '~/components/ScatterChart'; @@ -24,76 +24,8 @@ import type {WithStyled} from '~/utils/style'; import {rem} from '~/utils/style'; import styled from 'styled-components'; import {useTranslation} from 'react-i18next'; -import useWebAssembly from '~/hooks/useWebAssembly'; import useWorker from '~/hooks/useWorker'; -function useComputeHighDimensional( - reduction: Reduction, - vectors: Float32Array, - dim: number, - is3D: boolean, - perplexity: number, - learningRate: number, - neighbors: number -) { - const pcaParams = useMemo(() => { - if (reduction === 'pca') { - return [Array.from(vectors), dim, 3] as const; - } - return [[], 0, 3]; - }, [reduction, vectors, dim]); - - const tsneInitParams = useRef({perplexity, epsilon: learningRate}); - const tsneParams = useMemo(() => { - if (reduction === 'tsne') { - return { - input: vectors, - dim, - n: is3D ? 3 : 2, - ...tsneInitParams.current - }; - } - return { - input: new Float32Array(), - dim: 0, - n: is3D ? 3 : 2, - perplexity: 5 - }; - }, [reduction, vectors, dim, is3D]); - - const umapParams = useMemo(() => { - if (reduction === 'umap') { - return { - input: vectors, - dim, - n: is3D ? 3 : 2, - neighbors - }; - } - return { - input: new Float32Array(), - dim: 0, - n: is3D ? 3 : 2, - neighbors: 15 - }; - }, [reduction, vectors, dim, is3D, neighbors]); - - const pcaResult = useWebAssembly('high_dimensional_pca', pcaParams); - const tsneResult = useWorker('high-dimensional/tsne', tsneParams); - const umapResult = useWorker('high-dimensional/umap', umapParams); - - if (reduction === 'pca') { - return pcaResult; - } - if (reduction === 'tsne') { - return tsneResult; - } - if (reduction === 'umap') { - return umapResult; - } - return null as never; -} - const Wrapper = styled.div` height: 100%; display: flex; @@ -143,7 +75,7 @@ type HighDimensionalChartProps = { neighbors: number; highlightIndices?: number[]; onCalculate?: () => unknown; - onCalculated?: (data: PCAResult | TSNEResult | UMAPResult) => unknown; + onCalculated?: (data: CalculateResult) => unknown; onError?: (e: Error) => unknown; }; @@ -196,15 +128,42 @@ const HighDimensionalChart = React.forwardRef(() => { + const result = { + input: vectors, + dim, + n: is3D ? 3 : 2 + }; + switch (reduction) { + case 'pca': + return { + reduction, + params: { + ...result + } + }; + case 'tsne': + return { + reduction, + params: { + perplexity, + epsilon: learningRate, + ...result + } + }; + case 'umap': + return { + reduction, + params: { + neighbors, + ...result + } + }; + default: + return null as never; + } + }, [dim, is3D, learningRate, neighbors, perplexity, reduction, vectors]); + const {data, error, worker} = useWorker('high-dimensional/calculate', params); const iterationId = useRef(null); const iteration = useCallback(() => { diff --git a/frontend/packages/core/src/components/HighDimensionalPage/PCADetail.tsx b/frontend/packages/core/src/components/HighDimensionalPage/PCADetail.tsx index e1e9ed0056356c5b82aba0ffc6e6169bba54fd8f..bacc00637eeeefa7337b85eb5cf64e51479bcf08 100644 --- a/frontend/packages/core/src/components/HighDimensionalPage/PCADetail.tsx +++ b/frontend/packages/core/src/components/HighDimensionalPage/PCADetail.tsx @@ -31,13 +31,13 @@ const Wrapper = styled(Field)` export type PCADetailProps = { dimension: Dimension; variance: number[]; + totalVariance: number; }; -const PCADetail: FunctionComponent = ({dimension, variance}) => { +const PCADetail: FunctionComponent = ({dimension, variance, totalVariance}) => { const {t} = useTranslation(['high-dimensional', 'common']); const dim = useMemo(() => (dimension === '3d' ? 3 : 2), [dimension]); - const totalVariance = useMemo(() => variance.reduce((s, c) => s + c, 0), [variance]); return ( diff --git a/frontend/packages/core/src/pages/high-dimensional.tsx b/frontend/packages/core/src/pages/high-dimensional.tsx index 36e76aa7a91e06c95f4ed484ab6dc57baadd7fab..5705d49d8529db74b61345c64c593fb69442345a 100644 --- a/frontend/packages/core/src/pages/high-dimensional.tsx +++ b/frontend/packages/core/src/pages/high-dimensional.tsx @@ -53,6 +53,18 @@ import useWorker from '~/hooks/useWorker'; const MODE = import.meta.env.MODE; +const MAX_COUNT: Record = { + pca: 50000, + tsne: 10000, + umap: 5000 +} as const; + +const MAX_DIMENSION: Record = { + pca: 200, + tsne: undefined, + umap: undefined +}; + const AsideTitle = styled.div` font-size: ${rem(16)}; line-height: ${rem(16)}; @@ -182,6 +194,12 @@ const HighDimensional: FunctionComponent = () => { ); const labelByLabels = useMemo(() => getLabelByLabels(labelBy), [getLabelByLabels, labelBy]); + // dimension of display + const [dimension, setDimension] = useState('3d'); + const [reduction, setReduction] = useState('pca'); + + const is3D = useMemo(() => dimension === '3d', [dimension]); + const readFile = useCallback( (phase: string, file: File | null, setter: React.Dispatch>) => { if (file) { @@ -221,12 +239,17 @@ const HighDimensional: FunctionComponent = () => { }, []); const params = useMemo(() => { + const maxValues = { + maxCount: MAX_COUNT[reduction], + maxDimension: MAX_DIMENSION[reduction] + }; if (vectorContent) { return { from: 'string', params: { vectors: vectorContent, - metadata: metadataContent + metadata: metadataContent, + ...maxValues } }; } @@ -236,12 +259,13 @@ const HighDimensional: FunctionComponent = () => { params: { shape: selectedEmbedding.shape, vectors: tensorData.data, - metadata: metadataData ?? '' + metadata: metadataData ?? '', + ...maxValues } }; } return null; - }, [vectorContent, metadataContent, selectedEmbedding, tensorData, metadataData]); + }, [reduction, vectorContent, selectedEmbedding, tensorData, metadataContent, metadataData]); const result = useWorker('high-dimensional/parse-data', params); useEffect(() => { const {error, data} = result; @@ -264,12 +288,6 @@ const HighDimensional: FunctionComponent = () => { selectedEmbedding ]); - // dimension of display - const [dimension, setDimension] = useState('3d'); - const [reduction, setReduction] = useState('pca'); - - const is3D = useMemo(() => dimension === '3d', [dimension]); - const [perplexity, setPerplexity] = useState(5); const [learningRate, setLearningRate] = useState(10); @@ -324,7 +342,13 @@ const HighDimensional: FunctionComponent = () => { const detail = useMemo(() => { switch (reduction) { case 'pca': - return ; + return ( + + ); case 'tsne': return ( (str: string, processer?: (item: string) => T): T[][] { +function split(str: string, handler?: (item: string) => T): T[][] { return safeSplit(str, '\n') - .map(r => safeSplit(r, '\t').map(n => (processer ? processer(n) : n) as T)) + .map(r => safeSplit(r, '\t').map(n => (handler ? handler(n) : n) as T)) .filter(r => r.length); } @@ -64,13 +64,19 @@ function alignItems(data: T[][], dimension: number, defaultValue: T): T[][] { }); } -function parseVectors(str: string): VectorResult { +function parseVectors(str: string, maxCount?: number, maxDimension?: number): VectorResult { if (!str) { throw new ParserError('Tenser file is empty', ParserError.CODES.TENSER_EMPTY); } let vectors = split(str, Number.parseFloat); - // TODO: sampling - const dimension = Math.min(...vectors.map(vector => vector.length)); + // TODO: random sampling + if (maxCount) { + vectors = vectors.slice(0, maxCount); + } + let dimension = Math.min(...vectors.map(vector => vector.length)); + if (maxDimension) { + dimension = Math.min(dimension, maxDimension); + } vectors = alignItems(vectors, dimension, 0); return { dimension, @@ -124,31 +130,51 @@ function genMetadataAndLabels(metadata: string, count: number) { }; } -export function parseFromString({vectors: v, metadata: m}: ParseFromStringParams): ParseResult { +export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension}: ParseFromStringParams): ParseResult { const result: ParseResult = { + count: 0, dimension: 0, vectors: new Float32Array(), labels: [], metadata: [] }; if (v) { - const {dimension, vectors, count} = parseVectors(v); + const {dimension, vectors, count} = parseVectors(v, maxCount, maxDimension); result.dimension = dimension; result.vectors = vectors; - Object.assign(result, genMetadataAndLabels(m, count)); + const metadataAndLabels = genMetadataAndLabels(m, count); + metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count); + Object.assign(result, metadataAndLabels); } return result; } -export async function parseFromBlob({shape, vectors: v, metadata: m}: ParseFromBlobParams): Promise { - const [count, dimension] = shape; - const vectors = new Float32Array(await v.arrayBuffer()); - if (count * dimension !== vectors.length) { +export async function parseFromBlob({ + shape, + vectors: v, + metadata: m, + maxCount, + maxDimension +}: ParseFromBlobParams): Promise { + // TODO: random sampling + const [originalCount, originalDimension] = shape; + const originalVectors = new Float32Array(await v.arrayBuffer()); + if (originalCount * originalDimension !== originalVectors.length) { throw new ParserError('Size of tensor does not match.', ParserError.CODES.SHAPE_MISMATCH); } + const count = maxCount ? Math.min(originalCount, maxCount) : originalCount; + const dimension = maxDimension ? Math.min(originalDimension, maxDimension) : originalDimension; + const vectors = new Float32Array(count * dimension); + for (let c = 0; c < count; c++) { + const offset = c * originalDimension; + vectors.set(originalVectors.subarray(offset, offset + dimension), c * dimension); + } + const metadataAndLabels = genMetadataAndLabels(m, originalCount); + metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count); return { + count, dimension, vectors, - ...genMetadataAndLabels(m, count) + ...metadataAndLabels }; } diff --git a/frontend/packages/core/src/resource/high-dimensional/pca.ts b/frontend/packages/core/src/resource/high-dimensional/pca.ts new file mode 100644 index 0000000000000000000000000000000000000000..3fecf95767947d662d3fec9d63cee8c112c8c4c8 --- /dev/null +++ b/frontend/packages/core/src/resource/high-dimensional/pca.ts @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Baidu Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import numeric from 'numeric'; + +export default (input: Float32Array, dim: number, nComponents: number) => { + const n = input.length / dim; + const vectors: number[][] = []; + for (let i = 0; i < n; i++) { + vectors.push(Array.from(input.subarray(i * dim, i * dim + dim))); + } + const {dot, transpose, svd: numericSvd, div} = numeric; + const scalar = dot(transpose(vectors), vectors); + const sigma = div(scalar as number[][], n); + const svd = numericSvd(sigma); + const variances: number[] = svd.S; + let totalVariance = 0; + for (let i = 0; i < variances.length; i++) { + totalVariance += variances[i]; + } + for (let i = 0; i < variances.length; i++) { + variances[i] /= totalVariance; + } + const U: number[][] = svd.U; + const pcaVectors = vectors.map(vector => { + const newV = new Float32Array(nComponents); + for (let newDim = 0; newDim < nComponents; newDim++) { + let dot = 0; + for (let oldDim = 0; oldDim < vector.length; oldDim++) { + dot += vector[oldDim] * U[oldDim][newDim]; + } + newV[newDim] = dot; + } + return Array.from(newV); + }); + const variance = variances.slice(0, nComponents); + return { + vectors: pcaVectors, + variance, + totalVariance: variance.reduce((s, c) => s + c, 0) + }; +}; diff --git a/frontend/packages/core/src/resource/high-dimensional/types.ts b/frontend/packages/core/src/resource/high-dimensional/types.ts index de16e2b586ad5ef1bbcb2c735cfd79024904e801..633da24bb2752a5d1135a6a5112282be933c6fa0 100644 --- a/frontend/packages/core/src/resource/high-dimensional/types.ts +++ b/frontend/packages/core/src/resource/high-dimensional/types.ts @@ -31,16 +31,20 @@ export type MetadataResult = { metadata: string[][]; }; -export type ParseFromStringParams = { - vectors: string; +interface BaseParseParams { metadata: string; -}; + maxCount?: number; + maxDimension?: number; +} + +export interface ParseFromStringParams extends BaseParseParams { + vectors: string; +} -export type ParseFromBlobParams = { +export interface ParseFromBlobParams extends BaseParseParams { shape: [number, number]; vectors: Blob; - metadata: string; -}; +} export type ParseParams = | { @@ -54,14 +58,15 @@ export type ParseParams = | null; export type ParseResult = { + count: number; dimension: number; vectors: Float32Array; labels: string[]; metadata: string[][]; }; -export type PcaParams = { - input: number[]; +export type PCAParams = { + input: Float32Array; dim: number; n: number; }; @@ -69,6 +74,7 @@ export type PcaParams = { export type PCAResult = { vectors: Vectors; variance: number[]; + totalVariance: number; }; export type TSNEParams = { @@ -96,3 +102,19 @@ export type UMAPResult = { epoch: number; nEpochs: number; }; + +export type CalculateParams = + | { + reduction: 'pca'; + params: PCAParams; + } + | { + reduction: 'tsne'; + params: TSNEParams; + } + | { + reduction: 'umap'; + params: UMAPParams; + }; + +export type CalculateResult = PCAResult | TSNEResult | UMAPResult; diff --git a/frontend/packages/core/src/worker/high-dimensional/tsne.ts b/frontend/packages/core/src/worker/high-dimensional/calculate.ts similarity index 63% rename from frontend/packages/core/src/worker/high-dimensional/tsne.ts rename to frontend/packages/core/src/worker/high-dimensional/calculate.ts index 21474f93bbdf61c80c6edd551a89919c7f21f5b3..e61bf226249e85b35a4e1d9be1fec69a4c76179c 100644 --- a/frontend/packages/core/src/worker/high-dimensional/tsne.ts +++ b/frontend/packages/core/src/worker/high-dimensional/calculate.ts @@ -14,10 +14,19 @@ * limitations under the License. */ -import type {TSNEParams, TSNEResult} from '~/resource/high-dimensional'; +import type { + CalculateParams, + PCAParams, + PCAResult, + TSNEParams, + TSNEResult, + UMAPParams, + UMAPResult, + Vectors +} from '~/resource/high-dimensional'; +import {PCA, UMAP, tSNE} from '~/resource/high-dimensional'; import {WorkerSelf} from '~/worker'; -import {tSNE} from '~/resource/high-dimensional'; import type {tSNEOptions} from '~/resource/high-dimensional/tsne'; type InfoStepData = { @@ -34,7 +43,29 @@ export type InfoData = InfoStepData | InfoResetData | InfoParamsData; const workerSelf = new WorkerSelf(); workerSelf.emit('INITIALIZED'); -workerSelf.on('RUN', data => { +workerSelf.on('RUN', ({reduction, params}) => { + switch (reduction) { + case 'pca': + return pca(params as PCAParams); + case 'tsne': + return tsne(params as TSNEParams); + case 'umap': + return umap(params as UMAPParams); + default: + return null as never; + } +}); + +function pca(data: PCAParams) { + const {vectors, variance, totalVariance} = PCA(data.input, data.dim, data.n); + workerSelf.emit('RESULT', { + vectors: vectors as Vectors, + variance, + totalVariance + }); +} + +function tsne(data: TSNEParams) { const t_sne = new tSNE({ dimension: data.n, perplexity: data.perplexity, @@ -43,7 +74,7 @@ workerSelf.on('RUN', data => { const reset = () => { t_sne.setData(data.input, data.dim); - return workerSelf.emit('RESULT', { + workerSelf.emit('RESULT', { vectors: t_sne.solution as [number, number, number][], step: t_sne.step }); @@ -78,4 +109,18 @@ workerSelf.on('RUN', data => { return null as never; } }); -}); +} + +function umap(data: UMAPParams) { + const result = UMAP(data.n, data.neighbors, data.input, data.dim); + if (result) { + workerSelf.emit('RESULT', { + vectors: result.embedding as [number, number, number][], + epoch: result.nEpochs, + nEpochs: result.nEpochs + }); + } + workerSelf.on('INFO', () => { + workerSelf.emit('INITIALIZED'); + }); +} diff --git a/frontend/packages/core/src/worker/high-dimensional/umap.ts b/frontend/packages/core/src/worker/high-dimensional/umap.ts deleted file mode 100644 index 1ef3fa8db098692fac394609ff477b93c5561fb0..0000000000000000000000000000000000000000 --- a/frontend/packages/core/src/worker/high-dimensional/umap.ts +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Baidu Inc. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import type {UMAPParams, UMAPResult} from '~/resource/high-dimensional'; - -import {UMAP} from '~/resource/high-dimensional'; -import {WorkerSelf} from '~/worker'; - -const workerSelf = new WorkerSelf(); -workerSelf.emit('INITIALIZED'); -workerSelf.on('RUN', data => { - const result = UMAP(data.n, data.neighbors, data.input, data.dim); - if (result) { - workerSelf.emit('RESULT', { - vectors: result.embedding as [number, number, number][], - epoch: result.nEpochs, - nEpochs: result.nEpochs - }); - } -}); -workerSelf.on('INFO', () => { - workerSelf.emit('INITIALIZED'); -}); diff --git a/frontend/yarn.lock b/frontend/yarn.lock index a24f2d75708f2e8f1f64c0de90ef7f93665301a6..ef2a62ced780f4d7bfa1f8da6eefd80db433308d 100644 --- a/frontend/yarn.lock +++ b/frontend/yarn.lock @@ -3325,6 +3325,11 @@ resolved "https://registry.yarnpkg.com/@types/nprogress/-/nprogress-0.2.0.tgz#86c593682d4199212a0509cc3c4d562bbbd6e45f" integrity sha512-1cYJrqq9GezNFPsWTZpFut/d4CjpZqA0vhqDUPFWYKF1oIyBz5qnoYMzR+0C/T96t3ebLAC1SSnwrVOm5/j74A== +"@types/numeric@1.2.1": + version "1.2.1" + resolved "https://registry.yarnpkg.com/@types/numeric/-/numeric-1.2.1.tgz#6bce5d0c4f1b20f2cbd4a3d47922b8fe6e36ad56" + integrity sha512-30gQPisgZW5+ErkDVTZkoVKmwIWdjf2O6HmgKr3E1FJBdMYFldOPSJlQYP2VMafHuhOKvbLFA4Hf+ohvArz1+w== + "@types/parse-json@^4.0.0": version "4.0.0" resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.0.tgz#2f8bb441434d163b35fb8ffdccd7138927ffb8c0" @@ -10601,6 +10606,11 @@ number-is-nan@^1.0.0: resolved "https://registry.yarnpkg.com/number-is-nan/-/number-is-nan-1.0.1.tgz#097b602b53422a522c1afb8790318336941a011d" integrity sha1-CXtgK1NCKlIsGvuHkDGDNpQaAR0= +numeric@1.2.6: + version "1.2.6" + resolved "https://registry.yarnpkg.com/numeric/-/numeric-1.2.6.tgz#765b02bef97988fcf880d4eb3f36b80fa31335aa" + integrity sha1-dlsCvvl5iPz4gNTrPza4D6MTNao= + nwsapi@^2.2.0: version "2.2.0" resolved "https://registry.yarnpkg.com/nwsapi/-/nwsapi-2.2.0.tgz#204879a9e3d068ff2a55139c2c772780681a38b7"