diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 5cddd7734989ebeabcd388e7990c7f207edd50fb..48ab55c2362b857e00e402b067422c1a0c57d6d2 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import json +import os import threading import weakref from concurrent.futures import Future, ThreadPoolExecutor @@ -274,7 +275,8 @@ def dump_graph( keep_var_name: int = 1, keep_param_name: bool = False, keep_opr_priority: bool = False, - strip_info_file=None + strip_info_file=None, + append_json=False ): """serialize the computing graph of `output_vars` and get byte result. @@ -295,6 +297,9 @@ def dump_graph( :param keep_opr_priority: whether to keep priority setting for operators :param strip_info_file: a string for path or a file handler. if is not None, then the dump information for code strip would be written to ``strip_info_file`` + :param append_json: will be check when `strip_info_file` is not None. if set + true, the information for code strip will be append to strip_info_file. + if set false, will rewrite strip_info_file :return: dump result as byte string, and an instance of namedtuple :class:`CompGraphDumpResult`, whose fields are: @@ -342,10 +347,25 @@ def dump_graph( if strip_info_file is not None: if isinstance(strip_info_file, str): - strip_info_file = open(strip_info_file, "w") - strip_info = json.loads(_imperative_rt.get_info_for_strip(ov)) - strip_info["hash"] = dump_info.content_hash - json.dump(strip_info, strip_info_file) + if not os.path.exists(strip_info_file): + os.mknod(strip_info_file) + strip_info_file = open(strip_info_file, "r+") + new_strip_dict = json.loads(_imperative_rt.get_info_for_strip(ov)) + ori_strip_dict = new_strip_dict + json_content = strip_info_file.read() + if append_json and len(json_content) != 0: + # if there are contents in json file. Read them first and then append new information + ori_strip_dict = json.loads(json_content) + for k in ori_strip_dict: + new_strip_dict_v = new_strip_dict.get(k) + if new_strip_dict_v is not None: + for value in new_strip_dict_v: + if not value in ori_strip_dict[k]: + ori_strip_dict[k].append(value) + ori_strip_dict["hash"] = dump_info.content_hash + strip_info_file.seek(0) + strip_info_file.truncate() + json.dump(ori_strip_dict, strip_info_file) return dump_content, dump_info diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index d6a9bba8aeea5635f6f5b8c7bdfea4ef661b9768..f8c5fed77398c10adc7e228b33d5bcb0e492ef45 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -267,7 +267,7 @@ void init_graph_rt(py::module m) { {"opr_types", to_json(opr_types)}, {"dtypes", to_json(dtype_names)}, {"elemwise_modes", to_json(elemwise_modes)}, - }); + })->to_string(); }); m.def("dump_graph", []( diff --git a/sdk/load-and-run/dump_with_testcase_imperative.py b/sdk/load-and-run/dump_with_testcase_imperative.py index b7a347ad690ecc7171e6a0290885ae471ee458f3..0b3cce776d236f7e4643b5df5d0e85b20d75c2d6 100755 --- a/sdk/load-and-run/dump_with_testcase_imperative.py +++ b/sdk/load-and-run/dump_with_testcase_imperative.py @@ -17,10 +17,10 @@ import numpy as np import megengine as mge import megengine.core._imperative_rt as rt import megengine.core.tensor.megbrain_graph as G -from megengine.core.tensor.megbrain_graph import VarNode from megengine import cgtools from megengine.core.ops import builtin from megengine.core.tensor.core import apply +from megengine.core.tensor.megbrain_graph import VarNode from megengine.core.tensor.raw_tensor import as_raw_tensor logger = mge.get_logger(__name__) @@ -485,13 +485,30 @@ def main(): sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) else: sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) + + + strip_info_file = args.output + '.json' if args.output_strip_info else None with open(args.output, "wb") as fout: fout.write(b"mgbtest0") fout.write(struct.pack("I", len(feeds["testcases"]))) - dump_content, _ = G.dump_graph([VarNode(i) for i in output_mgbvars]) + if isinstance(output_mgbvars, dict): + wrap_output_vars = dict([(i,VarNode(j)) for i,j in output_mgbvars]) + else: + wrap_output_vars = [VarNode(i) for i in output_mgbvars] + dump_content, stat = G.dump_graph( + wrap_output_vars, + append_json=True, + strip_info_file=strip_info_file, + **sereg_kwargs) fout.write(dump_content) + logger.info( + 'graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'.format( + stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 + ) + ) + def make_dev_tensor(value, dtype=None, device=None): return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() @@ -509,8 +526,11 @@ def main(): testcase.keys() ) with open(args.output, "ab") as fout: - dump_content, _ = G.dump_graph(output_mgbvars) - fout.write(dump_content) + dump_content, _ = G.dump_graph( + output_mgbvars, + strip_info_file = strip_info_file, + append_json=True) + fout.write(dump_content)