From 4f6c5d8f01608cc1cc545943f06889abb12617a6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 28 Oct 2021 14:42:42 +0800 Subject: [PATCH] feat(mge/dump): enable jit.dump to dump with testcase GitOrigin-RevId: 5dce3564529c9a04f118a599637237f68a101e77 --- imperative/python/megengine/jit/tracing.py | 374 ++++++++++++ .../python/test/unit/jit/test_tracing.py | 10 + sdk/load-and-run/dump_with_testcase_mge.py | 535 ------------------ 3 files changed, 384 insertions(+), 535 deletions(-) delete mode 100755 sdk/load-and-run/dump_with_testcase_mge.py diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index ad1dd2d10..d9d867fac 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -13,10 +13,17 @@ import itertools import json import os import pickle +import re +import struct from typing import Any +import cv2 import numpy as np +from megengine.logger import get_logger + +from .. import tensor +from ..core import _imperative_rt as rt from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( @@ -38,12 +45,15 @@ from ..core._wrap import as_device from ..core.ops.builtin import BatchNorm, OpDef from ..core.tensor import megbrain_graph as G from ..core.tensor.utils import setscalar +from ..utils import comp_graph_tools as cgtools from ..utils.naming import AutoNaming from ..utils.profiler import is_profiling from .dtr_config import DTRConfig from .graph_opt_config import GraphOptimizationConfig from .sublinear_memory_config import SublinearMemoryConfig +logger = get_logger(__name__) + def _input_node_use_static_shape(): return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None @@ -692,6 +702,289 @@ class trace: self._process_outputs(outputs) return outputs + def _make_feed( + self, + graph, + outputs, + input_data, + repeat, + silent, + no_assert, + maxerr, + resize_input, + input_transform, + ): + def auto_reformat_image(path, data, dst_shape): + """reformat image to target shape + + :param data: image data as numpy array + :param dst_shape: target shape + """ + dim3_format = False # required input format does not contain batch + hwc_format = False # required input format is NHWC + + if not dst_shape: # input tensor shape is not predefined + if len(data.shape) == 2: + chl = 1 + h = data.shape[0] + w = data.shape[1] + else: + assert ( + len(data.shape) == 3 + ), "Input image must be of dimension 2 or 3" + h, w, chl = data.shape + dst_shape = (1, chl, h, w) + + if len(dst_shape) == 3: + dst_shape = (1,) + dst_shape + dim3_format = True + + assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) + chl = dst_shape[1] + if chl in [1, 3]: + n, c, h, w = dst_shape + dst_shape = (n, h, w, c) + else: + chl = dst_shape[3] + assert chl in [ + 1, + 3, + ], "can not infer input format from shape: {}".format(dst_shape) + hwc_format = True + + # dst_shape has now been normalized to NHWC format + + if resize_input: + h, w = dst_shape[1:3] + data = cv2.resize(data, (w, h)) + logger.info("input {} resized to {}".format(path, data.shape)) + + if chl == 1: + data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) + data = data[:, :, np.newaxis] + + assert data.ndim == 3 + data = data[np.newaxis] + # data normalized to NHWC format + + if not hwc_format: + data = np.transpose(data, (0, 3, 1, 2)) + + if dim3_format: + data = np.squeeze(data, 0) + + return data + + def read_input_data(dst_shape, dtype, path): + def check_shape_equal(dst_shape, data_shape): + if len(dst_shape): + assert len(data_shape) == len( + dst_shape + ), "input/data shapes mismatch: {} vs {}".format( + dst_shape, data_shape + ) + + if data_shape[1:] != dst_shape[1:]: + logger.warning( + "dst_shape is {}; data_shape is {}".format( + dst_shape, data_shape + ) + ) + + if path.startswith("#"): + assert not resize_input + assert not input_transform + spec = path + m = re.match( + r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec + ) + assert m, "bad spec {}".format(spec) + + rng_min = float(m.group(1)) + rng_max = float(m.group(2)) + if m.group(3): + shape_str = m.group(3) + try: + shape = shape_str[1:].split(",") + if shape[-1].strip() == "...": + shape = shape[:-1] + shape.extend(list(dst_shape[len(shape) :])) + data_shape = tuple(map(int, shape)) + except ValueError as e: + raise ValueError("bad spec {}: {}".format(spec, e.args)) + else: + data_shape = dst_shape + + check_shape_equal(dst_shape, data_shape) + return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) + + # try to load image + data = cv2.imread(path, cv2.IMREAD_COLOR) + if data is None: + assert not resize_input + data = np.load(path) + assert isinstance(data, np.ndarray) + else: + # load image succeeds, so we expect input format is image format + data = auto_reformat_image(path, data, dst_shape) + + data = np.repeat(data, repeat, axis=0) + if repeat > 1: + logger.info( + "repeat input for {} times, data shape is {}".format( + repeat, data.shape + ) + ) + + check_shape_equal(dst_shape, data.shape) + + if input_transform: + data = eval(input_transform, {"data": data, "np": np}) + + return data + + def gen_one_testcase(inputs, spec): + paths = spec.split(";") + if len(paths) != len(inputs): + if len(paths) == 1 and paths[0].startswith("#"): + paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] + assert len(paths) == len( + inputs + ), "required inputs: {}; data paths: {}".format(inputs.keys(), paths) + if len(paths) == 1 and ":" not in paths[0]: + paths[0] = next(iter(inputs.keys())) + ":" + paths[0] + + ret = {} + for path in paths: + var, path = path.split(":") + ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path) + return ret + + inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") + inputs = {i.name: i for i in inputs} + + if not no_assert: + + replace_varmap = {} + inp_map = {} + # replace var use InputNode + for name, var in inputs.items(): + inp = G.InputNode( + device="xpux", dtype=var.dtype, shape=var.shape, graph=graph + ) + replace_varmap[var] = inp.outputs[0]._node + inp_map[name] = inp + + new = cgtools.replace_vars(outputs, replace_varmap) + if isinstance(new, rt.VarNode): + new = list(new) + + output_nodes = [G.OutputNode(var) for var in new] + func = graph.compile(*[node.outputs[0]._node for node in output_nodes]) + + def make_dev_tensor(value, dtype=None, device=None): + return tensor(value, dtype=dtype, device=device)._dev_tensor() + + def calculate(*args, **kwargs): + output_val = [] + # set inputs value + for name, var in inputs.items(): + val = kwargs.pop(name, None) + assert val is not None, "miss input name{}".format(name) + dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") + inp_map[name].set_value(dev_tensor) + + func.execute() + + for res in output_nodes: + output_val.append(res.get_value().numpy()) + return output_val + + def expect_name(var): + return "{}:expect".format(var.name) + + testcases = [] + + np.set_printoptions(precision=2, threshold=4, suppress=True) + + data_list = [] + for item in input_data: + if item.startswith("@"): + with open(item[1:], "r") as f: + data_list.extend( + [line.rstrip() for line in f if line.rstrip() != ""] + ) + else: + data_list.append(item) + + for inp_spec in data_list: + cur_testcase = gen_one_testcase(inputs, inp_spec) + assert len(cur_testcase) == len( + inputs + ), "required inputs: {}; given data: {}".format( + inputs.keys(), cur_testcase.keys() + ) + + if not no_assert: + outputs_get = calculate(**cur_testcase) + for var, val in zip(outputs, outputs_get): + cur_testcase[expect_name(var)] = val + logger.info( + "generate test groundtruth: var={} shape={} range=({}, {})" + " mean={} var={}".format( + var, + val.shape, + val.min(), + val.max(), + np.mean(val), + np.var(val), + ) + ) + testcases.append(cur_testcase) + logger.info( + "add testcase: \n {}".format( + "\n ".join( + "{}: shape={} dtype={} range=({:.2f},{:.2f}) " + "mean={:.2f} sd={:.2f}".format( + k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) + ) + for k, v in sorted(cur_testcase.items()) + ) + ) + ) + + if not no_assert: + + def expect_shp(var): + ret = var.shape + if ret: + return ret + return testcases[0][expect_name(var)].shape + + def assert_equal(expect, real, **kwargs): + op = AssertEqual(**kwargs) + (res,) = G.apply_normal_varnode(op, expect, real) + return res._node + + verbose = not silent + + outputs_new = [] + for i in outputs: + device = rt.CompNode("xpux") + dtype = i.dtype + name = expect_name(i) + shape = expect_shp(i) + # make expect output as one input of model. + expect_get = rt.make_h2d(graph, device, dtype, shape, name) + # insert assert opr to check expect and real. + outputs_new.append( + assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,) + ) + inputs[expect_name(i)] = expect_get + outputs = outputs_new + + return {"outputs": outputs, "testcases": testcases} + def dump( self, file, @@ -708,6 +1001,13 @@ class trace: optimize_for_inference=True, user_info: Any = None, enable_metadata: bool = True, + input_data=None, + repeat=1, + silent=False, + no_assert=False, + maxerr=1e-4, + resize_input=False, + input_transform=None, **kwargs ): r"""Serializes trace to file system. @@ -738,6 +1038,27 @@ class trace: will skip all optimize options if this is False. Default: True user_info: any type object, which will be pickled to bytes. enable_metadata: whether to save metadata into output file. + input_data: input test data and current network output would be used as groundtruth. + The format is "var0:file0;var1:file1..." to specify data files for input vars. + It can also be "#rand(min,max,shape...)" for generating random input data, for + example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)" + where `...` means the remaining part of the original shape. If the shape is not + specified, the shape of corresponding input tensors in the network will be used. + If there is only one input var, its name can be omitted. Each data file can either + be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option + can be given multiple times to add multiple testcases. If you start the data + with the letter @, the rest should be a filename, and each line in the file should + be a single datum in the format described above. *NOTE* If `input_data` is not None, + you can only use load-and-run to run the output file. + repeat: how many times the input image is repeated. Useful when running benchmark for + batch size other than one. Have no effect on randomly generated input data. + silent: whether set verbose to False in assert_equal opr. + no_assert: whether insert assert_equal opr to check result; this option is useful for + benchmarking. + maxerr: max error for assert_equal check during runtime. + resize_input: whether resize input image to fit input var shape. + input_transform: a python expression to transform the input data. + Example: data / np.std(data) Keyword Arguments: @@ -778,6 +1099,8 @@ class trace: input for inference on nvidia backend(this optimization pass will result in mismatch of the precision of output of training and inference) + * enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and + etc opr """ if not self._capture_as_const: raise ValueError( @@ -892,8 +1215,28 @@ class trace: v.name = output_names[i] dest_vars.append(v) + dest_vars = [i._node for i in dest_vars] + + if input_data is not None: + feeds = self._make_feed( + graph, + dest_vars, + input_data, + repeat, + silent, + no_assert, + maxerr, + resize_input, + input_transform, + ) + assert ( + isinstance(feeds, dict) and feeds["testcases"] + ), "testcases can not be empty" + dest_vars = feeds["outputs"] + if optimize_for_inference: dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) + dest_vars = [i._node for i in dest_vars] metadata = SerializationMetadata() if enable_metadata: @@ -910,6 +1253,9 @@ class trace: if keep_opr_priority: graph._set_priority_to_id(dest_vars) + if input_data is not None: + file.write(b"mgbtest0") + file.write(struct.pack("I", len(feeds["testcases"]))) dump_content, dump_info = G.dump_graph( dest_vars, keep_var_name=keep_var_name, @@ -921,6 +1267,34 @@ class trace: metadata=metadata, ) file.write(dump_content) + + if input_data is not None: + inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy") + inputs = sorted((i.name, i.dtype) for i in inputs) + + def make_dev_tensor(value, dtype=None, device=None): + return tensor(value, dtype=dtype, device=device)._dev_tensor() + + for testcase in feeds["testcases"]: + assert isinstance(testcase, dict) + cg = G.Graph() + output_mgbvars = [] + for name, dtype in inputs: + output_mgbvars.append( + cg.make_const( + make_dev_tensor( + testcase.pop(name), dtype=dtype, device="cpux" + ) + ) + ) + assert not testcase, "extra inputs provided in testcase: {}".format( + testcase.keys() + ) + dump_content, _ = G.dump_graph( + output_mgbvars, strip_info_file=strip_info_file, append_json=True, + ) + file.write(dump_content) + return dump_info def _process_inputs(self, *args, **kwargs): diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 6e443e6c9..a228a3b38 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -287,6 +287,16 @@ def test_dump_backward_graph(): np.testing.assert_equal(results[1], dx0) +def test_dump_with_testcase(): + @trace(symbolic=True, capture_as_const=True) + def f(x): + return exp(x) + + f(tensor(1.0)) + file = io.BytesIO() + f.dump(file, input_data=["#rand(0, 255, 1)"]) + + @pytest.mark.parametrize("trace_mode", [False, True]) def test_trace_profiler(trace_mode): @trace(symbolic=trace_mode, profiling=True) diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py deleted file mode 100755 index 524ef1fa4..000000000 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ /dev/null @@ -1,535 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import argparse -import os -import re -import struct - -import cv2 -import numpy as np - -import megengine as mge -import megengine.core._imperative_rt as rt -import megengine.core.tensor.megbrain_graph as G -from megengine import tensor -from megengine.core._imperative_rt.core2 import apply -from megengine.core.ops import builtin -from megengine.utils import comp_graph_tools as cgtools - -logger = mge.get_logger(__name__) - - -def auto_reformat_image(args, path, data, dst_shape): - """reformat image to target shape - - :param data: image data as numpy array - :param dst_shape: target shape - """ - dim3_format = False # required input format does not contain batch - hwc_format = False # required input format is NHWC - - if not dst_shape: # input tensor shape is not predefined - if len(data.shape) == 2: - chl = 1 - h = data.shape[0] - w = data.shape[1] - else: - assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" - h, w, chl = data.shape - dst_shape = (1, chl, h, w) - - if len(dst_shape) == 3: - dst_shape = (1,) + dst_shape - dim3_format = True - - assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) - chl = dst_shape[1] - if chl in [1, 3]: - n, c, h, w = dst_shape - dst_shape = (n, h, w, c) - else: - chl = dst_shape[3] - assert chl in [1, 3], "can not infer input format from shape: {}".format( - dst_shape - ) - hwc_format = True - - # dst_shape has now been normalized to NHWC format - - if args.resize_input: - h, w = dst_shape[1:3] - data = cv2.resize(data, (w, h)) - logger.info("input {} resized to {}".format(path, data.shape)) - - if chl == 1: - data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) - data = data[:, :, np.newaxis] - - assert data.ndim == 3 - data = data[np.newaxis] - # data normalized to NHWC format - - if not hwc_format: - data = np.transpose(data, (0, 3, 1, 2)) - - if dim3_format: - data = np.squeeze(data, 0) - - return data - - -def read_input_data(args, dst_shape, dtype, path, repeat): - def check_shape_equal(dst_shape, data_shape): - if len(dst_shape): - assert len(data_shape) == len( - dst_shape - ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) - - if data_shape[1:] != dst_shape[1:]: - logger.warning( - "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) - ) - - if path.startswith("#"): - assert not args.resize_input - assert not args.input_transform - spec = path - m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) - assert m, "bad spec {}".format(spec) - - rng_min = float(m.group(1)) - rng_max = float(m.group(2)) - if m.group(3): - shape_str = m.group(3) - try: - shape = shape_str[1:].split(",") - if shape[-1].strip() == "...": - shape = shape[:-1] - shape.extend(list(dst_shape[len(shape) :])) - data_shape = tuple(map(int, shape)) - except ValueError as e: - raise ValueError("bad spec {}: {}".format(spec, e.args)) - else: - data_shape = dst_shape - - check_shape_equal(dst_shape, data_shape) - return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) - - # try to load image - data = cv2.imread(path, cv2.IMREAD_COLOR) - if data is None: - assert not args.resize_input - data = np.load(path) - assert isinstance(data, np.ndarray) - else: - # load image succeeds, so we expect input format is image format - data = auto_reformat_image(args, path, data, dst_shape) - - data = np.repeat(data, repeat, axis=0) - if repeat > 1: - logger.info( - "repeat input for {} times, data shape is {}".format(repeat, data.shape) - ) - - check_shape_equal(dst_shape, data.shape) - - if args.input_transform: - data = eval(args.input_transform, {"data": data, "np": np}) - - return data - - -def gen_one_testcase(args, inputs, spec): - paths = spec.split(";") - if len(paths) != len(inputs): - if len(paths) == 1 and paths[0].startswith("#"): - paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] - assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( - inputs.keys(), paths - ) - if len(paths) == 1 and ":" not in paths[0]: - paths[0] = next(iter(inputs.keys())) + ":" + paths[0] - - ret = {} - for path in paths: - var, path = path.split(":") - if args.repeat: - repeat = args.repeat - else: - repeat = 1 - ret[var] = read_input_data( - args, inputs[var].shape, inputs[var].dtype, path, repeat - ) - return ret - - -def make_feeds(args): - ret = G.load_graph(args.input) - cg_rt, outputs = ret.graph, ret.output_vars_list - inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") - - inputs = {i.name: i for i in inputs} - if not args.no_assert: - - replace_varmap = {} - inp_map = {} - # replace var use InputNode - for name, var in inputs.items(): - inp = G.InputNode( - device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt - ) - replace_varmap[var] = inp.outputs[0] - inp_map[name] = inp - - new = cgtools.replace_vars(outputs, replace_varmap) - if isinstance(new, rt.VarNode): - new = list(new) - - output_nodes = [G.OutputNode(var) for var in new] - func = cg_rt.compile([node.outputs[0] for node in output_nodes]) - - def make_dev_tensor(value, dtype=None, device=None): - return tensor(value, dtype=dtype, device=device)._dev_tensor() - - def calculate(*args, **kwargs): - output_val = [] - # set inputs value - for name, var in inputs.items(): - val = kwargs.pop(name, None) - assert val is not None, "miss input name{}".format(name) - dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") - inp_map[name].set_value(dev_tensor) - - func.execute() - - for res in output_nodes: - output_val.append(res.get_value().numpy()) - return output_val - - def expect_name(var): - return "{}:expect".format(var.name) - - testcases = [] - - np.set_printoptions(precision=2, threshold=4, suppress=True) - - data_list = [] - for item in args.data: - if item.startswith("@"): - with open(item[1:], "r") as f: - data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) - else: - data_list.append(item) - - for inp_spec in data_list: - cur_testcase = gen_one_testcase(args, inputs, inp_spec) - assert len(cur_testcase) == len( - inputs - ), "required inputs: {}; given data: {}".format( - inputs.keys(), cur_testcase.keys() - ) - - if not args.no_assert: - outputs_get = calculate(**cur_testcase) - for var, val in zip(outputs, outputs_get): - cur_testcase[expect_name(var)] = val - logger.info( - "generate test groundtruth: var={} shape={} range=({}, {})" - " mean={} var={}".format( - var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) - ) - ) - testcases.append(cur_testcase) - logger.info( - "add testcase: \n {}".format( - "\n ".join( - "{}: shape={} dtype={} range=({:.2f},{:.2f}) " - "mean={:.2f} sd={:.2f}".format( - k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) - ) - for k, v in sorted(cur_testcase.items()) - ) - ) - ) - - if not args.no_assert: - - def expect_shp(var): - ret = var.shape - if ret: - return ret - return testcases[0][expect_name(var)].shape - - def assert_equal(expect, real, **kwargs): - op = builtin.AssertEqual(**kwargs) - (res,) = apply(op, expect, real) - return res - - verbose = not args.silent - - outputs_new = [] - for i in outputs: - device = rt.CompNode("xpux") - dtype = i.dtype - name = expect_name(i) - shape = expect_shp(i) - # make expect output as one input of model. - expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) - # insert assert opr to check expect and real. - outputs_new.append( - assert_equal( - expect_get, - i, - verbose=verbose, - maxerr=args.maxerr, - ) - ) - inputs[expect_name(i)] = expect_get - outputs = outputs_new - - return {"outputs": outputs, "testcases": testcases} - - -def optimize_for_inference(args, outputs): - args_list = [ - "enable_io16xc32", - "enable_ioc16", - "enable_hwcd4", - "enable_nchw4", - "enable_nchw88", - "enable_nchw44", - "enable_nchw44_dot", - "enable_nchw32", - "enable_chwn4", - "enable_fuse_conv_bias_nonlinearity", - "enable_fuse_conv_bias_with_z", - "enable_fuse_preprocess", - ] - kwargs = {} - for k in args_list: - if getattr(args, k): - kwargs[k] = True - - if args.optimize_for_inference: - outputs = G.optimize_for_inference(outputs, **kwargs) - - return outputs - - -def main(): - parser = argparse.ArgumentParser( - description="Pack computing graph, input values and expected output " - "values into one file for checking correctness. README.md gives more " - "details on the usage", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("input", help="MegEngine dumped model file") - parser.add_argument("-o", "--output", help="output file", required=True) - parser.add_argument( - "-d", - "--data", - default=[], - action="append", - required=True, - help="Given input test data when input file is a network, " - "and current network output would be used as groundtruth. " - "The format is var0:file0;var1:file1... to specify data files for " - "input vars. It can also be #rand(min,max,shape...) for generating " - "random input data, for example, #rand(0,255), " - "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " - "the remaining part of the original shape. " - "If the shape is not specified, the shape of " - "corresponding input tensors in the network will be used. " - "If there is only one input var, its name can be omitted. " - "Each data file can either be an image which can be loaded by opencv, " - "or a pickled numpy.ndarray. " - "This option can be given multiple times to add multiple testcases. " - " *NOTE* " - "If you start the data with the letter @, the rest should be a " - "filename, and each line in the file should be a single datum in " - "the format described above. ", - ) - parser.add_argument( - "--repeat", - type=int, - default=1, - help="Specify how many times the input image is repeated. " - "Useful when running benchmark for batch size other than one. " - "Have no effect on randomly generated input data.", - ) - parser.add_argument( - "--silent", - action="store_true", - help="set verbose to False in asserti_equal opr", - ) - parser.add_argument( - "--optimize-for-inference", - action="store_true", - help="enable optimization for inference", - ) - parser.add_argument( - "--no-assert", - action="store_true", - help="do not insert assert_equal opr to check result; " - "this option is useful for benchmarking", - ) - parser.add_argument( - "--maxerr", - type=float, - default=1e-4, - help="max error for assert_equal check during runtime", - ) - parser.add_argument( - "--resize-input", - action="store_true", - help="resize input image to fit input var shape", - ) - parser.add_argument( - "--input-transform", - help="a python expression to transform the input data. " - "Example: data / np.std(data)", - ) - parser.add_argument( - "--discard-var-name", - action="store_true", - help="discard variable and param names in the " "generated output", - ) - parser.add_argument( - "--output-strip-info", action="store_true", help="output code strip information" - ) - parser.add_argument( - "--enable-io16xc32", - action="store_true", - help="transform the mode to float16 io float32 compute", - ) - parser.add_argument( - "--enable-ioc16", - action="store_true", - help="transform the dtype of the model to float16 io " "and compute", - ) - parser.add_argument( - "--enable-fuse-conv-bias-nonlinearity", - action="store_true", - help="fuse convolution bias and nonlinearity opr to a " - "conv_bias opr and compute", - ) - parser.add_argument( - "--enable-hwcd4", - action="store_true", - help="transform the model format from NCHW to NHWCD4 " - "for inference; you may need to disable CUDA and set " - "MGB_USE_MEGDNN_DBG=2", - ) - parser.add_argument( - "--enable-nchw4", - action="store_true", - help="transform the model format from NCHW to NCHW4 " "for inference", - ) - parser.add_argument( - "--enable-nchw88", - action="store_true", - help="transform the model format from NCHW to NCHW88 " "for inference", - ) - parser.add_argument( - "--enable-nchw44", - action="store_true", - help="transform the model format from NCHW to NCHW44 " "for inference", - ) - parser.add_argument( - "--enable-nchw44-dot", - action="store_true", - help="transform the model format from NCHW to NCHW44_DOT " - "for optimizing armv8.2 dot in inference", - ) - parser.add_argument( - "--enable-nchw32", - action="store_true", - help="transform the model format from NCHW4 to NCHW32 " - "for inference on nvidia TensoCore", - ) - parser.add_argument( - "--enable-chwn4", - action="store_true", - help="transform the model format to CHWN4 " - "for inference, mainly used for nvidia tensorcore", - ) - parser.add_argument( - "--enable-fuse-conv-bias-with-z", - action="store_true", - help="fuse conv_bias with z input for inference on " - "nvidia GPU (this optimization pass will result in mismatch " - "of the precision of output of training and inference)", - ) - parser.add_argument( - "--enable-fuse-preprocess", - action="store_true", - help="fuse astype\pad_channel\dimshuffle and etc opr " - "from h2d opr", - ) - args = parser.parse_args() - - feeds = make_feeds(args) - - assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" - - output_mgbvars = feeds["outputs"] - output_mgbvars = optimize_for_inference(args, output_mgbvars) - - inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") - inputs = sorted((i.name, i.dtype) for i in inputs) - - if args.discard_var_name: - sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) - else: - sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) - - strip_info_file = args.output + ".json" if args.output_strip_info else None - - with open(args.output, "wb") as fout: - fout.write(b"mgbtest0") - fout.write(struct.pack("I", len(feeds["testcases"]))) - dump_content, stat = G.dump_graph( - output_mgbvars, - append_json=True, - strip_info_file=strip_info_file, - **sereg_kwargs, - ) - fout.write(dump_content) - - logger.info( - "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( - stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 - ) - ) - - def make_dev_tensor(value, dtype=None, device=None): - return tensor(value, dtype=dtype, device=device)._dev_tensor() - - for testcase in feeds["testcases"]: - assert isinstance(testcase, dict) - cg = G.Graph() - output_mgbvars = [] - for name, dtype in inputs: - output_mgbvars.append( - cg.make_const( - make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") - ) - ) - assert not testcase, "extra inputs provided in testcase: {}".format( - testcase.keys() - ) - with open(args.output, "ab") as fout: - dump_content, _ = G.dump_graph( - output_mgbvars, strip_info_file=strip_info_file, append_json=True - ) - fout.write(dump_content) - - -if __name__ == "__main__": - main() -- GitLab