From 968f74ce8856ce2a7eb4504b6645254874c78ecb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 31 Aug 2020 16:47:28 +0800 Subject: [PATCH] chore(mgb): add no_force_inplace option to ComputingGraph GitOrigin-RevId: 350c90fb8613217b2fbc07134b3a65c3341efd94 --- imperative/src/impl/proxy_graph.cpp | 1 + src/core/include/megbrain/graph/cg.h | 11 +++++++++++ src/opr/impl/dnn/batch_norm.cpp | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index a449edff..f04a7119 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -376,6 +376,7 @@ public: ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) { options().imperative_proxy_graph = true; + options().no_force_inplace = true; options().log_level = 0; m_var_receiver_info.dev_value = 1; m_var_receiver_info.allow_empty_value = 1; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 7949b1c6..41c702ed 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -464,6 +464,17 @@ class ComputingGraph : public std::enable_shared_from_this, bool imperative_proxy_graph = false; + /*! + * Request that operators should not force update their inputs. + * + * THIS FLAG IS RESERVED FOR INTERNAL USE + * + * When this flag is set, operators like AddUpdate and BatchNorm + * will still attempt to inplace update their inputs, but failing + * to do so will not be considered as an error. + */ + bool no_force_inplace = false; + //! add extra deps for the comp seq if a specific var is dependent ThinHashMap extra_vardeps; diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index a8daf566..6461573b 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -40,7 +40,7 @@ BatchNormForward::BatchNormForward(VarNode *x, Super{x->owner_graph(), config, "batch_norm", {x, scale, bias, mean, variance}} { - if(owner_graph()->options().imperative_proxy_graph) { + if(owner_graph()->options().no_force_inplace) { m_force_inplace = false; } -- GitLab