提交 54a4d70e 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(src/serialization): add support of serializing metadata

GitOrigin-RevId: b563c94451b06055d53c99a85bb5689b3f907365
上级 721091fa
...@@ -11,7 +11,7 @@ import json ...@@ -11,7 +11,7 @@ import json
import os import os
import weakref import weakref
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs): ...@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_chwn4 -- * enable_chwn4 --
whether to use CHWN4 data layout, currently whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore. used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. into one opr.
...@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs): ...@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs):
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64,
} }
for k, v in inference_optimize_layout_transform_map.items(): for k, v in inference_optimize_layout_transform_map.items():
...@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs): ...@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs):
dest_vars = _unwrap(dest_vars) dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return _wrap(res_vars) return _wrap(res_vars), inference_options.serialize()
def deserialize_infer_option(x: int) -> Dict[str, bool]:
r"""
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.
:param x: inference options represented by int.
:return: inference options represented by dict.
"""
inference_options = GraphOptimizeOptions.deserialize(x)
inference_optimize_layout_transform_map = {
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4",
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4",
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88",
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32",
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44",
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot",
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4",
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64",
}
ret = dict()
layout = inference_options.layout_transform
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT:
ret[inference_optimize_layout_transform_map[layout]] = True
if inference_options.f16_io_f32_comp:
ret["enable_io16xc32"] = True
if inference_options.f16_io_comp:
ret["enable_ioc16"] = True
if inference_options.fuse_conv_bias_nonlinearity:
ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True
return ret
def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
...@@ -331,7 +374,8 @@ def dump_graph( ...@@ -331,7 +374,8 @@ def dump_graph(
keep_param_name: bool = False, keep_param_name: bool = False,
keep_opr_priority: bool = False, keep_opr_priority: bool = False,
strip_info_file=None, strip_info_file=None,
append_json=False append_json=False,
metadata=None
) -> Tuple[bytes, CompGraphDumpResult]: ) -> Tuple[bytes, CompGraphDumpResult]:
""" """
serialize the computing graph of `output_vars` and get byte result. serialize the computing graph of `output_vars` and get byte result.
...@@ -393,6 +437,7 @@ def dump_graph( ...@@ -393,6 +437,7 @@ def dump_graph(
keep_opr_name, keep_opr_name,
keep_param_name, keep_param_name,
keep_opr_priority, keep_opr_priority,
metadata,
stat, stat,
inputs, inputs,
outputs, outputs,
...@@ -427,7 +472,7 @@ def dump_graph( ...@@ -427,7 +472,7 @@ def dump_graph(
CompGraphLoadResult = collections.namedtuple( CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"]
) )
...@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult: ...@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult:
buf = open(fpath, "rb").read() buf = open(fpath, "rb").read()
else: else:
buf = fpath.read() buf = fpath.read()
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata)
def _wrap(x): def _wrap(x):
......
...@@ -12,10 +12,12 @@ import functools ...@@ -12,10 +12,12 @@ import functools
import itertools import itertools
import json import json
import os import os
import pickle
from typing import Any
import numpy as np import numpy as np
from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler, SerializationMetadata
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
TensorWeakRef, TensorWeakRef,
...@@ -670,6 +672,8 @@ class trace: ...@@ -670,6 +672,8 @@ class trace:
strip_info_file=None, strip_info_file=None,
append_json=False, append_json=False,
optimize_for_inference=True, optimize_for_inference=True,
user_info: Any = None,
enable_metadata: bool = True,
**kwargs **kwargs
): ):
r""" r"""
...@@ -697,6 +701,8 @@ class trace: ...@@ -697,6 +701,8 @@ class trace:
if set false, will rewrite strip_info_file if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations, :param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.
:Keyword Arguments: :Keyword Arguments:
...@@ -729,6 +735,9 @@ class trace: ...@@ -729,6 +735,9 @@ class trace:
* enable_chwn4 -- * enable_chwn4 --
whether to use CHWN4 data layout, currently whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore. used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. into one opr.
...@@ -851,7 +860,15 @@ class trace: ...@@ -851,7 +860,15 @@ class trace:
dest_vars.append(v) dest_vars.append(v)
if optimize_for_inference: if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
metadata = SerializationMetadata()
if enable_metadata:
metadata.user_info = pickle.dumps(user_info)
metadata.is_valid = True
metadata.graph_modified = False
if optimize_for_inference:
metadata.optimize_options = optimize_options
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
...@@ -864,6 +881,7 @@ class trace: ...@@ -864,6 +881,7 @@ class trace:
keep_opr_priority=keep_opr_priority, keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file, strip_info_file=strip_info_file,
append_json=append_json, append_json=append_json,
metadata=metadata,
) )
file.write(dump_content) file.write(dump_content)
return dump_info return dump_info
......
...@@ -411,7 +411,8 @@ def main(): ...@@ -411,7 +411,8 @@ def main():
args.embed_input = True args.embed_input = True
logger.info("loading model ...") logger.info("loading model ...")
graph, _, output_vars = G.load_graph(args.net) ret = G.load_graph(args.net)
graph, output_vars = ret.graph, ret.output_vars_list
input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy")
if args.output_name is not None: if args.output_name is not None:
......
...@@ -391,7 +391,8 @@ class GraphInference: ...@@ -391,7 +391,8 @@ class GraphInference:
optimize_for_inference: bool = False, optimize_for_inference: bool = False,
**kwargs **kwargs
): ):
self._graph, _, output_nodes = G.load_graph(file) ret = G.load_graph(file)
self._graph, output_nodes = ret.graph, ret.output_vars_list
if outputs is not None: if outputs is not None:
output_nodes = find_vars_by_name(output_nodes, outputs) output_nodes = find_vars_by_name(output_nodes, outputs)
self._origin_outputs = output_nodes self._origin_outputs = output_nodes
......
...@@ -9,14 +9,12 @@ ...@@ -9,14 +9,12 @@
import collections import collections
import fnmatch import fnmatch
import itertools import itertools
import pickle
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Sequence from typing import Any, Dict, List, Sequence
import numpy as np from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..logger import get_logger from ..logger import get_logger
...@@ -42,6 +40,30 @@ class Network: ...@@ -42,6 +40,30 @@ class Network:
self.all_oprs_map = OrderedDict() self.all_oprs_map = OrderedDict()
self.all_vars_map = OrderedDict() self.all_vars_map = OrderedDict()
self.graph = ComputingGraph() self.graph = ComputingGraph()
self._metadata = None
@property
def metadata(self):
r"""
Load metadata as a dict.
"""
if not self._metadata.is_valid:
logger.info("metadata is not valid!")
return None
ret = dict()
try:
user_info = pickle.loads(self._metadata.user_info)
except: # pylint: disable=bare-except
logger.warning(
"can't parse user info by pickle, so return the original bytes object!"
)
user_info = self._metadata.user_info
ret["user_info"] = user_info
ret["graph_modified"] = self._metadata.graph_modified
ret["optimized_for_inference"] = self._metadata.optimized_for_inference
if ret["optimized_for_inference"]:
ret.update(G.deserialize_infer_option(self._metadata.optimize_options))
return ret
@classmethod @classmethod
def load(cls, model_path: str, outspec: List[str] = None): def load(cls, model_path: str, outspec: List[str] = None):
...@@ -51,7 +73,8 @@ class Network: ...@@ -51,7 +73,8 @@ class Network:
:param outspec: only load the subgraph with outspec as its endpoints. :param outspec: only load the subgraph with outspec as its endpoints.
""" """
self = cls() self = cls()
_, _, outputs = G.load_graph(model_path) ret = G.load_graph(model_path)
outputs, self._metadata = ret.output_vars_list, ret.metadata
if outspec is not None: if outspec is not None:
output_spec = outspec.copy() output_spec = outspec.copy()
all_vars = get_dep_vars(outputs) + outputs all_vars = get_dep_vars(outputs) + outputs
...@@ -125,6 +148,9 @@ class Network: ...@@ -125,6 +148,9 @@ class Network:
* enable_chwn4 -- * enable_chwn4 --
whether to use CHWN4 data layout, currently whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore. used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. into one opr.
...@@ -152,6 +178,8 @@ class Network: ...@@ -152,6 +178,8 @@ class Network:
append_json=False, append_json=False,
optimize_for_inference=True, optimize_for_inference=True,
append=False, append=False,
user_info: Any = None,
enable_metadata=True,
**kwargs **kwargs
): ):
""" """
...@@ -176,6 +204,8 @@ class Network: ...@@ -176,6 +204,8 @@ class Network:
if set false, will rewrite strip_info_file if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations, :param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.
:Keyword Arguments: :Keyword Arguments:
...@@ -201,7 +231,15 @@ class Network: ...@@ -201,7 +231,15 @@ class Network:
) )
if optimize_for_inference: if optimize_for_inference:
out = G.optimize_for_inference(out, **kwargs) out, optimize_options = G.optimize_for_inference(out, **kwargs)
metadata = SerializationMetadata()
if enable_metadata:
metadata.is_valid = True
metadata.graph_modified = True
metadata.user_info = pickle.dumps(user_info)
if optimize_for_inference:
metadata.optimize_options = optimize_options
dump_content, _ = G.dump_graph( dump_content, _ = G.dump_graph(
out, out,
...@@ -211,6 +249,7 @@ class Network: ...@@ -211,6 +249,7 @@ class Network:
keep_opr_priority=keep_opr_priority, keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file, strip_info_file=strip_info_file,
append_json=append_json, append_json=append_json,
metadata=metadata,
) )
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
......
...@@ -34,6 +34,7 @@ namespace ser = mgb::serialization; ...@@ -34,6 +34,7 @@ namespace ser = mgb::serialization;
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata;
namespace { namespace {
class _CompGraphProfilerImpl { class _CompGraphProfilerImpl {
...@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) { ...@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) {
auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
.def(py::init()) .def(py::init())
.def("serialize", &_OptimizeForInferenceOptions::serialize)
.def_static("deserialize", &_OptimizeForInferenceOptions::deserialize)
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
...@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) { ...@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) {
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
.value("NCHW32", _LayoutTransform::NCHW32) .value("NCHW32", _LayoutTransform::NCHW32)
.value("CHWN4", _LayoutTransform::CHWN4) .value("CHWN4", _LayoutTransform::CHWN4)
.value("NCHW64", _LayoutTransform::NCHW64)
.export_values() .export_values()
; ;
...@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) { ...@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) {
})->to_string(); })->to_string();
}); });
py::class_<_SerializationMetadata>(m, "SerializationMetadata")
.def(py::init())
.def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); },
&_SerializationMetadata::set_user_info)
.def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference)
.def_property("optimize_options", &_SerializationMetadata::get_optimize_options,
&_SerializationMetadata::set_optimize_options)
.def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
.def_readwrite("is_valid", &_SerializationMetadata::is_valid)
;
m.def("dump_graph", []( m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars, const std::vector<VarNode*>& dest_vars,
int keep_var_name, int keep_var_name,
bool keep_opr_name, bool keep_opr_name,
bool keep_param_name, bool keep_param_name,
bool keep_opr_priority, bool keep_opr_priority,
std::optional<_SerializationMetadata> metadata,
py::list& stat, py::list& stat,
py::list& inputs, py::list& inputs,
py::list& outputs, py::list& outputs,
...@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) { ...@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) {
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority, keep_opr_name}; keep_opr_priority, keep_opr_name};
auto rst = dumper->dump(symvars, config); ser::GraphDumper::DumpResult rst;
if (metadata)
rst = dumper->dump(symvars, config, *metadata);
else
rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) { for (auto i : rst.inputs) {
inputs.append(py::cast(i)); inputs.append(py::cast(i));
} }
...@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) { ...@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) {
for (const auto& var : rst.output_var_list) { for (const auto& var : rst.output_var_list) {
iter.add(var); iter.add(var);
} }
return rst.graph; auto ret = py::tuple(2);
ret[0] = py::cast(rst.graph);
ret[1] = py::cast(rst.metadata);
return ret;
}); });
#define CURRENT_CLASS cg::ComputingGraph::Options #define CURRENT_CLASS cg::ComputingGraph::Options
......
...@@ -239,8 +239,7 @@ def test_dump_volatile(): ...@@ -239,8 +239,7 @@ def test_dump_volatile():
file = io.BytesIO() file = io.BytesIO()
f.dump(file, optimize_for_inference=False) f.dump(file, optimize_for_inference=False)
file.seek(0) file.seek(0)
cg, _, outputs = G.load_graph(file) (out,) = G.load_graph(file).output_vars_list
(out,) = outputs
assert ( assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "ImmutableTensor" == "ImmutableTensor"
...@@ -337,12 +336,12 @@ def test_goptions_log_exp(): ...@@ -337,12 +336,12 @@ def test_goptions_log_exp():
f(tensor(1.0)) f(tensor(1.0))
_, out = mkstemp() _, out = mkstemp()
f.dump(out, optimize_for_inference=False) f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out) outputs = G.load_graph(out).output_vars_list
oprs_1 = cgtools.get_oprs_seq(outputs) oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0)) g(tensor(1.0))
g.dump(out, optimize_for_inference=False) g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out) outputs = G.load_graph(out).output_vars_list
oprs_2 = cgtools.get_oprs_seq(outputs) oprs_2 = cgtools.get_oprs_seq(outputs)
assert len(oprs_1) - len(oprs_2) == 2 assert len(oprs_1) - len(oprs_2) == 2
......
...@@ -88,7 +88,7 @@ def test_graph_traversal(): ...@@ -88,7 +88,7 @@ def test_graph_traversal():
file = io.BytesIO() file = io.BytesIO()
fun.dump(file, optimize_for_inference=False) fun.dump(file, optimize_for_inference=False)
file.seek(0) file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file) outputs = mgb_graph.load_graph(file).output_vars_list
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
input_var = map_vars[1] input_var = map_vars[1]
...@@ -101,7 +101,9 @@ def test_load_refcnt(): ...@@ -101,7 +101,9 @@ def test_load_refcnt():
graph = mgb_graph.Graph() graph = mgb_graph.Graph()
varnode = graph.make_const(0) varnode = graph.make_const(0)
buf, _ = mgb_graph.dump_graph([varnode]) buf, _ = mgb_graph.dump_graph([varnode])
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) ret = mgb_graph.load_graph(io.BytesIO(buf))
graph, (varnode,) = ret.graph, ret.output_vars_list
del ret
del graph del graph
varnode.owner varnode.owner
...@@ -132,7 +134,7 @@ def test_get_opr_seq(): ...@@ -132,7 +134,7 @@ def test_get_opr_seq():
file = io.BytesIO() file = io.BytesIO()
func.dump(file, optimize_for_inference=False) func.dump(file, optimize_for_inference=False)
file.seek(0) file.seek(0)
*_, outputs = mgb_graph.load_graph(file) outputs = mgb_graph.load_graph(file).output_vars_list
seq_1 = cgtools.get_oprs_seq(outputs, True) seq_1 = cgtools.get_oprs_seq(outputs, True)
assert len(seq_1) == 5 assert len(seq_1) == 5
......
...@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): ...@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
keep_var_name=2, keep_var_name=2,
) )
file.seek(0) file.seek(0)
*_, outputs = G.load_graph(file) outputs = G.load_graph(file).output_vars_list
ops = cgtools.get_oprs_seq(outputs) ops = cgtools.get_oprs_seq(outputs)
return ops return ops
...@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name): ...@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name):
file = io.BytesIO() file = io.BytesIO()
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
file.seek(0) file.seek(0)
*_, outputs = G.load_graph(file) outputs = G.load_graph(file).output_vars_list
op = cgtools.get_oprs_seq(outputs)[-1] op = cgtools.get_oprs_seq(outputs)[-1]
assert op.inputs[0].name == var_name assert op.inputs[0].name == var_name
......
...@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape ...@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape
from megengine.utils.network_node import Host2DeviceCopy, VarNode from megengine.utils.network_node import Host2DeviceCopy, VarNode
def test_metadata():
x = Tensor(0)
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return x * 2
fwd(x)
orig_model = io.BytesIO()
fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": "test",
"graph_modified": False, # False: tracing.dump
"optimized_for_inference": False,
}
orig_model.seek(0)
graph.dump(
orig_model,
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
optimize_for_inference=True,
enable_nchw4=True,
enable_ioc16=True,
)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, # True: Network.dump
"optimized_for_inference": True,
"enable_nchw4": True,
"enable_ioc16": True,
}
orig_model.seek(0)
fwd.dump(orig_model, enable_metadata=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata is None
def test_replace_var(): def test_replace_var():
a = Tensor([1, 2]) a = Tensor([1, 2])
......
...@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec): ...@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec):
def make_feeds(args): def make_feeds(args):
cg_rt, _, outputs = G.load_graph(args.input) ret = G.load_graph(args.input)
cg_rt, outputs = ret.graph, ret.output_vars_list
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
inputs = {i.name: i for i in inputs} inputs = {i.name: i for i in inputs}
......
...@@ -322,7 +322,31 @@ namespace gopt { ...@@ -322,7 +322,31 @@ namespace gopt {
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter(); static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter();
}; };
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {}; struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
uint64_t serialize() {
uint64_t ret = 0;
ret |= (uint64_t)layout_transform << 32;
if (f16_io_f32_comp) ret |= 1u;
if (f16_io_comp) ret |= 1u << 1;
if (fuse_conv_bias_nonlinearity) ret |= 1u << 2;
if (fuse_conv_bias_with_z) ret |= 1u << 3;
if (weight_preprocess) ret |= 1u << 4;
if (fuse_preprocess) ret |= 1u << 5;
return ret;
}
static OptimizeForInferenceOptions deserialize(uint64_t buf) {
OptimizeForInferenceOptions ret;
ret.f16_io_f32_comp = buf & 1u;
ret.f16_io_comp = buf & 1u << 1;
ret.fuse_conv_bias_nonlinearity = buf & 1u << 2;
ret.fuse_conv_bias_with_z = buf & 1u << 3;
ret.weight_preprocess = buf & 1u << 4;
ret.fuse_preprocess = buf & 1u << 5;
ret.layout_transform = (LayoutTransform)(buf >> 32);
return ret;
}
};
/*! /*!
* \brief optimize a computing graph for inference * \brief optimize a computing graph for inference
......
...@@ -128,6 +128,13 @@ table Operator { ...@@ -128,6 +128,13 @@ table Operator {
name:string; name:string;
} }
table Metadata {
is_valid:bool;
graph_modified:bool;
user_info:string;
optimize_options:ulong;
}
struct OutputVar { struct OutputVar {
compact_id:uint; compact_id:uint;
original_id:uint; original_id:uint;
...@@ -141,6 +148,7 @@ table Graph { ...@@ -141,6 +148,7 @@ table Graph {
nr_shared_tensor:uint; nr_shared_tensor:uint;
oprs:[Operator]; oprs:[Operator];
output_vars_idx:[OutputVar]; output_vars_idx:[OutputVar];
metadata:Metadata;
} }
root_type Graph; root_type Graph;
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "megbrain/serialization/internal/flatbuffers_helper.h" #include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h" #include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/version.h" #include "megbrain/version.h"
...@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { ...@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
std::vector<flatbuffers::Offset<void>> m_cur_opr_param; std::vector<flatbuffers::Offset<void>> m_cur_opr_param;
void init_oprs_to_dump(const SymbolVarArray& endpoints); void init_oprs_to_dump(const SymbolVarArray& endpoints);
flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata);
flatbuffers::Offset<fbs::Operator> build_single_opr( flatbuffers::Offset<fbs::Operator> build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry); cg::OperatorNodeBase* opr, const OprRegistry* registry);
...@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { ...@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
public: public:
GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
DumpResult dump(const SymbolVarArray& output_vars, DumpResult dump(const SymbolVarArray& output_vars,
const DumpConfig& config = {}) override; const DumpConfig& config = {},
const Metadata& metadata = {}) override;
const GraphDumpConfig& config() const override { return m_config; } const GraphDumpConfig& config() const override { return m_config; }
void dump_tensor(const std::string& name, const HostTensorND& tensor, void dump_tensor(const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) override; TensorWriteMethod method) override;
...@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) { ...@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
} }
} }
flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata(
const Metadata& metadata) {
auto user_info = m_builder.CreateSharedString(metadata.user_info);
fbs::MetadataBuilder builder(m_builder);
builder.add_is_valid(metadata.is_valid);
builder.add_graph_modified(metadata.graph_modified);
builder.add_user_info(user_info);
builder.add_optimize_options(metadata.optimize_options);
return builder.Finish();
}
flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry) { cg::OperatorNodeBase* opr, const OprRegistry* registry) {
m_cur_opr = opr; m_cur_opr = opr;
...@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( ...@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
} }
GraphDumper::DumpResult GraphDumperOSS::dump( GraphDumper::DumpResult GraphDumperOSS::dump(
const SymbolVarArray& output_vars, const DumpConfig& config) { const SymbolVarArray& output_vars,
const DumpConfig& config, const Metadata& metadata) {
mgb_throw_if(output_vars.empty(), SerializationError, mgb_throw_if(output_vars.empty(), SerializationError,
"Can't dump empty graph"); "Can't dump empty graph");
...@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump( ...@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint64_t offset_to_fbs = 0; uint64_t offset_to_fbs = 0;
m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));
// Dump metadata
auto fbmeta = build_metadata(metadata);
// Dump operators // Dump operators
init_oprs_to_dump(output_vars); init_oprs_to_dump(output_vars);
std::vector<flatbuffers::Offset<fbs::Operator>> oprs; std::vector<flatbuffers::Offset<fbs::Operator>> oprs;
...@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump( ...@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
graph.add_oprs(fb_oprs); graph.add_oprs(fb_oprs);
graph.add_output_vars_idx(fb_output_vars); graph.add_output_vars_idx(fb_output_vars);
graph.add_nr_shared_tensor(m_nr_shared_tensor); graph.add_nr_shared_tensor(m_nr_shared_tensor);
graph.add_metadata(fbmeta);
m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier());
// Write actual offset_to_fbs // Write actual offset_to_fbs
...@@ -531,6 +550,7 @@ public: ...@@ -531,6 +550,7 @@ public:
mgb_assert(nr == 1); mgb_assert(nr == 1);
} }
Metadata load_metadata();
LoadResult load_oprs(); LoadResult load_oprs();
CompNode load_comp_node(const fbs::CompNode* comp_node); CompNode load_comp_node(const fbs::CompNode* comp_node);
...@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() { ...@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
return sh_ptr_ref; return sh_ptr_ref;
} }
Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() {
const auto* fbmeta = m_loader->m_graph->metadata();
Metadata ret;
ret.is_valid = fbmeta->is_valid();
ret.graph_modified = fbmeta->graph_modified();
if (fbmeta->user_info()) {
ret.user_info = fbmeta->user_info()->str();
ret.has_user_info = true;
}
if (fbmeta->optimize_options()) {
ret.optimize_options = fbmeta->optimize_options();
ret.optimized_for_inference = true;
}
return ret;
}
void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
const fbs::Operator* fbopr) { const fbs::Operator* fbopr) {
m_cur_opr_tensor_cnt = 0; m_cur_opr_tensor_cnt = 0;
...@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, ...@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
} }
OprLoadContextImpl ctx{this, m_graph->mgb_version()}; OprLoadContextImpl ctx{this, m_graph->mgb_version()};
auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs(); auto result = ctx.load_oprs();
result.metadata = metadata;
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
auto cur = m_file->tell(); auto cur = m_file->tell();
......
/**
* \file src/serialization/include/megbrain/serialization/metadata.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \brief MegEngine model's metadata
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include <string>
namespace mgb {
namespace serialization {
struct Metadata {
bool is_valid = false;
bool graph_modified = false;
bool has_user_info = false;
std::string user_info;
bool optimized_for_inference = false;
uint64_t optimize_options;
#define ADD_PROPERTY(type, name) \
type get_##name() const { return name; } \
void set_##name(type x) { \
name = x; \
has_##name = true; \
}
ADD_PROPERTY(std::string, user_info)
#undef ADD_PROPERTY
uint64_t get_optimize_options() { return optimize_options; }
void set_optimize_options(uint64_t value) {
optimized_for_inference = true;
optimize_options = value;
}
};
} // namespace serialization
} // namespace mgb
\ No newline at end of file
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/serialization/dump_format.h" #include "megbrain/serialization/dump_format.h"
#include "megbrain/serialization/file.h" #include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h" #include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/metadata.h"
namespace mgb { namespace mgb {
namespace serialization { namespace serialization {
...@@ -32,6 +33,9 @@ namespace serialization { ...@@ -32,6 +33,9 @@ namespace serialization {
//! expliit dtor decl to reduce binary size //! expliit dtor decl to reduce binary size
~LoadResult() noexcept; ~LoadResult() noexcept;
//! metadata
Metadata metadata;
using TensorMap = std::unordered_map< using TensorMap = std::unordered_map<
std::string, std::shared_ptr<HostTensorND>>; std::string, std::shared_ptr<HostTensorND>>;
...@@ -178,7 +182,8 @@ namespace serialization { ...@@ -178,7 +182,8 @@ namespace serialization {
virtual DumpResult dump( virtual DumpResult dump(
const SymbolVarArray &output_vars, const SymbolVarArray &output_vars,
const DumpConfig &config = {}) = 0; const DumpConfig &config = {},
const Metadata &metadata = {}) = 0;
virtual GraphDumpFormat format() const = 0; virtual GraphDumpFormat format() const = 0;
}; };
......
...@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) { ...@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) {
load(); load();
} }
TEST(TestSerializer2, Metadata) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});
Metadata metadata;
metadata.user_info = "TEST_METADATA";
metadata.has_user_info = true;
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")}, {}, metadata);
};
auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto metadata = rst.metadata;
int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA");
EXPECT_EQ(cmp, 0);
};
dump();
load();
}
TEST(TestSerializer2, APlusB) { TEST(TestSerializer2, APlusB) {
auto fname = GET_OUTPUT_FILE(); auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3}; TensorShape shape{2, 3};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册