提交 478f2c51 编写于 作者: M Megvii Engine Team

feat(mge/jit): add trace/dump options

GitOrigin-RevId: 0f43c14599e7ca45aee12c4402bf8d178218a620
上级 44d0b5da
...@@ -14,6 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor ...@@ -14,6 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
import numpy as np import numpy as np
from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.ops import BackwardGraph from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device from .._wrap import device as as_device
from ..ops.builtin import OpDef from ..ops.builtin import OpDef
...@@ -83,6 +84,84 @@ class Graph(_imperative_rt.ComputingGraph): ...@@ -83,6 +84,84 @@ class Graph(_imperative_rt.ComputingGraph):
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name))
def optimize_for_inference(dest_vars, **kwargs):
r"""Applies optimize_for_inference pass for computing graph.
:param dest_vars: list of output vars in the computing graph
:Keyword Arguments:
* enable_io16xc32 --
whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
changed to float16.
* enable_ioc16 --
whether to use float16 for both I/O and computation
precision.
* enable_hwcd4 --
whether to use NHWCD4 data layout. This is faster on some
OpenCL backend.
* enable_nchw88 --
whether to use NCHW88 data layout, currently
used in X86 AVX backend.
* enable_nchw44 --
whether to use NCHW44 data layout, currently
used in arm backend.
* enable_nchw44_dot --
whether to use NCHW44_dot data layout, currently
used in armv8.2+dotprod backend.
* enable_nchw4 --
whether to use NCHW4 data layout, currently
used in nvidia backend(based on cudnn).
* enable_nchw32 --
whether to use NCHW32 data layout, currently
used in nvidia backend with tensorcore(based on cudnn).
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
"""
inference_options = GraphOptimizeOptions()
if optimize_for_inference:
inference_optimize_layout_transform_map = {
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4,
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4,
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88,
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32,
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
}
for k, v in inference_optimize_layout_transform_map.items():
if kwargs.pop(k, False):
inference_options.layout_transform = v
if kwargs.pop("enable_io16xc32", False):
inference_options.f16_io_f32_comp = True
if kwargs.pop("enable_ioc16", False):
inference_options.f16_io_comp = True
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False):
inference_options.fuse_conv_bias_nonlinearity = True
if kwargs.pop("enable_fuse_conv_bias_with_z", False):
inference_options.fuse_conv_bias_with_z = True
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
res_vars = _imperative_rt.optimize_for_inference(
[i._node for i in dest_vars], inference_options
)
return [VarNode(i) for i in res_vars]
def dump(*args): def dump(*args):
return _imperative_rt.dump_graph([i._node for i in args]) return _imperative_rt.dump_graph([i._node for i in args])
......
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_tensor_shape
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
...@@ -76,6 +77,22 @@ class TensorInfo: ...@@ -76,6 +77,22 @@ class TensorInfo:
class trace: class trace:
"""
Wraps a callable and provide:
* tracing via :meth:`.trace` and :meth:`.dump`
* accelerated evalutaion via :meth:`.__call__`
:param function: the function will be traced.
:param symbolic: whether to apply symbolic execution for tracing. Default: False
:param capture_as_const: capture global vars or closures as const value. Default: False
:param sublinear_memory_config: configuration for sublinear memory optimization.
If not None, it enables sublinear memory optimization with given setting.
:param profiling: whether to profile compiled trace. Default: False
:param opt_level: optimization level for compiling trace.
:param symbolic_shape: whether to use symbolic shape for tracing. Default: True
"""
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if not args: if not args:
return functools.partial(cls, **kwargs) return functools.partial(cls, **kwargs)
...@@ -88,6 +105,8 @@ class trace: ...@@ -88,6 +105,8 @@ class trace:
capture_as_const=False, capture_as_const=False,
sublinear_memory_config: SublinearMemoryConfig = None, sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False, profiling: bool = False,
opt_level: int = None,
tensor_shape: bool = True,
): ):
self.__wrapped__ = function self.__wrapped__ = function
self._symbolic = symbolic self._symbolic = symbolic
...@@ -95,6 +114,8 @@ class trace: ...@@ -95,6 +114,8 @@ class trace:
self._sublinear_memory_config = sublinear_memory_config self._sublinear_memory_config = sublinear_memory_config
self._profiling = profiling self._profiling = profiling
self._profiler = None self._profiler = None
self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape
self._untraced = True self._untraced = True
self._tinfo = [] # handle -> TensorInfo self._tinfo = [] # handle -> TensorInfo
...@@ -112,6 +133,8 @@ class trace: ...@@ -112,6 +133,8 @@ class trace:
self._output_bindings = None self._output_bindings = None
self._output_names = None self._output_names = None
set_tensor_shape(self._tensor_shape)
def _new_handle(self): def _new_handle(self):
handle = len(self._tinfo) handle = len(self._tinfo)
info = TensorInfo() info = TensorInfo()
...@@ -307,6 +330,9 @@ class trace: ...@@ -307,6 +330,9 @@ class trace:
def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):
graph.options.seq_opt.enable_seq_comp_node_opt = False graph.options.seq_opt.enable_seq_comp_node_opt = False
# graph opt level
if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level
# sublinear # sublinear
if self._sublinear_memory_config is not None: if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True graph.options.enable_sublinear_memory_opt = True
...@@ -320,6 +346,7 @@ class trace: ...@@ -320,6 +346,7 @@ class trace:
) )
sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
sublinear_config.num_worker = self._sublinear_memory_config.num_worker sublinear_config.num_worker = self._sublinear_memory_config.num_worker
# profile
if self._profiling: if self._profiling:
self._profiler = GraphProfiler(graph) self._profiler = GraphProfiler(graph)
...@@ -416,7 +443,55 @@ class trace: ...@@ -416,7 +443,55 @@ class trace:
self._process_outputs(outputs) self._process_outputs(outputs)
return outputs return outputs
def dump(self, file, *, arg_names=None, output_names=None): def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs):
r"""Serializes trace to file system.
:param file: output file, could be file object or filename.
:param arg_names: names of the input tensors in the traced function.
:param output_names: names of the output tensors in the traced function,
use the default name if not specified.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:Keyword Arguments:
* enable_io16xc32 --
whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
changed to float16.
* enable_ioc16 --
whether to use float16 for both I/O and computation
precision.
* enable_hwcd4 --
whether to use NHWCD4 data layout. This is faster on some
OpenCL backend.
* enable_nchw88 --
whether to use NCHW88 data layout, currently
used in X86 AVX backend.
* enable_nchw44 --
whether to use NCHW44 data layout, currently
used in arm backend.
* enable_nchw44_dot --
whether to use NCHW44_dot data layout, currently
used in armv8.2+dotprod backend.
* enable_nchw4 --
whether to use NCHW4 data layout, currently
used in nvidia backend(based on cudnn).
* enable_nchw32 --
whether to use NCHW32 data layout, currently
used in nvidia backend with tensorcore(based on cudnn).
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
"""
if not self._capture_as_const: if not self._capture_as_const:
raise ValueError( raise ValueError(
"you must specify capture_as_const=True at __init__ to use dump" "you must specify capture_as_const=True at __init__ to use dump"
...@@ -482,8 +557,11 @@ class trace: ...@@ -482,8 +557,11 @@ class trace:
v.name = output_names[i] v.name = output_names[i]
dest_vars.append(v) dest_vars.append(v)
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
if isinstance(file, str): if isinstance(file, str):
file = open(file, "wb") permission = "wb" if append == False else "ab"
file = open(file, permission)
file.write(G.dump(*dest_vars)) file.write(G.dump(*dest_vars))
def _process_inputs(self, *args, **kwargs): def _process_inputs(self, *args, **kwargs):
......
...@@ -20,12 +20,17 @@ ...@@ -20,12 +20,17 @@
#include "./helper.h" #include "./helper.h"
#include "megbrain/plugin/profiler.h" #include "megbrain/plugin/profiler.h"
#include "./common.h" #include "./common.h"
#include "megbrain/gopt/inference.h"
namespace py = pybind11; namespace py = pybind11;
using namespace mgb; using namespace mgb;
using namespace imperative; using namespace imperative;
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
namespace { namespace {
class _CompGraphProfilerImpl { class _CompGraphProfilerImpl {
std::shared_ptr<ComputingGraph> m_comp_graph; std::shared_ptr<ComputingGraph> m_comp_graph;
...@@ -138,6 +143,37 @@ void init_graph_rt(py::module m) { ...@@ -138,6 +143,37 @@ void init_graph_rt(py::module m) {
return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size()); return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
}); });
auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
.def(py::init())
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
.def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
.def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform)
;
py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
.value("DEFAULT", _LayoutTransform::DEFAULT)
.value("NCHW4", _LayoutTransform::NCHW4)
.value("NHWCD4", _LayoutTransform::NHWCD4)
.value("NCHW88", _LayoutTransform::NCHW88)
.value("NCHW44", _LayoutTransform::NCHW44)
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
.value("NCHW32", _LayoutTransform::NCHW32)
.value("CHWN4", _LayoutTransform::CHWN4)
.export_values()
;
m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt);
VarNodeArray vars;
for (auto& si: res_symvars)
vars.push_back(si.node());
return vars;
});
#define CURRENT_CLASS cg::ComputingGraph::Options #define CURRENT_CLASS cg::ComputingGraph::Options
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pytest
from megengine.core import Tensor
# from megengine.core.interpreter.hints import function
@pytest.mark.skip(reason="under rewrite")
def test_1():
@function
def f(x, p):
x = x + 1
if p:
return x * x
return x * 2
x = Tensor(0)
for _ in range(5):
assert f(x, 0).numpy() == 2
assert f(x, 1).numpy() == 1
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io import io
from tempfile import mkstemp
import numpy as np import numpy as np
import pytest
from megengine import tensor
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
...@@ -101,3 +114,85 @@ def test_trace_profiler(): ...@@ -101,3 +114,85 @@ def test_trace_profiler():
out = f.get_profile() out = f.get_profile()
assert out.get("profiler") assert out.get("profiler")
@pytest.mark.skip(reason="eq_to_unit failed in inplace.cpp")
def test_goptions_div_zero():
@trace(symbolic=True, opt_level=0)
def f(x):
return x / x
@trace(symbolic=True, opt_level=1)
def g(x):
return x / x
out = f(tensor(0.0))
if out == out:
raise ValueError("actual result should be nan")
out = g(tensor(0.0))
if out != out:
raise ValueError("actual result should be 1")
@pytest.mark.skip(reason="cast to Elemwise failed in inplace.cpp")
def test_goptions_log_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
return log(exp(x))
@trace(symbolic=True, opt_level=1, capture_as_const=True)
def g(x):
return log(exp(x))
f(tensor(1.0))
_, out = mkstemp()
f.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0))
g.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
assert len(oprs_1) - len(oprs_2) == 2
@pytest.mark.skip(reason="need cgtools to check final oprs")
def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y):
return log(exp(x) + exp(y))
@trace(symbolic=True, opt_level=1, capture_as_const=True)
def g(x, y):
return log(exp(x) + exp(y))
f(tensor(1.0), tensor(2.0))
_, out = mkstemp()
f.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0), tensor(2.0))
g.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
assert len(oprs_1) - len(oprs_2) == 2
@pytest.mark.skip(reason="need cgtools to check computing input dtype")
def test_optimize_for_inference():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return exp(x)
_, out = mkstemp()
f(tensor(5.0))
f.dump(out, optimize_for_inference=True, optimize_options={"enable_io16xc32": True})
res = G.load_comp_graph_from_file(out)
computing_input = res.output_vars_list[0].owner.inputs[0]
assert computing_input.dtype == np.float16
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册