dataFeed.es6 1013 字节
Newer Older
Y
yangmingming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/**
 * @file 直接数据输入
 * @author hantianjiao@baidu.com
 */

export default class dataFeed {
    toFloat32Array(data) {
        for (let i = 0; i < data.length; i++) {
            this.f32Arr[i] = data[i];
        }
    }

    getLengthFromShape(shape) {
        return shape.reduce((a, b) => a * b);
    }

    loadData() {
        return fetch(this.dataPath).then(res => res.json());
    }

    getOutput() {
        return this.loadData().then(data => {
            this.toFloat32Array(data);
            return [{
                data: this.f32Arr,
                shape: this.shape,
                name: 'x'
            }];
        });
    }

    async process(input) {
        this.len = this.getLengthFromShape(input.shape);
        if (!this.f32Arr || this.len > this.f32Arr.length) {
            this.f32Arr = new Float32Array(this.len);
        }
        this.shape = input.shape;
        this.dataPath = input.input;
        let output = await this.getOutput();
        return output;
    }
}