提交 54a4d70e 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(src/serialization): add support of serializing metadata

GitOrigin-RevId: b563c94451b06055d53c99a85bb5689b3f907365
上级 721091fa
......@@ -11,7 +11,7 @@ import json
import os
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
......@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs):
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64,
}
for k, v in inference_optimize_layout_transform_map.items():
......@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs):
dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return _wrap(res_vars)
return _wrap(res_vars), inference_options.serialize()
def deserialize_infer_option(x: int) -> Dict[str, bool]:
r"""
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.
:param x: inference options represented by int.
:return: inference options represented by dict.
"""
inference_options = GraphOptimizeOptions.deserialize(x)
inference_optimize_layout_transform_map = {
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4",
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4",
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88",
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32",
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44",
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot",
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4",
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64",
}
ret = dict()
layout = inference_options.layout_transform
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT:
ret[inference_optimize_layout_transform_map[layout]] = True
if inference_options.f16_io_f32_comp:
ret["enable_io16xc32"] = True
if inference_options.f16_io_comp:
ret["enable_ioc16"] = True
if inference_options.fuse_conv_bias_nonlinearity:
ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True
return ret
def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
......@@ -331,7 +374,8 @@ def dump_graph(
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False
append_json=False,
metadata=None
) -> Tuple[bytes, CompGraphDumpResult]:
"""
serialize the computing graph of `output_vars` and get byte result.
......@@ -393,6 +437,7 @@ def dump_graph(
keep_opr_name,
keep_param_name,
keep_opr_priority,
metadata,
stat,
inputs,
outputs,
......@@ -427,7 +472,7 @@ def dump_graph(
CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"]
)
......@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult:
buf = open(fpath, "rb").read()
else:
buf = fpath.read()
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list)
cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata)
def _wrap(x):
......
......@@ -12,10 +12,12 @@ import functools
import itertools
import json
import os
import pickle
from typing import Any
import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import GraphProfiler, SerializationMetadata
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
TensorWeakRef,
......@@ -670,6 +672,8 @@ class trace:
strip_info_file=None,
append_json=False,
optimize_for_inference=True,
user_info: Any = None,
enable_metadata: bool = True,
**kwargs
):
r"""
......@@ -697,6 +701,8 @@ class trace:
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.
:Keyword Arguments:
......@@ -729,6 +735,9 @@ class trace:
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
......@@ -851,7 +860,15 @@ class trace:
dest_vars.append(v)
if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
metadata = SerializationMetadata()
if enable_metadata:
metadata.user_info = pickle.dumps(user_info)
metadata.is_valid = True
metadata.graph_modified = False
if optimize_for_inference:
metadata.optimize_options = optimize_options
if isinstance(file, str):
permission = "wb" if append == False else "ab"
......@@ -864,6 +881,7 @@ class trace:
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
metadata=metadata,
)
file.write(dump_content)
return dump_info
......
......@@ -411,7 +411,8 @@ def main():
args.embed_input = True
logger.info("loading model ...")
graph, _, output_vars = G.load_graph(args.net)
ret = G.load_graph(args.net)
graph, output_vars = ret.graph, ret.output_vars_list
input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy")
if args.output_name is not None:
......
......@@ -391,7 +391,8 @@ class GraphInference:
optimize_for_inference: bool = False,
**kwargs
):
self._graph, _, output_nodes = G.load_graph(file)
ret = G.load_graph(file)
self._graph, output_nodes = ret.graph, ret.output_vars_list
if outputs is not None:
output_nodes = find_vars_by_name(output_nodes, outputs)
self._origin_outputs = output_nodes
......
......@@ -9,14 +9,12 @@
import collections
import fnmatch
import itertools
import pickle
import re
from collections import OrderedDict
from typing import Dict, List, Sequence
from typing import Any, Dict, List, Sequence
import numpy as np
from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
......@@ -42,6 +40,30 @@ class Network:
self.all_oprs_map = OrderedDict()
self.all_vars_map = OrderedDict()
self.graph = ComputingGraph()
self._metadata = None
@property
def metadata(self):
r"""
Load metadata as a dict.
"""
if not self._metadata.is_valid:
logger.info("metadata is not valid!")
return None
ret = dict()
try:
user_info = pickle.loads(self._metadata.user_info)
except: # pylint: disable=bare-except
logger.warning(
"can't parse user info by pickle, so return the original bytes object!"
)
user_info = self._metadata.user_info
ret["user_info"] = user_info
ret["graph_modified"] = self._metadata.graph_modified
ret["optimized_for_inference"] = self._metadata.optimized_for_inference
if ret["optimized_for_inference"]:
ret.update(G.deserialize_infer_option(self._metadata.optimize_options))
return ret
@classmethod
def load(cls, model_path: str, outspec: List[str] = None):
......@@ -51,7 +73,8 @@ class Network:
:param outspec: only load the subgraph with outspec as its endpoints.
"""
self = cls()
_, _, outputs = G.load_graph(model_path)
ret = G.load_graph(model_path)
outputs, self._metadata = ret.output_vars_list, ret.metadata
if outspec is not None:
output_spec = outspec.copy()
all_vars = get_dep_vars(outputs) + outputs
......@@ -125,6 +148,9 @@ class Network:
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
......@@ -152,6 +178,8 @@ class Network:
append_json=False,
optimize_for_inference=True,
append=False,
user_info: Any = None,
enable_metadata=True,
**kwargs
):
"""
......@@ -176,6 +204,8 @@ class Network:
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.
:Keyword Arguments:
......@@ -201,7 +231,15 @@ class Network:
)
if optimize_for_inference:
out = G.optimize_for_inference(out, **kwargs)
out, optimize_options = G.optimize_for_inference(out, **kwargs)
metadata = SerializationMetadata()
if enable_metadata:
metadata.is_valid = True
metadata.graph_modified = True
metadata.user_info = pickle.dumps(user_info)
if optimize_for_inference:
metadata.optimize_options = optimize_options
dump_content, _ = G.dump_graph(
out,
......@@ -211,6 +249,7 @@ class Network:
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
metadata=metadata,
)
if isinstance(file, str):
permission = "wb" if append == False else "ab"
......
......@@ -34,6 +34,7 @@ namespace ser = mgb::serialization;
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata;
namespace {
class _CompGraphProfilerImpl {
......@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) {
auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
.def(py::init())
.def("serialize", &_OptimizeForInferenceOptions::serialize)
.def_static("deserialize", &_OptimizeForInferenceOptions::deserialize)
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
......@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) {
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
.value("NCHW32", _LayoutTransform::NCHW32)
.value("CHWN4", _LayoutTransform::CHWN4)
.value("NCHW64", _LayoutTransform::NCHW64)
.export_values()
;
......@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) {
})->to_string();
});
py::class_<_SerializationMetadata>(m, "SerializationMetadata")
.def(py::init())
.def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); },
&_SerializationMetadata::set_user_info)
.def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference)
.def_property("optimize_options", &_SerializationMetadata::get_optimize_options,
&_SerializationMetadata::set_optimize_options)
.def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
.def_readwrite("is_valid", &_SerializationMetadata::is_valid)
;
m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars,
int keep_var_name,
bool keep_opr_name,
bool keep_param_name,
bool keep_opr_priority,
std::optional<_SerializationMetadata> metadata,
py::list& stat,
py::list& inputs,
py::list& outputs,
......@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) {
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority, keep_opr_name};
auto rst = dumper->dump(symvars, config);
ser::GraphDumper::DumpResult rst;
if (metadata)
rst = dumper->dump(symvars, config, *metadata);
else
rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) {
inputs.append(py::cast(i));
}
......@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) {
for (const auto& var : rst.output_var_list) {
iter.add(var);
}
return rst.graph;
auto ret = py::tuple(2);
ret[0] = py::cast(rst.graph);
ret[1] = py::cast(rst.metadata);
return ret;
});
#define CURRENT_CLASS cg::ComputingGraph::Options
......
......@@ -239,8 +239,7 @@ def test_dump_volatile():
file = io.BytesIO()
f.dump(file, optimize_for_inference=False)
file.seek(0)
cg, _, outputs = G.load_graph(file)
(out,) = outputs
(out,) = G.load_graph(file).output_vars_list
assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "ImmutableTensor"
......@@ -337,12 +336,12 @@ def test_goptions_log_exp():
f(tensor(1.0))
_, out = mkstemp()
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
outputs = G.load_graph(out).output_vars_list
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0))
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
outputs = G.load_graph(out).output_vars_list
oprs_2 = cgtools.get_oprs_seq(outputs)
assert len(oprs_1) - len(oprs_2) == 2
......
......@@ -88,7 +88,7 @@ def test_graph_traversal():
file = io.BytesIO()
fun.dump(file, optimize_for_inference=False)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
outputs = mgb_graph.load_graph(file).output_vars_list
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
input_var = map_vars[1]
......@@ -101,7 +101,9 @@ def test_load_refcnt():
graph = mgb_graph.Graph()
varnode = graph.make_const(0)
buf, _ = mgb_graph.dump_graph([varnode])
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
ret = mgb_graph.load_graph(io.BytesIO(buf))
graph, (varnode,) = ret.graph, ret.output_vars_list
del ret
del graph
varnode.owner
......@@ -132,7 +134,7 @@ def test_get_opr_seq():
file = io.BytesIO()
func.dump(file, optimize_for_inference=False)
file.seek(0)
*_, outputs = mgb_graph.load_graph(file)
outputs = mgb_graph.load_graph(file).output_vars_list
seq_1 = cgtools.get_oprs_seq(outputs, True)
assert len(seq_1) == 5
......
......@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
keep_var_name=2,
)
file.seek(0)
*_, outputs = G.load_graph(file)
outputs = G.load_graph(file).output_vars_list
ops = cgtools.get_oprs_seq(outputs)
return ops
......@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name):
file = io.BytesIO()
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
file.seek(0)
*_, outputs = G.load_graph(file)
outputs = G.load_graph(file).output_vars_list
op = cgtools.get_oprs_seq(outputs)[-1]
assert op.inputs[0].name == var_name
......
......@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape
from megengine.utils.network_node import Host2DeviceCopy, VarNode
def test_metadata():
x = Tensor(0)
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return x * 2
fwd(x)
orig_model = io.BytesIO()
fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": "test",
"graph_modified": False, # False: tracing.dump
"optimized_for_inference": False,
}
orig_model.seek(0)
graph.dump(
orig_model,
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
optimize_for_inference=True,
enable_nchw4=True,
enable_ioc16=True,
)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, # True: Network.dump
"optimized_for_inference": True,
"enable_nchw4": True,
"enable_ioc16": True,
}
orig_model.seek(0)
fwd.dump(orig_model, enable_metadata=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata is None
def test_replace_var():
a = Tensor([1, 2])
......
......@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec):
def make_feeds(args):
cg_rt, _, outputs = G.load_graph(args.input)
ret = G.load_graph(args.input)
cg_rt, outputs = ret.graph, ret.output_vars_list
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
inputs = {i.name: i for i in inputs}
......
......@@ -322,7 +322,31 @@ namespace gopt {
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter();
};
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {};
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
uint64_t serialize() {
uint64_t ret = 0;
ret |= (uint64_t)layout_transform << 32;
if (f16_io_f32_comp) ret |= 1u;
if (f16_io_comp) ret |= 1u << 1;
if (fuse_conv_bias_nonlinearity) ret |= 1u << 2;
if (fuse_conv_bias_with_z) ret |= 1u << 3;
if (weight_preprocess) ret |= 1u << 4;
if (fuse_preprocess) ret |= 1u << 5;
return ret;
}
static OptimizeForInferenceOptions deserialize(uint64_t buf) {
OptimizeForInferenceOptions ret;
ret.f16_io_f32_comp = buf & 1u;
ret.f16_io_comp = buf & 1u << 1;
ret.fuse_conv_bias_nonlinearity = buf & 1u << 2;
ret.fuse_conv_bias_with_z = buf & 1u << 3;
ret.weight_preprocess = buf & 1u << 4;
ret.fuse_preprocess = buf & 1u << 5;
ret.layout_transform = (LayoutTransform)(buf >> 32);
return ret;
}
};
/*!
* \brief optimize a computing graph for inference
......
......@@ -128,6 +128,13 @@ table Operator {
name:string;
}
table Metadata {
is_valid:bool;
graph_modified:bool;
user_info:string;
optimize_options:ulong;
}
struct OutputVar {
compact_id:uint;
original_id:uint;
......@@ -141,6 +148,7 @@ table Graph {
nr_shared_tensor:uint;
oprs:[Operator];
output_vars_idx:[OutputVar];
metadata:Metadata;
}
root_type Graph;
......@@ -30,6 +30,7 @@
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"
......@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
std::vector<flatbuffers::Offset<void>> m_cur_opr_param;
void init_oprs_to_dump(const SymbolVarArray& endpoints);
flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata);
flatbuffers::Offset<fbs::Operator> build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry);
......@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
public:
GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
DumpResult dump(const SymbolVarArray& output_vars,
const DumpConfig& config = {}) override;
const DumpConfig& config = {},
const Metadata& metadata = {}) override;
const GraphDumpConfig& config() const override { return m_config; }
void dump_tensor(const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) override;
......@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
}
}
flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata(
const Metadata& metadata) {
auto user_info = m_builder.CreateSharedString(metadata.user_info);
fbs::MetadataBuilder builder(m_builder);
builder.add_is_valid(metadata.is_valid);
builder.add_graph_modified(metadata.graph_modified);
builder.add_user_info(user_info);
builder.add_optimize_options(metadata.optimize_options);
return builder.Finish();
}
flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry) {
m_cur_opr = opr;
......@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
GraphDumper::DumpResult GraphDumperOSS::dump(
const SymbolVarArray& output_vars, const DumpConfig& config) {
const SymbolVarArray& output_vars,
const DumpConfig& config, const Metadata& metadata) {
mgb_throw_if(output_vars.empty(), SerializationError,
"Can't dump empty graph");
......@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint64_t offset_to_fbs = 0;
m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));
// Dump metadata
auto fbmeta = build_metadata(metadata);
// Dump operators
init_oprs_to_dump(output_vars);
std::vector<flatbuffers::Offset<fbs::Operator>> oprs;
......@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
graph.add_oprs(fb_oprs);
graph.add_output_vars_idx(fb_output_vars);
graph.add_nr_shared_tensor(m_nr_shared_tensor);
graph.add_metadata(fbmeta);
m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier());
// Write actual offset_to_fbs
......@@ -531,6 +550,7 @@ public:
mgb_assert(nr == 1);
}
Metadata load_metadata();
LoadResult load_oprs();
CompNode load_comp_node(const fbs::CompNode* comp_node);
......@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
return sh_ptr_ref;
}
Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() {
const auto* fbmeta = m_loader->m_graph->metadata();
Metadata ret;
ret.is_valid = fbmeta->is_valid();
ret.graph_modified = fbmeta->graph_modified();
if (fbmeta->user_info()) {
ret.user_info = fbmeta->user_info()->str();
ret.has_user_info = true;
}
if (fbmeta->optimize_options()) {
ret.optimize_options = fbmeta->optimize_options();
ret.optimized_for_inference = true;
}
return ret;
}
void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
const fbs::Operator* fbopr) {
m_cur_opr_tensor_cnt = 0;
......@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
}
OprLoadContextImpl ctx{this, m_graph->mgb_version()};
auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs();
result.metadata = metadata;
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
auto cur = m_file->tell();
......
/**
* \file src/serialization/include/megbrain/serialization/metadata.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \brief MegEngine model's metadata
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include <string>
namespace mgb {
namespace serialization {
struct Metadata {
bool is_valid = false;
bool graph_modified = false;
bool has_user_info = false;
std::string user_info;
bool optimized_for_inference = false;
uint64_t optimize_options;
#define ADD_PROPERTY(type, name) \
type get_##name() const { return name; } \
void set_##name(type x) { \
name = x; \
has_##name = true; \
}
ADD_PROPERTY(std::string, user_info)
#undef ADD_PROPERTY
uint64_t get_optimize_options() { return optimize_options; }
void set_optimize_options(uint64_t value) {
optimized_for_inference = true;
optimize_options = value;
}
};
} // namespace serialization
} // namespace mgb
\ No newline at end of file
......@@ -15,6 +15,7 @@
#include "megbrain/serialization/dump_format.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/metadata.h"
namespace mgb {
namespace serialization {
......@@ -32,6 +33,9 @@ namespace serialization {
//! expliit dtor decl to reduce binary size
~LoadResult() noexcept;
//! metadata
Metadata metadata;
using TensorMap = std::unordered_map<
std::string, std::shared_ptr<HostTensorND>>;
......@@ -178,7 +182,8 @@ namespace serialization {
virtual DumpResult dump(
const SymbolVarArray &output_vars,
const DumpConfig &config = {}) = 0;
const DumpConfig &config = {},
const Metadata &metadata = {}) = 0;
virtual GraphDumpFormat format() const = 0;
};
......
......@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) {
load();
}
TEST(TestSerializer2, Metadata) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});
Metadata metadata;
metadata.user_info = "TEST_METADATA";
metadata.has_user_info = true;
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")}, {}, metadata);
};
auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto metadata = rst.metadata;
int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA");
EXPECT_EQ(cmp, 0);
};
dump();
load();
}
TEST(TestSerializer2, APlusB) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册