提交 91efd67d 编写于 作者: M Megvii Engine Team

refactor(mge/jit): change dump options, add test

GitOrigin-RevId: fbc0d51c2be1fd51aaea121f6afa48b25abf661a
上级 099ffeac
......@@ -130,32 +130,31 @@ def optimize_for_inference(dest_vars, **kwargs):
inference)
"""
inference_options = GraphOptimizeOptions()
if optimize_for_inference:
inference_optimize_layout_transform_map = {
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4,
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4,
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88,
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32,
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
}
for k, v in inference_optimize_layout_transform_map.items():
if kwargs.pop(k, False):
inference_options.layout_transform = v
if kwargs.pop("enable_io16xc32", False):
inference_options.f16_io_f32_comp = True
if kwargs.pop("enable_ioc16", False):
inference_options.f16_io_comp = True
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False):
inference_options.fuse_conv_bias_nonlinearity = True
if kwargs.pop("enable_fuse_conv_bias_with_z", False):
inference_options.fuse_conv_bias_with_z = True
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
inference_optimize_layout_transform_map = {
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4,
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4,
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88,
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32,
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
}
for k, v in inference_optimize_layout_transform_map.items():
if kwargs.pop(k, False):
inference_options.layout_transform = v
if kwargs.pop("enable_io16xc32", False):
inference_options.f16_io_f32_comp = True
if kwargs.pop("enable_ioc16", False):
inference_options.f16_io_comp = True
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False):
inference_options.fuse_conv_bias_nonlinearity = True
if kwargs.pop("enable_fuse_conv_bias_with_z", False):
inference_options.fuse_conv_bias_with_z = True
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
res_vars = _imperative_rt.optimize_for_inference(
[i._node for i in dest_vars], inference_options
......
......@@ -458,7 +458,16 @@ class trace:
self._process_outputs(outputs)
return outputs
def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs):
def dump(
self,
file,
*,
arg_names=None,
output_names=None,
append=False,
optimize_for_inference=True,
**kwargs
):
r"""Serializes trace to file system.
:param file: output file, could be file object or filename.
......@@ -467,6 +476,8 @@ class trace:
use the default name if not specified.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:Keyword Arguments:
......@@ -572,7 +583,8 @@ class trace:
v.name = output_names[i]
dest_vars.append(v)
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
if isinstance(file, str):
permission = "wb" if append == False else "ab"
......
......@@ -155,6 +155,9 @@ void init_graph_rt(py::module m) {
})
.def_property_readonly("id",[](cg::VarNode* v){
return (v->id());
})
.def("__repr__", [](cg::VarNode* v) {
return "Var:" + v->name();
});
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
......@@ -175,6 +178,9 @@ void init_graph_rt(py::module m) {
})
.def_property_readonly("type",[](cg::OperatorNodeBase* opr){
return opr->dyn_typeinfo()->name;
})
.def("__repr__", [](cg::OperatorNodeBase* opr){
return "Opr:" + opr->name();
});
......
......@@ -67,7 +67,6 @@ def test_replace_oprs():
np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
@pytest.mark.skip(reason="Please check opr index")
def test_graph_traversal():
net = M.Conv2d(3, 32, 3)
......@@ -77,11 +76,11 @@ def test_graph_traversal():
return x
data = np.random.random([1, 3, 224, 224]).astype(np.float32)
for i in range(3):
for _ in range(3):
fun(megengine.tensor(data))
file = io.BytesIO()
fun.dump(file)
fun.dump(file, optimize_for_inference=False)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
......
......@@ -13,7 +13,6 @@ import numpy as np
import pytest
import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.module as M
from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape
......@@ -150,7 +149,6 @@ def test_capture_dump():
np.testing.assert_equal(result[0], y)
@pytest.mark.skip(reason="get MultipleDeviceTensorHolder instead of SharedDeviceTensor")
def test_dump_volatile():
p = as_raw_tensor([2])
......@@ -167,7 +165,7 @@ def test_dump_volatile():
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
file = io.BytesIO()
f.dump(file)
f.dump(file, optimize_for_inference=False)
file.seek(0)
cg, _, outputs = G.load_graph(file)
(out,) = outputs
......@@ -196,26 +194,7 @@ def test_trace_profiler():
assert out.get("profiler")
@pytest.mark.skip(reason="eq_to_unit failed in inplace.cpp")
def test_goptions_div_zero():
@trace(symbolic=True, opt_level=0)
def f(x):
return x / x
@trace(symbolic=True, opt_level=1)
def g(x):
return x / x
out = f(tensor(0.0))
if out == out:
raise ValueError("actual result should be nan")
out = g(tensor(0.0))
if out != out:
raise ValueError("actual result should be 1")
@pytest.mark.skip(reason="cast to Elemwise failed in inplace.cpp")
@pytest.mark.skip(reason="could not disable opt_level")
def test_goptions_log_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
......@@ -227,19 +206,19 @@ def test_goptions_log_exp():
f(tensor(1.0))
_, out = mkstemp()
f.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0))
g.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
assert len(oprs_1) - len(oprs_2) == 2
@pytest.mark.skip(reason="need cgtools to check final oprs")
@pytest.mark.skip(reason="could not disable opt_level")
def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y):
......@@ -251,19 +230,18 @@ def test_goptions_log_sum_exp():
f(tensor(1.0), tensor(2.0))
_, out = mkstemp()
f.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0), tensor(2.0))
g.dump(out)
*_, outputs = G.load_comp_graph_from_file(out)
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
assert len(oprs_1) - len(oprs_2) == 2
@pytest.mark.skip(reason="need cgtools to check computing input dtype")
def test_optimize_for_inference():
@trace(symbolic=True, capture_as_const=True)
def f(x):
......@@ -271,9 +249,9 @@ def test_optimize_for_inference():
_, out = mkstemp()
f(tensor(5.0))
f.dump(out, optimize_for_inference=True, optimize_options={"enable_io16xc32": True})
f.dump(out, enable_io16xc32=True)
res = G.load_comp_graph_from_file(out)
res = G.load_graph(out)
computing_input = res.output_vars_list[0].owner.inputs[0]
assert computing_input.dtype == np.float16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册