提交 be236642 编写于 作者: M Megvii Engine Team

feat(sdk/load_and_run): add output-strip-info for dump with testcase imperative

GitOrigin-RevId: 337d95c7c24ad514bfccdb92b63eb6d97659e743
上级 026af620
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import json
import os
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
......@@ -274,7 +275,8 @@ def dump_graph(
keep_var_name: int = 1,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None
strip_info_file=None,
append_json=False
):
"""serialize the computing graph of `output_vars` and get byte result.
......@@ -295,6 +297,9 @@ def dump_graph(
:param keep_opr_priority: whether to keep priority setting for operators
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
:param append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
:return: dump result as byte string, and an instance of namedtuple
:class:`CompGraphDumpResult`, whose fields are:
......@@ -342,10 +347,25 @@ def dump_graph(
if strip_info_file is not None:
if isinstance(strip_info_file, str):
strip_info_file = open(strip_info_file, "w")
strip_info = json.loads(_imperative_rt.get_info_for_strip(ov))
strip_info["hash"] = dump_info.content_hash
json.dump(strip_info, strip_info_file)
if not os.path.exists(strip_info_file):
os.mknod(strip_info_file)
strip_info_file = open(strip_info_file, "r+")
new_strip_dict = json.loads(_imperative_rt.get_info_for_strip(ov))
ori_strip_dict = new_strip_dict
json_content = strip_info_file.read()
if append_json and len(json_content) != 0:
# if there are contents in json file. Read them first and then append new information
ori_strip_dict = json.loads(json_content)
for k in ori_strip_dict:
new_strip_dict_v = new_strip_dict.get(k)
if new_strip_dict_v is not None:
for value in new_strip_dict_v:
if not value in ori_strip_dict[k]:
ori_strip_dict[k].append(value)
ori_strip_dict["hash"] = dump_info.content_hash
strip_info_file.seek(0)
strip_info_file.truncate()
json.dump(ori_strip_dict, strip_info_file)
return dump_content, dump_info
......
......@@ -267,7 +267,7 @@ void init_graph_rt(py::module m) {
{"opr_types", to_json(opr_types)},
{"dtypes", to_json(dtype_names)},
{"elemwise_modes", to_json(elemwise_modes)},
});
})->to_string();
});
m.def("dump_graph", [](
......
......@@ -17,10 +17,10 @@ import numpy as np
import megengine as mge
import megengine.core._imperative_rt as rt
import megengine.core.tensor.megbrain_graph as G
from megengine.core.tensor.megbrain_graph import VarNode
from megengine import cgtools
from megengine.core.ops import builtin
from megengine.core.tensor.core import apply
from megengine.core.tensor.megbrain_graph import VarNode
from megengine.core.tensor.raw_tensor import as_raw_tensor
logger = mge.get_logger(__name__)
......@@ -485,13 +485,30 @@ def main():
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
else:
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)
strip_info_file = args.output + '.json' if args.output_strip_info else None
with open(args.output, "wb") as fout:
fout.write(b"mgbtest0")
fout.write(struct.pack("I", len(feeds["testcases"])))
dump_content, _ = G.dump_graph([VarNode(i) for i in output_mgbvars])
if isinstance(output_mgbvars, dict):
wrap_output_vars = dict([(i,VarNode(j)) for i,j in output_mgbvars])
else:
wrap_output_vars = [VarNode(i) for i in output_mgbvars]
dump_content, stat = G.dump_graph(
wrap_output_vars,
append_json=True,
strip_info_file=strip_info_file,
**sereg_kwargs)
fout.write(dump_content)
logger.info(
'graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'.format(
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024
)
)
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
......@@ -509,8 +526,11 @@ def main():
testcase.keys()
)
with open(args.output, "ab") as fout:
dump_content, _ = G.dump_graph(output_mgbvars)
fout.write(dump_content)
dump_content, _ = G.dump_graph(
output_mgbvars,
strip_info_file = strip_info_file,
append_json=True)
fout.write(dump_content)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册