提交 4f6c5d8f 编写于 作者: M Megvii Engine Team

feat(mge/dump): enable jit.dump to dump with testcase

GitOrigin-RevId: 5dce3564529c9a04f118a599637237f68a101e77
上级 182ca25d
......@@ -13,10 +13,17 @@ import itertools
import json
import os
import pickle
import re
import struct
from typing import Any
import cv2
import numpy as np
from megengine.logger import get_logger
from .. import tensor
from ..core import _imperative_rt as rt
from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
......@@ -38,12 +45,15 @@ from ..core._wrap import as_device
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils import comp_graph_tools as cgtools
from ..utils.naming import AutoNaming
from ..utils.profiler import is_profiling
from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig
logger = get_logger(__name__)
def _input_node_use_static_shape():
return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
......@@ -692,6 +702,289 @@ class trace:
self._process_outputs(outputs)
return outputs
def _make_feed(
self,
graph,
outputs,
input_data,
repeat,
silent,
no_assert,
maxerr,
resize_input,
input_transform,
):
def auto_reformat_image(path, data, dst_shape):
"""reformat image to target shape
:param data: image data as numpy array
:param dst_shape: target shape
"""
dim3_format = False # required input format does not contain batch
hwc_format = False # required input format is NHWC
if not dst_shape: # input tensor shape is not predefined
if len(data.shape) == 2:
chl = 1
h = data.shape[0]
w = data.shape[1]
else:
assert (
len(data.shape) == 3
), "Input image must be of dimension 2 or 3"
h, w, chl = data.shape
dst_shape = (1, chl, h, w)
if len(dst_shape) == 3:
dst_shape = (1,) + dst_shape
dim3_format = True
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
chl = dst_shape[1]
if chl in [1, 3]:
n, c, h, w = dst_shape
dst_shape = (n, h, w, c)
else:
chl = dst_shape[3]
assert chl in [
1,
3,
], "can not infer input format from shape: {}".format(dst_shape)
hwc_format = True
# dst_shape has now been normalized to NHWC format
if resize_input:
h, w = dst_shape[1:3]
data = cv2.resize(data, (w, h))
logger.info("input {} resized to {}".format(path, data.shape))
if chl == 1:
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
data = data[:, :, np.newaxis]
assert data.ndim == 3
data = data[np.newaxis]
# data normalized to NHWC format
if not hwc_format:
data = np.transpose(data, (0, 3, 1, 2))
if dim3_format:
data = np.squeeze(data, 0)
return data
def read_input_data(dst_shape, dtype, path):
def check_shape_equal(dst_shape, data_shape):
if len(dst_shape):
assert len(data_shape) == len(
dst_shape
), "input/data shapes mismatch: {} vs {}".format(
dst_shape, data_shape
)
if data_shape[1:] != dst_shape[1:]:
logger.warning(
"dst_shape is {}; data_shape is {}".format(
dst_shape, data_shape
)
)
if path.startswith("#"):
assert not resize_input
assert not input_transform
spec = path
m = re.match(
r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec
)
assert m, "bad spec {}".format(spec)
rng_min = float(m.group(1))
rng_max = float(m.group(2))
if m.group(3):
shape_str = m.group(3)
try:
shape = shape_str[1:].split(",")
if shape[-1].strip() == "...":
shape = shape[:-1]
shape.extend(list(dst_shape[len(shape) :]))
data_shape = tuple(map(int, shape))
except ValueError as e:
raise ValueError("bad spec {}: {}".format(spec, e.args))
else:
data_shape = dst_shape
check_shape_equal(dst_shape, data_shape)
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
# try to load image
data = cv2.imread(path, cv2.IMREAD_COLOR)
if data is None:
assert not resize_input
data = np.load(path)
assert isinstance(data, np.ndarray)
else:
# load image succeeds, so we expect input format is image format
data = auto_reformat_image(path, data, dst_shape)
data = np.repeat(data, repeat, axis=0)
if repeat > 1:
logger.info(
"repeat input for {} times, data shape is {}".format(
repeat, data.shape
)
)
check_shape_equal(dst_shape, data.shape)
if input_transform:
data = eval(input_transform, {"data": data, "np": np})
return data
def gen_one_testcase(inputs, spec):
paths = spec.split(";")
if len(paths) != len(inputs):
if len(paths) == 1 and paths[0].startswith("#"):
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
assert len(paths) == len(
inputs
), "required inputs: {}; data paths: {}".format(inputs.keys(), paths)
if len(paths) == 1 and ":" not in paths[0]:
paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
ret = {}
for path in paths:
var, path = path.split(":")
ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path)
return ret
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
inputs = {i.name: i for i in inputs}
if not no_assert:
replace_varmap = {}
inp_map = {}
# replace var use InputNode
for name, var in inputs.items():
inp = G.InputNode(
device="xpux", dtype=var.dtype, shape=var.shape, graph=graph
)
replace_varmap[var] = inp.outputs[0]._node
inp_map[name] = inp
new = cgtools.replace_vars(outputs, replace_varmap)
if isinstance(new, rt.VarNode):
new = list(new)
output_nodes = [G.OutputNode(var) for var in new]
func = graph.compile(*[node.outputs[0]._node for node in output_nodes])
def make_dev_tensor(value, dtype=None, device=None):
return tensor(value, dtype=dtype, device=device)._dev_tensor()
def calculate(*args, **kwargs):
output_val = []
# set inputs value
for name, var in inputs.items():
val = kwargs.pop(name, None)
assert val is not None, "miss input name{}".format(name)
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
inp_map[name].set_value(dev_tensor)
func.execute()
for res in output_nodes:
output_val.append(res.get_value().numpy())
return output_val
def expect_name(var):
return "{}:expect".format(var.name)
testcases = []
np.set_printoptions(precision=2, threshold=4, suppress=True)
data_list = []
for item in input_data:
if item.startswith("@"):
with open(item[1:], "r") as f:
data_list.extend(
[line.rstrip() for line in f if line.rstrip() != ""]
)
else:
data_list.append(item)
for inp_spec in data_list:
cur_testcase = gen_one_testcase(inputs, inp_spec)
assert len(cur_testcase) == len(
inputs
), "required inputs: {}; given data: {}".format(
inputs.keys(), cur_testcase.keys()
)
if not no_assert:
outputs_get = calculate(**cur_testcase)
for var, val in zip(outputs, outputs_get):
cur_testcase[expect_name(var)] = val
logger.info(
"generate test groundtruth: var={} shape={} range=({}, {})"
" mean={} var={}".format(
var,
val.shape,
val.min(),
val.max(),
np.mean(val),
np.var(val),
)
)
testcases.append(cur_testcase)
logger.info(
"add testcase: \n {}".format(
"\n ".join(
"{}: shape={} dtype={} range=({:.2f},{:.2f}) "
"mean={:.2f} sd={:.2f}".format(
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
)
for k, v in sorted(cur_testcase.items())
)
)
)
if not no_assert:
def expect_shp(var):
ret = var.shape
if ret:
return ret
return testcases[0][expect_name(var)].shape
def assert_equal(expect, real, **kwargs):
op = AssertEqual(**kwargs)
(res,) = G.apply_normal_varnode(op, expect, real)
return res._node
verbose = not silent
outputs_new = []
for i in outputs:
device = rt.CompNode("xpux")
dtype = i.dtype
name = expect_name(i)
shape = expect_shp(i)
# make expect output as one input of model.
expect_get = rt.make_h2d(graph, device, dtype, shape, name)
# insert assert opr to check expect and real.
outputs_new.append(
assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,)
)
inputs[expect_name(i)] = expect_get
outputs = outputs_new
return {"outputs": outputs, "testcases": testcases}
def dump(
self,
file,
......@@ -708,6 +1001,13 @@ class trace:
optimize_for_inference=True,
user_info: Any = None,
enable_metadata: bool = True,
input_data=None,
repeat=1,
silent=False,
no_assert=False,
maxerr=1e-4,
resize_input=False,
input_transform=None,
**kwargs
):
r"""Serializes trace to file system.
......@@ -738,6 +1038,27 @@ class trace:
will skip all optimize options if this is False. Default: True
user_info: any type object, which will be pickled to bytes.
enable_metadata: whether to save metadata into output file.
input_data: input test data and current network output would be used as groundtruth.
The format is "var0:file0;var1:file1..." to specify data files for input vars.
It can also be "#rand(min,max,shape...)" for generating random input data, for
example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)"
where `...` means the remaining part of the original shape. If the shape is not
specified, the shape of corresponding input tensors in the network will be used.
If there is only one input var, its name can be omitted. Each data file can either
be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option
can be given multiple times to add multiple testcases. If you start the data
with the letter @, the rest should be a filename, and each line in the file should
be a single datum in the format described above. *NOTE* If `input_data` is not None,
you can only use load-and-run to run the output file.
repeat: how many times the input image is repeated. Useful when running benchmark for
batch size other than one. Have no effect on randomly generated input data.
silent: whether set verbose to False in assert_equal opr.
no_assert: whether insert assert_equal opr to check result; this option is useful for
benchmarking.
maxerr: max error for assert_equal check during runtime.
resize_input: whether resize input image to fit input var shape.
input_transform: a python expression to transform the input data.
Example: data / np.std(data)
Keyword Arguments:
......@@ -778,6 +1099,8 @@ class trace:
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
* enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and
etc opr
"""
if not self._capture_as_const:
raise ValueError(
......@@ -892,8 +1215,28 @@ class trace:
v.name = output_names[i]
dest_vars.append(v)
dest_vars = [i._node for i in dest_vars]
if input_data is not None:
feeds = self._make_feed(
graph,
dest_vars,
input_data,
repeat,
silent,
no_assert,
maxerr,
resize_input,
input_transform,
)
assert (
isinstance(feeds, dict) and feeds["testcases"]
), "testcases can not be empty"
dest_vars = feeds["outputs"]
if optimize_for_inference:
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
dest_vars = [i._node for i in dest_vars]
metadata = SerializationMetadata()
if enable_metadata:
......@@ -910,6 +1253,9 @@ class trace:
if keep_opr_priority:
graph._set_priority_to_id(dest_vars)
if input_data is not None:
file.write(b"mgbtest0")
file.write(struct.pack("I", len(feeds["testcases"])))
dump_content, dump_info = G.dump_graph(
dest_vars,
keep_var_name=keep_var_name,
......@@ -921,6 +1267,34 @@ class trace:
metadata=metadata,
)
file.write(dump_content)
if input_data is not None:
inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy")
inputs = sorted((i.name, i.dtype) for i in inputs)
def make_dev_tensor(value, dtype=None, device=None):
return tensor(value, dtype=dtype, device=device)._dev_tensor()
for testcase in feeds["testcases"]:
assert isinstance(testcase, dict)
cg = G.Graph()
output_mgbvars = []
for name, dtype in inputs:
output_mgbvars.append(
cg.make_const(
make_dev_tensor(
testcase.pop(name), dtype=dtype, device="cpux"
)
)
)
assert not testcase, "extra inputs provided in testcase: {}".format(
testcase.keys()
)
dump_content, _ = G.dump_graph(
output_mgbvars, strip_info_file=strip_info_file, append_json=True,
)
file.write(dump_content)
return dump_info
def _process_inputs(self, *args, **kwargs):
......
......@@ -287,6 +287,16 @@ def test_dump_backward_graph():
np.testing.assert_equal(results[1], dx0)
def test_dump_with_testcase():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return exp(x)
f(tensor(1.0))
file = io.BytesIO()
f.dump(file, input_data=["#rand(0, 255, 1)"])
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True)
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import os
import re
import struct
import cv2
import numpy as np
import megengine as mge
import megengine.core._imperative_rt as rt
import megengine.core.tensor.megbrain_graph as G
from megengine import tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.utils import comp_graph_tools as cgtools
logger = mge.get_logger(__name__)
def auto_reformat_image(args, path, data, dst_shape):
"""reformat image to target shape
:param data: image data as numpy array
:param dst_shape: target shape
"""
dim3_format = False # required input format does not contain batch
hwc_format = False # required input format is NHWC
if not dst_shape: # input tensor shape is not predefined
if len(data.shape) == 2:
chl = 1
h = data.shape[0]
w = data.shape[1]
else:
assert len(data.shape) == 3, "Input image must be of dimension 2 or 3"
h, w, chl = data.shape
dst_shape = (1, chl, h, w)
if len(dst_shape) == 3:
dst_shape = (1,) + dst_shape
dim3_format = True
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
chl = dst_shape[1]
if chl in [1, 3]:
n, c, h, w = dst_shape
dst_shape = (n, h, w, c)
else:
chl = dst_shape[3]
assert chl in [1, 3], "can not infer input format from shape: {}".format(
dst_shape
)
hwc_format = True
# dst_shape has now been normalized to NHWC format
if args.resize_input:
h, w = dst_shape[1:3]
data = cv2.resize(data, (w, h))
logger.info("input {} resized to {}".format(path, data.shape))
if chl == 1:
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
data = data[:, :, np.newaxis]
assert data.ndim == 3
data = data[np.newaxis]
# data normalized to NHWC format
if not hwc_format:
data = np.transpose(data, (0, 3, 1, 2))
if dim3_format:
data = np.squeeze(data, 0)
return data
def read_input_data(args, dst_shape, dtype, path, repeat):
def check_shape_equal(dst_shape, data_shape):
if len(dst_shape):
assert len(data_shape) == len(
dst_shape
), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape)
if data_shape[1:] != dst_shape[1:]:
logger.warning(
"dst_shape is {}; data_shape is {}".format(dst_shape, data_shape)
)
if path.startswith("#"):
assert not args.resize_input
assert not args.input_transform
spec = path
m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec)
assert m, "bad spec {}".format(spec)
rng_min = float(m.group(1))
rng_max = float(m.group(2))
if m.group(3):
shape_str = m.group(3)
try:
shape = shape_str[1:].split(",")
if shape[-1].strip() == "...":
shape = shape[:-1]
shape.extend(list(dst_shape[len(shape) :]))
data_shape = tuple(map(int, shape))
except ValueError as e:
raise ValueError("bad spec {}: {}".format(spec, e.args))
else:
data_shape = dst_shape
check_shape_equal(dst_shape, data_shape)
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
# try to load image
data = cv2.imread(path, cv2.IMREAD_COLOR)
if data is None:
assert not args.resize_input
data = np.load(path)
assert isinstance(data, np.ndarray)
else:
# load image succeeds, so we expect input format is image format
data = auto_reformat_image(args, path, data, dst_shape)
data = np.repeat(data, repeat, axis=0)
if repeat > 1:
logger.info(
"repeat input for {} times, data shape is {}".format(repeat, data.shape)
)
check_shape_equal(dst_shape, data.shape)
if args.input_transform:
data = eval(args.input_transform, {"data": data, "np": np})
return data
def gen_one_testcase(args, inputs, spec):
paths = spec.split(";")
if len(paths) != len(inputs):
if len(paths) == 1 and paths[0].startswith("#"):
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format(
inputs.keys(), paths
)
if len(paths) == 1 and ":" not in paths[0]:
paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
ret = {}
for path in paths:
var, path = path.split(":")
if args.repeat:
repeat = args.repeat
else:
repeat = 1
ret[var] = read_input_data(
args, inputs[var].shape, inputs[var].dtype, path, repeat
)
return ret
def make_feeds(args):
ret = G.load_graph(args.input)
cg_rt, outputs = ret.graph, ret.output_vars_list
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
inputs = {i.name: i for i in inputs}
if not args.no_assert:
replace_varmap = {}
inp_map = {}
# replace var use InputNode
for name, var in inputs.items():
inp = G.InputNode(
device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt
)
replace_varmap[var] = inp.outputs[0]
inp_map[name] = inp
new = cgtools.replace_vars(outputs, replace_varmap)
if isinstance(new, rt.VarNode):
new = list(new)
output_nodes = [G.OutputNode(var) for var in new]
func = cg_rt.compile([node.outputs[0] for node in output_nodes])
def make_dev_tensor(value, dtype=None, device=None):
return tensor(value, dtype=dtype, device=device)._dev_tensor()
def calculate(*args, **kwargs):
output_val = []
# set inputs value
for name, var in inputs.items():
val = kwargs.pop(name, None)
assert val is not None, "miss input name{}".format(name)
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
inp_map[name].set_value(dev_tensor)
func.execute()
for res in output_nodes:
output_val.append(res.get_value().numpy())
return output_val
def expect_name(var):
return "{}:expect".format(var.name)
testcases = []
np.set_printoptions(precision=2, threshold=4, suppress=True)
data_list = []
for item in args.data:
if item.startswith("@"):
with open(item[1:], "r") as f:
data_list.extend([line.rstrip() for line in f if line.rstrip() != ""])
else:
data_list.append(item)
for inp_spec in data_list:
cur_testcase = gen_one_testcase(args, inputs, inp_spec)
assert len(cur_testcase) == len(
inputs
), "required inputs: {}; given data: {}".format(
inputs.keys(), cur_testcase.keys()
)
if not args.no_assert:
outputs_get = calculate(**cur_testcase)
for var, val in zip(outputs, outputs_get):
cur_testcase[expect_name(var)] = val
logger.info(
"generate test groundtruth: var={} shape={} range=({}, {})"
" mean={} var={}".format(
var, val.shape, val.min(), val.max(), np.mean(val), np.var(val)
)
)
testcases.append(cur_testcase)
logger.info(
"add testcase: \n {}".format(
"\n ".join(
"{}: shape={} dtype={} range=({:.2f},{:.2f}) "
"mean={:.2f} sd={:.2f}".format(
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
)
for k, v in sorted(cur_testcase.items())
)
)
)
if not args.no_assert:
def expect_shp(var):
ret = var.shape
if ret:
return ret
return testcases[0][expect_name(var)].shape
def assert_equal(expect, real, **kwargs):
op = builtin.AssertEqual(**kwargs)
(res,) = apply(op, expect, real)
return res
verbose = not args.silent
outputs_new = []
for i in outputs:
device = rt.CompNode("xpux")
dtype = i.dtype
name = expect_name(i)
shape = expect_shp(i)
# make expect output as one input of model.
expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name)
# insert assert opr to check expect and real.
outputs_new.append(
assert_equal(
expect_get,
i,
verbose=verbose,
maxerr=args.maxerr,
)
)
inputs[expect_name(i)] = expect_get
outputs = outputs_new
return {"outputs": outputs, "testcases": testcases}
def optimize_for_inference(args, outputs):
args_list = [
"enable_io16xc32",
"enable_ioc16",
"enable_hwcd4",
"enable_nchw4",
"enable_nchw88",
"enable_nchw44",
"enable_nchw44_dot",
"enable_nchw32",
"enable_chwn4",
"enable_fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z",
"enable_fuse_preprocess",
]
kwargs = {}
for k in args_list:
if getattr(args, k):
kwargs[k] = True
if args.optimize_for_inference:
outputs = G.optimize_for_inference(outputs, **kwargs)
return outputs
def main():
parser = argparse.ArgumentParser(
description="Pack computing graph, input values and expected output "
"values into one file for checking correctness. README.md gives more "
"details on the usage",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("input", help="MegEngine dumped model file")
parser.add_argument("-o", "--output", help="output file", required=True)
parser.add_argument(
"-d",
"--data",
default=[],
action="append",
required=True,
help="Given input test data when input file is a network, "
"and current network output would be used as groundtruth. "
"The format is var0:file0;var1:file1... to specify data files for "
"input vars. It can also be #rand(min,max,shape...) for generating "
"random input data, for example, #rand(0,255), "
"#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means "
"the remaining part of the original shape. "
"If the shape is not specified, the shape of "
"corresponding input tensors in the network will be used. "
"If there is only one input var, its name can be omitted. "
"Each data file can either be an image which can be loaded by opencv, "
"or a pickled numpy.ndarray. "
"This option can be given multiple times to add multiple testcases. "
" *NOTE* "
"If you start the data with the letter @, the rest should be a "
"filename, and each line in the file should be a single datum in "
"the format described above. ",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="Specify how many times the input image is repeated. "
"Useful when running benchmark for batch size other than one. "
"Have no effect on randomly generated input data.",
)
parser.add_argument(
"--silent",
action="store_true",
help="set verbose to False in asserti_equal opr",
)
parser.add_argument(
"--optimize-for-inference",
action="store_true",
help="enable optimization for inference",
)
parser.add_argument(
"--no-assert",
action="store_true",
help="do not insert assert_equal opr to check result; "
"this option is useful for benchmarking",
)
parser.add_argument(
"--maxerr",
type=float,
default=1e-4,
help="max error for assert_equal check during runtime",
)
parser.add_argument(
"--resize-input",
action="store_true",
help="resize input image to fit input var shape",
)
parser.add_argument(
"--input-transform",
help="a python expression to transform the input data. "
"Example: data / np.std(data)",
)
parser.add_argument(
"--discard-var-name",
action="store_true",
help="discard variable and param names in the " "generated output",
)
parser.add_argument(
"--output-strip-info", action="store_true", help="output code strip information"
)
parser.add_argument(
"--enable-io16xc32",
action="store_true",
help="transform the mode to float16 io float32 compute",
)
parser.add_argument(
"--enable-ioc16",
action="store_true",
help="transform the dtype of the model to float16 io " "and compute",
)
parser.add_argument(
"--enable-fuse-conv-bias-nonlinearity",
action="store_true",
help="fuse convolution bias and nonlinearity opr to a "
"conv_bias opr and compute",
)
parser.add_argument(
"--enable-hwcd4",
action="store_true",
help="transform the model format from NCHW to NHWCD4 "
"for inference; you may need to disable CUDA and set "
"MGB_USE_MEGDNN_DBG=2",
)
parser.add_argument(
"--enable-nchw4",
action="store_true",
help="transform the model format from NCHW to NCHW4 " "for inference",
)
parser.add_argument(
"--enable-nchw88",
action="store_true",
help="transform the model format from NCHW to NCHW88 " "for inference",
)
parser.add_argument(
"--enable-nchw44",
action="store_true",
help="transform the model format from NCHW to NCHW44 " "for inference",
)
parser.add_argument(
"--enable-nchw44-dot",
action="store_true",
help="transform the model format from NCHW to NCHW44_DOT "
"for optimizing armv8.2 dot in inference",
)
parser.add_argument(
"--enable-nchw32",
action="store_true",
help="transform the model format from NCHW4 to NCHW32 "
"for inference on nvidia TensoCore",
)
parser.add_argument(
"--enable-chwn4",
action="store_true",
help="transform the model format to CHWN4 "
"for inference, mainly used for nvidia tensorcore",
)
parser.add_argument(
"--enable-fuse-conv-bias-with-z",
action="store_true",
help="fuse conv_bias with z input for inference on "
"nvidia GPU (this optimization pass will result in mismatch "
"of the precision of output of training and inference)",
)
parser.add_argument(
"--enable-fuse-preprocess",
action="store_true",
help="fuse astype\pad_channel\dimshuffle and etc opr "
"from h2d opr",
)
args = parser.parse_args()
feeds = make_feeds(args)
assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty"
output_mgbvars = feeds["outputs"]
output_mgbvars = optimize_for_inference(args, output_mgbvars)
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
inputs = sorted((i.name, i.dtype) for i in inputs)
if args.discard_var_name:
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
else:
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)
strip_info_file = args.output + ".json" if args.output_strip_info else None
with open(args.output, "wb") as fout:
fout.write(b"mgbtest0")
fout.write(struct.pack("I", len(feeds["testcases"])))
dump_content, stat = G.dump_graph(
output_mgbvars,
append_json=True,
strip_info_file=strip_info_file,
**sereg_kwargs,
)
fout.write(dump_content)
logger.info(
"graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format(
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024
)
)
def make_dev_tensor(value, dtype=None, device=None):
return tensor(value, dtype=dtype, device=device)._dev_tensor()
for testcase in feeds["testcases"]:
assert isinstance(testcase, dict)
cg = G.Graph()
output_mgbvars = []
for name, dtype in inputs:
output_mgbvars.append(
cg.make_const(
make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux")
)
)
assert not testcase, "extra inputs provided in testcase: {}".format(
testcase.keys()
)
with open(args.output, "ab") as fout:
dump_content, _ = G.dump_graph(
output_mgbvars, strip_info_file=strip_info_file, append_json=True
)
fout.write(dump_content)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册