From 2afceb41879ea29a3f893b5da700cf610c545ddc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 30 Jul 2020 00:31:29 +0800 Subject: [PATCH] fix(mgb/atlas): use dyn output alloc if enable dynamic batchsize GitOrigin-RevId: 45a6c6ad518de9172101fef4003988b24e86b1a3 --- src/core/impl/graph/var_node.cpp | 4 ++-- src/core/impl/graph/var_node_mem_mgr.cpp | 9 +++++++-- src/core/impl/graph/var_node_mem_mgr.h | 8 +++++++- src/core/include/megbrain/graph/var_node.h | 4 +++- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 6cfae871a..e6a5a1ff5 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -307,7 +307,7 @@ VarNode& VarNode::shape(const TensorShape &shape) { return *this; } -VarNode& VarNode::shape_alloc(const TensorShape &shape) { +VarNode& VarNode::shape_alloc(const TensorShape &shape, size_t size_req) { mgb_assert(shape.ndim, "got empty shape in shape_alloc: " "var=%s owner_opr=%s{%s}", cname(), owner_opr()->cname(), owner_opr()->dyn_typeinfo()->name); @@ -316,7 +316,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) { " NO_SYS_MEM_ALLOC flag; actual var: %s", cg::dump_var_info({this}).c_str()); ComputingGraphImpl::downcast(owner_graph()) - ->var_node_mem_manager().var_alloc_with_shape(this, shape); + ->var_node_mem_manager().var_alloc_with_shape(this, shape, size_req); return *this; } diff --git a/src/core/impl/graph/var_node_mem_mgr.cpp b/src/core/impl/graph/var_node_mem_mgr.cpp index 381436b08..3e8bc615e 100644 --- a/src/core/impl/graph/var_node_mem_mgr.cpp +++ b/src/core/impl/graph/var_node_mem_mgr.cpp @@ -1239,13 +1239,18 @@ void VarNodeMemManager::make_dev_tensor_from_mem_plan_single( } void VarNodeMemManager::var_alloc_with_shape(VarNode* var, - const TensorShape& shape) { + const TensorShape& shape, + size_t size_req) { mgb_assert(var->format().is_default(), "dynamic shape is currently only supported for var with " "default format; got %s", var->format().to_string().c_str()); var->shape(shape); - auto size_req = var->dtype().size(shape.total_nr_elems()); + if (size_req != 0) { + mgb_assert(var->dtype().size(shape.total_nr_elems()) <= size_req); + } else { + size_req = var->dtype().size(shape.total_nr_elems()); + } auto&& mplan = var->m_mem_plan; if (!mplan.valid() || mplan.chunk().owner_var != var) diff --git a/src/core/impl/graph/var_node_mem_mgr.h b/src/core/impl/graph/var_node_mem_mgr.h index 2f2e99717..00e19cd63 100644 --- a/src/core/impl/graph/var_node_mem_mgr.h +++ b/src/core/impl/graph/var_node_mem_mgr.h @@ -294,7 +294,13 @@ class VarNodeMemManager { void add_layout_constraint_level( VarNode *dest, LayoutConstraintLevel level); - void var_alloc_with_shape(VarNode *var, const TensorShape &shape); + /** + * \brief alloc var memory with shape. + * + * Alloc memory of size_seq if size_req != 0. + */ + void var_alloc_with_shape(VarNode* var, const TensorShape& shape, + size_t size_req = 0); /*! * \brief initialize mem plan for a single var diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index f9a5d4d26..31e4fef2b 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -462,8 +462,10 @@ class VarNode final: public GraphNodeBase { * this var must have NO_SYS_MEM_ALLOC flag; if shape does not increase * and original tensor storage is valid, it is guaranteed that old data * would be retained. + * + * \warning Alloc size_req memory if size_req != 0. */ - VarNode& shape_alloc(const TensorShape &shape); + VarNode& shape_alloc(const TensorShape &shape, size_t size_req = 0); /*! * \brief directly reset device tensor from another var -- GitLab