未验证 提交 915f58de 编写于 作者: P Peter Pan 提交者: GitHub

chore: update pca algorithm (#870)

上级 f501709b
......@@ -51,6 +51,7 @@
"mime-types": "2.1.27",
"moment": "2.29.1",
"nprogress": "0.2.0",
"numeric": "1.2.6",
"polished": "4.0.5",
"query-string": "6.13.7",
"react": "17.0.1",
......@@ -94,6 +95,7 @@
"@types/lodash": "4.14.165",
"@types/mime-types": "2.1.0",
"@types/nprogress": "0.2.0",
"@types/numeric": "1.2.1",
"@types/react": "17.0.0",
"@types/react-dom": "17.0.0",
"@types/react-helmet": "6.1.0",
......
......@@ -14,7 +14,7 @@
* limitations under the License.
*/
import type {PCAResult, Reduction, TSNEResult, UMAPResult} from '~/resource/high-dimensional';
import type {CalculateParams, CalculateResult, Reduction} from '~/resource/high-dimensional';
import React, {useCallback, useEffect, useImperativeHandle, useLayoutEffect, useMemo, useRef, useState} from 'react';
import ScatterChart, {ScatterChartRef} from '~/components/ScatterChart';
......@@ -24,76 +24,8 @@ import type {WithStyled} from '~/utils/style';
import {rem} from '~/utils/style';
import styled from 'styled-components';
import {useTranslation} from 'react-i18next';
import useWebAssembly from '~/hooks/useWebAssembly';
import useWorker from '~/hooks/useWorker';
function useComputeHighDimensional(
reduction: Reduction,
vectors: Float32Array,
dim: number,
is3D: boolean,
perplexity: number,
learningRate: number,
neighbors: number
) {
const pcaParams = useMemo(() => {
if (reduction === 'pca') {
return [Array.from(vectors), dim, 3] as const;
}
return [[], 0, 3];
}, [reduction, vectors, dim]);
const tsneInitParams = useRef({perplexity, epsilon: learningRate});
const tsneParams = useMemo(() => {
if (reduction === 'tsne') {
return {
input: vectors,
dim,
n: is3D ? 3 : 2,
...tsneInitParams.current
};
}
return {
input: new Float32Array(),
dim: 0,
n: is3D ? 3 : 2,
perplexity: 5
};
}, [reduction, vectors, dim, is3D]);
const umapParams = useMemo(() => {
if (reduction === 'umap') {
return {
input: vectors,
dim,
n: is3D ? 3 : 2,
neighbors
};
}
return {
input: new Float32Array(),
dim: 0,
n: is3D ? 3 : 2,
neighbors: 15
};
}, [reduction, vectors, dim, is3D, neighbors]);
const pcaResult = useWebAssembly<PCAResult>('high_dimensional_pca', pcaParams);
const tsneResult = useWorker<TSNEResult>('high-dimensional/tsne', tsneParams);
const umapResult = useWorker<UMAPResult>('high-dimensional/umap', umapParams);
if (reduction === 'pca') {
return pcaResult;
}
if (reduction === 'tsne') {
return tsneResult;
}
if (reduction === 'umap') {
return umapResult;
}
return null as never;
}
const Wrapper = styled.div`
height: 100%;
display: flex;
......@@ -143,7 +75,7 @@ type HighDimensionalChartProps = {
neighbors: number;
highlightIndices?: number[];
onCalculate?: () => unknown;
onCalculated?: (data: PCAResult | TSNEResult | UMAPResult) => unknown;
onCalculated?: (data: CalculateResult) => unknown;
onError?: (e: Error) => unknown;
};
......@@ -196,15 +128,42 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
}
}, []);
const {data, error, worker} = useComputeHighDimensional(
reduction,
vectors,
dim,
is3D,
perplexity,
learningRate,
neighbors
);
const params = useMemo<CalculateParams>(() => {
const result = {
input: vectors,
dim,
n: is3D ? 3 : 2
};
switch (reduction) {
case 'pca':
return {
reduction,
params: {
...result
}
};
case 'tsne':
return {
reduction,
params: {
perplexity,
epsilon: learningRate,
...result
}
};
case 'umap':
return {
reduction,
params: {
neighbors,
...result
}
};
default:
return null as never;
}
}, [dim, is3D, learningRate, neighbors, perplexity, reduction, vectors]);
const {data, error, worker} = useWorker<CalculateResult>('high-dimensional/calculate', params);
const iterationId = useRef<number | null>(null);
const iteration = useCallback(() => {
......
......@@ -31,13 +31,13 @@ const Wrapper = styled(Field)`
export type PCADetailProps = {
dimension: Dimension;
variance: number[];
totalVariance: number;
};
const PCADetail: FunctionComponent<PCADetailProps> = ({dimension, variance}) => {
const PCADetail: FunctionComponent<PCADetailProps> = ({dimension, variance, totalVariance}) => {
const {t} = useTranslation(['high-dimensional', 'common']);
const dim = useMemo(() => (dimension === '3d' ? 3 : 2), [dimension]);
const totalVariance = useMemo(() => variance.reduce((s, c) => s + c, 0), [variance]);
return (
<Wrapper>
......
......@@ -53,6 +53,18 @@ import useWorker from '~/hooks/useWorker';
const MODE = import.meta.env.MODE;
const MAX_COUNT: Record<Reduction, number | undefined> = {
pca: 50000,
tsne: 10000,
umap: 5000
} as const;
const MAX_DIMENSION: Record<Reduction, number | undefined> = {
pca: 200,
tsne: undefined,
umap: undefined
};
const AsideTitle = styled.div`
font-size: ${rem(16)};
line-height: ${rem(16)};
......@@ -182,6 +194,12 @@ const HighDimensional: FunctionComponent = () => {
);
const labelByLabels = useMemo(() => getLabelByLabels(labelBy), [getLabelByLabels, labelBy]);
// dimension of display
const [dimension, setDimension] = useState<Dimension>('3d');
const [reduction, setReduction] = useState<Reduction>('pca');
const is3D = useMemo(() => dimension === '3d', [dimension]);
const readFile = useCallback(
(phase: string, file: File | null, setter: React.Dispatch<React.SetStateAction<string>>) => {
if (file) {
......@@ -221,12 +239,17 @@ const HighDimensional: FunctionComponent = () => {
}, []);
const params = useMemo<ParseParams>(() => {
const maxValues = {
maxCount: MAX_COUNT[reduction],
maxDimension: MAX_DIMENSION[reduction]
};
if (vectorContent) {
return {
from: 'string',
params: {
vectors: vectorContent,
metadata: metadataContent
metadata: metadataContent,
...maxValues
}
};
}
......@@ -236,12 +259,13 @@ const HighDimensional: FunctionComponent = () => {
params: {
shape: selectedEmbedding.shape,
vectors: tensorData.data,
metadata: metadataData ?? ''
metadata: metadataData ?? '',
...maxValues
}
};
}
return null;
}, [vectorContent, metadataContent, selectedEmbedding, tensorData, metadataData]);
}, [reduction, vectorContent, selectedEmbedding, tensorData, metadataContent, metadataData]);
const result = useWorker<ParseResult, ParseParams>('high-dimensional/parse-data', params);
useEffect(() => {
const {error, data} = result;
......@@ -264,12 +288,6 @@ const HighDimensional: FunctionComponent = () => {
selectedEmbedding
]);
// dimension of display
const [dimension, setDimension] = useState<Dimension>('3d');
const [reduction, setReduction] = useState<Reduction>('pca');
const is3D = useMemo(() => dimension === '3d', [dimension]);
const [perplexity, setPerplexity] = useState(5);
const [learningRate, setLearningRate] = useState(10);
......@@ -324,7 +342,13 @@ const HighDimensional: FunctionComponent = () => {
const detail = useMemo(() => {
switch (reduction) {
case 'pca':
return <PCADetail dimension={dimension} variance={(data as PCAResult)?.variance ?? []} />;
return (
<PCADetail
dimension={dimension}
variance={(data as PCAResult)?.variance ?? []}
totalVariance={(data as PCAResult)?.totalVariance ?? 0}
/>
);
case 'tsne':
return (
<TSNEDetail
......
......@@ -15,20 +15,23 @@
*/
export type {
CalculateParams,
CalculateResult,
Dimension,
Reduction,
PcaParams,
PCAResult,
Vectors,
ParseParams,
ParseResult,
PCAParams,
PCAResult,
Reduction,
TSNEParams,
TSNEResult,
UMAPParams,
UMAPResult
UMAPResult,
Vectors
} from './types';
export {parseFromBlob, parseFromString, ParserError} from './parser';
export {default as PCA} from './pca';
export {default as tSNE} from './tsne';
export {default as UMAP} from './umap';
......@@ -45,9 +45,9 @@ export class ParserError extends Error {
}
}
function split<T = string>(str: string, processer?: (item: string) => T): T[][] {
function split<T = string>(str: string, handler?: (item: string) => T): T[][] {
return safeSplit(str, '\n')
.map(r => safeSplit(r, '\t').map(n => (processer ? processer(n) : n) as T))
.map(r => safeSplit(r, '\t').map(n => (handler ? handler(n) : n) as T))
.filter(r => r.length);
}
......@@ -64,13 +64,19 @@ function alignItems<T>(data: T[][], dimension: number, defaultValue: T): T[][] {
});
}
function parseVectors(str: string): VectorResult {
function parseVectors(str: string, maxCount?: number, maxDimension?: number): VectorResult {
if (!str) {
throw new ParserError('Tenser file is empty', ParserError.CODES.TENSER_EMPTY);
}
let vectors = split(str, Number.parseFloat);
// TODO: sampling
const dimension = Math.min(...vectors.map(vector => vector.length));
// TODO: random sampling
if (maxCount) {
vectors = vectors.slice(0, maxCount);
}
let dimension = Math.min(...vectors.map(vector => vector.length));
if (maxDimension) {
dimension = Math.min(dimension, maxDimension);
}
vectors = alignItems(vectors, dimension, 0);
return {
dimension,
......@@ -124,31 +130,51 @@ function genMetadataAndLabels(metadata: string, count: number) {
};
}
export function parseFromString({vectors: v, metadata: m}: ParseFromStringParams): ParseResult {
export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension}: ParseFromStringParams): ParseResult {
const result: ParseResult = {
count: 0,
dimension: 0,
vectors: new Float32Array(),
labels: [],
metadata: []
};
if (v) {
const {dimension, vectors, count} = parseVectors(v);
const {dimension, vectors, count} = parseVectors(v, maxCount, maxDimension);
result.dimension = dimension;
result.vectors = vectors;
Object.assign(result, genMetadataAndLabels(m, count));
const metadataAndLabels = genMetadataAndLabels(m, count);
metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count);
Object.assign(result, metadataAndLabels);
}
return result;
}
export async function parseFromBlob({shape, vectors: v, metadata: m}: ParseFromBlobParams): Promise<ParseResult> {
const [count, dimension] = shape;
const vectors = new Float32Array(await v.arrayBuffer());
if (count * dimension !== vectors.length) {
export async function parseFromBlob({
shape,
vectors: v,
metadata: m,
maxCount,
maxDimension
}: ParseFromBlobParams): Promise<ParseResult> {
// TODO: random sampling
const [originalCount, originalDimension] = shape;
const originalVectors = new Float32Array(await v.arrayBuffer());
if (originalCount * originalDimension !== originalVectors.length) {
throw new ParserError('Size of tensor does not match.', ParserError.CODES.SHAPE_MISMATCH);
}
const count = maxCount ? Math.min(originalCount, maxCount) : originalCount;
const dimension = maxDimension ? Math.min(originalDimension, maxDimension) : originalDimension;
const vectors = new Float32Array(count * dimension);
for (let c = 0; c < count; c++) {
const offset = c * originalDimension;
vectors.set(originalVectors.subarray(offset, offset + dimension), c * dimension);
}
const metadataAndLabels = genMetadataAndLabels(m, originalCount);
metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count);
return {
count,
dimension,
vectors,
...genMetadataAndLabels(m, count)
...metadataAndLabels
};
}
......@@ -14,23 +14,42 @@
* limitations under the License.
*/
import type {UMAPParams, UMAPResult} from '~/resource/high-dimensional';
import numeric from 'numeric';
import {UMAP} from '~/resource/high-dimensional';
import {WorkerSelf} from '~/worker';
const workerSelf = new WorkerSelf();
workerSelf.emit('INITIALIZED');
workerSelf.on<UMAPParams>('RUN', data => {
const result = UMAP(data.n, data.neighbors, data.input, data.dim);
if (result) {
workerSelf.emit<UMAPResult>('RESULT', {
vectors: result.embedding as [number, number, number][],
epoch: result.nEpochs,
nEpochs: result.nEpochs
});
export default (input: Float32Array, dim: number, nComponents: number) => {
const n = input.length / dim;
const vectors: number[][] = [];
for (let i = 0; i < n; i++) {
vectors.push(Array.from(input.subarray(i * dim, i * dim + dim)));
}
const {dot, transpose, svd: numericSvd, div} = numeric;
const scalar = dot(transpose(vectors), vectors);
const sigma = div(scalar as number[][], n);
const svd = numericSvd(sigma);
const variances: number[] = svd.S;
let totalVariance = 0;
for (let i = 0; i < variances.length; i++) {
totalVariance += variances[i];
}
for (let i = 0; i < variances.length; i++) {
variances[i] /= totalVariance;
}
});
workerSelf.on('INFO', () => {
workerSelf.emit('INITIALIZED');
});
const U: number[][] = svd.U;
const pcaVectors = vectors.map(vector => {
const newV = new Float32Array(nComponents);
for (let newDim = 0; newDim < nComponents; newDim++) {
let dot = 0;
for (let oldDim = 0; oldDim < vector.length; oldDim++) {
dot += vector[oldDim] * U[oldDim][newDim];
}
newV[newDim] = dot;
}
return Array.from(newV);
});
const variance = variances.slice(0, nComponents);
return {
vectors: pcaVectors,
variance,
totalVariance: variance.reduce((s, c) => s + c, 0)
};
};
......@@ -31,16 +31,20 @@ export type MetadataResult = {
metadata: string[][];
};
export type ParseFromStringParams = {
vectors: string;
interface BaseParseParams {
metadata: string;
};
maxCount?: number;
maxDimension?: number;
}
export interface ParseFromStringParams extends BaseParseParams {
vectors: string;
}
export type ParseFromBlobParams = {
export interface ParseFromBlobParams extends BaseParseParams {
shape: [number, number];
vectors: Blob;
metadata: string;
};
}
export type ParseParams =
| {
......@@ -54,14 +58,15 @@ export type ParseParams =
| null;
export type ParseResult = {
count: number;
dimension: number;
vectors: Float32Array;
labels: string[];
metadata: string[][];
};
export type PcaParams = {
input: number[];
export type PCAParams = {
input: Float32Array;
dim: number;
n: number;
};
......@@ -69,6 +74,7 @@ export type PcaParams = {
export type PCAResult = {
vectors: Vectors;
variance: number[];
totalVariance: number;
};
export type TSNEParams = {
......@@ -96,3 +102,19 @@ export type UMAPResult = {
epoch: number;
nEpochs: number;
};
export type CalculateParams =
| {
reduction: 'pca';
params: PCAParams;
}
| {
reduction: 'tsne';
params: TSNEParams;
}
| {
reduction: 'umap';
params: UMAPParams;
};
export type CalculateResult = PCAResult | TSNEResult | UMAPResult;
......@@ -14,10 +14,19 @@
* limitations under the License.
*/
import type {TSNEParams, TSNEResult} from '~/resource/high-dimensional';
import type {
CalculateParams,
PCAParams,
PCAResult,
TSNEParams,
TSNEResult,
UMAPParams,
UMAPResult,
Vectors
} from '~/resource/high-dimensional';
import {PCA, UMAP, tSNE} from '~/resource/high-dimensional';
import {WorkerSelf} from '~/worker';
import {tSNE} from '~/resource/high-dimensional';
import type {tSNEOptions} from '~/resource/high-dimensional/tsne';
type InfoStepData = {
......@@ -34,7 +43,29 @@ export type InfoData = InfoStepData | InfoResetData | InfoParamsData;
const workerSelf = new WorkerSelf();
workerSelf.emit('INITIALIZED');
workerSelf.on<TSNEParams>('RUN', data => {
workerSelf.on<CalculateParams>('RUN', ({reduction, params}) => {
switch (reduction) {
case 'pca':
return pca(params as PCAParams);
case 'tsne':
return tsne(params as TSNEParams);
case 'umap':
return umap(params as UMAPParams);
default:
return null as never;
}
});
function pca(data: PCAParams) {
const {vectors, variance, totalVariance} = PCA(data.input, data.dim, data.n);
workerSelf.emit<PCAResult>('RESULT', {
vectors: vectors as Vectors,
variance,
totalVariance
});
}
function tsne(data: TSNEParams) {
const t_sne = new tSNE({
dimension: data.n,
perplexity: data.perplexity,
......@@ -43,7 +74,7 @@ workerSelf.on<TSNEParams>('RUN', data => {
const reset = () => {
t_sne.setData(data.input, data.dim);
return workerSelf.emit<TSNEResult>('RESULT', {
workerSelf.emit<TSNEResult>('RESULT', {
vectors: t_sne.solution as [number, number, number][],
step: t_sne.step
});
......@@ -78,4 +109,18 @@ workerSelf.on<TSNEParams>('RUN', data => {
return null as never;
}
});
});
}
function umap(data: UMAPParams) {
const result = UMAP(data.n, data.neighbors, data.input, data.dim);
if (result) {
workerSelf.emit<UMAPResult>('RESULT', {
vectors: result.embedding as [number, number, number][],
epoch: result.nEpochs,
nEpochs: result.nEpochs
});
}
workerSelf.on('INFO', () => {
workerSelf.emit('INITIALIZED');
});
}
......@@ -3325,6 +3325,11 @@
resolved "https://registry.yarnpkg.com/@types/nprogress/-/nprogress-0.2.0.tgz#86c593682d4199212a0509cc3c4d562bbbd6e45f"
integrity sha512-1cYJrqq9GezNFPsWTZpFut/d4CjpZqA0vhqDUPFWYKF1oIyBz5qnoYMzR+0C/T96t3ebLAC1SSnwrVOm5/j74A==
"@types/numeric@1.2.1":
version "1.2.1"
resolved "https://registry.yarnpkg.com/@types/numeric/-/numeric-1.2.1.tgz#6bce5d0c4f1b20f2cbd4a3d47922b8fe6e36ad56"
integrity sha512-30gQPisgZW5+ErkDVTZkoVKmwIWdjf2O6HmgKr3E1FJBdMYFldOPSJlQYP2VMafHuhOKvbLFA4Hf+ohvArz1+w==
"@types/parse-json@^4.0.0":
version "4.0.0"
resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.0.tgz#2f8bb441434d163b35fb8ffdccd7138927ffb8c0"
......@@ -10601,6 +10606,11 @@ number-is-nan@^1.0.0:
resolved "https://registry.yarnpkg.com/number-is-nan/-/number-is-nan-1.0.1.tgz#097b602b53422a522c1afb8790318336941a011d"
integrity sha1-CXtgK1NCKlIsGvuHkDGDNpQaAR0=
numeric@1.2.6:
version "1.2.6"
resolved "https://registry.yarnpkg.com/numeric/-/numeric-1.2.6.tgz#765b02bef97988fcf880d4eb3f36b80fa31335aa"
integrity sha1-dlsCvvl5iPz4gNTrPza4D6MTNao=
nwsapi@^2.2.0:
version "2.2.0"
resolved "https://registry.yarnpkg.com/nwsapi/-/nwsapi-2.2.0.tgz#204879a9e3d068ff2a55139c2c772780681a38b7"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册