提交 a404c508 编写于 作者: M Megvii Engine Team

feat(mge): support dump with specific format

GitOrigin-RevId: 57a7c0de02ec6ee30a67b5cc069dbdd7dc0f6437
上级 fba523a1
...@@ -11,13 +11,12 @@ import json ...@@ -11,13 +11,12 @@ import json
import os import os
import weakref import weakref
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple, Union
import numpy as np import numpy as np
from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt import GraphOptimizeOptions, SerializationFormat
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._wrap import as_device from .._wrap import as_device
from ..ops.builtin import OpDef from ..ops.builtin import OpDef
...@@ -377,7 +376,8 @@ def dump_graph( ...@@ -377,7 +376,8 @@ def dump_graph(
keep_opr_priority: bool = False, keep_opr_priority: bool = False,
strip_info_file=None, strip_info_file=None,
append_json=False, append_json=False,
metadata=None metadata=None,
dump_format=None
) -> Tuple[bytes, CompGraphDumpResult]: ) -> Tuple[bytes, CompGraphDumpResult]:
r"""serialize the computing graph of `output_vars` and get byte result. r"""serialize the computing graph of `output_vars` and get byte result.
...@@ -398,6 +398,7 @@ def dump_graph( ...@@ -398,6 +398,7 @@ def dump_graph(
append_json: will be check when `strip_info_file` is not None. if set append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file. true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file if set false, will rewrite strip_info_file
dump_format: using different dump formats.
Note: Note:
The underlying C++ API only accepts a var list. If a dict is given, The underlying C++ API only accepts a var list. If a dict is given,
...@@ -434,6 +435,12 @@ def dump_graph( ...@@ -434,6 +435,12 @@ def dump_graph(
outputs = [] outputs = []
params = [] params = []
dump_format_map = {
None: None,
"FBS": SerializationFormat.FBS,
}
dump_format = dump_format_map[dump_format]
dump_content = _imperative_rt.dump_graph( dump_content = _imperative_rt.dump_graph(
ov, ov,
keep_var_name, keep_var_name,
...@@ -441,6 +448,7 @@ def dump_graph( ...@@ -441,6 +448,7 @@ def dump_graph(
keep_param_name, keep_param_name,
keep_opr_priority, keep_opr_priority,
metadata, metadata,
dump_format,
stat, stat,
inputs, inputs,
outputs, outputs,
......
...@@ -1008,6 +1008,7 @@ class trace: ...@@ -1008,6 +1008,7 @@ class trace:
maxerr=1e-4, maxerr=1e-4,
resize_input=False, resize_input=False,
input_transform=None, input_transform=None,
dump_format: str = None,
**kwargs **kwargs
): ):
r"""Serializes trace to file system. r"""Serializes trace to file system.
...@@ -1059,6 +1060,7 @@ class trace: ...@@ -1059,6 +1060,7 @@ class trace:
resize_input: whether resize input image to fit input var shape. resize_input: whether resize input image to fit input var shape.
input_transform: a python expression to transform the input data. input_transform: a python expression to transform the input data.
Example: data / np.std(data) Example: data / np.std(data)
dump_format: using different dump formats.
Keyword Arguments: Keyword Arguments:
...@@ -1265,6 +1267,7 @@ class trace: ...@@ -1265,6 +1267,7 @@ class trace:
strip_info_file=strip_info_file, strip_info_file=strip_info_file,
append_json=append_json, append_json=append_json,
metadata=metadata, metadata=metadata,
dump_format=dump_format,
) )
file.write(dump_content) file.write(dump_content)
......
...@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; ...@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata; using _SerializationMetadata = mgb::serialization::Metadata;
using _SerializationFormat = mgb::serialization::GraphDumpFormat;
namespace { namespace {
class _CompGraphProfilerImpl { class _CompGraphProfilerImpl {
...@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) { ...@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) {
.value("NCHW64", _LayoutTransform::NCHW64) .value("NCHW64", _LayoutTransform::NCHW64)
.export_values(); .export_values();
py::enum_<_SerializationFormat>(m, "SerializationFormat")
.value("FBS", _SerializationFormat::FLATBUFFERS)
.export_values();
m.def("optimize_for_inference", m.def("optimize_for_inference",
[](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
...@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) { ...@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) {
m.def("dump_graph", m.def("dump_graph",
[](const std::vector<VarNode*>& dest_vars, int keep_var_name, [](const std::vector<VarNode*>& dest_vars, int keep_var_name,
bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, bool keep_opr_name, bool keep_param_name, bool keep_opr_priority,
std::optional<_SerializationMetadata> metadata, py::list& stat, std::optional<_SerializationMetadata> metadata,
std::optional<_SerializationFormat> dump_format, py::list& stat,
py::list& inputs, py::list& outputs, py::list& params) { py::list& inputs, py::list& outputs, py::list& params) {
std::vector<uint8_t> buf; std::vector<uint8_t> buf;
auto dumper = ser::GraphDumpFormat format;
ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf)); if (dump_format.has_value()) {
format = dump_format.value();
} else {
format = {};
}
auto dumper = ser::GraphDumper::make(
ser::OutputFile::make_vector_proxy(&buf), format);
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
ser::GraphDumper::DumpConfig config{ ser::GraphDumper::DumpConfig config{
......
...@@ -190,7 +190,13 @@ def test_print_in_trace(): ...@@ -190,7 +190,13 @@ def test_print_in_trace():
np.testing.assert_equal(z, buf) np.testing.assert_equal(z, buf)
def test_dump(): @pytest.mark.parametrize(
"dump_format",
[
"FBS",
],
)
def test_dump(dump_format):
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(a, b): def f(a, b):
return a + b return a + b
...@@ -205,7 +211,7 @@ def test_dump(): ...@@ -205,7 +211,7 @@ def test_dump():
np.testing.assert_equal(f(a, b).numpy(), y) np.testing.assert_equal(f(a, b).numpy(), y)
file = io.BytesIO() file = io.BytesIO()
dump_info = f.dump(file) dump_info = f.dump(file, dump_format=dump_format)
assert dump_info.nr_opr == 3 assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD"]) np.testing.assert_equal(dump_info.outputs, ["ADD"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册