提交 24f00cc6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5617 clear graph output address in graph destructor

Merge pull request !5617 from limingqi107/master
...@@ -46,7 +46,7 @@ const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, c ...@@ -46,7 +46,7 @@ const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, c
MS_LOG(EXCEPTION) << "The pattern is not transpose pair, " MS_LOG(EXCEPTION) << "The pattern is not transpose pair, "
<< "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node); << "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node);
} }
// If transpose operator used by more than one other operators, it cant not be deleted directly. // If transpose operator used by more than one other operators, it cant not be deleted directly.
if (IsUsedByOthers(graph, input_node)) { if (IsUsedByOthers(graph, input_node)) {
MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope() MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope()
<< "] is used by more than one other operators."; << "] is used by more than one other operators.";
......
...@@ -397,8 +397,8 @@ void GPUKernelRuntime::ReleaseDeviceRes() { ...@@ -397,8 +397,8 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
bin_map->RemoveKernelCache(); bin_map->RemoveKernelCache();
} }
void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &, void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &, const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) { const std::vector<CNodePtr> &execution_order) {
MS_LOG(INFO) << "Clear graph:" << graph_id << " GPU runtime resource"; MS_LOG(INFO) << "Clear graph:" << graph_id << " GPU runtime resource";
// Release the kernel resource. // Release the kernel resource.
...@@ -409,6 +409,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v ...@@ -409,6 +409,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
} }
kernel_mod->ReleaseResource(); kernel_mod->ReleaseResource();
} }
// Clear the output address of graph.
ClearOutputAddress(inputs, value_nodes, execution_order);
} }
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
......
...@@ -854,6 +854,40 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vect ...@@ -854,6 +854,40 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vect
MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
} }
void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) {
// clear input parameter output address.
for (const auto &input_node : inputs) {
MS_EXCEPTION_IF_NULL(input_node);
if (!input_node->isa<Parameter>()) {
continue;
}
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
if (!AnfAlgo::OutputAddrExist(input_node, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
}
}
// clear input value node output address.
for (const auto &value_node : value_nodes) {
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
}
// clear cnode output address.
for (const auto &cnode : execution_order) {
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
if (!AnfAlgo::OutputAddrExist(cnode, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
}
}
}
bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr,
const AddressPtrList &kernel_inputs, const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs, const AddressPtrList &kernel_outputs,
......
...@@ -72,6 +72,9 @@ class KernelRuntime { ...@@ -72,6 +72,9 @@ class KernelRuntime {
virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes, const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order); const std::vector<CNodePtr> &execution_order);
virtual void ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order);
virtual bool SyncStream() = 0; virtual bool SyncStream() = 0;
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册