diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index d112e37e9968727e0a94574030409d5890aee1be..cd76bbb0f08f19f447c8140e8c325e1c8a08c657 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -18,6 +18,7 @@ import megengine._internal as mgb from megengine._internal.plugin import CompGraphProfiler from ..core import Tensor, graph, tensor +from .sublinear_memory_config import SublinearMemConfig def sideeffect(f): @@ -78,10 +79,12 @@ class trace: * accelerated evalutaion via :meth:`.__call__` :param func: Positional only argument. - :param symbolic: Whether to use symbolic tensor. + :param symbolic: Whether to use symbolic tensor. Default: False :param opt_level: Optimization level for compiling trace. :param log_level: Log level. - :param profiling: Whether to profile compiled trace. + :param enable_sublinear: Enable sublinear memory optimization. Default: False + :param sublinear_mem_config: Configuration for sublinear memory optimization. + :param profiling: Whether to profile compiled trace. Default: False """ _active_instance = None @@ -103,12 +106,16 @@ class trace: symbolic: bool = False, opt_level: int = None, log_level: int = None, + enable_sublinear: bool = False, + sublinear_mem_config: SublinearMemConfig = None, profiling: bool = False ): self.__wrapped__ = func self._symbolic = symbolic self._graph_opt_level = opt_level self._log_level = log_level + self._enable_sublinear = enable_sublinear + self._sublinear_mem_config = sublinear_mem_config self._status = self._UNSTARTED self._args = None self._kwargs = None @@ -280,11 +287,35 @@ class trace: def _apply_graph_options(self, cg): # graph opt level - if not self._graph_opt_level is None: + if not (self._graph_opt_level is None): cg.set_option("graph_opt_level", self._graph_opt_level) # log level - if not self._log_level is None: + if not (self._log_level is None): cg.set_option("log_level", self._log_level) + # sublinear + if self._enable_sublinear: + cg.set_option("enable_sublinear_memory_opt", True) + if not (self._sublinear_mem_config is None): + cg.set_option( + "sublinear_mem_cofig.lb_memory", + self._sublinear_mem_config.lb_memory, + ) + cg.set_option( + "sublinear_mem_cofig.genetic_nr_iter", + self._sublinear_mem_config.genetic_nr_iter, + ) + cg.set_option( + "sublinear_mem_cofig.genetic_pool_size", + self._sublinear_mem_config.genetic_pool_size, + ) + cg.set_option( + "sublinear_mem_cofig.thresh_nr_try", + self._sublinear_mem_config.thresh_nr_try, + ) + cg.set_option( + "sublinear_mem_cofig.num_worker", + self._sublinear_mem_config.num_worker, + ) # profile if self._profiling: self._profiler = CompGraphProfiler(cg) diff --git a/python_module/megengine/jit/sublinear_memory_config.py b/python_module/megengine/jit/sublinear_memory_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7603a843ef77242d3a32625c306cfd6b6989fd73 --- /dev/null +++ b/python_module/megengine/jit/sublinear_memory_config.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ..core.device import get_device_count + + +class SublinearMemConfig: + r""" + Configuration for sublinear memory optimization. + + :param thresh_nr_try: number of samples both for searching in linear space + and around current thresh in sublinear memory optimization. Default: 10. + It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. + :param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. + Default: 0. + It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. + :param genetic_pool_size: number of samples for the crossover random selection + during genetic optimization. Default: 20. + It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. + :param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. + It can be used to perform manual tradeoff between memory and speed. Default: 0. + It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. + :param num_worker: number of thread workers to search the optimum checkpoints + in sublinear memory optimization. Default: half of cpu number in the system. + It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. + """ + + def __init__( + self, + thresh_nr_try: int = 10, + genetic_nr_iter: int = 0, + genetic_pool_size: int = 20, + lb_memory: int = 0, + num_worker: int = get_device_count("cpu") / 2, + ): + self.thresh_nr_try = thresh_nr_try + self.genetic_nr_iter = genetic_nr_iter + self.genetic_pool_size = genetic_pool_size + self.lb_memory = lb_memory + self.num_worker = num_worker diff --git a/python_module/src/cpp/megbrain_config.cpp b/python_module/src/cpp/megbrain_config.cpp index 2b3bb9790ab81b4264fb449ca96e8faf06b530ae..ed2b40ac3339072d44d3c3930d5c9bcf6aefcd7a 100644 --- a/python_module/src/cpp/megbrain_config.cpp +++ b/python_module/src/cpp/megbrain_config.cpp @@ -42,7 +42,8 @@ bool _config::set_comp_graph_option( std::is_same::value || \ std::is_same::value || \ std::is_same::value || \ - std::is_same::value, \ + std::is_same::value || \ + std::is_same::value, \ "not bool/int opt"); \ if (name == #name_chk) { \ auto ret = opt.name_chk; \ @@ -66,6 +67,11 @@ bool _config::set_comp_graph_option( SET_CG_OPTION(allocate_static_mem_after_graph_compile); SET_CG_OPTION(log_level); SET_CG_OPTION(enable_sublinear_memory_opt); + SET_CG_OPTION(sublinear_mem_cofig.lb_memory); + SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter); + SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size); + SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try); + SET_CG_OPTION(sublinear_mem_cofig.num_worker); SET_CG_OPTION(enable_var_mem_defragment); SET_CG_OPTION(eager_evaluation); SET_CG_OPTION(enable_memory_swap); diff --git a/python_module/test/integration/test_correctness.py b/python_module/test/integration/test_correctness.py index ed17d0136738bce0450d9ffa3b193a7ab1fdbf16..282818317fc9a8f4f2d115365e8a6cdd10e6d184 100644 --- a/python_module/test/integration/test_correctness.py +++ b/python_module/test/integration/test_correctness.py @@ -17,6 +17,7 @@ import megengine as mge import megengine.functional as F from megengine import jit, tensor from megengine.functional.debug_param import set_conv_execution_strategy +from megengine.jit import SublinearMemConfig from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.optimizer import SGD from megengine.test import assertTensorClose @@ -130,7 +131,14 @@ def update_model(model_path): mge.save(checkpoint, model_path) -def run_test(model_path, use_jit, use_symbolic): +def run_test( + model_path, + use_jit, + use_symbolic, + enable_sublinear=False, + sublinear_mem_config=None, + max_err=None, +): """ Load the model with test cases and run the training for one iter. @@ -152,11 +160,17 @@ def run_test(model_path, use_jit, use_symbolic): data.set_value(checkpoint["data"]) label.set_value(checkpoint["label"]) - max_err = 1e-5 + if max_err is None: + max_err = 1e-5 train_func = train if use_jit: - train_func = jit.trace(train_func, symbolic=use_symbolic) + train_func = jit.trace( + train_func, + symbolic=use_symbolic, + enable_sublinear=enable_sublinear, + sublinear_mem_config=sublinear_mem_config, + ) opt.zero_grad() loss = train_func(data, label, net=net, opt=opt) @@ -183,3 +197,14 @@ def test_correctness(): run_test(model_path, False, False) run_test(model_path, True, False) run_test(model_path, True, True) + + # sublinear + config = SublinearMemConfig(genetic_nr_iter=10) + run_test( + model_path, + True, + True, + enable_sublinear=True, + sublinear_mem_config=config, + max_err=1e-5, + ) diff --git a/python_module/test/unit/jit/test_jit.py b/python_module/test/unit/jit/test_jit.py index ebb7fc6cb11164c3e56339ea346ff7abbf56df56..d5104037fc0e51cc3a11d2bdbe3bf181e6aef541 100644 --- a/python_module/test/unit/jit/test_jit.py +++ b/python_module/test/unit/jit/test_jit.py @@ -18,6 +18,7 @@ import megengine._internal as mgb import megengine.module as M from megengine import jit, tensor from megengine.core.tensor import Tensor +from megengine.jit import SublinearMemConfig from megengine.test import assertTensorClose @@ -185,3 +186,14 @@ def test_dump_bn_fused(): mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" ) + + +# Simply verify the options passed down +def test_sublinear(): + config = SublinearMemConfig(genetic_nr_iter=10) + + @jit.trace(symbolic=True, enable_sublinear=True, sublinear_mem_config=config) + def f(x): + return x + x + + f([0.0]) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 59e48417945aa8cb515aff4e88663285e66e24e6..97aaaebbbd98698d420e3b240523d77e1e07ba12 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner) static_infer_comp_seq_manager{owner}, grad_manager{owner}, #if MGB_ENABLE_SUBLINEAR - seq_modifier_for_sublinear_memory{owner}, + seq_modifier_for_sublinear_memory{owner, + &(owner->options().sublinear_mem_cofig)}, #endif #if MGB_ENABLE_MEMORY_SWAP memory_swap_support{owner}, diff --git a/src/core/impl/graph/seq_sublinear_memory.cpp b/src/core/impl/graph/seq_sublinear_memory.cpp index 4f4553531b0db15fb1acb6cf25712ff17b08ec08..16fc4d2f1866f58f4f3b0a57530b4e27d2cd8b16 100644 --- a/src/core/impl/graph/seq_sublinear_memory.cpp +++ b/src/core/impl/graph/seq_sublinear_memory.cpp @@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { std::vector> m_futures; std::mutex m_mtx; - struct Config { - size_t thresh_nr_try = 10; - size_t genetic_nr_iter = 0; - size_t genetic_pool_size = 20; - double lb_memory = 0; - }; - Config m_config; - /*! * \brief check given thresh, and update states * \return bottleneck value for given thresh @@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { public: ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) : m_par_modifier{par} { + auto & m_config = m_par_modifier->m_config; + //! allow environmental variable to overwrite the setting if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { - m_config.thresh_nr_try = std::stoi(env); + m_config->thresh_nr_try = std::stoi(env); } if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { - m_config.genetic_nr_iter = std::stoi(env); + m_config->genetic_nr_iter = std::stoi(env); } if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { auto psize = static_cast(std::stoi(env)); - mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0, + mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0, "invalid pool size %zu in genetic algorithm,", psize); - m_config.genetic_pool_size = psize; + m_config->genetic_pool_size = psize; } if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { - m_config.lb_memory = std::stod(env) * 1024 * 1024; + m_config->lb_memory = std::stoi(env) * 1024 * 1024; } } @@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { invoke_search(thresh); } - size_t NR_TRY = m_config.thresh_nr_try; + size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try; // search in linear space auto step = init_thresh / (NR_TRY + 1); @@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { RNGxorshf rng(2333); - size_t POOL_SIZE = m_config.genetic_pool_size; - size_t NR_ITER = m_config.genetic_nr_iter; + size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size; + size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter; auto mutation = [&](const SplitPointSet& sps) { auto s = *sps; size_t length = s.size(); @@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { } void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { - size_t lower_bound = m_config.lb_memory; + size_t lower_bound = m_par_modifier->m_config->lb_memory; if (m_min_bottleneck >= lower_bound) return; OprFootprint footprint; @@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search( msg.push_back('\n'); msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", m_min_bottleneck * SIZE2MB)); - if(!m_config.genetic_nr_iter) { + if(!m_par_modifier->m_config->genetic_nr_iter) { msg.append(ssprintf( "\nGenetic algorithm is currently DISABLED, " "set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" @@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action( "invalid planner concurrency: %zu", set); planner_concur = set; } else { - planner_concur = sys::get_cpu_count() / 2; + planner_concur = m_config->num_worker; } mgb_log_debug("use %zu threads to search for sublinear memory plan; " @@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { } SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( - ComputingGraphImpl* owner) - : m_mem_opt(owner), m_owner_graph(owner) {} + ComputingGraphImpl* owner, Config* config_p) + : m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} #endif // !MGB_ENABLE_SUBLINEAR diff --git a/src/core/impl/graph/seq_sublinear_memory.h b/src/core/impl/graph/seq_sublinear_memory.h index 63c0cdd155cd4a5c91ee6c30c1a5b39018a8c143..4335ef6040728d1dee5d443d932726d3cec3656d 100644 --- a/src/core/impl/graph/seq_sublinear_memory.h +++ b/src/core/impl/graph/seq_sublinear_memory.h @@ -12,6 +12,7 @@ #pragma once #include "./memory_optimizer.h" +#include "megbrain/graph/cg.h" #include "megbrain/utils/async_worker.h" #if MGB_ENABLE_SUBLINEAR @@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory { using SeqModifyAction = std::unordered_map; using SplitPointSet = std::shared_ptr>; + //! Config options + using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; + Config* m_config; + //! get modifications to be taken under some specific constraints class ModifyActionPlanner; @@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory { } public: - SeqModifierForSublinearMemory(ComputingGraphImpl* owner); + SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); //! see memory_optimizer set_priority_before_opt void set_priority_before_opt(const VarNodeArray& endpoints) { diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 6a17d9e50820656a6fc888031d90908c3840af0f..5283be523d4a6bc57cb4447717f4320c91180259 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -16,6 +16,7 @@ #include "megbrain/graph/static_infer.h" #include "megbrain/graph/seq_comp_node_opt.h" #include "megbrain/utils/event.h" +#include "megbrain/system.h" #if MGB_ENABLE_JSON #include "megbrain/utils/json.h" @@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this, //! whether to enable sublinear memory optimization bool enable_sublinear_memory_opt = false; + //! Control parameter for sublinear memory optimization + struct SublinearMemConfig { + int thresh_nr_try = 10; + int genetic_nr_iter = 0; + int genetic_pool_size = 20; + int lb_memory = 0; + int num_worker = sys::get_cpu_count() / 2; + } sublinear_mem_cofig; + //! do not re-profile to select best impl algo when input shape //! changes (use previous algo) bool no_profiling_on_shape_change = false; diff --git a/src/core/test/sublinear_memory.cpp b/src/core/test/sublinear_memory.cpp index d918e469c2016c06a75f6a583a487d4b733075ef..2e16bee3cae856e058895a1d1416d4bb43cf7e5d 100644 --- a/src/core/test/sublinear_memory.cpp +++ b/src/core/test/sublinear_memory.cpp @@ -504,57 +504,47 @@ TEST(TestSublinearMemory, DepsInTopoSort) { } TEST(TestSublinearMemory, BadOpr) { - constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER"; - auto old_value = getenv(KEY); - setenv(KEY, "50", 1); - MGB_TRY { - HostTensorGenerator<> gen; - auto cn = CompNode::load("xpu0"); - constexpr size_t N = 1024, Scale = 2; - auto host_x = gen({N}, cn); - for (bool bad : {false, true}) { - auto graph = ComputingGraph::make(); - auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), - bad_var = SublinearBadOpr::make(x, bad, Scale), - y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), - y1 = SublinearBadOpr::make(y0, false, N * Scale), - y = y1 + 1, - z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); - set_priority(y0, 0); - set_priority(y1, 1); - set_priority(y, 2); - set_priority(z, 3); - graph->options().graph_opt_level = 0; - graph->options().enable_sublinear_memory_opt = 1; - auto func = graph->compile({{y, {}}, {z, {}}}); - auto&& results = static_cast(graph.get()) - ->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); - // bottleneck: - // if bad : y = y1 + 1, bad_var should be saved to calculate - // z later, total memory usage is - // N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) - // else : bad_var = BadOpr(x), total memory usage is - // N(x) + N * scale(bad_var), bad_var would be recomputed - // when calculate z = reduce(bad_var) - size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; - ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); - size_t nr_bad_opr = 0; - auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { - if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { - ++ nr_bad_opr; - } - return true; - }; - func->iter_opr_seq(count_up); - ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); - } - } MGB_FINALLY( - if (old_value) { - setenv(KEY, old_value, 1); - } else { - unsetenv(KEY); - } - ); + HostTensorGenerator<> gen; + auto cn = CompNode::load("xpu0"); + constexpr size_t N = 1024, Scale = 2; + auto host_x = gen({N}, cn); + for (bool bad : {false, true}) { + auto graph = ComputingGraph::make(); + auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), + bad_var = SublinearBadOpr::make(x, bad, Scale), + y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), + y1 = SublinearBadOpr::make(y0, false, N * Scale), + y = y1 + 1, + z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); + set_priority(y0, 0); + set_priority(y1, 1); + set_priority(y, 2); + set_priority(z, 3); + graph->options().graph_opt_level = 0; + graph->options().enable_sublinear_memory_opt = 1; + graph->options().sublinear_mem_cofig.genetic_nr_iter = 50; + auto func = graph->compile({{y, {}}, {z, {}}}); + auto&& results = static_cast(graph.get()) + ->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); + // bottleneck: + // if bad : y = y1 + 1, bad_var should be saved to calculate + // z later, total memory usage is + // N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) + // else : bad_var = BadOpr(x), total memory usage is + // N(x) + N * scale(bad_var), bad_var would be recomputed + // when calculate z = reduce(bad_var) + size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; + ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); + size_t nr_bad_opr = 0; + auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { + if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { + ++ nr_bad_opr; + } + return true; + }; + func->iter_opr_seq(count_up); + ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); + } } #else