diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ea918c15c08f7f4bd7eee02e452a6956bece1b9f..1de96c166c2ab6c3ad9b31e66ffaae726c5faf67 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf } TraceManager::EndTrace(); } + new_parameter->IncreaseUsedGraphCount(); graph_inputs->push_back(new_parameter); valid_inputs->push_back(true); return new_parameter; @@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph } TraceManager::EndTrace(); } + new_parameter->IncreaseUsedGraphCount(); return new_parameter; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d3cd2cda98bc8d823ba63729dbf5e65222c647e5..74809a3329f8858a62a8bfb22abdf5f20c0099fd 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector &inputs, if (!input_node->isa()) { continue; } + auto parameter = input_node->cast(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->DecreaseUsedGraphCount(); + // Only the parameter has no graph used, then clear the output address. + if (parameter->used_graph_count() != 0) { + continue; + } for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) { if (!AnfAlgo::OutputAddrExist(input_node, index)) { continue; } - AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); + AnfAlgo::SetOutputAddr(nullptr, index, input_node.get()); } } // clear input value node output address. diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index b544beaa5e52ccf28cc59dc8fb01791bb816e2a6..cf4567539fa0c6be7f7edfbbb7cb79439a3dc96b 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -282,7 +282,7 @@ class ANode : public AnfNode { class Parameter : public ANode { public: explicit Parameter(const FuncGraphPtr &func_graph) - : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} + : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {} ~Parameter() override = default; MS_DECLARE_PARENT(Parameter, ANode); @@ -300,6 +300,10 @@ class Parameter : public ANode { ValuePtr default_param() const { return default_param_; } ParamInfoPtr param_info() const; + void IncreaseUsedGraphCount() { used_graph_count_++; } + void DecreaseUsedGraphCount() { used_graph_count_--; } + int used_graph_count() const { return used_graph_count_; } + bool operator==(const AnfNode &other) const override { if (!other.isa()) { return false; @@ -315,6 +319,8 @@ class Parameter : public ANode { std::string name_; bool has_default_; ValuePtr default_param_; + // The count of graphs using the parameter. + int used_graph_count_; }; using ParameterPtr = std::shared_ptr;