提交 951ed476 编写于 作者: M Megvii Engine Team

fix(minigraph): supports varnode forwarding

GitOrigin-RevId: 4494106f0a34d05c039211969444946cc6197bae
上级 27d4c4b3
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "./proxy_graph_base.h" #include "./proxy_graph_base.h"
#include <optional> #include <optional>
#include "megbrain/opr/utility.h"
#include "range/v3/all.hpp" #include "range/v3/all.hpp"
namespace mgb::imperative::proxy_graph { namespace mgb::imperative::proxy_graph {
...@@ -83,7 +84,7 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>; ...@@ -83,7 +84,7 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>;
template <typename T> template <typename T>
TensorAdaptor(T*) -> TensorAdaptor<T, void>; TensorAdaptor(T*) -> TensorAdaptor<T, void>;
SmallVector<Tensor*> to_raw_ptr_array( inline SmallVector<Tensor*> to_raw_ptr_array(
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) { const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
SmallVector<Tensor*> ret; SmallVector<Tensor*> ret;
for (auto&& i : inputs) { for (auto&& i : inputs) {
...@@ -243,6 +244,13 @@ public: ...@@ -243,6 +244,13 @@ public:
vinputs[i] = opr_ref_keeper.back()->output(0); vinputs[i] = opr_ref_keeper.back()->output(0);
} }
auto ovars = OpDef::apply_on_var_node(opdef, vinputs); auto ovars = OpDef::apply_on_var_node(opdef, vinputs);
if (!m_opr) {
// identity
mgb_assert(vinputs.size() == 1 && ovars.size() == 1);
mgb_assert(ovars[0] == vinputs[0]);
auto&& input = vinputs[0];
ovars[0] = opr::Identity::make(input).node();
}
mgb_assert(m_opr); mgb_assert(m_opr);
output_data.resize(m_opr->output().size()); output_data.resize(m_opr->output().size());
for (auto* v : ovars) { for (auto* v : ovars) {
...@@ -343,7 +351,6 @@ public: ...@@ -343,7 +351,6 @@ public:
} else { } else {
mgb_assert(j < outputs.size()); mgb_assert(j < outputs.size());
auto&& tensor = outputs[j]; auto&& tensor = outputs[j];
auto&& layout = tensor->layout();
if (var->m_mem_plan.chunk().owner_var != var) { if (var->m_mem_plan.chunk().owner_var != var) {
tensor->assign_from_dev_tensor( tensor->assign_from_dev_tensor(
var->m_dev_tensor); // memory forwarding var->m_dev_tensor); // memory forwarding
...@@ -613,6 +620,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph { ...@@ -613,6 +620,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph {
busy_oprs.pop_front(); busy_oprs.pop_front();
return m_opr; return m_opr;
} }
mgb_assert(false);
} }
template <bool in_use> template <bool in_use>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册