提交 a404c508 编写于 作者: M Megvii Engine Team

feat(mge): support dump with specific format

GitOrigin-RevId: 57a7c0de02ec6ee30a67b5cc069dbdd7dc0f6437
上级 fba523a1
......@@ -11,13 +11,12 @@ import json
import os
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union
import numpy as np
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat
from .._wrap import as_device
from ..ops.builtin import OpDef
......@@ -377,7 +376,8 @@ def dump_graph(
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
metadata=None
metadata=None,
dump_format=None
) -> Tuple[bytes, CompGraphDumpResult]:
r"""serialize the computing graph of `output_vars` and get byte result.
......@@ -398,6 +398,7 @@ def dump_graph(
append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
dump_format: using different dump formats.
Note:
The underlying C++ API only accepts a var list. If a dict is given,
......@@ -434,6 +435,12 @@ def dump_graph(
outputs = []
params = []
dump_format_map = {
None: None,
"FBS": SerializationFormat.FBS,
}
dump_format = dump_format_map[dump_format]
dump_content = _imperative_rt.dump_graph(
ov,
keep_var_name,
......@@ -441,6 +448,7 @@ def dump_graph(
keep_param_name,
keep_opr_priority,
metadata,
dump_format,
stat,
inputs,
outputs,
......
......@@ -1008,6 +1008,7 @@ class trace:
maxerr=1e-4,
resize_input=False,
input_transform=None,
dump_format: str = None,
**kwargs
):
r"""Serializes trace to file system.
......@@ -1059,6 +1060,7 @@ class trace:
resize_input: whether resize input image to fit input var shape.
input_transform: a python expression to transform the input data.
Example: data / np.std(data)
dump_format: using different dump formats.
Keyword Arguments:
......@@ -1265,6 +1267,7 @@ class trace:
strip_info_file=strip_info_file,
append_json=append_json,
metadata=metadata,
dump_format=dump_format,
)
file.write(dump_content)
......
......@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata;
using _SerializationFormat = mgb::serialization::GraphDumpFormat;
namespace {
class _CompGraphProfilerImpl {
......@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) {
.value("NCHW64", _LayoutTransform::NCHW64)
.export_values();
py::enum_<_SerializationFormat>(m, "SerializationFormat")
.value("FBS", _SerializationFormat::FLATBUFFERS)
.export_values();
m.def("optimize_for_inference",
[](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
......@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) {
m.def("dump_graph",
[](const std::vector<VarNode*>& dest_vars, int keep_var_name,
bool keep_opr_name, bool keep_param_name, bool keep_opr_priority,
std::optional<_SerializationMetadata> metadata, py::list& stat,
std::optional<_SerializationMetadata> metadata,
std::optional<_SerializationFormat> dump_format, py::list& stat,
py::list& inputs, py::list& outputs, py::list& params) {
std::vector<uint8_t> buf;
auto dumper =
ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf));
ser::GraphDumpFormat format;
if (dump_format.has_value()) {
format = dump_format.value();
} else {
format = {};
}
auto dumper = ser::GraphDumper::make(
ser::OutputFile::make_vector_proxy(&buf), format);
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
ser::GraphDumper::DumpConfig config{
......
......@@ -190,7 +190,13 @@ def test_print_in_trace():
np.testing.assert_equal(z, buf)
def test_dump():
@pytest.mark.parametrize(
"dump_format",
[
"FBS",
],
)
def test_dump(dump_format):
@trace(symbolic=True, capture_as_const=True)
def f(a, b):
return a + b
......@@ -205,7 +211,7 @@ def test_dump():
np.testing.assert_equal(f(a, b).numpy(), y)
file = io.BytesIO()
dump_info = f.dump(file)
dump_info = f.dump(file, dump_format=dump_format)
assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册