From a404c508e9c844da52b36d573ea9d8939e0657be Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 Oct 2021 19:33:57 +0800 Subject: [PATCH] feat(mge): support dump with specific format GitOrigin-RevId: 57a7c0de02ec6ee30a67b5cc069dbdd7dc0f6437 --- .../megengine/core/tensor/megbrain_graph.py | 16 ++++++++++++---- imperative/python/megengine/jit/tracing.py | 3 +++ imperative/python/src/graph_rt.cpp | 18 +++++++++++++++--- .../python/test/unit/jit/test_tracing.py | 10 ++++++++-- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 2dac010af..1f7f2907f 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -11,13 +11,12 @@ import json import os import weakref from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy as np from .. import _imperative_rt -from .._imperative_rt import GraphOptimizeOptions -from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode +from .._imperative_rt import GraphOptimizeOptions, SerializationFormat from .._wrap import as_device from ..ops.builtin import OpDef @@ -377,7 +376,8 @@ def dump_graph( keep_opr_priority: bool = False, strip_info_file=None, append_json=False, - metadata=None + metadata=None, + dump_format=None ) -> Tuple[bytes, CompGraphDumpResult]: r"""serialize the computing graph of `output_vars` and get byte result. @@ -398,6 +398,7 @@ def dump_graph( append_json: will be check when `strip_info_file` is not None. if set true, the information for code strip will be append to strip_info_file. if set false, will rewrite strip_info_file + dump_format: using different dump formats. Note: The underlying C++ API only accepts a var list. If a dict is given, @@ -434,6 +435,12 @@ def dump_graph( outputs = [] params = [] + dump_format_map = { + None: None, + "FBS": SerializationFormat.FBS, + } + dump_format = dump_format_map[dump_format] + dump_content = _imperative_rt.dump_graph( ov, keep_var_name, @@ -441,6 +448,7 @@ def dump_graph( keep_param_name, keep_opr_priority, metadata, + dump_format, stat, inputs, outputs, diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index d9d867fac..b93fa027c 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -1008,6 +1008,7 @@ class trace: maxerr=1e-4, resize_input=False, input_transform=None, + dump_format: str = None, **kwargs ): r"""Serializes trace to file system. @@ -1059,6 +1060,7 @@ class trace: resize_input: whether resize input image to fit input var shape. input_transform: a python expression to transform the input data. Example: data / np.std(data) + dump_format: using different dump formats. Keyword Arguments: @@ -1265,6 +1267,7 @@ class trace: strip_info_file=strip_info_file, append_json=append_json, metadata=metadata, + dump_format=dump_format, ) file.write(dump_content) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index afaf90d4f..3b5ee6015 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using _SerializationMetadata = mgb::serialization::Metadata; +using _SerializationFormat = mgb::serialization::GraphDumpFormat; namespace { class _CompGraphProfilerImpl { @@ -310,6 +311,10 @@ void init_graph_rt(py::module m) { .value("NCHW64", _LayoutTransform::NCHW64) .export_values(); + py::enum_<_SerializationFormat>(m, "SerializationFormat") + .value("FBS", _SerializationFormat::FLATBUFFERS) + .export_values(); + m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); @@ -380,11 +385,18 @@ void init_graph_rt(py::module m) { m.def("dump_graph", [](const std::vector& dest_vars, int keep_var_name, bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, - std::optional<_SerializationMetadata> metadata, py::list& stat, + std::optional<_SerializationMetadata> metadata, + std::optional<_SerializationFormat> dump_format, py::list& stat, py::list& inputs, py::list& outputs, py::list& params) { std::vector buf; - auto dumper = - ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf)); + ser::GraphDumpFormat format; + if (dump_format.has_value()) { + format = dump_format.value(); + } else { + format = {}; + } + auto dumper = ser::GraphDumper::make( + ser::OutputFile::make_vector_proxy(&buf), format); SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); ser::GraphDumper::DumpConfig config{ diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index a228a3b38..de2f2da9c 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -190,7 +190,13 @@ def test_print_in_trace(): np.testing.assert_equal(z, buf) -def test_dump(): +@pytest.mark.parametrize( + "dump_format", + [ + "FBS", + ], +) +def test_dump(dump_format): @trace(symbolic=True, capture_as_const=True) def f(a, b): return a + b @@ -205,7 +211,7 @@ def test_dump(): np.testing.assert_equal(f(a, b).numpy(), y) file = io.BytesIO() - dump_info = f.dump(file) + dump_info = f.dump(file, dump_format=dump_format) assert dump_info.nr_opr == 3 np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) np.testing.assert_equal(dump_info.outputs, ["ADD"]) -- GitLab