提交 3e4e4c46 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/jit): add graph_opt_config and jit_config interfaces

GitOrigin-RevId: 170d9eeab295d3f1553b9991e507f877af478251
上级 1c7d0802
......@@ -11,6 +11,7 @@ from ..core._imperative_rt.core2 import (
set_cpp_apply_with_tracing,
)
from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import (
apply_const_with_tracing,
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
class GraphOptimizationConfig:
r"""
Configuration for graph optimization: False for OFF, True for ON. The default value
None means that opt_level will decide whther this optimization will be applied or not.
:param jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization
:param jit_fuse_reduce: whether to fuse reduce in JIT optimization
"""
def __init__(self):
self.jit_fuse_dimshuffle = None
self.jit_fuse_reduce = None
def __repr__(self):
val2str = {None: "UNSET", False: "OFF", True: "ON"}
return (
"GraphOptimizationConfig {"
+ " jit_fuse_dimshuffle = "
+ val2str[self.jit_fuse_dimshuffle]
+ ", jit_fuse_reduce = "
+ val2str[self.jit_fuse_reduce]
+ " }"
)
......@@ -38,6 +38,7 @@ from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils.naming import AutoNaming
from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig
......@@ -129,6 +130,7 @@ class trace:
If not None, it enables sublinear memory optimization with given setting.
:param profiling: whether to profile compiled trace. Default: False
:param opt_level: optimization level for compiling trace. Default: 2
:param graph_opt_config: configuration for graph optimization. Default: None
:param symbolic_shape: whether to use symbolic shape for tracing. Default: True
"""
......@@ -146,6 +148,7 @@ class trace:
dtr_config: DTRConfig = None,
profiling: bool = False,
opt_level: int = 2,
graph_opt_config: GraphOptimizationConfig = None,
symbolic_shape: bool = True,
):
self.__wrapped__ = function
......@@ -156,6 +159,7 @@ class trace:
self._profiling = profiling
self._profiler = None
self._graph_opt_level = opt_level
self._graph_opt_config = graph_opt_config
self._symbolic_shape = symbolic_shape
self._output_handles = set()
......@@ -502,7 +506,14 @@ class trace:
graph.options.dtr_config.evictee_minimum_size = (
self._dtr_config.evictee_minimum_size
)
# graph optimization
if self._graph_opt_config is not None:
mapping = {None: 0, False: 1, True: 2}
jit_config = graph.options.graph_opt.jit_config
jit_config.fuse_dimshuffle = mapping[
self._graph_opt_config.jit_fuse_dimshuffle
]
jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
# sublinear
if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True
......
......@@ -421,12 +421,20 @@ void init_graph_rt(py::module m) {
#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt
py::class_<cg::ComputingGraph::Options::GraphOpt>(PyComputingGraphOptions, "GraphOpt")
auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>(
PyComputingGraphOptions, "GraphOpt")
DEF_READWRITE(jit)
DEF_READWRITE(jit_config)
DEF_READWRITE(tensorrt);
#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig
py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(PyGraphOpt, "JITConfig")
DEF_READWRITE(fuse_dimshuffle)
DEF_READWRITE(fuse_reduce);
#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig
py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig")
......
......@@ -25,7 +25,7 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace
from megengine.module import Module
from megengine.random import normal, uniform
from megengine.utils.naming import AutoNaming
......@@ -605,3 +605,30 @@ def test_trace_advance_indexing(shape_mode):
for _ in range(3):
result_trace = f_traced(**params)
np.testing.assert_equal(expected, result_trace.numpy())
@pytest.mark.require_ngpu(1) # nvrtc backend
def test_trace_jit_config():
def run(fuse_dimshuffle, fuse_reduce):
config = GraphOptimizationConfig()
config.jit_fuse_dimshuffle = fuse_dimshuffle
config.jit_fuse_reduce = fuse_reduce
# set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time
@trace(opt_level=1, graph_opt_config=config)
def func(x):
return x + 1
x = tensor(2)
y = func(x)
func._compile()
options = func._graph.options
mapping = {None: 0, False: 1, True: 2}
assert options.graph_opt.jit == 0
assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]
assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce]
for fuse_dimshuffle in [None, False, True]:
for fuse_reduce in [None, False, True]:
run(fuse_dimshuffle, fuse_reduce)
......@@ -145,6 +145,24 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) {
}
#endif
/* ========================== JITConfig ========================== */
bool ComputingGraph::Options::GraphOpt::JITConfig::enabled() const {
if (fuse_dimshuffle != UNSET) return true;
if (fuse_reduce != UNSET) return true;
return false;
}
void ComputingGraph::Options::GraphOpt::JITConfig::update(
const JITConfig& modifier) {
if (modifier.fuse_dimshuffle != UNSET) {
this->fuse_dimshuffle = modifier.fuse_dimshuffle;
}
if (modifier.fuse_reduce != UNSET) {
this->fuse_reduce = modifier.fuse_reduce;
}
}
/* ========================== CallbackCaller ========================== */
MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
SingleCNOperatorNodeBase) // {
......@@ -538,12 +556,18 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
#if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
setenv("MGB_JIT_BACKEND","NVRTC",1);
if (std::abs(options().graph_opt_level) == 0 &&
(options().graph_opt.jit || options().graph_opt.jit_config.enabled())) {
// Deprecated usage added previously. It allows NVRTC JIT optimization
// when graph_opt_level is 0. This usage is not recommanded any more.
mgb_log_warn(
"It is not recommanded to enable JIT optimization when "
"graph_opt_level is 0.");
setenv("MGB_JIT_BACKEND", "NVRTC", 1);
gopt::GraphOptimizer optimizer;
optimizer.add_pass<gopt::JITFusionPass>(
sopr_stat.has_virtual_grad,
std::max<uint8_t>(options().graph_opt.jit, 1));
optimizer.add_pass<gopt::JITFusionPass>(sopr_stat.has_virtual_grad,
options().graph_opt.jit,
options().graph_opt.jit_config);
optimizer.apply_inplace(dest_vars);
}
#endif
......
......@@ -338,6 +338,20 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! this value indicates JIT level: 1 for basic elemwise opr; 2
//! for including reduce oprs
uint8_t jit = 0;
//! jit configurations
struct JITConfig {
static const int UNSET = 0;
static const int OFF = 1;
static const int ON = 2;
int fuse_dimshuffle = UNSET;
int fuse_reduce = UNSET;
bool enabled() const;
void update(const JITConfig& modifier);
} jit_config;
//! whether to enable fine-grained TensorRT opr replace
bool tensorrt = false;
} graph_opt;
......
......@@ -645,11 +645,21 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_pass<RemoveRedundantCopyPass>();
#if MGB_JIT
bool need_jit = false;
if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 ||
comp_graph_opt->graph_opt.jit)) {
need_jit = true;
using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
int jit_opt_level = 0;
JITConfig jit_config;
// for more detail on what is happening here, see comments on the
// constuctor of class JITFusionPass in fusion_pass.h
if (comp_graph_opt) {
jit_opt_level = comp_graph_opt->graph_opt.jit;
if (comp_graph_opt->graph_opt_level >= 3) {
jit_opt_level = std::max(jit_opt_level, 1);
}
jit_config = comp_graph_opt->graph_opt.jit_config;
}
bool need_jit = (jit_opt_level > 0) || jit_config.enabled();
if (need_jit && after_grad) {
add_pass<gopt::RecompTypeCvtPass>();
}
......@@ -662,9 +672,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
#if MGB_JIT
if (need_jit) {
add_pass<gopt::JITFusionPass>(
after_grad,
std::max<uint8_t>(comp_graph_opt->graph_opt.jit, 1));
add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config);
}
#endif
......
......@@ -428,14 +428,33 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
return false;
}
JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level)
JITFusionPass::JITFusionPass(bool after_grad, int jit_opt_level,
const JITConfig& jit_config)
: m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} {
// TODO reduce and dimshuffle can not coexsit now.
if (jit_opt_level >= 2) {
m_feature_bits |= JITFeatureBits::REDUCE;
} else {
// get default config from jit_opt_level
JITConfig config;
if (jit_opt_level == 1) {
config.fuse_dimshuffle = JITConfig::ON;
config.fuse_reduce = JITConfig::OFF;
} else if (jit_opt_level >= 2) {
config.fuse_dimshuffle = JITConfig::OFF;
config.fuse_reduce = JITConfig::ON;
}
// overwrite default config with custom settings
config.update(jit_config);
bool fuse_dimshuffle = config.fuse_dimshuffle == JITConfig::ON;
bool fuse_reduce = config.fuse_reduce == JITConfig::ON;
if (fuse_dimshuffle && fuse_reduce) {
mgb_assert(false, "reduce and dimshuffle can not coexist now");
}
if (fuse_dimshuffle) {
m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
}
if (fuse_reduce) {
m_feature_bits |= JITFeatureBits::REDUCE;
}
}
const char* JITFusionPass::name() const {
......
......@@ -39,7 +39,40 @@ class JITFusionPass final : public Pass {
JITFeatureBits m_feature_bits;
public:
JITFusionPass(bool after_grad = true, int8_t jit_opt_level = 1);
using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
/*
* Explanation of how graph_opt_level, jit_opt_level and jit_config
* control the behavior of JIT optimization:
*
* The design of this API is restricted by the historical burden of
* jit_opt_level and we have to support the old interface jit_opt_level and
* the new interface jit_config at the same time.
*
* How JITFusionPass decides its behavior:
* (1) When graph_opt_level is 3, it sets jit_opt_level to 1
* (2) When the user-defined jit_opt_level is greater than 1, it overwrites
* the previous value of jit_opt_level
* (3) We get a default jit_config from jit_opt_level:
* jit_opt_level = 0: JIT optimization OFF
* jit_opt_level = 1: dimshuffle ON, reduce OFF
* jit_opt_level = 2: dimshuffle OFF, reduce ON
* (4) The user-defined jit_config provides more precise control and
* overwrites the default settings defined by jit_opt_level
*
* Situations in which JIT optimization is ON:
* (1) graph_opt_level = 3
* (2) graph_opt_level = 2, jit_opt_level > 0
* (3) graph_opt_level = 2, jit_opt_level = 0, jit_config is set
* (4) graph_opt_level = 0, jit_opt_level > 0 (deprecated usage)
*
* Situations in which JIT optimization is OFF:
* (1) graph_opt_level = 2, jit_opt_level = 0, jit_config is unset
* (2) graph_opt_level = 1
* (3) graph_opt_level = 0, jit_opt_level = 0
*/
JITFusionPass(bool after_grad = true, int jit_opt_level = 0,
const JITConfig& jit_config = {});
const char* name() const override;
void apply(OptState& opt) const override;
};
......
......@@ -27,6 +27,8 @@
#include "megbrain/test/helper.h"
#include "megbrain/opr/dnn/convolution.h"
#include "../../core/impl/graph/cg_impl_seq.h"
#if MGB_JIT
using namespace mgb;
......@@ -1455,6 +1457,122 @@ TEST(TestJITNvrtc, DimshuffleGrad) {
}
}
TEST(TestJITNvrtc, JITConfig) {
using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
using CompSeq = cg::ComputingGraphImpl::ComputingSequence;
using ReduceMode = opr::Reduce::Param::Mode;
static const int UNSET = JITConfig::UNSET;
static const int OFF = JITConfig::OFF;
static const int ON = JITConfig::ON;
REQUIRE_GPU(1);
set_backend(Backend::NVRTC);
auto cn = CompNode::load("gpu0");
HostTensorGenerator<> gen;
auto run = [&](int graph_opt_level, int jit_opt_level,
const JITConfig& jit_config, bool expect_dimshuffle_fused,
bool expect_reduce_fused, bool expect_jit_enabled) {
auto cg = ComputingGraph::make();
cg->options().graph_opt_level = graph_opt_level;
cg->options().graph_opt.jit = jit_opt_level;
cg->options().graph_opt.jit_config = jit_config;
auto host_x = gen({2, 3, 4, 5}, cn);
auto x = opr::SharedDeviceTensor::make(*cg, *host_x);
// three types of operations to be fused by JIT
x = (2 * x + 3) * (3 * x - 1); // Elemwise
x = opr::Dimshuffle::make(x, {1, 2, 3, 0}); // Dimshuffle
x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce
auto func = cg->compile({make_callback_copy(x + 1, *host_x)});
auto comp_seq = dynamic_cast<CompSeq*>(func.get());
ASSERT_TRUE(comp_seq != nullptr);
bool dimshuffle_found = false, reduce_found = false,
jit_executor_found = false;
auto on_opr = [&](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Dimshuffle>()) {
dimshuffle_found = true;
} else if (opr->same_type<opr::Reduce>()) {
reduce_found = true;
} else if (opr->same_type<JITExecutor>()) {
jit_executor_found = true;
}
return true;
};
comp_seq->iter_opr_seq(on_opr);
ASSERT_EQ(expect_dimshuffle_fused, !dimshuffle_found);
ASSERT_EQ(expect_reduce_fused, !reduce_found);
ASSERT_EQ(expect_jit_enabled, jit_executor_found);
};
// graph_opt_level = 1, always OFF
for (int jit_opt_level : {0, 1, 2}) {
for (int fuse_dimshuffle : {UNSET, OFF, ON}) {
for (int fuse_reduce : {UNSET, OFF, ON}) {
run(1, jit_opt_level, JITConfig{fuse_dimshuffle, fuse_reduce},
false, false, false);
}
}
}
// some test cases are commented because dimshuffle and reduce can not be
// fused at the same time
for (int graph_opt_level : {0, 2}) {
// jit_opt_level = 0, default = {OFF, OFF}
run(graph_opt_level, 0, JITConfig{UNSET, UNSET}, false, false, false);
run(graph_opt_level, 0, JITConfig{UNSET, OFF}, false, false, true);
run(graph_opt_level, 0, JITConfig{UNSET, ON}, false, true, true);
run(graph_opt_level, 0, JITConfig{OFF, UNSET}, false, false, true);
run(graph_opt_level, 0, JITConfig{OFF, OFF}, false, false, true);
run(graph_opt_level, 0, JITConfig{OFF, ON}, false, true, true);
run(graph_opt_level, 0, JITConfig{ON, UNSET}, true, false, true);
run(graph_opt_level, 0, JITConfig{ON, OFF}, true, false, true);
// run(graph_opt_level, 0, JITConfig{ON, ON}, true, true, true);
}
{
// graph_opt_level = 3, jit_opt_level = 0, default = {ON, OFF}
run(3, 0, JITConfig{UNSET, UNSET}, true, false, true);
run(3, 0, JITConfig{UNSET, OFF}, true, false, true);
// run(3, 0, JITConfig{UNSET, ON}, true, true, true);
run(3, 0, JITConfig{OFF, UNSET}, false, false, true);
run(3, 0, JITConfig{OFF, OFF}, false, false, true);
run(3, 0, JITConfig{OFF, ON}, false, true, true);
run(3, 0, JITConfig{ON, UNSET}, true, false, true);
run(3, 0, JITConfig{ON, OFF}, true, false, true);
// run(3, 0, JITConfig{ON, ON}, true, true, true);
}
for (int graph_opt_level : {0, 2, 3}) {
// jit_opt_level = 1, default = {ON, OFF}
run(graph_opt_level, 1, JITConfig{UNSET, UNSET}, true, false, true);
run(graph_opt_level, 1, JITConfig{UNSET, OFF}, true, false, true);
// run(graph_opt_level, 1, JITConfig{UNSET, ON}, true, true, true);
run(graph_opt_level, 1, JITConfig{OFF, UNSET}, false, false, true);
run(graph_opt_level, 1, JITConfig{OFF, OFF}, false, false, true);
run(graph_opt_level, 1, JITConfig{OFF, ON}, false, true, true);
run(graph_opt_level, 1, JITConfig{ON, UNSET}, true, false, true);
run(graph_opt_level, 1, JITConfig{ON, OFF}, true, false, true);
// run(graph_opt_level, 1, JITConfig{ON, ON}, true, true, true);
// jit_opt_level = 2, default = {OFF, ON}
run(graph_opt_level, 2, JITConfig{UNSET, UNSET}, false, true, true);
run(graph_opt_level, 2, JITConfig{UNSET, OFF}, false, false, true);
run(graph_opt_level, 2, JITConfig{UNSET, ON}, false, true, true);
run(graph_opt_level, 2, JITConfig{OFF, UNSET}, false, true, true);
run(graph_opt_level, 2, JITConfig{OFF, OFF}, false, false, true);
run(graph_opt_level, 2, JITConfig{OFF, ON}, false, true, true);
// run(graph_opt_level, 2, JITConfig{ON, UNSET}, true, true, true);
run(graph_opt_level, 2, JITConfig{ON, OFF}, true, false, true);
// run(graph_opt_level, 2, JITConfig{ON, ON}, true, true, true);
}
}
TEST(TestJITExecutor, GradBehavior) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册