提交 2afceb41 编写于 作者: M Megvii Engine Team

fix(mgb/atlas): use dyn output alloc if enable dynamic batchsize

GitOrigin-RevId: 45a6c6ad518de9172101fef4003988b24e86b1a3
上级 6bcc6fae
...@@ -307,7 +307,7 @@ VarNode& VarNode::shape(const TensorShape &shape) { ...@@ -307,7 +307,7 @@ VarNode& VarNode::shape(const TensorShape &shape) {
return *this; return *this;
} }
VarNode& VarNode::shape_alloc(const TensorShape &shape) { VarNode& VarNode::shape_alloc(const TensorShape &shape, size_t size_req) {
mgb_assert(shape.ndim, "got empty shape in shape_alloc: " mgb_assert(shape.ndim, "got empty shape in shape_alloc: "
"var=%s owner_opr=%s{%s}", cname(), owner_opr()->cname(), "var=%s owner_opr=%s{%s}", cname(), owner_opr()->cname(),
owner_opr()->dyn_typeinfo()->name); owner_opr()->dyn_typeinfo()->name);
...@@ -316,7 +316,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) { ...@@ -316,7 +316,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) {
" NO_SYS_MEM_ALLOC flag; actual var: %s", " NO_SYS_MEM_ALLOC flag; actual var: %s",
cg::dump_var_info({this}).c_str()); cg::dump_var_info({this}).c_str());
ComputingGraphImpl::downcast(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().var_alloc_with_shape(this, shape); ->var_node_mem_manager().var_alloc_with_shape(this, shape, size_req);
return *this; return *this;
} }
......
...@@ -1239,13 +1239,18 @@ void VarNodeMemManager::make_dev_tensor_from_mem_plan_single( ...@@ -1239,13 +1239,18 @@ void VarNodeMemManager::make_dev_tensor_from_mem_plan_single(
} }
void VarNodeMemManager::var_alloc_with_shape(VarNode* var, void VarNodeMemManager::var_alloc_with_shape(VarNode* var,
const TensorShape& shape) { const TensorShape& shape,
size_t size_req) {
mgb_assert(var->format().is_default(), mgb_assert(var->format().is_default(),
"dynamic shape is currently only supported for var with " "dynamic shape is currently only supported for var with "
"default format; got %s", "default format; got %s",
var->format().to_string().c_str()); var->format().to_string().c_str());
var->shape(shape); var->shape(shape);
auto size_req = var->dtype().size(shape.total_nr_elems()); if (size_req != 0) {
mgb_assert(var->dtype().size(shape.total_nr_elems()) <= size_req);
} else {
size_req = var->dtype().size(shape.total_nr_elems());
}
auto&& mplan = var->m_mem_plan; auto&& mplan = var->m_mem_plan;
if (!mplan.valid() || mplan.chunk().owner_var != var) if (!mplan.valid() || mplan.chunk().owner_var != var)
......
...@@ -294,7 +294,13 @@ class VarNodeMemManager { ...@@ -294,7 +294,13 @@ class VarNodeMemManager {
void add_layout_constraint_level( void add_layout_constraint_level(
VarNode *dest, LayoutConstraintLevel level); VarNode *dest, LayoutConstraintLevel level);
void var_alloc_with_shape(VarNode *var, const TensorShape &shape); /**
* \brief alloc var memory with shape.
*
* Alloc memory of size_seq if size_req != 0.
*/
void var_alloc_with_shape(VarNode* var, const TensorShape& shape,
size_t size_req = 0);
/*! /*!
* \brief initialize mem plan for a single var * \brief initialize mem plan for a single var
......
...@@ -462,8 +462,10 @@ class VarNode final: public GraphNodeBase { ...@@ -462,8 +462,10 @@ class VarNode final: public GraphNodeBase {
* this var must have NO_SYS_MEM_ALLOC flag; if shape does not increase * this var must have NO_SYS_MEM_ALLOC flag; if shape does not increase
* and original tensor storage is valid, it is guaranteed that old data * and original tensor storage is valid, it is guaranteed that old data
* would be retained. * would be retained.
*
* \warning Alloc size_req memory if size_req != 0.
*/ */
VarNode& shape_alloc(const TensorShape &shape); VarNode& shape_alloc(const TensorShape &shape, size_t size_req = 0);
/*! /*!
* \brief directly reset device tensor from another var * \brief directly reset device tensor from another var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册