提交 d909950f 编写于 作者: M Megvii Engine Team

refactor(imperative): remove swap in dtr

GitOrigin-RevId: 5c9e42f74a758cf9394dd84b69861f47b1206e49
上级 bcf69d8f
...@@ -118,7 +118,6 @@ def enable_sqrt_sampling(mod, value: bool): ...@@ -118,7 +118,6 @@ def enable_sqrt_sampling(mod, value: bool):
def enable(): def enable():
r"""Enable to record computing path of tensors and to perform DTR policy.""" r"""Enable to record computing path of tensors and to perform DTR policy."""
_set_defrag(True)
_set_option("enable_dtr_auto_drop", 1) _set_option("enable_dtr_auto_drop", 1)
_set_option("enable_drop", 1) _set_option("enable_drop", 1)
_set_option("buffer_length", 0) _set_option("buffer_length", 0)
...@@ -127,7 +126,6 @@ def enable(): ...@@ -127,7 +126,6 @@ def enable():
def disable(): def disable():
r"""Stop recording computing path of tensors and performing DTR policy.""" r"""Stop recording computing path of tensors and performing DTR policy."""
_set_defrag(False)
_set_option("enable_dtr_auto_drop", 0) _set_option("enable_dtr_auto_drop", 0)
_set_option("enable_drop", 0) _set_option("enable_drop", 0)
_set_option("record_computing_path", 0) _set_option("record_computing_path", 0)
...@@ -605,14 +605,6 @@ PyObject* TensorWrapper::_dev_tensor() { ...@@ -605,14 +605,6 @@ PyObject* TensorWrapper::_dev_tensor() {
return py::cast(dev_tensor).release().ptr(); return py::cast(dev_tensor).release().ptr();
} }
void TensorWrapper::_swap_out() {
interpreter_for_py->swap_out(m_tensor->m_handle.get());
}
void TensorWrapper::_swap_in() {
interpreter_for_py->swap_in(m_tensor->m_handle.get());
}
void TensorWrapper::_drop() { void TensorWrapper::_drop() {
interpreter_for_py->drop(m_tensor->m_handle.get()); interpreter_for_py->drop(m_tensor->m_handle.get());
} }
...@@ -931,8 +923,6 @@ void init_tensor(py::module m) { ...@@ -931,8 +923,6 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::unsetscalar>("_unsetscalar") .def<&TensorWrapper::unsetscalar>("_unsetscalar")
.def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::detach>("detach")
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_swap_out>("_swap_out")
.def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def<&TensorWrapper::_use_cnt>("_use_cnt") .def<&TensorWrapper::_use_cnt>("_use_cnt")
...@@ -1032,8 +1022,6 @@ void init_tensor(py::module m) { ...@@ -1032,8 +1022,6 @@ void init_tensor(py::module m) {
}); });
m.def("get_option", m.def("get_option",
[](std::string name) { return interpreter_for_py->get_option(name); }); [](std::string name) { return interpreter_for_py->get_option(name); });
m.def("_set_swap_flag",
[](bool flag) { interpreter_for_py->set_option("enable_swap", flag); });
m.def("_set_drop_flag", m.def("_set_drop_flag",
[](bool flag) { interpreter_for_py->set_option("enable_drop", flag); }); [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); });
m.def("config_async_level", [](int level) { m.def("config_async_level", [](int level) {
......
...@@ -194,8 +194,6 @@ struct TensorWrapper { ...@@ -194,8 +194,6 @@ struct TensorWrapper {
void setscalar(); void setscalar();
void unsetscalar(); void unsetscalar();
PyObject* _dev_tensor(); PyObject* _dev_tensor();
void _swap_in();
void _swap_out();
void _drop(); void _drop();
PyObject* varnode(); PyObject* varnode();
void reset_varnode(); void reset_varnode();
......
...@@ -14,12 +14,7 @@ import megengine as mge ...@@ -14,12 +14,7 @@ import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor
from megengine.core._imperative_rt.core2 import ( from megengine.core._imperative_rt.core2 import _set_drop_flag, get_option, set_option
_set_drop_flag,
_set_swap_flag,
get_option,
set_option,
)
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
...@@ -70,7 +65,6 @@ class XORNet(Module): ...@@ -70,7 +65,6 @@ class XORNet(Module):
def forward(self, x): def forward(self, x):
y = self.fc0(x) y = self.fc0(x)
x._swap_out()
x = F.tanh(y) x = F.tanh(y)
y = self.fc1(x) y = self.fc1(x)
x = F.tanh(y) x = F.tanh(y)
...@@ -80,8 +74,7 @@ class XORNet(Module): ...@@ -80,8 +74,7 @@ class XORNet(Module):
return y return y
def test_training_converge_with_swap_and_drop(): def test_training_converge_with_drop():
_set_swap_flag(True)
_set_drop_flag(True) _set_drop_flag(True)
old_buffer_length = get_option("buffer_length") old_buffer_length = get_option("buffer_length")
set_option("buffer_length", 0) set_option("buffer_length", 0)
...@@ -125,6 +118,5 @@ def test_training_converge_with_swap_and_drop(): ...@@ -125,6 +118,5 @@ def test_training_converge_with_swap_and_drop():
precision precision
) )
_set_swap_flag(False)
_set_drop_flag(False) _set_drop_flag(False)
set_option("buffer_length", old_buffer_length) set_option("buffer_length", old_buffer_length)
...@@ -9,7 +9,6 @@ import megengine.functional as F ...@@ -9,7 +9,6 @@ import megengine.functional as F
from megengine.core._imperative_rt.core2 import ( from megengine.core._imperative_rt.core2 import (
AsyncError, AsyncError,
_set_drop_flag, _set_drop_flag,
_set_swap_flag,
config_async_level, config_async_level,
get_async_level, get_async_level,
) )
...@@ -61,24 +60,20 @@ def test_host_compute_elemwise(): ...@@ -61,24 +60,20 @@ def test_host_compute_elemwise():
d = F.reshape(a, c) d = F.reshape(a, c)
def test_swap_drop_basic(): def test_drop_basic():
_set_swap_flag(True)
_set_drop_flag(True) _set_drop_flag(True)
# test xpu compute # test xpu compute
x = mge.tensor(np.ones((3, 3)), dtype=np.float32) x = mge.tensor(np.ones((3, 3)), dtype=np.float32)
y = mge.tensor(np.ones((3, 3)), dtype=np.float32) y = mge.tensor(np.ones((3, 3)), dtype=np.float32)
z = x + y z = x + y
x._swap_out()
z._drop() z._drop()
z.numpy() z.numpy()
# test host value compute # test host value compute
x = mge.tensor(np.ones((2, 2)), dtype=np.float32) x = mge.tensor(np.ones((2, 2)), dtype=np.float32)
y = mge.tensor(np.ones((2, 2)), dtype=np.float32) y = mge.tensor(np.ones((2, 2)), dtype=np.float32)
z = x + y z = x + y
x._swap_out()
z._drop() z._drop()
z.numpy() z.numpy()
_set_swap_flag(False)
_set_drop_flag(False) _set_drop_flag(False)
......
...@@ -84,28 +84,6 @@ struct GetValue { ...@@ -84,28 +84,6 @@ struct GetValue {
const char* get_name() const { return "GetValue"; } const char* get_name() const { return "GetValue"; }
}; };
struct SwapIn {
TensorInfo* dest;
template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}
const char* get_name() const { return "SwapIn"; }
};
struct SwapOut {
TensorInfo* dest;
template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}
const char* get_name() const { return "SwapOut"; }
};
struct Drop { struct Drop {
TensorInfo* dest; TensorInfo* dest;
...@@ -171,8 +149,8 @@ struct PopScope { ...@@ -171,8 +149,8 @@ struct PopScope {
}; };
using CommandData = std::variant< using CommandData = std::variant<
Put, ApplyOp, Del, GetValue, SwapIn, SwapOut, Drop, SetOption, StartProfile, Put, ApplyOp, Del, GetValue, Drop, SetOption, StartProfile, StopProfile,
StopProfile, PushScope, PopScope>; PushScope, PopScope>;
struct Command { struct Command {
uint64_t id; uint64_t id;
......
...@@ -197,32 +197,6 @@ void ChannelImpl::del_impl(Handle handle) { ...@@ -197,32 +197,6 @@ void ChannelImpl::del_impl(Handle handle) {
m_buffer.enqueue(Del{info}); m_buffer.enqueue(Del{info});
} }
void ChannelImpl::swap_in(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
if (state.options.enable_swap) {
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(SwapIn{info});
}
}
void ChannelImpl::swap_out(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
if (state.options.enable_swap) {
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(SwapOut{info});
}
}
void ChannelImpl::drop(Handle handle) { void ChannelImpl::drop(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
...@@ -266,7 +240,7 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -266,7 +240,7 @@ void ChannelImpl::dispatch_default_cpu(
input_tensornds.emplace_back( input_tensornds.emplace_back(
info->ptr->get_value().proxy_to_default_cpu()); info->ptr->get_value().proxy_to_default_cpu());
} else { } else {
// It's OK for SwapOut. We assign h_value before drop ptr // We assign h_value before drop ptr
mgb_assert(!info->h_value.empty(), "inp->h_value is empty!"); mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
} }
...@@ -660,10 +634,6 @@ void ChannelImpl::regenerate(TensorInfo* dest) { ...@@ -660,10 +634,6 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
"dtr"}); "dtr"});
if (!m_applying) if (!m_applying)
flush_apply_stack(); flush_apply_stack();
} else if (dest->evict_type == EvictType::SWAP) {
MGB_RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandKind::ReGen);
produce_tensor(dest, Tensor::make(dest->h_value));
MGB_RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandKind::ReGen);
} }
} }
...@@ -1185,29 +1155,6 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1185,29 +1155,6 @@ void ChannelImpl::process_one_task(Command& icmd) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
notify_tensor_unsafe(cmd.dest); notify_tensor_unsafe(cmd.dest);
imperative_log_profile_end("GetValue"); imperative_log_profile_end("GetValue");
} else if constexpr (std::is_same_v<T, SwapIn>) {
if (cmd.dest->invalid)
return;
MGB_RECORD_EVENT(
TensorCommandEvent, cmd.dest->id, TensorCommandKind::SwapIn);
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
MGB_RECORD_EVENT(
TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::SwapIn);
sample_on_device(cmd.dest->desc.comp_node, false);
} else if constexpr (std::is_same_v<T, SwapOut>) {
if (cmd.dest->invalid)
return;
MGB_RECORD_EVENT(
TensorCommandEvent, cmd.dest->id, TensorCommandKind::SwapOut);
cmd.dest->h_value = cmd.dest->ptr->get_value();
if (cmd.dest->evict_type == EvictType::NONE) {
cmd.dest->evict_type = EvictType::SWAP;
cmd.dest->status = TensorInfo::Swapped;
release_tensor(cmd.dest);
}
MGB_RECORD_EVENT(
TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::SwapOut);
sample_on_device(cmd.dest->desc.comp_node, false);
} else if constexpr (std::is_same_v<T, Drop>) { } else if constexpr (std::is_same_v<T, Drop>) {
if (cmd.dest->invalid) if (cmd.dest->invalid)
return; return;
...@@ -1223,7 +1170,7 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1223,7 +1170,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
for (auto* info : cmd.capture_tensors) { for (auto* info : cmd.capture_tensors) {
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
if (info->status == TensorInfo::Produced) { if (info->status == TensorInfo::Produced) {
// TODO: handle swap/drop // TODO: handle drop
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
TensorProduceEvent, info->id, info->desc.layout, TensorProduceEvent, info->id, info->desc.layout,
info->desc.comp_node, info->ptr->dev_tensor().raw_ptr()); info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
...@@ -1387,9 +1334,7 @@ auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range) ...@@ -1387,9 +1334,7 @@ auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
if (cmd.dest == dest) { if (cmd.dest == dest) {
found = iter; found = iter;
} }
} else if constexpr ( } else if constexpr (std::is_same_v<T, Drop>) {
std::is_same_v<T, SwapIn> || std::is_same_v<T, SwapOut> ||
std::is_same_v<T, Drop>) {
// TODO: ignore swap-like commands, just remove them from buffer // TODO: ignore swap-like commands, just remove them from buffer
if (cmd.dest == dest) { if (cmd.dest == dest) {
found = iter; found = iter;
......
...@@ -46,8 +46,6 @@ struct ChannelImpl : Interpreter::Channel { ...@@ -46,8 +46,6 @@ struct ChannelImpl : Interpreter::Channel {
Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override; Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
void del(Handle) override; void del(Handle) override;
void swap_in(Handle) override;
void swap_out(Handle) override;
void drop(Handle) override; void drop(Handle) override;
SmallVector<Handle> apply_op( SmallVector<Handle> apply_op(
......
...@@ -35,7 +35,6 @@ public: ...@@ -35,7 +35,6 @@ public:
"level 2: both device and user side errors are async;\n" "level 2: both device and user side errors are async;\n"
"level 1: user side errors are sync;\n" "level 1: user side errors are sync;\n"
"level 0: both sync."); "level 0: both sync.");
DEF_OPTION(enable_swap, "MEGENGINE_ENABLE_SWAP", 0, "");
DEF_OPTION(enable_drop, "MEGENGINE_ENABLE_DROP", 0, ""); DEF_OPTION(enable_drop, "MEGENGINE_ENABLE_DROP", 0, "");
DEF_OPTION(max_recompute_time, "MEGENGINE_MAX_RECOMP_TIME", 1, ""); DEF_OPTION(max_recompute_time, "MEGENGINE_MAX_RECOMP_TIME", 1, "");
DEF_OPTION( DEF_OPTION(
......
...@@ -21,8 +21,7 @@ namespace interpreter::intl { ...@@ -21,8 +21,7 @@ namespace interpreter::intl {
enum EvictType { enum EvictType {
NONE = 0, NONE = 0,
SWAP = 1, DROP = 1,
DROP = 2,
}; };
/*! /*!
...@@ -49,7 +48,6 @@ struct TensorInfo { ...@@ -49,7 +48,6 @@ struct TensorInfo {
InvalidStatus, InvalidStatus,
Allocated, Allocated,
Produced, Produced,
Swapped,
Dropped, Dropped,
Deleted, Deleted,
}; };
...@@ -75,9 +73,7 @@ struct TensorInfo { ...@@ -75,9 +73,7 @@ struct TensorInfo {
// Status should be only modified in worker thread // Status should be only modified in worker thread
Status status = InvalidStatus; Status status = InvalidStatus;
// Used by HostCompute and Memory Swap. // Used by HostCompute
// HostCompute and Swap does not happen in one thread.
// Maybe a barrier is needed.
HostTensorND h_value; HostTensorND h_value;
// reserved for auto drop // reserved for auto drop
......
...@@ -232,10 +232,6 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> { ...@@ -232,10 +232,6 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> {
return "Drop"; return "Drop";
case TensorCommandKind::Del: case TensorCommandKind::Del:
return "Del"; return "Del";
case TensorCommandKind::SwapIn:
return "SwapIn";
case TensorCommandKind::SwapOut:
return "SwapOut";
case TensorCommandKind::RecFree: case TensorCommandKind::RecFree:
return "RecFree"; return "RecFree";
case TensorCommandKind::ReGen: case TensorCommandKind::ReGen:
......
...@@ -156,16 +156,7 @@ DEF_DUR_EVENT(StartProfile, { size_t capture_count; }); ...@@ -156,16 +156,7 @@ DEF_DUR_EVENT(StartProfile, { size_t capture_count; });
DEF_DUR_EVENT(StopProfile, { size_t escape_count; }); DEF_DUR_EVENT(StopProfile, { size_t escape_count; });
enum class TensorCommandKind { enum class TensorCommandKind { Put, Del, Drop, ReGen, RecFree, GetValue };
Put,
Del,
SwapIn,
SwapOut,
Drop,
ReGen,
RecFree,
GetValue
};
DEF_DUR_EVENT(TensorCommand, { DEF_DUR_EVENT(TensorCommand, {
using Kind = TensorCommandKind; using Kind = TensorCommandKind;
......
...@@ -39,8 +39,6 @@ struct Interpreter { ...@@ -39,8 +39,6 @@ struct Interpreter {
virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0; virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0;
virtual void del(Handle) = 0; virtual void del(Handle) = 0;
virtual void swap_in(Handle) = 0;
virtual void swap_out(Handle) = 0;
virtual void drop(Handle) = 0; virtual void drop(Handle) = 0;
virtual SmallVector<Handle> apply_op( virtual SmallVector<Handle> apply_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册