提交 d86ed426 编写于 作者: M Megvii Engine Team

fix(dtr): simulate the system stack to avoid stack overflow during recomputing

GitOrigin-RevId: cb73e62b19588c870f7ac5479fad2ec0e9a41d97
上级 c25125e3
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
import megengine.tensor as tensor
from megengine.autodiff import GradManager
from megengine.data import DataLoader, RandomSampler, transform
from megengine.data.dataset import CIFAR10
def _weights_init(m):
classname = m.__class__.__name__
if isinstance(m, M.Linear) or isinstance(m, M.Conv2d):
M.init.msra_normal_(m.weight)
mean = [125.3, 123.0, 113.9]
std = [63.0, 62.1, 66.7]
class BasicBlock(M.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = M.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = M.BatchNorm2d(planes)
self.conv2 = M.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = M.BatchNorm2d(planes)
self.shortcut = M.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = M.Sequential(
M.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
M.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(M.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 16
self.conv1 = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = M.Linear(64, num_classes)
self.apply(_weights_init)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return M.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = out.mean(3).mean(2)
out = self.linear(out)
return out
@pytest.mark.require_ngpu(1)
def test_dtr_resnet1202():
batch_size = 64
resnet1202 = ResNet(BasicBlock, [200, 200, 200])
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
gm = GradManager().attach(resnet1202.parameters())
def train_func(data, label, *, net, gm):
net.train()
with gm:
pred = net(data)
loss = F.loss.cross_entropy(pred, label)
gm.backward(loss)
return pred, loss
mge.dtr.enable()
data = np.random.randn(batch_size, 3, 32, 32).astype("float32")
label = np.random.randint(0, 10, size=(batch_size,)).astype("int32")
for step in range(10):
opt.clear_grad()
_, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm)
opt.step()
loss.item()
......@@ -615,13 +615,15 @@ void ChannelImpl::release_tensor(TensorInfo* dest) {
}
void ChannelImpl::regenerate(TensorInfo* dest) {
RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen);
if (dest->evict_type == EvictType::DROP) {
recompute(dest->producer);
auto &&path = dest->producer;
m_apply_stack.push({ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest});
if (!m_applying) flush_apply_stack();
} else if (dest->evict_type == EvictType::SWAP) {
RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen);
produce_tensor(dest, Tensor::make(dest->h_value));
RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen);
}
RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen);
}
void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
......@@ -635,17 +637,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
MemoryDesc desc;
};
SmallVector<TensorWithDesc> inputs;
// SmallVector<TensorPtr> tensor_inputs;
if (state.options.enable_dtr_auto_drop) {
m_dtr.pin(cmd.inputs);
}
for (auto i : cmd.inputs) {
if (!i->ptr && i->evict_type != EvictType::NONE) {
regenerate(i);
}
m_dtr.update_used_time(i);
}
// tensor_inputs.reserve(cmd.inputs.size());
inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) {
......@@ -781,20 +772,48 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
// End profiling operator
}
void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
void ChannelImpl::flush_apply_stack() {
m_applying = true;
auto& state = get_worker_state();
do_apply_op(ApplyOp{path->id, path->op, path->inputs, path->outputs, {}});
for (size_t i = 0;i < path->outputs.size();i ++) {
auto&& o = path->outputs[i];
if (o) {
o->recompute_times ++;
if (!o->ptr) {
if (state.options.enable_dtr_auto_drop) {
while (!m_apply_stack.empty()) {
auto& [cmd, idx, recomp] = m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
if (idx == 0) {
if (state.options.enable_dtr_auto_drop) {
m_dtr.pin(cmd.inputs);
}
if (recomp) {
RECORD_EVENT(TensorCommandEvent, recomp->id, TensorCommandEvent::ReGen);
}
}
bool regen = false;
for (size_t i = idx; i < cmd.inputs.size(); i ++) {
auto&& p = cmd.inputs[i];
if (state.options.enable_dtr_auto_drop) {
m_dtr.update_used_time(p);
}
if (!p->ptr && p->evict_type != EvictType::NONE) {
idx = i + 1;
regenerate(p); // add ApplyOp to the stack
regen = true;
break;
}
}
if (regen) continue;
// the required input tensors are already in memory
auto cmd_backup = cmd;
auto recomp_backup = recomp;
m_apply_stack.pop();
do_apply_op(cmd_backup);
if (recomp_backup) {
RECORD_EVENT(TensorCommandFinishEvent, recomp_backup->id, TensorCommandFinishEvent::ReGen);
for (auto o : cmd_backup.outputs) {
if (o) {
m_dtr.update_dsu_after_recompute(o);
}
}
}
}
m_applying = false;
}
bool ChannelImpl::auto_evict(size_t force_num) {
......@@ -997,7 +1016,8 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put);
sample_on_device(cmd.dest->desc.comp_node, false);
} else if constexpr (std::is_same_v<T, ApplyOp>) {
do_apply_op(cmd);
m_apply_stack.push({cmd, 0, nullptr});
flush_apply_stack();
for (size_t i = 0; i < cmd.outputs.size(); ++i) {
auto output = cmd.outputs[i];
if (output == nullptr) {
......
......@@ -14,10 +14,10 @@
#include <deque>
#include <future>
#include <list>
#include <stack>
#include <thread>
#include <unordered_set>
#include <variant>
#include "megbrain/comp_node.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"
......@@ -103,8 +103,8 @@ private:
void release_tensor(TensorInfo* dest);
void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path);
void do_apply_op(const ApplyOp& cmd);
void flush_apply_stack();
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> init_output_and_workspace(
const OpDef& def,
......@@ -149,7 +149,8 @@ private:
std::exception_ptr m_worker_exc;
std::function<void(std::string, std::string)> m_profile_dump_callback;
size_t m_storage_id = 0;
std::stack<std::tuple<ApplyOp, size_t, TensorInfo*>> m_apply_stack;
bool m_applying = false;
bool m_closed = false;
struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册