From 2f68aeb9b639ba3c16d1c1830555441c7820270f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Jun 2021 16:27:56 +0800 Subject: [PATCH] feat(imperative/jit): let trace support empty IO GitOrigin-RevId: 97a55242bfe4d23e447ac77842e213ecb21995ee --- imperative/src/impl/opr_utility.cpp | 11 ++- imperative/src/impl/proxy_graph.cpp | 6 +- src/core/impl/graph/var_node.cpp | 12 ++- src/core/include/megbrain/graph/var_node.h | 8 +- src/core/test/graph/misc.cpp | 92 ++++++++++++++++++++++ 5 files changed, 118 insertions(+), 11 deletions(-) diff --git a/imperative/src/impl/opr_utility.cpp b/imperative/src/impl/opr_utility.cpp index 4b6914bc8..7238f58c4 100644 --- a/imperative/src/impl/opr_utility.cpp +++ b/imperative/src/impl/opr_utility.cpp @@ -45,7 +45,10 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, if(m_use_static_shape){ mgb_assert(m_output_shape.ndim); } - add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC).dtype(dt); + add_output(None) + ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) + .dtype(dt); add_output(None) ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) @@ -109,6 +112,11 @@ void InputCallback::scn_do_execute() { if (m_use_static_shape) { mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); } + if (dev_tensor.empty()) { + auto layout = dev_tensor.layout(); + layout.init_contiguous_stride(); + dev_tensor.reset(dev_tensor.storage(), layout); + } output(0)->reset_dev_tensor_from_tensor(dev_tensor); } @@ -172,6 +180,7 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const { }; m_use_host_value = m_param.prefer_host_value && host_value_avail(); dep_types[0] = m_use_host_value ? NodeProp::DepType::HOST_VALUE : NodeProp::DepType::DEV_VALUE; + dep_types[0] |= NodeProp::DepType::VALUE_ALLOW_EMPTY; prop->reset_dep_type(input(), dep_types); return prop; } diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 1bcd65917..7b5a02fdc 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -564,11 +564,7 @@ void ProxyGraph::init_output_tensor(const SmallVector& outputs) { mgb_assert(var->comp_node() == tensor->comp_node() && var->shape().eq_shape(layout) && var->dtype() == layout.dtype); - if (!tensor->layout().is_empty()) { - var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); - } else { - var->m_dev_tensor.storage({var->comp_node()}); - } + var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); ++ j; } chk.mem_alloc_status.set_from_owner_var(); diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 7a4272fc7..4a2ab6925 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -361,9 +361,19 @@ VarNode& VarNode::reset_dev_tensor_from_tensor(const DeviceTensorND& value) { } void VarNode::assign_dev_tensor_from_tensor(const DeviceTensorND& value) { - mgb_assert(value.layout().is_contiguous() && + mgb_assert((value.layout().is_contiguous() || value.empty()) && m_dev_tensor.dtype() == value.dtype() && m_dev_tensor.format() == value.format()); + if (value.empty()) { + mgb_assert(value.shape_valid() && value.layout().is_empty()); + bool allow_empty = contain_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + auto &&recv = owner_graph()->var_receiver_in_current_comp_seq(this); + mgb_throw_if(!allow_empty || !recv.is_empty_allowed(), + GraphError, + "assign empty tensor to var %s, but allowed=%d, receiver=%s", + cg::dump_var_info({this}).c_str(), + allow_empty, recv.to_string().c_str()); + } if (cg::is_static_var_shape(this)) { mgb_assert(shape().eq_shape(value.shape()), "shape mismatch for static inferrable var when setting dev " diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index 0b6190234..0ef2fdb6d 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -480,8 +480,8 @@ class VarNode final: public GraphNodeBase { * * \param src_var the var node to provide dev tensor, which must have * been initialized, and does not have to be in the same computing - * graph. Its value must be contiguous. It can also be placed on a - * different comp node. + * graph. Its value must be contiguous or empty. It can also be + * placed on a different comp node. * * \return whether memory forwarding succeeds; if false is returned, a * new tensor would be allocated and its value is copied from src @@ -495,8 +495,8 @@ class VarNode final: public GraphNodeBase { * This function should only be called by this var's owner operator and * this var must have NO_SYS_MEM_ALLOC flag * - * \param value the tensor to be used; it must be contiguous and be - * placed on the same comp node of this var. + * \param value the tensor to be used; it must be contiguous or empty + * and be placed on the same comp node of this var. */ VarNode& reset_dev_tensor_from_tensor(const DeviceTensorND &value); diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index 961465768..cf1963a5e 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -10,6 +10,7 @@ */ #include "megbrain/opr/io.h" +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/utility.h" @@ -2336,4 +2337,95 @@ TEST(TestGraph, DynamicOutput) { MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4); } +namespace { +// used for test reset_dev_tensor_from_tensor +MGB_DEFINE_OPR_CLASS(MaybeEmptyTensorOpr, cg::SingleCNOperatorNodeBase)// { + DeviceTensorND m_dv; + + void init_output_comp_node() override { + output(0)->comp_node(m_dv.comp_node()); + comp_node(m_dv.comp_node()); + } + + void scn_do_execute() override { + output(0)->reset_dev_tensor_from_tensor(m_dv); + } + + void init_output_static_infer_desc() override { + using namespace cg::static_infer; + auto &&mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), + ShapeInferDesc::make_const(m_dv.shape())); + } + + public: + MaybeEmptyTensorOpr(ComputingGraph &graph, + const DeviceTensorND &dv, const OperatorNodeConfig &config): + Super(&graph, config, "", {}), m_dv{dv} { + add_output(None) + ->add_flag(cg::VarNode::Flag::NO_SYS_MEM_ALLOC) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .dtype(dv.dtype()); + } + + static SymbolVar make(ComputingGraph &graph, const DeviceTensorND &dv, + const OperatorNodeConfig &config = {}) { + return graph.insert_opr(std::make_unique( + graph, dv, config))->output(0); + } +}; + +} // anonymous namespace + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaybeEmptyTensorOpr); + +TEST(TestMemReuse, ResetEmptyDevTensor) { + // reciver opr allow empty tensor as input + auto allow_empty = [](const TensorShape& inp_shp) { + HostTensorGenerator<> gen; + auto g = ComputingGraph::make(); + auto host_x1 = gen(inp_shp), + host_x2 = gen(inp_shp); + DeviceTensorND dev_x1, dev_x2; + dev_x1.copy_from(*host_x1), dev_x2.copy_from(*host_x2); + auto x1 = MaybeEmptyTensorOpr::make(*g, dev_x1, {"x1"}), + x2 = MaybeEmptyTensorOpr::make(*g, dev_x2, {"x2"}), + y = x1 + x2; + HostTensorND host_y; + auto func = g->compile({make_callback_copy(y, host_y)}); + auto &&recv = x1.node()->owner_graph()->var_receiver_in_current_comp_seq(x1.node()); + ASSERT_TRUE(recv.is_empty_allowed()); + ASSERT_NO_THROW(func->execute().wait()); + if (inp_shp.is_empty()) { + ASSERT_TRUE(host_y.empty()); + ASSERT_TRUE(host_y.shape().is_empty()); + } + }; + + // reciver opr do not allow empty tensor as input + auto forbid_empty = [](const TensorShape& inp_shp) { + HostTensorGenerator<> gen; + auto g = ComputingGraph::make(); + auto host_x = gen(inp_shp); + DeviceTensorND dev_x; + dev_x.copy_from(*host_x); + auto x = MaybeEmptyTensorOpr::make(*g, dev_x, {"x"}), + y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0}); + HostTensorND host_y; + auto func = g->compile({make_callback_copy(y, host_y)}); + auto &&recv = x.node()->owner_graph()->var_receiver_in_current_comp_seq(x.node()); + ASSERT_TRUE(!recv.is_empty_allowed()); + if (inp_shp.is_empty()) { + ASSERT_ANY_THROW(func->execute().wait()); + } else { + ASSERT_NO_THROW(func->execute().wait()); + } + }; + + allow_empty({2, 3, 4, 5}); + allow_empty({2, 0, 3, 4}); + forbid_empty({4, 5, 6, 7}); + forbid_empty({8, 0, 0, 9}); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab