diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index a449edffa61721737a4e583aaa676d7d3f90458f..f04a71190bcbd0a49ff28daf8b2604cb1d3a965b 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -376,6 +376,7 @@ public: ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) { options().imperative_proxy_graph = true; + options().no_force_inplace = true; options().log_level = 0; m_var_receiver_info.dev_value = 1; m_var_receiver_info.allow_empty_value = 1; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 7949b1c66ab1b67dde05f0ff67db15cce0f3a7fa..41c702eda3407abec789677f285e73e4951de577 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -464,6 +464,17 @@ class ComputingGraph : public std::enable_shared_from_this, bool imperative_proxy_graph = false; + /*! + * Request that operators should not force update their inputs. + * + * THIS FLAG IS RESERVED FOR INTERNAL USE + * + * When this flag is set, operators like AddUpdate and BatchNorm + * will still attempt to inplace update their inputs, but failing + * to do so will not be considered as an error. + */ + bool no_force_inplace = false; + //! add extra deps for the comp seq if a specific var is dependent ThinHashMap extra_vardeps; diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index a8daf566670f66709fd46d02e26f682ad03c8ab6..6461573bb38f0043c0774d711eb27d9086f08ba1 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -40,7 +40,7 @@ BatchNormForward::BatchNormForward(VarNode *x, Super{x->owner_graph(), config, "batch_norm", {x, scale, bias, mean, variance}} { - if(owner_graph()->options().imperative_proxy_graph) { + if(owner_graph()->options().no_force_inplace) { m_force_inplace = false; }