提交 09dd4128 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3689 fix cpu multi graph mem error

Merge pull request !3689 from kisnwang/r0.6-fix-cpu-multi-graph-memory-error
......@@ -313,7 +313,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG(INFO) << "No kernel info";
return;
}
if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
MS_LOG(INFO) << "No kernel address";
return;
}
......@@ -1003,6 +1003,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
break;
}
}
if (internal_output) {
MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
graph->AddInternalOutput(out, backend_real_kernel.first);
......
......@@ -40,8 +40,7 @@ void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
AssignValueNodeAddress(kernel_graph);
AssignInputNodeAddress(kernel_graph);
AssignKernelOutputAddress(kernel_graph);
resource_manager_.MemPlan(kernel_graph);
resource_manager_.MemMalloc(kernel_graph);
resource_manager_.AssignMemory(kernel_graph);
}
void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) {
......
......@@ -34,11 +34,13 @@ void CPUResourceManager::MemFree() {
dynamic_mem_.clear();
}
void CPUResourceManager::MemPlan(const session::KernelGraph *graph) {
mem_plan_.MemPlan(graph);
size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph);
void CPUResourceManager::AssignMemory(const session::KernelGraph *graph) {
size_t graph_mem_size = mem_plan_.MemPlan(graph);
if (graph_mem_size > mem_size_) {
MemFree();
if (mem_size_ > 0) {
dynamic_mem_[mem_ptr_] = mem_size_;
mem_size_ = 0;
}
mem_ptr_ = reinterpret_cast<uint8_t *>(malloc(graph_mem_size));
if (mem_ptr_ != nullptr) {
mem_size_ = graph_mem_size;
......@@ -48,9 +50,6 @@ void CPUResourceManager::MemPlan(const session::KernelGraph *graph) {
dynamic_malloc_ = true;
}
}
}
void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) {
if (dynamic_malloc_) {
return;
}
......
......@@ -17,7 +17,7 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_
#include <vector>
#include <unordered_map>
#include <map>
#include "backend/session/kernel_graph.h"
#include "backend/session/session_basic.h"
#include "runtime/device/device_address.h"
......@@ -30,8 +30,7 @@ class CPUResourceManager {
CPUResourceManager() = default;
~CPUResourceManager();
void MemPlan(const session::KernelGraph *graph);
void MemMalloc(const session::KernelGraph *graph);
void AssignMemory(const session::KernelGraph *graph);
void IncreaseAddressRefCount(const session::KernelGraph *graph);
void DecreaseAddressRefCount(const AnfNodePtr &kernel);
void *MemMalloc(size_t mem_size);
......@@ -46,7 +45,7 @@ class CPUResourceManager {
size_t mem_size_{0};
uint8_t *mem_ptr_{nullptr};
bool dynamic_malloc_{false};
std::unordered_map<void *, size_t> dynamic_mem_;
std::map<void *, size_t> dynamic_mem_;
};
} // namespace cpu
} // namespace device
......
......@@ -19,7 +19,7 @@
namespace mindspore {
namespace device {
namespace cpu {
void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
size_t CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
size_t total_mem_size = 32;
auto kernels = graph->execution_order();
......@@ -58,15 +58,8 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
}
}
}
graph_mem_size_[graph] = total_mem_size;
}
size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const {
auto iter = graph_mem_size_.find(graph);
if (iter != graph_mem_size_.end()) {
return iter->second;
}
return 0;
return total_mem_size;
}
void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) {
......
......@@ -17,7 +17,6 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_
#include <vector>
#include <unordered_map>
#include "backend/session/kernel_graph.h"
#include "runtime/device/device_address.h"
......@@ -29,12 +28,8 @@ class CPUSimpleMemPlan {
CPUSimpleMemPlan() = default;
~CPUSimpleMemPlan() = default;
void MemPlan(const session::KernelGraph *graph);
size_t MemPlan(const session::KernelGraph *graph);
void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr);
size_t GetGraphMemSize(const session::KernelGraph *graph) const;
private:
std::unordered_map<const session::KernelGraph *, size_t> graph_mem_size_;
};
} // namespace cpu
} // namespace device
......
......@@ -270,7 +270,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
}
return target;
}
if (IsPrimitive(node, prim::kPrimMakeTuple)) {
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return GetMaketupleNodeTarget(cnode);
}
return default_target;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册