diff --git a/imperative/python/megengine/jit/__init__.py b/imperative/python/megengine/jit/__init__.py index 5f5db441472d3451d6325ed84ae52d574a7a7c26..bd50925e49f40b7a95aaf46ccdf456f0fd44a873 100644 --- a/imperative/python/megengine/jit/__init__.py +++ b/imperative/python/megengine/jit/__init__.py @@ -11,6 +11,7 @@ from ..core._imperative_rt.core2 import ( set_cpp_apply_with_tracing, ) from .dtr_config import DTRConfig +from .graph_opt_config import GraphOptimizationConfig from .sublinear_memory_config import SublinearMemoryConfig from .tracing import ( apply_const_with_tracing, diff --git a/imperative/python/megengine/jit/graph_opt_config.py b/imperative/python/megengine/jit/graph_opt_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9cea66b540f4acddf8980ab781a8cf455fb25607 --- /dev/null +++ b/imperative/python/megengine/jit/graph_opt_config.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + + +class GraphOptimizationConfig: + r""" + Configuration for graph optimization: False for OFF, True for ON. The default value + None means that opt_level will decide whther this optimization will be applied or not. + + :param jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization + :param jit_fuse_reduce: whether to fuse reduce in JIT optimization + """ + + def __init__(self): + self.jit_fuse_dimshuffle = None + self.jit_fuse_reduce = None + + def __repr__(self): + val2str = {None: "UNSET", False: "OFF", True: "ON"} + return ( + "GraphOptimizationConfig {" + + " jit_fuse_dimshuffle = " + + val2str[self.jit_fuse_dimshuffle] + + ", jit_fuse_reduce = " + + val2str[self.jit_fuse_reduce] + + " }" + ) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 31127281085120c03c804a5a0728f36e9139a6ee..8d57776e0211b1f2d427aff575f0ec0bf4962dfd 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -38,6 +38,7 @@ from ..core.tensor import megbrain_graph as G from ..core.tensor.utils import setscalar from ..utils.naming import AutoNaming from .dtr_config import DTRConfig +from .graph_opt_config import GraphOptimizationConfig from .sublinear_memory_config import SublinearMemoryConfig @@ -129,6 +130,7 @@ class trace: If not None, it enables sublinear memory optimization with given setting. :param profiling: whether to profile compiled trace. Default: False :param opt_level: optimization level for compiling trace. Default: 2 + :param graph_opt_config: configuration for graph optimization. Default: None :param symbolic_shape: whether to use symbolic shape for tracing. Default: True """ @@ -146,6 +148,7 @@ class trace: dtr_config: DTRConfig = None, profiling: bool = False, opt_level: int = 2, + graph_opt_config: GraphOptimizationConfig = None, symbolic_shape: bool = True, ): self.__wrapped__ = function @@ -156,6 +159,7 @@ class trace: self._profiling = profiling self._profiler = None self._graph_opt_level = opt_level + self._graph_opt_config = graph_opt_config self._symbolic_shape = symbolic_shape self._output_handles = set() @@ -502,7 +506,14 @@ class trace: graph.options.dtr_config.evictee_minimum_size = ( self._dtr_config.evictee_minimum_size ) - + # graph optimization + if self._graph_opt_config is not None: + mapping = {None: 0, False: 1, True: 2} + jit_config = graph.options.graph_opt.jit_config + jit_config.fuse_dimshuffle = mapping[ + self._graph_opt_config.jit_fuse_dimshuffle + ] + jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce] # sublinear if self._sublinear_memory_config is not None: graph.options.enable_sublinear_memory_opt = True diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 60b391b26bc36584a921e87773e51354f403a8d3..a43f756f2090f2e90601821c5394ed6ecf45d1cd 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -421,12 +421,20 @@ void init_graph_rt(py::module m) { #undef CURRENT_CLASS #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt - py::class_(PyComputingGraphOptions, "GraphOpt") + auto PyGraphOpt = py::class_( + PyComputingGraphOptions, "GraphOpt") DEF_READWRITE(jit) + DEF_READWRITE(jit_config) DEF_READWRITE(tensorrt); #undef CURRENT_CLASS +#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig + py::class_(PyGraphOpt, "JITConfig") + DEF_READWRITE(fuse_dimshuffle) + DEF_READWRITE(fuse_reduce); + +#undef CURRENT_CLASS #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig py::class_(PyComputingGraphOptions, "SublinearMemConfig") diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 7fee2ab0907b99583cb275c4c6a47b876e140e98..f6a650f8f9f342c8158c734fbd7357d2bc551424 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -25,7 +25,7 @@ from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.utils import isscalar from megengine.functional import exp, log -from megengine.jit import exclude_from_trace, trace +from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace from megengine.module import Module from megengine.random import normal, uniform from megengine.utils.naming import AutoNaming @@ -605,3 +605,30 @@ def test_trace_advance_indexing(shape_mode): for _ in range(3): result_trace = f_traced(**params) np.testing.assert_equal(expected, result_trace.numpy()) + + +@pytest.mark.require_ngpu(1) # nvrtc backend +def test_trace_jit_config(): + def run(fuse_dimshuffle, fuse_reduce): + config = GraphOptimizationConfig() + config.jit_fuse_dimshuffle = fuse_dimshuffle + config.jit_fuse_reduce = fuse_reduce + + # set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time + @trace(opt_level=1, graph_opt_config=config) + def func(x): + return x + 1 + + x = tensor(2) + y = func(x) + func._compile() + + options = func._graph.options + mapping = {None: 0, False: 1, True: 2} + assert options.graph_opt.jit == 0 + assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] + assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce] + + for fuse_dimshuffle in [None, False, True]: + for fuse_reduce in [None, False, True]: + run(fuse_dimshuffle, fuse_reduce) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index af1707193bb03fa9f82c2fe0052a91f077dae38d..43b9a5a4e356860735dcfdad29d8e8037f301e23 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -145,6 +145,24 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { } #endif +/* ========================== JITConfig ========================== */ + +bool ComputingGraph::Options::GraphOpt::JITConfig::enabled() const { + if (fuse_dimshuffle != UNSET) return true; + if (fuse_reduce != UNSET) return true; + return false; +} + +void ComputingGraph::Options::GraphOpt::JITConfig::update( + const JITConfig& modifier) { + if (modifier.fuse_dimshuffle != UNSET) { + this->fuse_dimshuffle = modifier.fuse_dimshuffle; + } + if (modifier.fuse_reduce != UNSET) { + this->fuse_reduce = modifier.fuse_reduce; + } +} + /* ========================== CallbackCaller ========================== */ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, SingleCNOperatorNodeBase) // { @@ -538,12 +556,18 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( #if MGB_JIT - if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { - setenv("MGB_JIT_BACKEND","NVRTC",1); + if (std::abs(options().graph_opt_level) == 0 && + (options().graph_opt.jit || options().graph_opt.jit_config.enabled())) { + // Deprecated usage added previously. It allows NVRTC JIT optimization + // when graph_opt_level is 0. This usage is not recommanded any more. + mgb_log_warn( + "It is not recommanded to enable JIT optimization when " + "graph_opt_level is 0."); + setenv("MGB_JIT_BACKEND", "NVRTC", 1); gopt::GraphOptimizer optimizer; - optimizer.add_pass( - sopr_stat.has_virtual_grad, - std::max(options().graph_opt.jit, 1)); + optimizer.add_pass(sopr_stat.has_virtual_grad, + options().graph_opt.jit, + options().graph_opt.jit_config); optimizer.apply_inplace(dest_vars); } #endif diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index ce4524fa959a8aeae5e24ca457b1b5532f0a4550..8ca976e46e7601f5f5d1f7bcb01805b7189be441 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -338,6 +338,20 @@ class ComputingGraph : public std::enable_shared_from_this, //! this value indicates JIT level: 1 for basic elemwise opr; 2 //! for including reduce oprs uint8_t jit = 0; + + //! jit configurations + struct JITConfig { + static const int UNSET = 0; + static const int OFF = 1; + static const int ON = 2; + + int fuse_dimshuffle = UNSET; + int fuse_reduce = UNSET; + + bool enabled() const; + void update(const JITConfig& modifier); + } jit_config; + //! whether to enable fine-grained TensorRT opr replace bool tensorrt = false; } graph_opt; diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 2eebbe6396bf0ca620dd35c38009789dc2c4b28c..9a914afc930c0ea256be387a68736706a58b2e14 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -645,11 +645,21 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( add_pass(); #if MGB_JIT - bool need_jit = false; - if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || - comp_graph_opt->graph_opt.jit)) { - need_jit = true; + using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; + int jit_opt_level = 0; + JITConfig jit_config; + + // for more detail on what is happening here, see comments on the + // constuctor of class JITFusionPass in fusion_pass.h + if (comp_graph_opt) { + jit_opt_level = comp_graph_opt->graph_opt.jit; + if (comp_graph_opt->graph_opt_level >= 3) { + jit_opt_level = std::max(jit_opt_level, 1); + } + jit_config = comp_graph_opt->graph_opt.jit_config; } + bool need_jit = (jit_opt_level > 0) || jit_config.enabled(); + if (need_jit && after_grad) { add_pass(); } @@ -662,9 +672,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( #if MGB_JIT if (need_jit) { - add_pass( - after_grad, - std::max(comp_graph_opt->graph_opt.jit, 1)); + add_pass(after_grad, jit_opt_level, jit_config); } #endif diff --git a/src/jit/impl/fusion_pass.cpp b/src/jit/impl/fusion_pass.cpp index 0f6712f576c52b1663f0a5af9c27d2558822b40f..a64f41f9db4e12ede4764e3bee6754b442ef2207 100644 --- a/src/jit/impl/fusion_pass.cpp +++ b/src/jit/impl/fusion_pass.cpp @@ -428,14 +428,33 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { return false; } -JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level) +JITFusionPass::JITFusionPass(bool after_grad, int jit_opt_level, + const JITConfig& jit_config) : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} { - // TODO reduce and dimshuffle can not coexsit now. - if (jit_opt_level >= 2) { - m_feature_bits |= JITFeatureBits::REDUCE; - } else { + // get default config from jit_opt_level + JITConfig config; + if (jit_opt_level == 1) { + config.fuse_dimshuffle = JITConfig::ON; + config.fuse_reduce = JITConfig::OFF; + } else if (jit_opt_level >= 2) { + config.fuse_dimshuffle = JITConfig::OFF; + config.fuse_reduce = JITConfig::ON; + } + + // overwrite default config with custom settings + config.update(jit_config); + bool fuse_dimshuffle = config.fuse_dimshuffle == JITConfig::ON; + bool fuse_reduce = config.fuse_reduce == JITConfig::ON; + + if (fuse_dimshuffle && fuse_reduce) { + mgb_assert(false, "reduce and dimshuffle can not coexist now"); + } + if (fuse_dimshuffle) { m_feature_bits |= JITFeatureBits::DIMSHUFFLE; } + if (fuse_reduce) { + m_feature_bits |= JITFeatureBits::REDUCE; + } } const char* JITFusionPass::name() const { diff --git a/src/jit/include/megbrain/jit/fusion_pass.h b/src/jit/include/megbrain/jit/fusion_pass.h index 91a7af9d35d1afeededf839107141b0035acd84a..c7d4a02c2d4e83adb3926bc9d628ad47dde401a5 100644 --- a/src/jit/include/megbrain/jit/fusion_pass.h +++ b/src/jit/include/megbrain/jit/fusion_pass.h @@ -39,7 +39,40 @@ class JITFusionPass final : public Pass { JITFeatureBits m_feature_bits; public: - JITFusionPass(bool after_grad = true, int8_t jit_opt_level = 1); + using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; + + /* + * Explanation of how graph_opt_level, jit_opt_level and jit_config + * control the behavior of JIT optimization: + * + * The design of this API is restricted by the historical burden of + * jit_opt_level and we have to support the old interface jit_opt_level and + * the new interface jit_config at the same time. + * + * How JITFusionPass decides its behavior: + * (1) When graph_opt_level is 3, it sets jit_opt_level to 1 + * (2) When the user-defined jit_opt_level is greater than 1, it overwrites + * the previous value of jit_opt_level + * (3) We get a default jit_config from jit_opt_level: + * jit_opt_level = 0: JIT optimization OFF + * jit_opt_level = 1: dimshuffle ON, reduce OFF + * jit_opt_level = 2: dimshuffle OFF, reduce ON + * (4) The user-defined jit_config provides more precise control and + * overwrites the default settings defined by jit_opt_level + * + * Situations in which JIT optimization is ON: + * (1) graph_opt_level = 3 + * (2) graph_opt_level = 2, jit_opt_level > 0 + * (3) graph_opt_level = 2, jit_opt_level = 0, jit_config is set + * (4) graph_opt_level = 0, jit_opt_level > 0 (deprecated usage) + * + * Situations in which JIT optimization is OFF: + * (1) graph_opt_level = 2, jit_opt_level = 0, jit_config is unset + * (2) graph_opt_level = 1 + * (3) graph_opt_level = 0, jit_opt_level = 0 + */ + JITFusionPass(bool after_grad = true, int jit_opt_level = 0, + const JITConfig& jit_config = {}); const char* name() const override; void apply(OptState& opt) const override; }; diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index a91e095cf3a28e45d07616215f4d972c4347a0f4..852f6a10e4aa04a2f8a382e6d1cc71c2a70654b0 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -27,6 +27,8 @@ #include "megbrain/test/helper.h" #include "megbrain/opr/dnn/convolution.h" +#include "../../core/impl/graph/cg_impl_seq.h" + #if MGB_JIT using namespace mgb; @@ -1455,6 +1457,122 @@ TEST(TestJITNvrtc, DimshuffleGrad) { } } +TEST(TestJITNvrtc, JITConfig) { + using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; + using CompSeq = cg::ComputingGraphImpl::ComputingSequence; + using ReduceMode = opr::Reduce::Param::Mode; + static const int UNSET = JITConfig::UNSET; + static const int OFF = JITConfig::OFF; + static const int ON = JITConfig::ON; + + REQUIRE_GPU(1); + set_backend(Backend::NVRTC); + auto cn = CompNode::load("gpu0"); + HostTensorGenerator<> gen; + + auto run = [&](int graph_opt_level, int jit_opt_level, + const JITConfig& jit_config, bool expect_dimshuffle_fused, + bool expect_reduce_fused, bool expect_jit_enabled) { + auto cg = ComputingGraph::make(); + cg->options().graph_opt_level = graph_opt_level; + cg->options().graph_opt.jit = jit_opt_level; + cg->options().graph_opt.jit_config = jit_config; + + auto host_x = gen({2, 3, 4, 5}, cn); + auto x = opr::SharedDeviceTensor::make(*cg, *host_x); + + // three types of operations to be fused by JIT + x = (2 * x + 3) * (3 * x - 1); // Elemwise + x = opr::Dimshuffle::make(x, {1, 2, 3, 0}); // Dimshuffle + x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce + + auto func = cg->compile({make_callback_copy(x + 1, *host_x)}); + auto comp_seq = dynamic_cast(func.get()); + ASSERT_TRUE(comp_seq != nullptr); + + bool dimshuffle_found = false, reduce_found = false, + jit_executor_found = false; + auto on_opr = [&](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + dimshuffle_found = true; + } else if (opr->same_type()) { + reduce_found = true; + } else if (opr->same_type()) { + jit_executor_found = true; + } + return true; + }; + comp_seq->iter_opr_seq(on_opr); + + ASSERT_EQ(expect_dimshuffle_fused, !dimshuffle_found); + ASSERT_EQ(expect_reduce_fused, !reduce_found); + ASSERT_EQ(expect_jit_enabled, jit_executor_found); + }; + + // graph_opt_level = 1, always OFF + for (int jit_opt_level : {0, 1, 2}) { + for (int fuse_dimshuffle : {UNSET, OFF, ON}) { + for (int fuse_reduce : {UNSET, OFF, ON}) { + run(1, jit_opt_level, JITConfig{fuse_dimshuffle, fuse_reduce}, + false, false, false); + } + } + } + + // some test cases are commented because dimshuffle and reduce can not be + // fused at the same time + + for (int graph_opt_level : {0, 2}) { + // jit_opt_level = 0, default = {OFF, OFF} + run(graph_opt_level, 0, JITConfig{UNSET, UNSET}, false, false, false); + run(graph_opt_level, 0, JITConfig{UNSET, OFF}, false, false, true); + run(graph_opt_level, 0, JITConfig{UNSET, ON}, false, true, true); + run(graph_opt_level, 0, JITConfig{OFF, UNSET}, false, false, true); + run(graph_opt_level, 0, JITConfig{OFF, OFF}, false, false, true); + run(graph_opt_level, 0, JITConfig{OFF, ON}, false, true, true); + run(graph_opt_level, 0, JITConfig{ON, UNSET}, true, false, true); + run(graph_opt_level, 0, JITConfig{ON, OFF}, true, false, true); + // run(graph_opt_level, 0, JITConfig{ON, ON}, true, true, true); + } + + { + // graph_opt_level = 3, jit_opt_level = 0, default = {ON, OFF} + run(3, 0, JITConfig{UNSET, UNSET}, true, false, true); + run(3, 0, JITConfig{UNSET, OFF}, true, false, true); + // run(3, 0, JITConfig{UNSET, ON}, true, true, true); + run(3, 0, JITConfig{OFF, UNSET}, false, false, true); + run(3, 0, JITConfig{OFF, OFF}, false, false, true); + run(3, 0, JITConfig{OFF, ON}, false, true, true); + run(3, 0, JITConfig{ON, UNSET}, true, false, true); + run(3, 0, JITConfig{ON, OFF}, true, false, true); + // run(3, 0, JITConfig{ON, ON}, true, true, true); + } + + for (int graph_opt_level : {0, 2, 3}) { + // jit_opt_level = 1, default = {ON, OFF} + run(graph_opt_level, 1, JITConfig{UNSET, UNSET}, true, false, true); + run(graph_opt_level, 1, JITConfig{UNSET, OFF}, true, false, true); + // run(graph_opt_level, 1, JITConfig{UNSET, ON}, true, true, true); + run(graph_opt_level, 1, JITConfig{OFF, UNSET}, false, false, true); + run(graph_opt_level, 1, JITConfig{OFF, OFF}, false, false, true); + run(graph_opt_level, 1, JITConfig{OFF, ON}, false, true, true); + run(graph_opt_level, 1, JITConfig{ON, UNSET}, true, false, true); + run(graph_opt_level, 1, JITConfig{ON, OFF}, true, false, true); + // run(graph_opt_level, 1, JITConfig{ON, ON}, true, true, true); + + // jit_opt_level = 2, default = {OFF, ON} + run(graph_opt_level, 2, JITConfig{UNSET, UNSET}, false, true, true); + run(graph_opt_level, 2, JITConfig{UNSET, OFF}, false, false, true); + run(graph_opt_level, 2, JITConfig{UNSET, ON}, false, true, true); + run(graph_opt_level, 2, JITConfig{OFF, UNSET}, false, true, true); + run(graph_opt_level, 2, JITConfig{OFF, OFF}, false, false, true); + run(graph_opt_level, 2, JITConfig{OFF, ON}, false, true, true); + // run(graph_opt_level, 2, JITConfig{ON, UNSET}, true, true, true); + run(graph_opt_level, 2, JITConfig{ON, OFF}, true, false, true); + // run(graph_opt_level, 2, JITConfig{ON, ON}, true, true, true); + } +} + TEST(TestJITExecutor, GradBehavior) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0");