From d86ed426ee2def8f41697844bceb0727096c869c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 16 Aug 2021 18:00:53 +0800 Subject: [PATCH] fix(dtr): simulate the system stack to avoid stack overflow during recomputing GitOrigin-RevId: cb73e62b19588c870f7ac5479fad2ec0e9a41d97 --- .../python/test/integration/test_dtr.py | 114 ++++++++++++++++++ .../src/impl/interpreter/interpreter_impl.cpp | 66 ++++++---- .../src/impl/interpreter/interpreter_impl.h | 7 +- 3 files changed, 161 insertions(+), 26 deletions(-) create mode 100644 imperative/python/test/integration/test_dtr.py diff --git a/imperative/python/test/integration/test_dtr.py b/imperative/python/test/integration/test_dtr.py new file mode 100644 index 000000000..25ad054f4 --- /dev/null +++ b/imperative/python/test/integration/test_dtr.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.optimizer as optim +import megengine.tensor as tensor +from megengine.autodiff import GradManager +from megengine.data import DataLoader, RandomSampler, transform +from megengine.data.dataset import CIFAR10 + + +def _weights_init(m): + classname = m.__class__.__name__ + if isinstance(m, M.Linear) or isinstance(m, M.Conv2d): + M.init.msra_normal_(m.weight) + + +mean = [125.3, 123.0, 113.9] +std = [63.0, 62.1, 66.7] + + +class BasicBlock(M.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = M.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = M.BatchNorm2d(planes) + self.conv2 = M.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = M.BatchNorm2d(planes) + self.shortcut = M.Sequential() + if stride != 1 or in_planes != planes: + self.shortcut = M.Sequential( + M.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + M.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(M.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 16 + + self.conv1 = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = M.BatchNorm2d(16) + self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + self.linear = M.Linear(64, num_classes) + + self.apply(_weights_init) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + + return M.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = out.mean(3).mean(2) + out = self.linear(out) + return out + + +@pytest.mark.require_ngpu(1) +def test_dtr_resnet1202(): + batch_size = 64 + resnet1202 = ResNet(BasicBlock, [200, 200, 200]) + opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) + gm = GradManager().attach(resnet1202.parameters()) + + def train_func(data, label, *, net, gm): + net.train() + with gm: + pred = net(data) + loss = F.loss.cross_entropy(pred, label) + gm.backward(loss) + return pred, loss + + mge.dtr.enable() + + data = np.random.randn(batch_size, 3, 32, 32).astype("float32") + label = np.random.randint(0, 10, size=(batch_size,)).astype("int32") + for step in range(10): + opt.clear_grad() + _, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm) + opt.step() + loss.item() diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 55f401fa6..d57430fed 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -615,13 +615,15 @@ void ChannelImpl::release_tensor(TensorInfo* dest) { } void ChannelImpl::regenerate(TensorInfo* dest) { - RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen); if (dest->evict_type == EvictType::DROP) { - recompute(dest->producer); + auto &&path = dest->producer; + m_apply_stack.push({ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest}); + if (!m_applying) flush_apply_stack(); } else if (dest->evict_type == EvictType::SWAP) { + RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen); produce_tensor(dest, Tensor::make(dest->h_value)); + RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen); } - RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen); } void ChannelImpl::do_apply_op(const ApplyOp& cmd) { @@ -635,17 +637,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { MemoryDesc desc; }; SmallVector inputs; - // SmallVector tensor_inputs; - if (state.options.enable_dtr_auto_drop) { - m_dtr.pin(cmd.inputs); - } - for (auto i : cmd.inputs) { - if (!i->ptr && i->evict_type != EvictType::NONE) { - regenerate(i); - } - m_dtr.update_used_time(i); - } - // tensor_inputs.reserve(cmd.inputs.size()); inputs.reserve(cmd.inputs.size()); // refcnt == 1, owners: [TensorInfo::ptr] for (auto i : cmd.inputs) { @@ -781,20 +772,48 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { // End profiling operator } -void ChannelImpl::recompute(TensorInfo::ComputePath* path) { +void ChannelImpl::flush_apply_stack() { + m_applying = true; auto& state = get_worker_state(); - do_apply_op(ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}); - for (size_t i = 0;i < path->outputs.size();i ++) { - auto&& o = path->outputs[i]; - if (o) { - o->recompute_times ++; - if (!o->ptr) { - if (state.options.enable_dtr_auto_drop) { + while (!m_apply_stack.empty()) { + auto& [cmd, idx, recomp] = m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory + if (idx == 0) { + if (state.options.enable_dtr_auto_drop) { + m_dtr.pin(cmd.inputs); + } + if (recomp) { + RECORD_EVENT(TensorCommandEvent, recomp->id, TensorCommandEvent::ReGen); + } + } + bool regen = false; + for (size_t i = idx; i < cmd.inputs.size(); i ++) { + auto&& p = cmd.inputs[i]; + if (state.options.enable_dtr_auto_drop) { + m_dtr.update_used_time(p); + } + if (!p->ptr && p->evict_type != EvictType::NONE) { + idx = i + 1; + regenerate(p); // add ApplyOp to the stack + regen = true; + break; + } + } + if (regen) continue; + // the required input tensors are already in memory + auto cmd_backup = cmd; + auto recomp_backup = recomp; + m_apply_stack.pop(); + do_apply_op(cmd_backup); + if (recomp_backup) { + RECORD_EVENT(TensorCommandFinishEvent, recomp_backup->id, TensorCommandFinishEvent::ReGen); + for (auto o : cmd_backup.outputs) { + if (o) { m_dtr.update_dsu_after_recompute(o); } } } } + m_applying = false; } bool ChannelImpl::auto_evict(size_t force_num) { @@ -997,7 +1016,8 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put); sample_on_device(cmd.dest->desc.comp_node, false); } else if constexpr (std::is_same_v) { - do_apply_op(cmd); + m_apply_stack.push({cmd, 0, nullptr}); + flush_apply_stack(); for (size_t i = 0; i < cmd.outputs.size(); ++i) { auto output = cmd.outputs[i]; if (output == nullptr) { diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 1396c753f..bea3cfeb7 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -14,10 +14,10 @@ #include #include #include +#include #include #include #include - #include "megbrain/comp_node.h" #include "megbrain/utils/mempool.h" #include "megbrain/imperative/interpreter.h" @@ -103,8 +103,8 @@ private: void release_tensor(TensorInfo* dest); void regenerate(TensorInfo* dest); - void recompute(TensorInfo::ComputePath* path); void do_apply_op(const ApplyOp& cmd); + void flush_apply_stack(); std::tuple, SmallVector, SmallVector> init_output_and_workspace( const OpDef& def, @@ -149,7 +149,8 @@ private: std::exception_ptr m_worker_exc; std::function m_profile_dump_callback; size_t m_storage_id = 0; - + std::stack> m_apply_stack; + bool m_applying = false; bool m_closed = false; struct WorkQueue : AsyncQueueSC { -- GitLab