From d712ac0da0496216e063f8bd4ea82015df751bfc Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 8 Sep 2020 19:18:31 +0800 Subject: [PATCH] add count of graphs using the parameter --- mindspore/ccsrc/backend/session/session_basic.cc | 2 ++ mindspore/ccsrc/runtime/device/kernel_runtime.cc | 9 ++++++++- mindspore/core/ir/anf.h | 8 +++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ea918c15c..1de96c166 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf } TraceManager::EndTrace(); } + new_parameter->IncreaseUsedGraphCount(); graph_inputs->push_back(new_parameter); valid_inputs->push_back(true); return new_parameter; @@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph } TraceManager::EndTrace(); } + new_parameter->IncreaseUsedGraphCount(); return new_parameter; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d3cd2cda9..74809a332 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector &inputs, if (!input_node->isa()) { continue; } + auto parameter = input_node->cast(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->DecreaseUsedGraphCount(); + // Only the parameter has no graph used, then clear the output address. + if (parameter->used_graph_count() != 0) { + continue; + } for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) { if (!AnfAlgo::OutputAddrExist(input_node, index)) { continue; } - AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); + AnfAlgo::SetOutputAddr(nullptr, index, input_node.get()); } } // clear input value node output address. diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index b544beaa5..cf4567539 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -282,7 +282,7 @@ class ANode : public AnfNode { class Parameter : public ANode { public: explicit Parameter(const FuncGraphPtr &func_graph) - : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} + : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {} ~Parameter() override = default; MS_DECLARE_PARENT(Parameter, ANode); @@ -300,6 +300,10 @@ class Parameter : public ANode { ValuePtr default_param() const { return default_param_; } ParamInfoPtr param_info() const; + void IncreaseUsedGraphCount() { used_graph_count_++; } + void DecreaseUsedGraphCount() { used_graph_count_--; } + int used_graph_count() const { return used_graph_count_; } + bool operator==(const AnfNode &other) const override { if (!other.isa()) { return false; @@ -315,6 +319,8 @@ class Parameter : public ANode { std::string name_; bool has_default_; ValuePtr default_param_; + // The count of graphs using the parameter. + int used_graph_count_; }; using ParameterPtr = std::shared_ptr; -- GitLab