executor.es6 3.5 KB
Newer Older
W
wangqun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
/* eslint-disable */
/**
 * @file GraphExecutor,封装可执行单元
 * @author wangqun@baidu.com
 */
// const fileDownload = require('js-file-download');
let start;
export default class GraphExecutor {

    constructor(model) {
        this.inputs = model.inputs;
        this.outputs  = model.outputs;
        this.attrs = model.attrs || model['sub-attrs'];
        this.type = model.type;
        this.finish = false;
        this.next = null;
        this.opData = null;
        this.id = +new Date() + model.type + Math.floor(Math.random() * 10 + 1) + model.idx;
    }

    get inputsName() {

        if (this.type === 'feed') {
            return this.inputs.X;
        }
        else if (this.type === 'batchnorm' || this.type === 'batch_norm') {
            return this.inputs.X;
        }
        else if (this.type === 'conv2d') {
            return this.inputs.Input;
        }
        else if (this.type === 'depthwise_conv2d') {
            return this.inputs.Input;
        }
W
wangqun 已提交
35 36 37
        else if (this.type === 'conv2d_transpose') {
			return this.inputs.Input;
		}
W
wangqun 已提交
38
        else if (this.type === 'elementwise_add') {
W
wangqun 已提交
39
            return this.inputs.X.concat(this.inputs.Y);
W
wangqun 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        }
        else if (this.type === 'relu' || this.type === 'leaky_relu') {
            return this.inputs.X;
        }
        else if (this.type === 'pool2d') {
            return this.inputs.X;
        }
        else if (this.type === 'mul') {
            return this.inputs.X;
        }
        else if (this.type === 'softmax') {
            return this.inputs.X;
        }
        else if (this.type === 'scale') {
            return this.inputs.X;
        }
        else if (this.type === 'fetch') {
            return this.inputs.X;
        }
        return this.inputs.Input || this.inputs.X;
    }

    get outputsName() {
        if (this.type === 'conv2d') {
            return this.outputs.Output;
        }
        else if (this.type === 'depthwise_conv2d') {
            return this.outputs.Output;
        }
        else if (this.type === 'batchnorm' || this.type === 'batch_norm') {
            this.outputs.out = this.outputs.Y;
W
wangqun 已提交
71 72
            delete this.outputs.Y;
            return this.outputs.out;
W
wangqun 已提交
73
        }
W
wangqun 已提交
74 75 76 77
		else if (this.outputs.Y) {
			this.outpus.out = this.outputs.Y;
			return this.outputs.out;
		}
W
wangqun 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        else {
            return this.outputs.Out || this.outputs.Output;
        }

    }

    /**
     * 将输入数据和具体op进行关联,触发执行具体每一个op
     * @param runtime
     * @param isRendered
     */
    execute(runtime, isRendered) {
        // console.log(inputs, outputs);
        if (this.type !== 'feed') {
            // let time = +Date.now();
            // log.start(this.opData.iLayer + '-' + this.type);
            runtime.run(this.type, this.opData, isRendered);
            // log.end(this.opData.iLayer + '-' + this.type);
            // if (runtime.gpu.frameBufferIsComplete().isComplete) {
            //     var result = runtime.read();
            //     let res = Array.prototype.slice.call(result);
            //     fileDownload(res, "result.csv");
            // }
            // let length = statistic.length;
            // statistic[length - 1].type = this.type;
            // statistic[length - 1].runTime = +Date.now() - time;
            // if (this.type === 'scale') {
            //     console.log('时间是:' + (+Date.now() - start));
            // }
        } else {
            start = +Date.now();
        }
    }
}

/* eslint-enable */