dump_with_testcase_mge.py 17.8 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import os
import re
import struct

import cv2
import numpy as np

import megengine as mge
18 19
import megengine.core._imperative_rt as rt
import megengine.core.tensor.megbrain_graph as G
20
from megengine import tensor
21
from megengine.core._imperative_rt.core2 import apply
22
from megengine.core.ops import builtin
23
from megengine.core.tensor.megbrain_graph import VarNode
24
from megengine.utils import comp_graph_tools as cgtools
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

logger = mge.get_logger(__name__)


def auto_reformat_image(args, path, data, dst_shape):
    """reformat image to target shape

    :param data: image data as numpy array
    :param dst_shape: target shape
    """
    dim3_format = False  # required input format does not contain batch
    hwc_format = False  # required input format is NHWC

    if not dst_shape:  # input tensor shape is not predefined
        if len(data.shape) == 2:
            chl = 1
            h = data.shape[0]
            w = data.shape[1]
        else:
44
            assert len(data.shape) == 3, "Input image must be of dimension 2 or 3"
45 46 47 48
            h, w, chl = data.shape
        dst_shape = (1, chl, h, w)

    if len(dst_shape) == 3:
49
        dst_shape = (1,) + dst_shape
50 51
        dim3_format = True

52
    assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
53 54 55 56 57 58
    chl = dst_shape[1]
    if chl in [1, 3]:
        n, c, h, w = dst_shape
        dst_shape = (n, h, w, c)
    else:
        chl = dst_shape[3]
59 60 61
        assert chl in [1, 3], "can not infer input format from shape: {}".format(
            dst_shape
        )
62 63 64 65 66 67 68
        hwc_format = True

    # dst_shape has now been normalized to NHWC format

    if args.resize_input:
        h, w = dst_shape[1:3]
        data = cv2.resize(data, (w, h))
69
        logger.info("input {} resized to {}".format(path, data.shape))
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

    if chl == 1:
        data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
        data = data[:, :, np.newaxis]

    assert data.ndim == 3
    data = data[np.newaxis]
    # data normalized to NHWC format

    if not hwc_format:
        data = np.transpose(data, (0, 3, 1, 2))

    if dim3_format:
        data = np.squeeze(data, 0)

    return data


def read_input_data(args, dst_shape, dtype, path, repeat):
    def check_shape_equal(dst_shape, data_shape):
        if len(dst_shape):
91 92 93
            assert len(data_shape) == len(
                dst_shape
            ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape)
94 95 96

            if data_shape[1:] != dst_shape[1:]:
                logger.warning(
97
                    "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape)
98 99
                )

100
    if path.startswith("#"):
101 102 103
        assert not args.resize_input
        assert not args.input_transform
        spec = path
104 105
        m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec)
        assert m, "bad spec {}".format(spec)
106 107 108 109 110 111

        rng_min = float(m.group(1))
        rng_max = float(m.group(2))
        if m.group(3):
            shape_str = m.group(3)
            try:
112 113
                shape = shape_str[1:].split(",")
                if shape[-1].strip() == "...":
114
                    shape = shape[:-1]
115
                    shape.extend(list(dst_shape[len(shape) :]))
116 117
                data_shape = tuple(map(int, shape))
            except ValueError as e:
118
                raise ValueError("bad spec {}: {}".format(spec, e.args))
119 120 121 122 123 124 125 126 127 128
        else:
            data_shape = dst_shape

        check_shape_equal(dst_shape, data_shape)
        return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)

    # try to load image
    data = cv2.imread(path, cv2.IMREAD_COLOR)
    if data is None:
        assert not args.resize_input
129
        data = np.load(path)
130 131 132 133 134 135 136 137
        assert isinstance(data, np.ndarray)
    else:
        # load image succeeds, so we expect input format is image format
        data = auto_reformat_image(args, path, data, dst_shape)

    data = np.repeat(data, repeat, axis=0)
    if repeat > 1:
        logger.info(
138
            "repeat input for {} times, data shape is {}".format(repeat, data.shape)
139 140 141 142 143
        )

    check_shape_equal(dst_shape, data.shape)

    if args.input_transform:
144
        data = eval(args.input_transform, {"data": data, "np": np})
145 146 147 148 149

    return data


def gen_one_testcase(args, inputs, spec):
150
    paths = spec.split(";")
151
    if len(paths) != len(inputs):
152 153 154 155
        if len(paths) == 1 and paths[0].startswith("#"):
            paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
    assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format(
        inputs.keys(), paths
156
    )
157 158
    if len(paths) == 1 and ":" not in paths[0]:
        paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
159 160 161

    ret = {}
    for path in paths:
162
        var, path = path.split(":")
163 164 165 166 167
        if args.repeat:
            repeat = args.repeat
        else:
            repeat = 1
        ret[var] = read_input_data(
168
            args, inputs[var].shape, inputs[var].dtype, path, repeat
169 170 171 172 173
        )
    return ret


def make_feeds(args):
174 175
    cg_rt, _, outputs = G.load_graph(args.input)
    inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
176

177
    inputs = {i.name: i for i in inputs}
178
    if not args.no_assert:
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

        replace_varmap = {}
        inp_map = {}
        # replace var use InputNode
        for name, var in inputs.items():
            inp = G.InputNode(
                device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt
            )
            replace_varmap[var] = inp.outputs[0]
            inp_map[name] = inp

        new = cgtools.replace_vars(outputs, replace_varmap)
        if isinstance(new, rt.VarNode):
            new = list(new)

        output_nodes = [G.OutputNode(var) for var in new]
        func = cg_rt.compile([node.outputs[0] for node in output_nodes])

        def make_dev_tensor(value, dtype=None, device=None):
198
            return tensor(value, dtype=dtype, device=device)._dev_tensor()
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213

        def calculate(*args, **kwargs):
            output_val = []
            # set inputs value
            for name, var in inputs.items():
                val = kwargs.pop(name, None)
                assert val is not None, "miss input name{}".format(name)
                dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
                inp_map[name].set_value(dev_tensor)

            func.execute()

            for res in output_nodes:
                output_val.append(res.get_value().numpy())
            return output_val
214 215

        def expect_name(var):
216
            return "{}:expect".format(var.name)
217 218 219 220 221 222 223

    testcases = []

    np.set_printoptions(precision=2, threshold=4, suppress=True)

    data_list = []
    for item in args.data:
224 225 226
        if item.startswith("@"):
            with open(item[1:], "r") as f:
                data_list.extend([line.rstrip() for line in f if line.rstrip() != ""])
227 228 229 230 231
        else:
            data_list.append(item)

    for inp_spec in data_list:
        cur_testcase = gen_one_testcase(args, inputs, inp_spec)
232 233 234 235
        assert len(cur_testcase) == len(
            inputs
        ), "required inputs: {}; given data: {}".format(
            inputs.keys(), cur_testcase.keys()
236 237 238
        )

        if not args.no_assert:
239
            outputs_get = calculate(**cur_testcase)
240 241 242
            for var, val in zip(outputs, outputs_get):
                cur_testcase[expect_name(var)] = val
                logger.info(
243 244
                    "generate test groundtruth: var={} shape={} range=({}, {})"
                    " mean={} var={}".format(
245 246 247 248 249
                        var, val.shape, val.min(), val.max(), np.mean(val), np.var(val)
                    )
                )
        testcases.append(cur_testcase)
        logger.info(
250 251 252 253
            "add testcase: \n {}".format(
                "\n ".join(
                    "{}: shape={} dtype={} range=({:.2f},{:.2f}) "
                    "mean={:.2f} sd={:.2f}".format(
254
                        k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
255 256
                    )
                    for k, v in sorted(cur_testcase.items())
257 258 259 260 261 262 263
                )
            )
        )

    if not args.no_assert:

        def expect_shp(var):
264
            ret = var.shape
265 266 267 268
            if ret:
                return ret
            return testcases[0][expect_name(var)].shape

269 270
        def assert_equal(expect, real, **kwargs):
            op = builtin.AssertEqual(**kwargs)
271 272
            (res,) = G.apply_normal_varnode(op, expect, real)
            return G.VarNode(res)
273

274 275 276 277
        verbose = not args.silent

        outputs_new = []
        for i in outputs:
278 279 280 281 282 283 284
            device = rt.CompNode("xpux")
            dtype = i.dtype
            name = expect_name(i)
            shape = expect_shp(i)
            # make expect output as one input of model.
            expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name)
            # insert assert opr to check expect and real.
285
            outputs_new.append(
286 287 288 289 290 291
                assert_equal(
                    G.VarNode(expect_get),
                    G.VarNode(i),
                    verbose=verbose,
                    maxerr=args.maxerr,
                )
292
            )
293
            inputs[expect_name(i)] = expect_get
294 295
        outputs = outputs_new

296
    return {"outputs": outputs, "testcases": testcases}
297 298 299 300


def optimize_for_inference(args, outputs):
    args_map = {
301 302 303 304 305 306 307 308 309 310 311
        "enable_io16xc32": "f16_io_f32_comp",
        "enable_ioc16": "f16_io_comp",
        "enable_hwcd4": "use_nhwcd4",
        "enable_nchw4": "use_nchw4",
        "enable_nchw88": "use_nchw88",
        "enable_nchw44": "use_nchw44",
        "enable_nchw44_dot": "use_nchw44_dot",
        "enable_nchw32": "use_nchw32",
        "enable_chwn4": "use_chwn4",
        "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
        "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
312 313 314 315
    }
    kwargs = {}
    for k, v in args_map.items():
        if getattr(args, k):
316 317 318
            assert (
                args.optimize_for_inference
            ), "optimize_for_inference should be set when {} is given".format(k)
319 320 321
            kwargs[v] = True

    if args.optimize_for_inference:
322
        outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)]
323 324 325 326 327 328

    return outputs


def main():
    parser = argparse.ArgumentParser(
329 330 331 332
        description="Pack computing graph, input values and expected output "
        "values into one file for checking correctness. README.md gives more "
        "details on the usage",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
333
    )
334 335
    parser.add_argument("input", help="MegEngine dumped model file")
    parser.add_argument("-o", "--output", help="output file", required=True)
336
    parser.add_argument(
337 338
        "-d",
        "--data",
339
        default=[],
340
        action="append",
341
        required=True,
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
        help="Given input test data when input file is a network, "
        "and current network output would be used as groundtruth. "
        "The format is var0:file0;var1:file1... to specify data files for "
        "input vars. It can also be #rand(min,max,shape...) for generating "
        "random input data, for example, #rand(0,255), "
        "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means "
        "the remaining part of the original shape. "
        "If the shape is not specified, the shape of "
        "corresponding input tensors in the network will be used. "
        "If there is only one input var, its name can be omitted. "
        "Each data file can either be an image which can be loaded by opencv, "
        "or a pickled numpy.ndarray. "
        "This option can be given multiple times to add multiple testcases. "
        " *NOTE* "
        "If you start the data with the letter @, the rest should be a "
        "filename, and each line in the file should be a single datum in "
        "the format described above. ",
359 360
    )
    parser.add_argument(
361
        "--repeat",
362 363
        type=int,
        default=1,
364 365 366
        help="Specify how many times the input image is repeated. "
        "Useful when running benchmark for batch size other than one. "
        "Have no effect on randomly generated input data.",
367 368
    )
    parser.add_argument(
369 370 371
        "--silent",
        action="store_true",
        help="set verbose to False in asserti_equal opr",
372 373
    )
    parser.add_argument(
374
        "--optimize-for-inference",
375 376
        action="store_true",
        help="enable optimization for inference",
377 378
    )
    parser.add_argument(
379 380 381 382
        "--no-assert",
        action="store_true",
        help="do not insert assert_equal opr to check result; "
        "this option is useful for benchmarking",
383 384
    )
    parser.add_argument(
385
        "--maxerr",
386 387
        type=float,
        default=1e-4,
388
        help="max error for assert_equal check during runtime",
389 390
    )
    parser.add_argument(
391 392 393
        "--resize-input",
        action="store_true",
        help="resize input image to fit input var shape",
394 395
    )
    parser.add_argument(
396 397 398
        "--input-transform",
        help="a python expression to transform the input data. "
        "Example: data / np.std(data)",
399 400
    )
    parser.add_argument(
401 402 403
        "--discard-var-name",
        action="store_true",
        help="discard variable and param names in the " "generated output",
404 405
    )
    parser.add_argument(
406
        "--output-strip-info", action="store_true", help="output code strip information"
407 408
    )
    parser.add_argument(
409 410 411
        "--enable-io16xc32",
        action="store_true",
        help="transform the mode to float16 io float32 compute",
412 413
    )
    parser.add_argument(
414 415 416
        "--enable-ioc16",
        action="store_true",
        help="transform the dtype of the model to float16 io " "and compute",
417 418
    )
    parser.add_argument(
419 420 421 422
        "--enable-fuse-conv-bias-nonlinearity",
        action="store_true",
        help="fuse convolution bias and nonlinearity opr to a "
        "conv_bias opr and compute",
423 424
    )
    parser.add_argument(
425 426 427 428 429
        "--enable-hwcd4",
        action="store_true",
        help="transform the model format from NCHW to NHWCD4 "
        "for inference; you may need to disable CUDA and set "
        "MGB_USE_MEGDNN_DBG=2",
430
    )
431
    parser.add_argument(
432 433 434
        "--enable-nchw4",
        action="store_true",
        help="transform the model format from NCHW to NCHW4 " "for inference",
435
    )
436
    parser.add_argument(
437 438 439
        "--enable-nchw88",
        action="store_true",
        help="transform the model format from NCHW to NCHW88 " "for inference",
440
    )
441
    parser.add_argument(
442 443 444
        "--enable-nchw44",
        action="store_true",
        help="transform the model format from NCHW to NCHW44 " "for inference",
445
    )
446
    parser.add_argument(
447 448 449 450
        "--enable-nchw44-dot",
        action="store_true",
        help="transform the model format from NCHW to NCHW44_DOT "
        "for optimizing armv8.2 dot in inference",
451
    )
452
    parser.add_argument(
453 454 455 456
        "--enable-nchw32",
        action="store_true",
        help="transform the model format from NCHW4 to NCHW32 "
        "for inference on nvidia TensoCore",
457
    )
458
    parser.add_argument(
459 460 461 462
        "--enable-chwn4",
        action="store_true",
        help="transform the model format to CHWN4 "
        "for inference, mainly used for nvidia tensorcore",
463
    )
464
    parser.add_argument(
465 466 467 468 469
        "--enable-fuse-conv-bias-with-z",
        action="store_true",
        help="fuse conv_bias with z input for inference on "
        "nvidia GPU (this optimization pass will result in mismatch "
        "of the precision of output of training and inference)",
470 471 472 473 474
    )
    args = parser.parse_args()

    feeds = make_feeds(args)

475
    assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty"
476

477
    output_mgbvars = feeds["outputs"]
478 479
    output_mgbvars = optimize_for_inference(args, output_mgbvars)

480
    inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
481 482 483 484 485 486 487
    inputs = sorted((i.name, i.dtype) for i in inputs)

    if args.discard_var_name:
        sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
    else:
        sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)

488 489 490 491 492 493 494 495 496 497 498 499 500
    strip_info_file = args.output + ".json" if args.output_strip_info else None

    with open(args.output, "wb") as fout:
        fout.write(b"mgbtest0")
        fout.write(struct.pack("I", len(feeds["testcases"])))
        if isinstance(output_mgbvars, dict):
            wrap_output_vars = dict([(i, VarNode(j)) for i, j in output_mgbvars])
        else:
            wrap_output_vars = [VarNode(i) for i in output_mgbvars]
        dump_content, stat = G.dump_graph(
            wrap_output_vars,
            append_json=True,
            strip_info_file=strip_info_file,
501
            **sereg_kwargs,
502
        )
503
        fout.write(dump_content)
504

505 506 507 508 509 510 511
        logger.info(
            "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format(
                stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024
            )
        )

    def make_dev_tensor(value, dtype=None, device=None):
512
        return tensor(value, dtype=dtype, device=device)._dev_tensor()
513 514

    for testcase in feeds["testcases"]:
515
        assert isinstance(testcase, dict)
516
        cg = G.Graph()
517 518 519
        output_mgbvars = []
        for name, dtype in inputs:
            output_mgbvars.append(
520 521 522
                cg.make_const(
                    make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux")
                )
523
            )
524
        assert not testcase, "extra inputs provided in testcase: {}".format(
525 526
            testcase.keys()
        )
527 528 529 530 531 532 533 534
        with open(args.output, "ab") as fout:
            dump_content, _ = G.dump_graph(
                output_mgbvars, strip_info_file=strip_info_file, append_json=True
            )
            fout.write(dump_content)


if __name__ == "__main__":
535
    main()