提交 951ed476 编写于 作者: M Megvii Engine Team

fix(minigraph): supports varnode forwarding

GitOrigin-RevId: 4494106f0a34d05c039211969444946cc6197bae
上级 27d4c4b3
......@@ -20,6 +20,7 @@
#include "./proxy_graph_base.h"
#include <optional>
#include "megbrain/opr/utility.h"
#include "range/v3/all.hpp"
namespace mgb::imperative::proxy_graph {
......@@ -83,7 +84,7 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>;
template <typename T>
TensorAdaptor(T*) -> TensorAdaptor<T, void>;
SmallVector<Tensor*> to_raw_ptr_array(
inline SmallVector<Tensor*> to_raw_ptr_array(
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
SmallVector<Tensor*> ret;
for (auto&& i : inputs) {
......@@ -243,6 +244,13 @@ public:
vinputs[i] = opr_ref_keeper.back()->output(0);
}
auto ovars = OpDef::apply_on_var_node(opdef, vinputs);
if (!m_opr) {
// identity
mgb_assert(vinputs.size() == 1 && ovars.size() == 1);
mgb_assert(ovars[0] == vinputs[0]);
auto&& input = vinputs[0];
ovars[0] = opr::Identity::make(input).node();
}
mgb_assert(m_opr);
output_data.resize(m_opr->output().size());
for (auto* v : ovars) {
......@@ -343,7 +351,6 @@ public:
} else {
mgb_assert(j < outputs.size());
auto&& tensor = outputs[j];
auto&& layout = tensor->layout();
if (var->m_mem_plan.chunk().owner_var != var) {
tensor->assign_from_dev_tensor(
var->m_dev_tensor); // memory forwarding
......@@ -613,6 +620,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph {
busy_oprs.pop_front();
return m_opr;
}
mgb_assert(false);
}
template <bool in_use>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册