提交 396653df 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/api): expose sublinear related parameters at mge api level

GitOrigin-RevId: 7a47f5d0d5a4d402b9666df587f962fe7276a9cd
上级 c985204b
...@@ -18,6 +18,7 @@ import megengine._internal as mgb ...@@ -18,6 +18,7 @@ import megengine._internal as mgb
from megengine._internal.plugin import CompGraphProfiler from megengine._internal.plugin import CompGraphProfiler
from ..core import Tensor, graph, tensor from ..core import Tensor, graph, tensor
from .sublinear_memory_config import SublinearMemConfig
def sideeffect(f): def sideeffect(f):
...@@ -78,10 +79,12 @@ class trace: ...@@ -78,10 +79,12 @@ class trace:
* accelerated evalutaion via :meth:`.__call__` * accelerated evalutaion via :meth:`.__call__`
:param func: Positional only argument. :param func: Positional only argument.
:param symbolic: Whether to use symbolic tensor. :param symbolic: Whether to use symbolic tensor. Default: False
:param opt_level: Optimization level for compiling trace. :param opt_level: Optimization level for compiling trace.
:param log_level: Log level. :param log_level: Log level.
:param profiling: Whether to profile compiled trace. :param enable_sublinear: Enable sublinear memory optimization. Default: False
:param sublinear_mem_config: Configuration for sublinear memory optimization.
:param profiling: Whether to profile compiled trace. Default: False
""" """
_active_instance = None _active_instance = None
...@@ -103,12 +106,16 @@ class trace: ...@@ -103,12 +106,16 @@ class trace:
symbolic: bool = False, symbolic: bool = False,
opt_level: int = None, opt_level: int = None,
log_level: int = None, log_level: int = None,
enable_sublinear: bool = False,
sublinear_mem_config: SublinearMemConfig = None,
profiling: bool = False profiling: bool = False
): ):
self.__wrapped__ = func self.__wrapped__ = func
self._symbolic = symbolic self._symbolic = symbolic
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._log_level = log_level self._log_level = log_level
self._enable_sublinear = enable_sublinear
self._sublinear_mem_config = sublinear_mem_config
self._status = self._UNSTARTED self._status = self._UNSTARTED
self._args = None self._args = None
self._kwargs = None self._kwargs = None
...@@ -280,11 +287,35 @@ class trace: ...@@ -280,11 +287,35 @@ class trace:
def _apply_graph_options(self, cg): def _apply_graph_options(self, cg):
# graph opt level # graph opt level
if not self._graph_opt_level is None: if not (self._graph_opt_level is None):
cg.set_option("graph_opt_level", self._graph_opt_level) cg.set_option("graph_opt_level", self._graph_opt_level)
# log level # log level
if not self._log_level is None: if not (self._log_level is None):
cg.set_option("log_level", self._log_level) cg.set_option("log_level", self._log_level)
# sublinear
if self._enable_sublinear:
cg.set_option("enable_sublinear_memory_opt", True)
if not (self._sublinear_mem_config is None):
cg.set_option(
"sublinear_mem_cofig.lb_memory",
self._sublinear_mem_config.lb_memory,
)
cg.set_option(
"sublinear_mem_cofig.genetic_nr_iter",
self._sublinear_mem_config.genetic_nr_iter,
)
cg.set_option(
"sublinear_mem_cofig.genetic_pool_size",
self._sublinear_mem_config.genetic_pool_size,
)
cg.set_option(
"sublinear_mem_cofig.thresh_nr_try",
self._sublinear_mem_config.thresh_nr_try,
)
cg.set_option(
"sublinear_mem_cofig.num_worker",
self._sublinear_mem_config.num_worker,
)
# profile # profile
if self._profiling: if self._profiling:
self._profiler = CompGraphProfiler(cg) self._profiler = CompGraphProfiler(cg)
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core.device import get_device_count
class SublinearMemConfig:
r"""
Configuration for sublinear memory optimization.
:param thresh_nr_try: number of samples both for searching in linear space
and around current thresh in sublinear memory optimization. Default: 10.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'.
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm.
Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'.
:param genetic_pool_size: number of samples for the crossover random selection
during genetic optimization. Default: 20.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'.
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization.
It can be used to perform manual tradeoff between memory and speed. Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'.
:param num_worker: number of thread workers to search the optimum checkpoints
in sublinear memory optimization. Default: half of cpu number in the system.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'.
"""
def __init__(
self,
thresh_nr_try: int = 10,
genetic_nr_iter: int = 0,
genetic_pool_size: int = 20,
lb_memory: int = 0,
num_worker: int = get_device_count("cpu") / 2,
):
self.thresh_nr_try = thresh_nr_try
self.genetic_nr_iter = genetic_nr_iter
self.genetic_pool_size = genetic_pool_size
self.lb_memory = lb_memory
self.num_worker = num_worker
...@@ -42,7 +42,8 @@ bool _config::set_comp_graph_option( ...@@ -42,7 +42,8 @@ bool _config::set_comp_graph_option(
std::is_same<decltype(opt.name_chk), bool>::value || \ std::is_same<decltype(opt.name_chk), bool>::value || \
std::is_same<decltype(opt.name_chk), uint8_t>::value || \ std::is_same<decltype(opt.name_chk), uint8_t>::value || \
std::is_same<decltype(opt.name_chk), int16_t>::value || \ std::is_same<decltype(opt.name_chk), int16_t>::value || \
std::is_same<decltype(opt.name_chk), uint16_t>::value, \ std::is_same<decltype(opt.name_chk), uint16_t>::value || \
std::is_same<decltype(opt.name_chk), int32_t>::value, \
"not bool/int opt"); \ "not bool/int opt"); \
if (name == #name_chk) { \ if (name == #name_chk) { \
auto ret = opt.name_chk; \ auto ret = opt.name_chk; \
...@@ -66,6 +67,11 @@ bool _config::set_comp_graph_option( ...@@ -66,6 +67,11 @@ bool _config::set_comp_graph_option(
SET_CG_OPTION(allocate_static_mem_after_graph_compile); SET_CG_OPTION(allocate_static_mem_after_graph_compile);
SET_CG_OPTION(log_level); SET_CG_OPTION(log_level);
SET_CG_OPTION(enable_sublinear_memory_opt); SET_CG_OPTION(enable_sublinear_memory_opt);
SET_CG_OPTION(sublinear_mem_cofig.lb_memory);
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter);
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size);
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try);
SET_CG_OPTION(sublinear_mem_cofig.num_worker);
SET_CG_OPTION(enable_var_mem_defragment); SET_CG_OPTION(enable_var_mem_defragment);
SET_CG_OPTION(eager_evaluation); SET_CG_OPTION(eager_evaluation);
SET_CG_OPTION(enable_memory_swap); SET_CG_OPTION(enable_memory_swap);
......
...@@ -17,6 +17,7 @@ import megengine as mge ...@@ -17,6 +17,7 @@ import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import jit, tensor from megengine import jit, tensor
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemConfig
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -130,7 +131,14 @@ def update_model(model_path): ...@@ -130,7 +131,14 @@ def update_model(model_path):
mge.save(checkpoint, model_path) mge.save(checkpoint, model_path)
def run_test(model_path, use_jit, use_symbolic): def run_test(
model_path,
use_jit,
use_symbolic,
enable_sublinear=False,
sublinear_mem_config=None,
max_err=None,
):
""" """
Load the model with test cases and run the training for one iter. Load the model with test cases and run the training for one iter.
...@@ -152,11 +160,17 @@ def run_test(model_path, use_jit, use_symbolic): ...@@ -152,11 +160,17 @@ def run_test(model_path, use_jit, use_symbolic):
data.set_value(checkpoint["data"]) data.set_value(checkpoint["data"])
label.set_value(checkpoint["label"]) label.set_value(checkpoint["label"])
if max_err is None:
max_err = 1e-5 max_err = 1e-5
train_func = train train_func = train
if use_jit: if use_jit:
train_func = jit.trace(train_func, symbolic=use_symbolic) train_func = jit.trace(
train_func,
symbolic=use_symbolic,
enable_sublinear=enable_sublinear,
sublinear_mem_config=sublinear_mem_config,
)
opt.zero_grad() opt.zero_grad()
loss = train_func(data, label, net=net, opt=opt) loss = train_func(data, label, net=net, opt=opt)
...@@ -183,3 +197,14 @@ def test_correctness(): ...@@ -183,3 +197,14 @@ def test_correctness():
run_test(model_path, False, False) run_test(model_path, False, False)
run_test(model_path, True, False) run_test(model_path, True, False)
run_test(model_path, True, True) run_test(model_path, True, True)
# sublinear
config = SublinearMemConfig(genetic_nr_iter=10)
run_test(
model_path,
True,
True,
enable_sublinear=True,
sublinear_mem_config=config,
max_err=1e-5,
)
...@@ -18,6 +18,7 @@ import megengine._internal as mgb ...@@ -18,6 +18,7 @@ import megengine._internal as mgb
import megengine.module as M import megengine.module as M
from megengine import jit, tensor from megengine import jit, tensor
from megengine.core.tensor import Tensor from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemConfig
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -185,3 +186,14 @@ def test_dump_bn_fused(): ...@@ -185,3 +186,14 @@ def test_dump_bn_fused():
mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder"
and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward"
) )
# Simply verify the options passed down
def test_sublinear():
config = SublinearMemConfig(genetic_nr_iter=10)
@jit.trace(symbolic=True, enable_sublinear=True, sublinear_mem_config=config)
def f(x):
return x + x
f([0.0])
...@@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner) ...@@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
static_infer_comp_seq_manager{owner}, static_infer_comp_seq_manager{owner},
grad_manager{owner}, grad_manager{owner},
#if MGB_ENABLE_SUBLINEAR #if MGB_ENABLE_SUBLINEAR
seq_modifier_for_sublinear_memory{owner}, seq_modifier_for_sublinear_memory{owner,
&(owner->options().sublinear_mem_cofig)},
#endif #endif
#if MGB_ENABLE_MEMORY_SWAP #if MGB_ENABLE_MEMORY_SWAP
memory_swap_support{owner}, memory_swap_support{owner},
......
...@@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { ...@@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN {
std::vector<std::future<void>> m_futures; std::vector<std::future<void>> m_futures;
std::mutex m_mtx; std::mutex m_mtx;
struct Config {
size_t thresh_nr_try = 10;
size_t genetic_nr_iter = 0;
size_t genetic_pool_size = 20;
double lb_memory = 0;
};
Config m_config;
/*! /*!
* \brief check given thresh, and update states * \brief check given thresh, and update states
* \return bottleneck value for given thresh * \return bottleneck value for given thresh
...@@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { ...@@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN {
public: public:
ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) ActionSearcherSingleCN(SeqModifierForSublinearMemory* par)
: m_par_modifier{par} { : m_par_modifier{par} {
auto & m_config = m_par_modifier->m_config;
//! allow environmental variable to overwrite the setting
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) {
m_config.thresh_nr_try = std::stoi(env); m_config->thresh_nr_try = std::stoi(env);
} }
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) {
m_config.genetic_nr_iter = std::stoi(env); m_config->genetic_nr_iter = std::stoi(env);
} }
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) {
auto psize = static_cast<size_t>(std::stoi(env)); auto psize = static_cast<size_t>(std::stoi(env));
mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0, mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0,
"invalid pool size %zu in genetic algorithm,", psize); "invalid pool size %zu in genetic algorithm,", psize);
m_config.genetic_pool_size = psize; m_config->genetic_pool_size = psize;
} }
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) {
m_config.lb_memory = std::stod(env) * 1024 * 1024; m_config->lb_memory = std::stoi(env) * 1024 * 1024;
} }
} }
...@@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { ...@@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() {
invoke_search(thresh); invoke_search(thresh);
} }
size_t NR_TRY = m_config.thresh_nr_try; size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try;
// search in linear space // search in linear space
auto step = init_thresh / (NR_TRY + 1); auto step = init_thresh / (NR_TRY + 1);
...@@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { ...@@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() {
void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() {
RNGxorshf rng(2333); RNGxorshf rng(2333);
size_t POOL_SIZE = m_config.genetic_pool_size; size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size;
size_t NR_ITER = m_config.genetic_nr_iter; size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter;
auto mutation = [&](const SplitPointSet& sps) { auto mutation = [&](const SplitPointSet& sps) {
auto s = *sps; auto s = *sps;
size_t length = s.size(); size_t length = s.size();
...@@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { ...@@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() {
} }
void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() {
size_t lower_bound = m_config.lb_memory; size_t lower_bound = m_par_modifier->m_config->lb_memory;
if (m_min_bottleneck >= lower_bound) if (m_min_bottleneck >= lower_bound)
return; return;
OprFootprint footprint; OprFootprint footprint;
...@@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search( ...@@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search(
msg.push_back('\n'); msg.push_back('\n');
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", msg.append(ssprintf("m_min_bottleneck: %-10.2f\n",
m_min_bottleneck * SIZE2MB)); m_min_bottleneck * SIZE2MB));
if(!m_config.genetic_nr_iter) { if(!m_par_modifier->m_config->genetic_nr_iter) {
msg.append(ssprintf( msg.append(ssprintf(
"\nGenetic algorithm is currently DISABLED, " "\nGenetic algorithm is currently DISABLED, "
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" "set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]"
...@@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action( ...@@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action(
"invalid planner concurrency: %zu", set); "invalid planner concurrency: %zu", set);
planner_concur = set; planner_concur = set;
} else { } else {
planner_concur = sys::get_cpu_count() / 2; planner_concur = m_config->num_worker;
} }
mgb_log_debug("use %zu threads to search for sublinear memory plan; " mgb_log_debug("use %zu threads to search for sublinear memory plan; "
...@@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { ...@@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() {
} }
SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( SeqModifierForSublinearMemory::SeqModifierForSublinearMemory(
ComputingGraphImpl* owner) ComputingGraphImpl* owner, Config* config_p)
: m_mem_opt(owner), m_owner_graph(owner) {} : m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {}
#endif // !MGB_ENABLE_SUBLINEAR #endif // !MGB_ENABLE_SUBLINEAR
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "./memory_optimizer.h" #include "./memory_optimizer.h"
#include "megbrain/graph/cg.h"
#include "megbrain/utils/async_worker.h" #include "megbrain/utils/async_worker.h"
#if MGB_ENABLE_SUBLINEAR #if MGB_ENABLE_SUBLINEAR
...@@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory { ...@@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory {
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>;
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; using SplitPointSet = std::shared_ptr<std::vector<size_t>>;
//! Config options
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig;
Config* m_config;
//! get modifications to be taken under some specific constraints //! get modifications to be taken under some specific constraints
class ModifyActionPlanner; class ModifyActionPlanner;
...@@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory { ...@@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory {
} }
public: public:
SeqModifierForSublinearMemory(ComputingGraphImpl* owner); SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g);
//! see memory_optimizer set_priority_before_opt //! see memory_optimizer set_priority_before_opt
void set_priority_before_opt(const VarNodeArray& endpoints) { void set_priority_before_opt(const VarNodeArray& endpoints) {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "megbrain/graph/static_infer.h" #include "megbrain/graph/static_infer.h"
#include "megbrain/graph/seq_comp_node_opt.h" #include "megbrain/graph/seq_comp_node_opt.h"
#include "megbrain/utils/event.h" #include "megbrain/utils/event.h"
#include "megbrain/system.h"
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
#include "megbrain/utils/json.h" #include "megbrain/utils/json.h"
...@@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! whether to enable sublinear memory optimization //! whether to enable sublinear memory optimization
bool enable_sublinear_memory_opt = false; bool enable_sublinear_memory_opt = false;
//! Control parameter for sublinear memory optimization
struct SublinearMemConfig {
int thresh_nr_try = 10;
int genetic_nr_iter = 0;
int genetic_pool_size = 20;
int lb_memory = 0;
int num_worker = sys::get_cpu_count() / 2;
} sublinear_mem_cofig;
//! do not re-profile to select best impl algo when input shape //! do not re-profile to select best impl algo when input shape
//! changes (use previous algo) //! changes (use previous algo)
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;
......
...@@ -504,10 +504,6 @@ TEST(TestSublinearMemory, DepsInTopoSort) { ...@@ -504,10 +504,6 @@ TEST(TestSublinearMemory, DepsInTopoSort) {
} }
TEST(TestSublinearMemory, BadOpr) { TEST(TestSublinearMemory, BadOpr) {
constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER";
auto old_value = getenv(KEY);
setenv(KEY, "50", 1);
MGB_TRY {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0"); auto cn = CompNode::load("xpu0");
constexpr size_t N = 1024, Scale = 2; constexpr size_t N = 1024, Scale = 2;
...@@ -526,6 +522,7 @@ TEST(TestSublinearMemory, BadOpr) { ...@@ -526,6 +522,7 @@ TEST(TestSublinearMemory, BadOpr) {
set_priority(z, 3); set_priority(z, 3);
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
graph->options().enable_sublinear_memory_opt = 1; graph->options().enable_sublinear_memory_opt = 1;
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50;
auto func = graph->compile({{y, {}}, {z, {}}}); auto func = graph->compile({{y, {}}, {z, {}}});
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get())
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); ->seq_modifier_for_sublinear_memory().prev_min_bottleneck();
...@@ -548,13 +545,6 @@ TEST(TestSublinearMemory, BadOpr) { ...@@ -548,13 +545,6 @@ TEST(TestSublinearMemory, BadOpr) {
func->iter_opr_seq(count_up); func->iter_opr_seq(count_up);
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); ASSERT_EQ(nr_bad_opr, bad ? 2 : 3);
} }
} MGB_FINALLY(
if (old_value) {
setenv(KEY, old_value, 1);
} else {
unsetenv(KEY);
}
);
} }
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册