提交 a5af35c1 编写于 作者: M Megvii Engine Team

refactor(imperative): remove command buffer

GitOrigin-RevId: 83c8cb6d3bed9b44b0424965fc7c4938b0ae5841
上级 bdb853ee
...@@ -120,7 +120,6 @@ def enable(): ...@@ -120,7 +120,6 @@ def enable():
r"""Enable to record computing path of tensors and to perform DTR policy.""" r"""Enable to record computing path of tensors and to perform DTR policy."""
_set_option("enable_dtr_auto_drop", 1) _set_option("enable_dtr_auto_drop", 1)
_set_option("enable_drop", 1) _set_option("enable_drop", 1)
_set_option("buffer_length", 0)
_set_option("record_computing_path", 1) _set_option("record_computing_path", 1)
......
...@@ -702,10 +702,6 @@ void init_tensor(py::module m) { ...@@ -702,10 +702,6 @@ void init_tensor(py::module m) {
}); });
m.def("get_option", m.def("get_option",
[channel](std::string name) { return channel->get_option(name); }); [channel](std::string name) { return channel->get_option(name); });
m.def("set_buffer_length", [channel](int length) {
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
channel->set_option("buffer_length", length);
});
m.def("push_scope", [channel](std::string name) { m.def("push_scope", [channel](std::string name) {
Transformation::push_scope(name); Transformation::push_scope(name);
channel->push_scope(name); channel->push_scope(name);
......
...@@ -76,8 +76,6 @@ class XORNet(Module): ...@@ -76,8 +76,6 @@ class XORNet(Module):
def test_training_converge_with_drop(): def test_training_converge_with_drop():
set_option("enable_drop", 1) set_option("enable_drop", 1)
old_buffer_length = get_option("buffer_length")
set_option("buffer_length", 0)
net = XORNet() net = XORNet()
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters()) gm = ad.GradManager().attach(net.parameters())
...@@ -119,4 +117,3 @@ def test_training_converge_with_drop(): ...@@ -119,4 +117,3 @@ def test_training_converge_with_drop():
) )
set_option("enable_drop", 0) set_option("enable_drop", 0)
set_option("buffer_length", old_buffer_length)
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
import pytest import pytest
import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Tensor, jit, random from megengine import Tensor, jit, random
from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt import CompNode
...@@ -209,9 +210,12 @@ def test_permutation_op(): ...@@ -209,9 +210,12 @@ def test_permutation_op():
assert str(output.device) == str(cn) assert str(output.device) == str(cn)
assert output.dtype == dtype assert output.dtype == dtype
# FIXME: remove this sync
mge.core.set_option("async_level", 0)
test_permutation_op_dtype(np.float32) test_permutation_op_dtype(np.float32)
test_permutation_op_dtype(np.int32) test_permutation_op_dtype(np.int32)
test_permutation_op_dtype(np.int16) test_permutation_op_dtype(np.int16)
mge.core.set_option("async_level", 2)
@pytest.mark.skipif( @pytest.mark.skipif(
......
...@@ -49,14 +49,12 @@ struct ApplyOp { ...@@ -49,14 +49,12 @@ struct ApplyOp {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs; SmallVector<TensorInfo*> outputs;
SmallVector<TensorInfo*> dels;
template <typename TFunctor> template <typename TFunctor>
void get_props(TFunctor&& functor) const { void get_props(TFunctor&& functor) const {
functor("op", op); functor("op", op);
functor("inputs", inputs); functor("inputs", inputs);
functor("outputs", outputs); functor("outputs", outputs);
functor("dels", dels);
} }
const char* get_name() const { return "ApplyOp"; } const char* get_name() const { return "ApplyOp"; }
......
...@@ -156,7 +156,9 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { ...@@ -156,7 +156,9 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->desc.value = value.proxy_to_default_cpu(); info->desc.value = value.proxy_to_default_cpu();
} }
info->mem_desc.id = StorageIdentifier::make(++m_storage_id); info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
m_buffer.enqueue(Put{info, value, no_cache}); m_worker.add_task(
{Profiler::next_id(), Put{info, value, no_cache},
get_channel_state().stack_manager.dump()});
if (m_async_level == 0) { if (m_async_level == 0) {
sync_impl(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
...@@ -200,7 +202,8 @@ void ChannelImpl::del_impl(Handle handle) { ...@@ -200,7 +202,8 @@ void ChannelImpl::del_impl(Handle handle) {
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle); m_valid_handle.erase(handle);
m_buffer.enqueue(Del{info}); m_worker.add_task(
{Profiler::next_id(), Del{info}, get_channel_state().stack_manager.dump()});
} }
void ChannelImpl::drop(Handle handle) { void ChannelImpl::drop(Handle handle) {
...@@ -212,7 +215,9 @@ void ChannelImpl::drop(Handle handle) { ...@@ -212,7 +215,9 @@ void ChannelImpl::drop(Handle handle) {
m_valid_handle.find(handle) != m_valid_handle.end(), m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(Drop{info}); m_worker.add_task(
{Profiler::next_id(), Drop{info},
get_channel_state().stack_manager.dump()});
} }
} }
...@@ -333,7 +338,9 @@ void ChannelImpl::dispatch_kernel( ...@@ -333,7 +338,9 @@ void ChannelImpl::dispatch_kernel(
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs), OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
tinfo_to_tid(cmd.outputs), state.stack_manager.dump()); tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
m_buffer.enqueue(std::move(cmd)); m_worker.add_task(
{Profiler::next_id(), std::move(cmd),
get_channel_state().stack_manager.dump()});
if (!validated && options.async_level == 1) { if (!validated && options.async_level == 1) {
sync_impl(); sync_impl();
} else if (options.async_level == 0) { } else if (options.async_level == 0) {
...@@ -466,7 +473,6 @@ void ChannelImpl::sync() { ...@@ -466,7 +473,6 @@ void ChannelImpl::sync() {
} }
void ChannelImpl::sync_impl() { void ChannelImpl::sync_impl() {
m_buffer.flush();
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe(); check_worker_exc_unsafe();
...@@ -499,7 +505,9 @@ void ChannelImpl::set_option(std::string name, size_t value) { ...@@ -499,7 +505,9 @@ void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
state.options.set_option(name, value); state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value}); m_worker.add_task(
{Profiler::next_id(), SetOption{name, value},
get_channel_state().stack_manager.dump()});
} }
void ChannelImpl::clear_candidates() { void ChannelImpl::clear_candidates() {
...@@ -604,7 +612,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) { ...@@ -604,7 +612,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr); m_pool.free(ptr);
} }
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this) {} ChannelImpl::ChannelImpl() : m_worker(this) {}
ChannelImpl::~ChannelImpl() { ChannelImpl::~ChannelImpl() {
close(); close();
...@@ -645,7 +653,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) { ...@@ -645,7 +653,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == EvictType::DROP) { if (dest->evict_type == EvictType::DROP) {
auto&& path = dest->producer; auto&& path = dest->producer;
m_apply_stack.push( m_apply_stack.push(
{ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest, {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
"dtr"}); "dtr"});
if (!m_applying) if (!m_applying)
flush_apply_stack(); flush_apply_stack();
...@@ -748,19 +756,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -748,19 +756,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
MGB_RECORD_EVENT(TensorUsageEvent, input_id); MGB_RECORD_EVENT(TensorUsageEvent, input_id);
MGB_RECORD_EVENT(OpInputFinishEvent, input_id); MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
} }
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by
// tensor_inputs after Del. Note for exprs like 'y = x op x', inplace is unsupported
// yet but Del would be also fused.
for (auto* del : cmd.dels) {
// refcnt --, owners: [tensor_inputs]
// if it's decreased to 1, would be detected at @see:
// proxy_graph_detail::apply_on_physical_tensor
uint64_t del_id = del->id;
MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
free(del);
MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
}
// Before wait // Before wait
// TODO: split operator wait and execute so that OpWait could be corrected recorded. // TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute // Before execute
...@@ -931,7 +926,6 @@ bool ChannelImpl::check_available() { ...@@ -931,7 +926,6 @@ bool ChannelImpl::check_available() {
} }
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
m_buffer.flush();
std::unique_lock<decltype(m_mutex)> lock(m_mutex); std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee, "duplicate waitee"); mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info; m_waitee = info;
...@@ -943,8 +937,9 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -943,8 +937,9 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if (require_host && !host_available()) { if (require_host && !host_available()) {
// avoid dead lock // avoid dead lock
lock.unlock(); lock.unlock();
m_buffer.enqueue(GetValue{info}); m_worker.add_task(
m_buffer.flush(); {Profiler::next_id(), GetValue{info},
get_channel_state().stack_manager.dump()});
lock.lock(); lock.lock();
wait_host = true; wait_host = true;
} }
...@@ -1266,141 +1261,25 @@ void ChannelImpl::check_worker_exc_unsafe() { ...@@ -1266,141 +1261,25 @@ void ChannelImpl::check_worker_exc_unsafe() {
} }
} }
void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
auto& state = m_owner->get_channel_state();
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
m_commands.push_back(
{Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos);
}
void ChannelImpl::CommandBuffer::flush() {
flush(m_commands.end());
}
void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
if (Profiler::is_profiling()) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
}
m_owner->m_worker.add_task(std::move(*iter));
}
m_commands.erase(m_commands.begin(), pos);
}
auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
auto& state = m_owner->get_channel_state();
return std::visit(
[this, &state](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
auto* op_type = cmd.op->dyn_typeinfo();
if (op_type == RemoteRecv::typeinfo() ||
op_type == RemoteSend::typeinfo() ||
op_type == CollectiveComm::typeinfo() ||
op_type == opr::InputCallback::typeinfo() ||
op_type == opr::OutputCallback::typeinfo()) {
return m_commands.end();
}
} else if constexpr (std::is_same_v<T, GetValue>) {
return m_commands.end();
}
size_t buffer_length = state.options.buffer_length;
if (m_commands.size() > buffer_length) {
return m_commands.begin() + (m_commands.size() - buffer_length);
}
return m_commands.begin();
},
cmd.data);
}
/**
* 1. Find ApplyOp(dest) in buffered commands
* 2. Check if there are other usages between ApplyOp and Del, return false if not
* 3. Fuse Del into ApplyOp, return true
*/
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
auto* dest = cmd.dest;
// TODO: eliminate Puts
auto begin = m_commands.begin(), end = m_commands.end();
auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd) {
if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
}
return false;
});
if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) {
return false;
}
std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
return true;
}
auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
-> Handle {
auto found = range[1];
for (auto iter = range[0]; iter != range[1]; ++iter) {
std::visit(
[&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
if (std::count(cmd.inputs.begin(), cmd.inputs.end(), dest) >
0) {
found = iter;
}
} else if constexpr (std::is_same_v<T, GetValue>) {
if (cmd.dest == dest) {
found = iter;
}
} else if constexpr (std::is_same_v<T, Drop>) {
// TODO: ignore swap-like commands, just remove them from buffer
if (cmd.dest == dest) {
found = iter;
}
}
},
iter->data);
};
return found;
}
auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) -> Handle {
return std::find_if(range[0], range[1], [dest](auto& cmd) {
return std::visit(
[dest](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
return std::count(
cmd.outputs.begin(), cmd.outputs.end(), dest) >
0;
} else if constexpr (std::is_same_v<T, Put>) {
return cmd.dest == dest;
}
return false;
},
cmd.data);
});
}
void ChannelImpl::start_profile() { void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto capture_tensors = collect_valid_tensors(); auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) { if (capture_tensors.size() > 0) {
m_buffer.enqueue(StartProfile{std::move(capture_tensors)}); m_worker.add_task(
{Profiler::next_id(), StartProfile{std::move(capture_tensors)},
get_channel_state().stack_manager.dump()});
} }
} }
void ChannelImpl::stop_profile() { void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
auto escape_tensors = collect_valid_tensors(); auto escape_tensors = collect_valid_tensors();
if (escape_tensors.size() > 0) { if (escape_tensors.size() > 0) {
m_buffer.enqueue(StopProfile{std::move(escape_tensors)}); m_worker.add_task(
{Profiler::next_id(), StopProfile{std::move(escape_tensors)},
get_channel_state().stack_manager.dump()});
} }
} }
...@@ -1410,7 +1289,9 @@ void ChannelImpl::push_scope(std::string name) { ...@@ -1410,7 +1289,9 @@ void ChannelImpl::push_scope(std::string name) {
auto& state = get_channel_state(); auto& state = get_channel_state();
state.stack_manager.enter(name); state.stack_manager.enter(name);
MGB_RECORD_EVENT(ScopeEvent, name); MGB_RECORD_EVENT(ScopeEvent, name);
m_buffer.enqueue(PushScope{name}); m_worker.add_task(
{Profiler::next_id(), PushScope{name},
get_channel_state().stack_manager.dump()});
} }
void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
...@@ -1419,7 +1300,9 @@ void ChannelImpl::pop_scope(std::string name) { ...@@ -1419,7 +1300,9 @@ void ChannelImpl::pop_scope(std::string name) {
auto& state = get_channel_state(); auto& state = get_channel_state();
state.stack_manager.exit(name); state.stack_manager.exit(name);
MGB_RECORD_EVENT(ScopeFinishEvent, name); MGB_RECORD_EVENT(ScopeFinishEvent, name);
m_buffer.enqueue(PopScope{name}); m_worker.add_task(
{Profiler::next_id(), PopScope{name},
get_channel_state().stack_manager.dump()});
} }
void ChannelImpl::assert_in_channel() { void ChannelImpl::assert_in_channel() {
......
...@@ -126,11 +126,6 @@ private: ...@@ -126,11 +126,6 @@ private:
void assert_in_worker(); void assert_in_worker();
std::thread::id get_worker_tid(); std::thread::id get_worker_tid();
// template <typename TCommand>
// void enqueue_command(TCommand&& cmd) {
// m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
// }
void sample_on_device(CompNode device, bool force); void sample_on_device(CompNode device, bool force);
// valid => status != Deleted // valid => status != Deleted
...@@ -178,46 +173,6 @@ private: ...@@ -178,46 +173,6 @@ private:
ChannelImpl* m_owner; ChannelImpl* m_owner;
} m_worker; } m_worker;
/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see:
* ChannelImpl::process_one_task
*/
struct CommandBuffer {
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {}
void enqueue(CommandData cmd);
bool empty() const { return m_commands.empty(); }
void flush();
private:
ChannelImpl* m_owner;
std::deque<Command> m_commands;
using Handle = decltype(m_commands)::iterator;
// [begin, end)
using Range = std::array<Handle, 2>;
// Launch commands in range [m_commands.begin(), pos)
void flush(Handle pos);
// Select flush position for incoming cmd
Handle flush_pos_for(const Command& cmd);
// Fuse del command into suitable ApplyOp
bool fuse_del(const Del& cmd);
// Returns the last handle that dest is used within range. If dest is not used,
// returns range[1]
Handle find_last_usage(TensorInfo* dest, Range range);
// Returns the produce position of dest. If not found, returns range[1]
Handle find_produce(TensorInfo* dest, Range range);
} m_buffer;
//! config whether raise error exactly when invoking op. //! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async; //! level 2: both device and user side errors are async;
//! level 1: user side errors are sync; //! level 1: user side errors are sync;
......
...@@ -40,9 +40,6 @@ public: ...@@ -40,9 +40,6 @@ public:
DEF_OPTION( DEF_OPTION(
catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1, catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1,
"catch worker exception if enabled, close it when debugging"); "catch worker exception if enabled, close it when debugging");
DEF_OPTION(
buffer_length, "MEGENGINE_COMMAND_BUFFER_LENGTH", 3,
"set command buffer length.");
DEF_OPTION( DEF_OPTION(
enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's " "enable host compute, thus computation may be done in host event if it's "
......
...@@ -626,23 +626,12 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( ...@@ -626,23 +626,12 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
/*********************** Logical Tensor Impl ***********************/ /*********************** Logical Tensor Impl ***********************/
size_t ProxyGraph::get_opr_output_size(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
return get_proxy_opr(opdef, inputs)->usable_output().size();
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph:: std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::
infer_output_attrs_fallible( infer_output_attrs_fallible(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
auto opr = get_proxy_opr(opdef, inputs); // this function is just a placeholder
CUR_OPR_GUARD(opr); // it will be overrided by ProxyGraphTypeI::infer_output_attrs_fallible in minigraph
SmallVector<LogicalTensorDesc> outputs; mgb_assert(0);
bool validated = do_shape_infer(false);
for (auto&& i : opr->usable_output()) {
outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
}
bool need_check = opr->same_type<opr::Reshape>();
return {outputs, validated && !need_check};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph:: std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
...@@ -823,12 +812,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph( ...@@ -823,12 +812,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
return result; return result;
} }
cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(!m_cur_opr);
auto vinputs = make_input_place_holders(inputs);
return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
}
VarNodeArray ProxyGraph::make_input_place_holders( VarNodeArray ProxyGraph::make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
......
...@@ -85,9 +85,6 @@ private: ...@@ -85,9 +85,6 @@ private:
/********************** Logical Tensor Helper **********************/ /********************** Logical Tensor Helper **********************/
cg::OperatorNodeBase* get_proxy_opr(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs);
cg::VarNodeArray make_input_place_holders( cg::VarNodeArray make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册