提交 a5af35c1 编写于 作者: M Megvii Engine Team

refactor(imperative): remove command buffer

GitOrigin-RevId: 83c8cb6d3bed9b44b0424965fc7c4938b0ae5841
上级 bdb853ee
......@@ -120,7 +120,6 @@ def enable():
r"""Enable to record computing path of tensors and to perform DTR policy."""
_set_option("enable_dtr_auto_drop", 1)
_set_option("enable_drop", 1)
_set_option("buffer_length", 0)
_set_option("record_computing_path", 1)
......
......@@ -702,10 +702,6 @@ void init_tensor(py::module m) {
});
m.def("get_option",
[channel](std::string name) { return channel->get_option(name); });
m.def("set_buffer_length", [channel](int length) {
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
channel->set_option("buffer_length", length);
});
m.def("push_scope", [channel](std::string name) {
Transformation::push_scope(name);
channel->push_scope(name);
......
......@@ -76,8 +76,6 @@ class XORNet(Module):
def test_training_converge_with_drop():
set_option("enable_drop", 1)
old_buffer_length = get_option("buffer_length")
set_option("buffer_length", 0)
net = XORNet()
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters())
......@@ -119,4 +117,3 @@ def test_training_converge_with_drop():
)
set_option("enable_drop", 0)
set_option("buffer_length", old_buffer_length)
......@@ -9,6 +9,7 @@
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
from megengine import Tensor, jit, random
from megengine.core._imperative_rt import CompNode
......@@ -209,9 +210,12 @@ def test_permutation_op():
assert str(output.device) == str(cn)
assert output.dtype == dtype
# FIXME: remove this sync
mge.core.set_option("async_level", 0)
test_permutation_op_dtype(np.float32)
test_permutation_op_dtype(np.int32)
test_permutation_op_dtype(np.int16)
mge.core.set_option("async_level", 2)
@pytest.mark.skipif(
......
......@@ -49,14 +49,12 @@ struct ApplyOp {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<TensorInfo*> dels;
template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("op", op);
functor("inputs", inputs);
functor("outputs", outputs);
functor("dels", dels);
}
const char* get_name() const { return "ApplyOp"; }
......
......@@ -156,7 +156,9 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->desc.value = value.proxy_to_default_cpu();
}
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
m_buffer.enqueue(Put{info, value, no_cache});
m_worker.add_task(
{Profiler::next_id(), Put{info, value, no_cache},
get_channel_state().stack_manager.dump()});
if (m_async_level == 0) {
sync_impl();
info->desc.comp_node.sync();
......@@ -200,7 +202,8 @@ void ChannelImpl::del_impl(Handle handle) {
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle);
m_buffer.enqueue(Del{info});
m_worker.add_task(
{Profiler::next_id(), Del{info}, get_channel_state().stack_manager.dump()});
}
void ChannelImpl::drop(Handle handle) {
......@@ -212,7 +215,9 @@ void ChannelImpl::drop(Handle handle) {
m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(Drop{info});
m_worker.add_task(
{Profiler::next_id(), Drop{info},
get_channel_state().stack_manager.dump()});
}
}
......@@ -333,7 +338,9 @@ void ChannelImpl::dispatch_kernel(
MGB_RECORD_EVENT(
OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
m_buffer.enqueue(std::move(cmd));
m_worker.add_task(
{Profiler::next_id(), std::move(cmd),
get_channel_state().stack_manager.dump()});
if (!validated && options.async_level == 1) {
sync_impl();
} else if (options.async_level == 0) {
......@@ -466,7 +473,6 @@ void ChannelImpl::sync() {
}
void ChannelImpl::sync_impl() {
m_buffer.flush();
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe();
......@@ -499,7 +505,9 @@ void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value});
m_worker.add_task(
{Profiler::next_id(), SetOption{name, value},
get_channel_state().stack_manager.dump()});
}
void ChannelImpl::clear_candidates() {
......@@ -604,7 +612,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr);
}
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this) {}
ChannelImpl::ChannelImpl() : m_worker(this) {}
ChannelImpl::~ChannelImpl() {
close();
......@@ -645,7 +653,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == EvictType::DROP) {
auto&& path = dest->producer;
m_apply_stack.push(
{ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest,
{ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
"dtr"});
if (!m_applying)
flush_apply_stack();
......@@ -748,19 +756,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
MGB_RECORD_EVENT(TensorUsageEvent, input_id);
MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
}
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by
// tensor_inputs after Del. Note for exprs like 'y = x op x', inplace is unsupported
// yet but Del would be also fused.
for (auto* del : cmd.dels) {
// refcnt --, owners: [tensor_inputs]
// if it's decreased to 1, would be detected at @see:
// proxy_graph_detail::apply_on_physical_tensor
uint64_t del_id = del->id;
MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
free(del);
MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
}
// Before wait
// TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute
......@@ -931,7 +926,6 @@ bool ChannelImpl::check_available() {
}
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
m_buffer.flush();
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info;
......@@ -943,8 +937,9 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if (require_host && !host_available()) {
// avoid dead lock
lock.unlock();
m_buffer.enqueue(GetValue{info});
m_buffer.flush();
m_worker.add_task(
{Profiler::next_id(), GetValue{info},
get_channel_state().stack_manager.dump()});
lock.lock();
wait_host = true;
}
......@@ -1266,141 +1261,25 @@ void ChannelImpl::check_worker_exc_unsafe() {
}
}
void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
auto& state = m_owner->get_channel_state();
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
m_commands.push_back(
{Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos);
}
void ChannelImpl::CommandBuffer::flush() {
flush(m_commands.end());
}
void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
if (Profiler::is_profiling()) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
}
m_owner->m_worker.add_task(std::move(*iter));
}
m_commands.erase(m_commands.begin(), pos);
}
auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
auto& state = m_owner->get_channel_state();
return std::visit(
[this, &state](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
auto* op_type = cmd.op->dyn_typeinfo();
if (op_type == RemoteRecv::typeinfo() ||
op_type == RemoteSend::typeinfo() ||
op_type == CollectiveComm::typeinfo() ||
op_type == opr::InputCallback::typeinfo() ||
op_type == opr::OutputCallback::typeinfo()) {
return m_commands.end();
}
} else if constexpr (std::is_same_v<T, GetValue>) {
return m_commands.end();
}
size_t buffer_length = state.options.buffer_length;
if (m_commands.size() > buffer_length) {
return m_commands.begin() + (m_commands.size() - buffer_length);
}
return m_commands.begin();
},
cmd.data);
}
/**
* 1. Find ApplyOp(dest) in buffered commands
* 2. Check if there are other usages between ApplyOp and Del, return false if not
* 3. Fuse Del into ApplyOp, return true
*/
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
auto* dest = cmd.dest;
// TODO: eliminate Puts
auto begin = m_commands.begin(), end = m_commands.end();
auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd) {
if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
}
return false;
});
if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) {
return false;
}
std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
return true;
}
auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
-> Handle {
auto found = range[1];
for (auto iter = range[0]; iter != range[1]; ++iter) {
std::visit(
[&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
if (std::count(cmd.inputs.begin(), cmd.inputs.end(), dest) >
0) {
found = iter;
}
} else if constexpr (std::is_same_v<T, GetValue>) {
if (cmd.dest == dest) {
found = iter;
}
} else if constexpr (std::is_same_v<T, Drop>) {
// TODO: ignore swap-like commands, just remove them from buffer
if (cmd.dest == dest) {
found = iter;
}
}
},
iter->data);
};
return found;
}
auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) -> Handle {
return std::find_if(range[0], range[1], [dest](auto& cmd) {
return std::visit(
[dest](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
return std::count(
cmd.outputs.begin(), cmd.outputs.end(), dest) >
0;
} else if constexpr (std::is_same_v<T, Put>) {
return cmd.dest == dest;
}
return false;
},
cmd.data);
});
}
void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) {
m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
m_worker.add_task(
{Profiler::next_id(), StartProfile{std::move(capture_tensors)},
get_channel_state().stack_manager.dump()});
}
}
void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
auto escape_tensors = collect_valid_tensors();
if (escape_tensors.size() > 0) {
m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
m_worker.add_task(
{Profiler::next_id(), StopProfile{std::move(escape_tensors)},
get_channel_state().stack_manager.dump()});
}
}
......@@ -1410,7 +1289,9 @@ void ChannelImpl::push_scope(std::string name) {
auto& state = get_channel_state();
state.stack_manager.enter(name);
MGB_RECORD_EVENT(ScopeEvent, name);
m_buffer.enqueue(PushScope{name});
m_worker.add_task(
{Profiler::next_id(), PushScope{name},
get_channel_state().stack_manager.dump()});
}
void ChannelImpl::pop_scope(std::string name) {
......@@ -1419,7 +1300,9 @@ void ChannelImpl::pop_scope(std::string name) {
auto& state = get_channel_state();
state.stack_manager.exit(name);
MGB_RECORD_EVENT(ScopeFinishEvent, name);
m_buffer.enqueue(PopScope{name});
m_worker.add_task(
{Profiler::next_id(), PopScope{name},
get_channel_state().stack_manager.dump()});
}
void ChannelImpl::assert_in_channel() {
......
......@@ -126,11 +126,6 @@ private:
void assert_in_worker();
std::thread::id get_worker_tid();
// template <typename TCommand>
// void enqueue_command(TCommand&& cmd) {
// m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
// }
void sample_on_device(CompNode device, bool force);
// valid => status != Deleted
......@@ -178,46 +173,6 @@ private:
ChannelImpl* m_owner;
} m_worker;
/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see:
* ChannelImpl::process_one_task
*/
struct CommandBuffer {
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {}
void enqueue(CommandData cmd);
bool empty() const { return m_commands.empty(); }
void flush();
private:
ChannelImpl* m_owner;
std::deque<Command> m_commands;
using Handle = decltype(m_commands)::iterator;
// [begin, end)
using Range = std::array<Handle, 2>;
// Launch commands in range [m_commands.begin(), pos)
void flush(Handle pos);
// Select flush position for incoming cmd
Handle flush_pos_for(const Command& cmd);
// Fuse del command into suitable ApplyOp
bool fuse_del(const Del& cmd);
// Returns the last handle that dest is used within range. If dest is not used,
// returns range[1]
Handle find_last_usage(TensorInfo* dest, Range range);
// Returns the produce position of dest. If not found, returns range[1]
Handle find_produce(TensorInfo* dest, Range range);
} m_buffer;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
......
......@@ -40,9 +40,6 @@ public:
DEF_OPTION(
catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1,
"catch worker exception if enabled, close it when debugging");
DEF_OPTION(
buffer_length, "MEGENGINE_COMMAND_BUFFER_LENGTH", 3,
"set command buffer length.");
DEF_OPTION(
enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's "
......
......@@ -626,23 +626,12 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
/*********************** Logical Tensor Impl ***********************/
size_t ProxyGraph::get_opr_output_size(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
return get_proxy_opr(opdef, inputs)->usable_output().size();
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::
infer_output_attrs_fallible(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
auto opr = get_proxy_opr(opdef, inputs);
CUR_OPR_GUARD(opr);
SmallVector<LogicalTensorDesc> outputs;
bool validated = do_shape_infer(false);
for (auto&& i : opr->usable_output()) {
outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
}
bool need_check = opr->same_type<opr::Reshape>();
return {outputs, validated && !need_check};
// this function is just a placeholder
// it will be overrided by ProxyGraphTypeI::infer_output_attrs_fallible in minigraph
mgb_assert(0);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
......@@ -823,12 +812,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
return result;
}
cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(!m_cur_opr);
auto vinputs = make_input_place_holders(inputs);
return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
}
VarNodeArray ProxyGraph::make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs) {
......
......@@ -85,9 +85,6 @@ private:
/********************** Logical Tensor Helper **********************/
cg::OperatorNodeBase* get_proxy_opr(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs);
cg::VarNodeArray make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册