paddle.es6 2.4 KB
Newer Older
W
wangqun 已提交
1 2 3 4 5
/* eslint-disable */
import 'babel-polyfill';
import Loader from '../loader/loader';
import Graph from '../graph/graph';
/**
W
wangqun 已提交
6
 * @file GraphModel,绘制生成model网络
W
wangqun 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
 * @author wangqun@baidu.com
 */

export default class Paddle {
    constructor(options) {
        this.version  = '0.0.1';
        this.loader = '';
        this.options = options;
        this.graph = '';
        this.multipart = false;
        // feed数据
        this.feed = null;
        this.index = 0;
        this.feedOp = null;
        this.feedItem = null;
        this.test = false;
        this.isExecuted = false;
        // 网络层数
        this.iLayer = 0;
        // fetch xhr jsonp
        this.params = {type: 'fetch'};
    }

    async load() {
W
wangqun 已提交
31

W
wangqun 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        if (this.options === null) {
            // todo saniac 报错提示修改
            throw new Error(
                'modelGonfig in loadGraphModel() cannot be null. Please provide a url ' +
                'or an IOHandler that loads the model');
        }

        const model = new Loader(this.options.urlConf, this.options.options);
        await model.load();
        this.preGraph(model);
        return this;

    }
    preGraph (artifacts) {
        let that = this;
        const graph = new Graph(that.options);
        that.graph = graph;
        that.graph.data = artifacts.data;
W
wangqun 已提交
50
        that.graph.formatWeight(that.graph.data.vars);
51
        const opsMap = that.graph.createOpsMap(that.graph.data.ops);
W
wangqun 已提交
52 53 54
        const opsMap1 = that.graph.constructOpsMap(opsMap);
        const opsMap2 = that.graph.arrangeMap(opsMap1);
        that.graph.weightMap = opsMap2;
W
wangqun 已提交
55 56 57 58 59 60 61 62 63 64 65 66
    }
    /**
     * Executes inference for the model for given input tensors.
     * @param inputs
     * @param outputs
     * @returns {*}
     */
    execute(inputs) {
        let that = this;
        this.feed = this.graph.feed = inputs;
        // 生成op数据
        if (!this.graph.isExecuted) {
67
            this.graph.weightMap.forEach((op, index) => {
W
wangqun 已提交
68 69
                const type = op.type;
                if (type !== 'feed' && type !== 'fetch') {
70

W
wangqun 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
                    that.graph.buildOpData(op);
                }
            });
        }
        this.graph.execute(inputs);
        return this.graph.inst;
    }
    updateFeed() {
        this.graph.feedItem.data = this.graph.feed.input[0].data;
    }
    /**
     * dispose
     */
    dispose() {
        this.graph.dispose();
    }
}
/* eslint-enable */