diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 3887c167fab54d959f74d556f291b84727db05b9..0799c726e1001a38d35d40956fc5923630f8ba34 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -11,7 +11,7 @@ import json import os import weakref from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs): * enable_chwn4 -- whether to use CHWN4 data layout, currently used in nvidia backend with tensorcore. + * enable_nchw64 -- + whether to use NCHW64 data layout, used for fast int4 + support on Nvidia GPU. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty into one opr. @@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs): "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, + "enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64, } for k, v in inference_optimize_layout_transform_map.items(): @@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs): dest_vars = _unwrap(dest_vars) res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) - return _wrap(res_vars) + return _wrap(res_vars), inference_options.serialize() + + +def deserialize_infer_option(x: int) -> Dict[str, bool]: + r""" + Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``. + + :param x: inference options represented by int. + :return: inference options represented by dict. + """ + + inference_options = GraphOptimizeOptions.deserialize(x) + + inference_optimize_layout_transform_map = { + GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4", + GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4", + GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88", + GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32", + GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44", + GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot", + GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4", + GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64", + } + + ret = dict() + + layout = inference_options.layout_transform + if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT: + ret[inference_optimize_layout_transform_map[layout]] = True + + if inference_options.f16_io_f32_comp: + ret["enable_io16xc32"] = True + if inference_options.f16_io_comp: + ret["enable_ioc16"] = True + if inference_options.fuse_conv_bias_nonlinearity: + ret["enable_fuse_conv_bias_nonlinearity"] = True + if inference_options.fuse_conv_bias_with_z: + ret["enable_fuse_conv_bias_with_z"] = True + + return ret def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): @@ -331,7 +374,8 @@ def dump_graph( keep_param_name: bool = False, keep_opr_priority: bool = False, strip_info_file=None, - append_json=False + append_json=False, + metadata=None ) -> Tuple[bytes, CompGraphDumpResult]: """ serialize the computing graph of `output_vars` and get byte result. @@ -393,6 +437,7 @@ def dump_graph( keep_opr_name, keep_param_name, keep_opr_priority, + metadata, stat, inputs, outputs, @@ -427,7 +472,7 @@ def dump_graph( CompGraphLoadResult = collections.namedtuple( - "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] + "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"] ) @@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult: buf = open(fpath, "rb").read() else: buf = fpath.read() - cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) - return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) + cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) + return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata) def _wrap(x): diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 8d57776e0211b1f2d427aff575f0ec0bf4962dfd..31999f1cac2ca8bd43565e5c21a7da1714f90862 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -12,10 +12,12 @@ import functools import itertools import json import os +import pickle +from typing import Any import numpy as np -from ..core._imperative_rt import GraphProfiler +from ..core._imperative_rt import GraphProfiler, SerializationMetadata from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( TensorWeakRef, @@ -670,6 +672,8 @@ class trace: strip_info_file=None, append_json=False, optimize_for_inference=True, + user_info: Any = None, + enable_metadata: bool = True, **kwargs ): r""" @@ -697,6 +701,8 @@ class trace: if set false, will rewrite strip_info_file :param optimize_for_inference: enbale optmizations, will skip all optimize options if this is False. Default: True + :param user_info: any type object, which will be pickled to bytes. + :param enable_metadata: whether to save metadata into output file. :Keyword Arguments: @@ -729,6 +735,9 @@ class trace: * enable_chwn4 -- whether to use CHWN4 data layout, currently used in nvidia backend with tensorcore. + * enable_nchw64 -- + whether to use NCHW64 data layout, used for fast int4 + support on Nvidia GPU. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty into one opr. @@ -851,7 +860,15 @@ class trace: dest_vars.append(v) if optimize_for_inference: - dest_vars = G.optimize_for_inference(dest_vars, **kwargs) + dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) + + metadata = SerializationMetadata() + if enable_metadata: + metadata.user_info = pickle.dumps(user_info) + metadata.is_valid = True + metadata.graph_modified = False + if optimize_for_inference: + metadata.optimize_options = optimize_options if isinstance(file, str): permission = "wb" if append == False else "ab" @@ -864,6 +881,7 @@ class trace: keep_opr_priority=keep_opr_priority, strip_info_file=strip_info_file, append_json=append_json, + metadata=metadata, ) file.write(dump_content) return dump_info diff --git a/imperative/python/megengine/tools/load_network_and_run.py b/imperative/python/megengine/tools/load_network_and_run.py index ba9cd76fc01c6897b4ee2fa883ae703ae2676417..46a288476a594f3ef8c875dfafdac6ec7f1857ed 100755 --- a/imperative/python/megengine/tools/load_network_and_run.py +++ b/imperative/python/megengine/tools/load_network_and_run.py @@ -411,7 +411,8 @@ def main(): args.embed_input = True logger.info("loading model ...") - graph, _, output_vars = G.load_graph(args.net) + ret = G.load_graph(args.net) + graph, output_vars = ret.graph, ret.output_vars_list input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") if args.output_name is not None: diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index 614a8812aae61140b914209ccbc49ef02e11ebce..379b116745bd7143eedc5cdf1bd85d46ae111262 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -391,7 +391,8 @@ class GraphInference: optimize_for_inference: bool = False, **kwargs ): - self._graph, _, output_nodes = G.load_graph(file) + ret = G.load_graph(file) + self._graph, output_nodes = ret.graph, ret.output_vars_list if outputs is not None: output_nodes = find_vars_by_name(output_nodes, outputs) self._origin_outputs = output_nodes diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index bc56f00bb8ef3500bfd7fceacacce7288d1f35cd..4bb167e652884f1c9b84e2f1eacfb95083c1a3fc 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -9,14 +9,12 @@ import collections import fnmatch import itertools +import pickle import re from collections import OrderedDict -from typing import Dict, List, Sequence +from typing import Any, Dict, List, Sequence -import numpy as np - -from ..core._imperative_rt import ComputingGraph -from ..core._imperative_rt.core2 import SymbolVar +from ..core._imperative_rt import ComputingGraph, SerializationMetadata from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core.tensor import megbrain_graph as G from ..logger import get_logger @@ -42,6 +40,30 @@ class Network: self.all_oprs_map = OrderedDict() self.all_vars_map = OrderedDict() self.graph = ComputingGraph() + self._metadata = None + + @property + def metadata(self): + r""" + Load metadata as a dict. + """ + if not self._metadata.is_valid: + logger.info("metadata is not valid!") + return None + ret = dict() + try: + user_info = pickle.loads(self._metadata.user_info) + except: # pylint: disable=bare-except + logger.warning( + "can't parse user info by pickle, so return the original bytes object!" + ) + user_info = self._metadata.user_info + ret["user_info"] = user_info + ret["graph_modified"] = self._metadata.graph_modified + ret["optimized_for_inference"] = self._metadata.optimized_for_inference + if ret["optimized_for_inference"]: + ret.update(G.deserialize_infer_option(self._metadata.optimize_options)) + return ret @classmethod def load(cls, model_path: str, outspec: List[str] = None): @@ -51,7 +73,8 @@ class Network: :param outspec: only load the subgraph with outspec as its endpoints. """ self = cls() - _, _, outputs = G.load_graph(model_path) + ret = G.load_graph(model_path) + outputs, self._metadata = ret.output_vars_list, ret.metadata if outspec is not None: output_spec = outspec.copy() all_vars = get_dep_vars(outputs) + outputs @@ -125,6 +148,9 @@ class Network: * enable_chwn4 -- whether to use CHWN4 data layout, currently used in nvidia backend with tensorcore. + * enable_nchw64 -- + whether to use NCHW64 data layout, used for fast int4 + support on Nvidia GPU. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty into one opr. @@ -152,6 +178,8 @@ class Network: append_json=False, optimize_for_inference=True, append=False, + user_info: Any = None, + enable_metadata=True, **kwargs ): """ @@ -176,6 +204,8 @@ class Network: if set false, will rewrite strip_info_file :param optimize_for_inference: enbale optmizations, will skip all optimize options if this is False. Default: True + :param user_info: any type object, which will be pickled to bytes. + :param enable_metadata: whether to save metadata into output file. :Keyword Arguments: @@ -201,7 +231,15 @@ class Network: ) if optimize_for_inference: - out = G.optimize_for_inference(out, **kwargs) + out, optimize_options = G.optimize_for_inference(out, **kwargs) + + metadata = SerializationMetadata() + if enable_metadata: + metadata.is_valid = True + metadata.graph_modified = True + metadata.user_info = pickle.dumps(user_info) + if optimize_for_inference: + metadata.optimize_options = optimize_options dump_content, _ = G.dump_graph( out, @@ -211,6 +249,7 @@ class Network: keep_opr_priority=keep_opr_priority, strip_info_file=strip_info_file, append_json=append_json, + metadata=metadata, ) if isinstance(file, str): permission = "wb" if append == False else "ab" diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index a43f756f2090f2e90601821c5394ed6ecf45d1cd..837bc91f9d7cd274f1c35b5a5da3d1073d875517 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -34,6 +34,7 @@ namespace ser = mgb::serialization; using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; +using _SerializationMetadata = mgb::serialization::Metadata; namespace { class _CompGraphProfilerImpl { @@ -240,6 +241,8 @@ void init_graph_rt(py::module m) { auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") .def(py::init()) + .def("serialize", &_OptimizeForInferenceOptions::serialize) + .def_static("deserialize", &_OptimizeForInferenceOptions::deserialize) .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) @@ -256,6 +259,7 @@ void init_graph_rt(py::module m) { .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) .value("NCHW32", _LayoutTransform::NCHW32) .value("CHWN4", _LayoutTransform::CHWN4) + .value("NCHW64", _LayoutTransform::NCHW64) .export_values() ; @@ -307,12 +311,24 @@ void init_graph_rt(py::module m) { })->to_string(); }); + py::class_<_SerializationMetadata>(m, "SerializationMetadata") + .def(py::init()) + .def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); }, + &_SerializationMetadata::set_user_info) + .def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference) + .def_property("optimize_options", &_SerializationMetadata::get_optimize_options, + &_SerializationMetadata::set_optimize_options) + .def_readwrite("graph_modified", &_SerializationMetadata::graph_modified) + .def_readwrite("is_valid", &_SerializationMetadata::is_valid) + ; + m.def("dump_graph", []( const std::vector& dest_vars, int keep_var_name, bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, + std::optional<_SerializationMetadata> metadata, py::list& stat, py::list& inputs, py::list& outputs, @@ -325,7 +341,12 @@ void init_graph_rt(py::module m) { ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name}; - auto rst = dumper->dump(symvars, config); + ser::GraphDumper::DumpResult rst; + if (metadata) + rst = dumper->dump(symvars, config, *metadata); + else + rst = dumper->dump(symvars, config); + for (auto i : rst.inputs) { inputs.append(py::cast(i)); } @@ -377,8 +398,10 @@ void init_graph_rt(py::module m) { for (const auto& var : rst.output_var_list) { iter.add(var); } - return rst.graph; - + auto ret = py::tuple(2); + ret[0] = py::cast(rst.graph); + ret[1] = py::cast(rst.metadata); + return ret; }); #define CURRENT_CLASS cg::ComputingGraph::Options diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index f6a650f8f9f342c8158c734fbd7357d2bc551424..7bd10051b51c78e88071d47124a027b2ad11f6cc 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -239,8 +239,7 @@ def test_dump_volatile(): file = io.BytesIO() f.dump(file, optimize_for_inference=False) file.seek(0) - cg, _, outputs = G.load_graph(file) - (out,) = outputs + (out,) = G.load_graph(file).output_vars_list assert ( cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) == "ImmutableTensor" @@ -337,12 +336,12 @@ def test_goptions_log_exp(): f(tensor(1.0)) _, out = mkstemp() f.dump(out, optimize_for_inference=False) - *_, outputs = G.load_graph(out) + outputs = G.load_graph(out).output_vars_list oprs_1 = cgtools.get_oprs_seq(outputs) g(tensor(1.0)) g.dump(out, optimize_for_inference=False) - *_, outputs = G.load_graph(out) + outputs = G.load_graph(out).output_vars_list oprs_2 = cgtools.get_oprs_seq(outputs) assert len(oprs_1) - len(oprs_2) == 2 diff --git a/imperative/python/test/unit/utils/test_cgtools.py b/imperative/python/test/unit/utils/test_cgtools.py index 406a9c084429075d593db4f426f6eb4374e92282..c8b51daa6895231bbbfc8ce3c03eafa4489db8c9 100644 --- a/imperative/python/test/unit/utils/test_cgtools.py +++ b/imperative/python/test/unit/utils/test_cgtools.py @@ -88,7 +88,7 @@ def test_graph_traversal(): file = io.BytesIO() fun.dump(file, optimize_for_inference=False) file.seek(0) - cg, _, outputs = mgb_graph.load_graph(file) + outputs = mgb_graph.load_graph(file).output_vars_list _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) input_var = map_vars[1] @@ -101,7 +101,9 @@ def test_load_refcnt(): graph = mgb_graph.Graph() varnode = graph.make_const(0) buf, _ = mgb_graph.dump_graph([varnode]) - graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) + ret = mgb_graph.load_graph(io.BytesIO(buf)) + graph, (varnode,) = ret.graph, ret.output_vars_list + del ret del graph varnode.owner @@ -132,7 +134,7 @@ def test_get_opr_seq(): file = io.BytesIO() func.dump(file, optimize_for_inference=False) file.seek(0) - *_, outputs = mgb_graph.load_graph(file) + outputs = mgb_graph.load_graph(file).output_vars_list seq_1 = cgtools.get_oprs_seq(outputs, True) assert len(seq_1) == 5 diff --git a/imperative/python/test/unit/utils/test_dump_naming.py b/imperative/python/test/unit/utils/test_dump_naming.py index a546c338ab7ee3da70e4dddc9e4e0cee203647d0..f154fc82302d7f3ee83debf6022cac3a5817f59f 100644 --- a/imperative/python/test/unit/utils/test_dump_naming.py +++ b/imperative/python/test/unit/utils/test_dump_naming.py @@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): keep_var_name=2, ) file.seek(0) - *_, outputs = G.load_graph(file) + outputs = G.load_graph(file).output_vars_list ops = cgtools.get_oprs_seq(outputs) return ops @@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name): file = io.BytesIO() func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) file.seek(0) - *_, outputs = G.load_graph(file) + outputs = G.load_graph(file).output_vars_list op = cgtools.get_oprs_seq(outputs)[-1] assert op.inputs[0].name == var_name diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index a3935e362b31e8b7d4158a1b9618ad6becdfd39f..b77608725a1c28eeaa56d7335322f66ca0bcea82 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape from megengine.utils.network_node import Host2DeviceCopy, VarNode +def test_metadata(): + x = Tensor(0) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x): + return x * 2 + + fwd(x) + + orig_model = io.BytesIO() + fwd.dump(orig_model, user_info="test", optimize_for_inference=False) + orig_model.seek(0) + graph = Net.load(orig_model) + assert graph.metadata == { + "user_info": "test", + "graph_modified": False, # False: tracing.dump + "optimized_for_inference": False, + } + + orig_model.seek(0) + graph.dump( + orig_model, + user_info={"str": "x", "tensor": x, "module": M.Module, "none": None}, + optimize_for_inference=True, + enable_nchw4=True, + enable_ioc16=True, + ) + orig_model.seek(0) + graph = Net.load(orig_model) + assert graph.metadata == { + "user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, + "graph_modified": True, # True: Network.dump + "optimized_for_inference": True, + "enable_nchw4": True, + "enable_ioc16": True, + } + + orig_model.seek(0) + fwd.dump(orig_model, enable_metadata=False) + orig_model.seek(0) + graph = Net.load(orig_model) + assert graph.metadata is None + + def test_replace_var(): a = Tensor([1, 2]) diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 36c2bde90d3870c041c61739c1c2e55c84b76247..bc6f4a2ba01037e7364db9108cf188b2699c87fc 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec): def make_feeds(args): - cg_rt, _, outputs = G.load_graph(args.input) + ret = G.load_graph(args.input) + cg_rt, outputs = ret.graph, ret.output_vars_list inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") inputs = {i.name: i for i in inputs} diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 4a31372fce54c2d4f5c6313f4e95a8a1948dc204..d034fd8a9db099edb2382172132368085b785e7e 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -322,7 +322,31 @@ namespace gopt { static std::unique_ptr make_nchw44_dot_converter(); }; - struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {}; + struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { + uint64_t serialize() { + uint64_t ret = 0; + ret |= (uint64_t)layout_transform << 32; + if (f16_io_f32_comp) ret |= 1u; + if (f16_io_comp) ret |= 1u << 1; + if (fuse_conv_bias_nonlinearity) ret |= 1u << 2; + if (fuse_conv_bias_with_z) ret |= 1u << 3; + if (weight_preprocess) ret |= 1u << 4; + if (fuse_preprocess) ret |= 1u << 5; + return ret; + } + + static OptimizeForInferenceOptions deserialize(uint64_t buf) { + OptimizeForInferenceOptions ret; + ret.f16_io_f32_comp = buf & 1u; + ret.f16_io_comp = buf & 1u << 1; + ret.fuse_conv_bias_nonlinearity = buf & 1u << 2; + ret.fuse_conv_bias_with_z = buf & 1u << 3; + ret.weight_preprocess = buf & 1u << 4; + ret.fuse_preprocess = buf & 1u << 5; + ret.layout_transform = (LayoutTransform)(buf >> 32); + return ret; + } + }; /*! * \brief optimize a computing graph for inference diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 7b3f9847491eb5237ca5808a3df60a21024c3865..4696cd7a1613bbd2066abf40377ce1a346854303 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -128,6 +128,13 @@ table Operator { name:string; } +table Metadata { + is_valid:bool; + graph_modified:bool; + user_info:string; + optimize_options:ulong; +} + struct OutputVar { compact_id:uint; original_id:uint; @@ -141,6 +148,7 @@ table Graph { nr_shared_tensor:uint; oprs:[Operator]; output_vars_idx:[OutputVar]; + metadata:Metadata; } root_type Graph; diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index d2c0324c837422130b9078234804cf2688c99c50..dd9bafc83b5ec8e3d67d979fd9601df7a229ee31 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -30,6 +30,7 @@ #include "megbrain/serialization/internal/flatbuffers_helper.h" #include "megbrain/serialization/internal/schema_generated.h" #include "megbrain/serialization/opr_load_dump.h" +#include "megbrain/serialization/metadata.h" #include "megbrain/serialization/serializer.h" #include "megbrain/version.h" @@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { std::vector> m_cur_opr_param; void init_oprs_to_dump(const SymbolVarArray& endpoints); + flatbuffers::Offset build_metadata(const Metadata& metadata); flatbuffers::Offset build_single_opr( cg::OperatorNodeBase* opr, const OprRegistry* registry); @@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { public: GraphDumperOSS(std::unique_ptr file) : m_file{std::move(file)} {} DumpResult dump(const SymbolVarArray& output_vars, - const DumpConfig& config = {}) override; + const DumpConfig& config = {}, + const Metadata& metadata = {}) override; const GraphDumpConfig& config() const override { return m_config; } void dump_tensor(const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) override; @@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) { } } +flatbuffers::Offset GraphDumperOSS::build_metadata( + const Metadata& metadata) { + auto user_info = m_builder.CreateSharedString(metadata.user_info); + fbs::MetadataBuilder builder(m_builder); + builder.add_is_valid(metadata.is_valid); + builder.add_graph_modified(metadata.graph_modified); + builder.add_user_info(user_info); + builder.add_optimize_options(metadata.optimize_options); + return builder.Finish(); +} + flatbuffers::Offset GraphDumperOSS::build_single_opr( cg::OperatorNodeBase* opr, const OprRegistry* registry) { m_cur_opr = opr; @@ -282,7 +296,8 @@ flatbuffers::Offset GraphDumperOSS::build_single_opr( } GraphDumper::DumpResult GraphDumperOSS::dump( - const SymbolVarArray& output_vars, const DumpConfig& config) { + const SymbolVarArray& output_vars, + const DumpConfig& config, const Metadata& metadata) { mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); @@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump( uint64_t offset_to_fbs = 0; m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); + // Dump metadata + auto fbmeta = build_metadata(metadata); + // Dump operators init_oprs_to_dump(output_vars); std::vector> oprs; @@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump( graph.add_oprs(fb_oprs); graph.add_output_vars_idx(fb_output_vars); graph.add_nr_shared_tensor(m_nr_shared_tensor); + graph.add_metadata(fbmeta); m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); // Write actual offset_to_fbs @@ -531,6 +550,7 @@ public: mgb_assert(nr == 1); } + Metadata load_metadata(); LoadResult load_oprs(); CompNode load_comp_node(const fbs::CompNode* comp_node); @@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() { return sh_ptr_ref; } +Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() { + const auto* fbmeta = m_loader->m_graph->metadata(); + Metadata ret; + ret.is_valid = fbmeta->is_valid(); + ret.graph_modified = fbmeta->graph_modified(); + if (fbmeta->user_info()) { + ret.user_info = fbmeta->user_info()->str(); + ret.has_user_info = true; + } + if (fbmeta->optimize_options()) { + ret.optimize_options = fbmeta->optimize_options(); + ret.optimized_for_inference = true; + } + return ret; +} + void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( const fbs::Operator* fbopr) { m_cur_opr_tensor_cnt = 0; @@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, } OprLoadContextImpl ctx{this, m_graph->mgb_version()}; + auto metadata = ctx.load_metadata(); auto result = ctx.load_oprs(); + result.metadata = metadata; auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; auto cur = m_file->tell(); diff --git a/src/serialization/include/megbrain/serialization/metadata.h b/src/serialization/include/megbrain/serialization/metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..7110669b08cd8fd6971d557bc7b6fed8cccbcd71 --- /dev/null +++ b/src/serialization/include/megbrain/serialization/metadata.h @@ -0,0 +1,46 @@ +/** + * \file src/serialization/include/megbrain/serialization/metadata.h + * + * This file is part of MegBrain, a deep learning framework developed by Megvii. + * + * \brief MegEngine model's metadata + * + * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + */ +#pragma once + +#include + +namespace mgb { +namespace serialization { + +struct Metadata { + bool is_valid = false; + + bool graph_modified = false; + + bool has_user_info = false; + std::string user_info; + + bool optimized_for_inference = false; + uint64_t optimize_options; + +#define ADD_PROPERTY(type, name) \ + type get_##name() const { return name; } \ + void set_##name(type x) { \ + name = x; \ + has_##name = true; \ + } +ADD_PROPERTY(std::string, user_info) +#undef ADD_PROPERTY + + uint64_t get_optimize_options() { return optimize_options; } + void set_optimize_options(uint64_t value) { + optimized_for_inference = true; + optimize_options = value; + } +}; + +} // namespace serialization +} // namespace mgb \ No newline at end of file diff --git a/src/serialization/include/megbrain/serialization/serializer.h b/src/serialization/include/megbrain/serialization/serializer.h index 752fe7409f5901f69928b708acd0a8883460f963..d195aa7b875ac2dcdf437d28398d3ccbea676d64 100644 --- a/src/serialization/include/megbrain/serialization/serializer.h +++ b/src/serialization/include/megbrain/serialization/serializer.h @@ -15,6 +15,7 @@ #include "megbrain/serialization/dump_format.h" #include "megbrain/serialization/file.h" #include "megbrain/serialization/load_dump_config.h" +#include "megbrain/serialization/metadata.h" namespace mgb { namespace serialization { @@ -32,6 +33,9 @@ namespace serialization { //! expliit dtor decl to reduce binary size ~LoadResult() noexcept; + //! metadata + Metadata metadata; + using TensorMap = std::unordered_map< std::string, std::shared_ptr>; @@ -178,7 +182,8 @@ namespace serialization { virtual DumpResult dump( const SymbolVarArray &output_vars, - const DumpConfig &config = {}) = 0; + const DumpConfig &config = {}, + const Metadata &metadata = {}) = 0; virtual GraphDumpFormat format() const = 0; }; diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index 6df3bb4ab75407d00730f0dad51291d2a5b11926..345af596dc1f53bed512b5871ad60c6904849eef 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) { load(); } +TEST(TestSerializer2, Metadata) { + auto fname = GET_OUTPUT_FILE(); + TensorShape shape{2, 3}; + + auto dump = [&]() { + auto cn = CompNode::load("xpu0"); + auto host_x = std::make_shared(cn, shape), + host_y = std::make_shared(cn, shape); + auto graph = ComputingGraph::make(); + auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), + y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); + using Mode = opr::Elemwise::Mode; + auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); + + Metadata metadata; + metadata.user_info = "TEST_METADATA"; + metadata.has_user_info = true; + + auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + auto rst = dumper->dump({z.rename("z")}, {}, metadata); + }; + + auto load = [&]() { + HostTensorGenerator<> gen; + auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + auto rst = loader->load(); + auto metadata = rst.metadata; + int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA"); + EXPECT_EQ(cmp, 0); + }; + + dump(); + load(); +} + TEST(TestSerializer2, APlusB) { auto fname = GET_OUTPUT_FILE(); TensorShape shape{2, 3};