testUtils.es6 3.2 KB
Newer Older
W
wangqun 已提交
1 2
import 'babel-polyfill';
import Paddle from '../../src/paddle/paddle';
W
wangqun 已提交
3
import Utils from '../../src/utils/utils';
W
wangqun 已提交
4 5 6 7 8 9 10 11 12

const unitPath = {
    'conv2d': 'model.test.conv2d.json',
    'batchnorm': 'model.test.batchnorm.json',
    'mul': 'model.test.mul.json',
    'pool2d': 'model.test.pool2d.json',
    'relu': 'model.test.relu.json',
    'scale': 'model.test.scale.json',
    'softmax': 'model.test.softmax.json',
W
wangqun 已提交
13 14 15 16 17 18 19 20 21 22
    'relu6' : 'model.test.relu6.json',
	'elementwise' : 'model.test.elementwise_add.json',
	'depthwise' : 'model.test.depthwise_conv2d.json',
	'reshape' : 'model.test.reshape.json',
	'bilinear_interp' : 'model.test.bilinear_interp.json',
	'transpose' : 'model.test.transpose.json',
	'conv2d_transpose': 'model.test.conv2d_transpose.json',
	'elementwise_add': 'model.test.elementwise_add.json',
    'concat': 'model.test.concat.json',
    'split': 'model.test.split.json'
W
wangqun 已提交
23 24
};
// 制定运行的 op
25
const modelType = 'conv2d';
W
wangqun 已提交
26
// 制定运行的 op
W
wangqun 已提交
27 28 29
const unitData = unitPath[modelType];

let datas;
W
wangqun 已提交
30
let output;
W
wangqun 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
async function run() {
    const path = 'test/unitData';
    const MODEL_CONFIG = {
        dir: `/${path}/`, // 存放模型的文件夹
        main: unitData, // 主文件
    };

    const paddle = new Paddle({
        urlConf: MODEL_CONFIG,
        options: {
            test: true
        }
    });

    let model = await paddle.load();
    datas = model.graph.data;
W
wangqun 已提交
47

W
wangqun 已提交
48
    output = deepCopy(datas);
W
wangqun 已提交
49

W
wangqun 已提交
50 51 52 53 54 55 56 57
    model.graph.weightMap.forEach(op => {
        const type = op.type;
        if (type !== 'feed' && type !== 'fetch') {
            console.log(op.type);
            model.graph.buildOpData(op);
        }
    });
    const executor = model.graph.weightMap;
W
wangqun 已提交
58
    model.graph.execute_(executor[0]);
W
wangqun 已提交
59

W
wangqun 已提交
60 61
    // NHWC输出
    let result = await model.graph.inst.read();
W
wangqun 已提交
62

W
wangqun 已提交
63 64 65
    // 获取 NHWC -> NCHW 的 输出
    const outputNCHWShape = getOutputShape();
    const outputNHWCShape = nchwShape2nhwcShape(outputNCHWShape);
W
wangqun 已提交
66

67 68
    let nchwResult = Utils.nhwc2nchw(result, outputNHWCShape);
    const formatData = Utils.formatReadData(nchwResult, outputNCHWShape);
W
wangqun 已提交
69 70

    console.log('NCHW RESULT');
71 72
    console.log(formatData);

W
wangqun 已提交
73
}
W
wangqun 已提交
74 75 76 77 78 79 80 81

run();

function deepCopy (data) {
    return JSON.parse(JSON.stringify(data));
}


W
wangqun 已提交
82 83
const getResult = function (id) {
    const data = output.ops.filter(item => id === item.type);
W
wangqun 已提交
84 85 86
    return getoutputs(data[0]);
};

W
wangqun 已提交
87 88
const getValue = function(name, datas) {
    return datas.vars.find(item => name === item.name);
W
wangqun 已提交
89 90
};

W
wangqun 已提交
91 92 93 94 95 96 97
const OUTPUT_KEYS = ['out', 'y', 'output'];
const getoutputs = function (data) {
    const outputkey = Object.keys(data.outputs).find(key => OUTPUT_KEYS.includes(key.toLowerCase()));
    const outputTensorId = data.outputs[outputkey].slice(-1)[0];
    const outputTensor = getValue(outputTensorId, output);

    return outputTensor;
W
wangqun 已提交
98 99
};

W
wangqun 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
function getOutputShape () {
    var outputTensor = getResult(modelType);
    return outputTensor.shape;
}

// NCHW shape 2 NHWC shape
function nchwShape2nhwcShape(nchw) {
    let batchNCHW = nchw;
    if (nchw.length < 4) {
        let batch = [];
        for (let i = 0; i < (4 - nchw.length); i++) {
            batch.push(1);
        }
        batchNCHW = batch.concat(nchw);
    }
115

W
wangqun 已提交
116 117 118 119 120 121 122
    const N = batchNCHW[0];
    const C = batchNCHW[1];
    const H = batchNCHW[2];
    const W = batchNCHW[3];

    return [N, H, W, C];
}