提交 76733ce8 编写于 作者: K kswang

fix cpu multi graph mem error

上级 7cb567eb
...@@ -312,7 +312,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const ...@@ -312,7 +312,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG(INFO) << "No kernel info"; MS_LOG(INFO) << "No kernel info";
return; return;
} }
if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) { if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
MS_LOG(INFO) << "No kernel address"; MS_LOG(INFO) << "No kernel address";
return; return;
} }
......
...@@ -40,8 +40,7 @@ void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { ...@@ -40,8 +40,7 @@ void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
AssignValueNodeAddress(kernel_graph); AssignValueNodeAddress(kernel_graph);
AssignInputNodeAddress(kernel_graph); AssignInputNodeAddress(kernel_graph);
AssignKernelOutputAddress(kernel_graph); AssignKernelOutputAddress(kernel_graph);
resource_manager_.MemPlan(kernel_graph); resource_manager_.AssignMemory(kernel_graph);
resource_manager_.MemMalloc(kernel_graph);
} }
void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) { void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) {
......
...@@ -34,11 +34,13 @@ void CPUResourceManager::MemFree() { ...@@ -34,11 +34,13 @@ void CPUResourceManager::MemFree() {
dynamic_mem_.clear(); dynamic_mem_.clear();
} }
void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { void CPUResourceManager::AssignMemory(const session::KernelGraph *graph) {
mem_plan_.MemPlan(graph); size_t graph_mem_size = mem_plan_.MemPlan(graph);
size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph);
if (graph_mem_size > mem_size_) { if (graph_mem_size > mem_size_) {
MemFree(); if (mem_size_ > 0) {
dynamic_mem_[mem_ptr_] = mem_size_;
mem_size_ = 0;
}
mem_ptr_ = reinterpret_cast<uint8_t *>(malloc(graph_mem_size)); mem_ptr_ = reinterpret_cast<uint8_t *>(malloc(graph_mem_size));
if (mem_ptr_ != nullptr) { if (mem_ptr_ != nullptr) {
mem_size_ = graph_mem_size; mem_size_ = graph_mem_size;
...@@ -48,9 +50,6 @@ void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { ...@@ -48,9 +50,6 @@ void CPUResourceManager::MemPlan(const session::KernelGraph *graph) {
dynamic_malloc_ = true; dynamic_malloc_ = true;
} }
} }
}
void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) {
if (dynamic_malloc_) { if (dynamic_malloc_) {
return; return;
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_
#include <vector> #include <vector>
#include <unordered_map> #include <map>
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "backend/session/session_basic.h" #include "backend/session/session_basic.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
...@@ -30,8 +30,7 @@ class CPUResourceManager { ...@@ -30,8 +30,7 @@ class CPUResourceManager {
CPUResourceManager() = default; CPUResourceManager() = default;
~CPUResourceManager(); ~CPUResourceManager();
void MemPlan(const session::KernelGraph *graph); void AssignMemory(const session::KernelGraph *graph);
void MemMalloc(const session::KernelGraph *graph);
void IncreaseAddressRefCount(const session::KernelGraph *graph); void IncreaseAddressRefCount(const session::KernelGraph *graph);
void DecreaseAddressRefCount(const AnfNodePtr &kernel); void DecreaseAddressRefCount(const AnfNodePtr &kernel);
void *MemMalloc(size_t mem_size); void *MemMalloc(size_t mem_size);
...@@ -46,7 +45,7 @@ class CPUResourceManager { ...@@ -46,7 +45,7 @@ class CPUResourceManager {
size_t mem_size_{0}; size_t mem_size_{0};
uint8_t *mem_ptr_{nullptr}; uint8_t *mem_ptr_{nullptr};
bool dynamic_malloc_{false}; bool dynamic_malloc_{false};
std::unordered_map<void *, size_t> dynamic_mem_; std::map<void *, size_t> dynamic_mem_;
}; };
} // namespace cpu } // namespace cpu
} // namespace device } // namespace device
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace cpu { namespace cpu {
void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { size_t CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
size_t total_mem_size = 32; size_t total_mem_size = 32;
auto kernels = graph->execution_order(); auto kernels = graph->execution_order();
...@@ -58,15 +58,8 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { ...@@ -58,15 +58,8 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
} }
} }
} }
graph_mem_size_[graph] = total_mem_size;
}
size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const { return total_mem_size;
auto iter = graph_mem_size_.find(graph);
if (iter != graph_mem_size_.end()) {
return iter->second;
}
return 0;
} }
void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) {
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_
#include <vector> #include <vector>
#include <unordered_map>
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
...@@ -29,12 +28,8 @@ class CPUSimpleMemPlan { ...@@ -29,12 +28,8 @@ class CPUSimpleMemPlan {
CPUSimpleMemPlan() = default; CPUSimpleMemPlan() = default;
~CPUSimpleMemPlan() = default; ~CPUSimpleMemPlan() = default;
void MemPlan(const session::KernelGraph *graph); size_t MemPlan(const session::KernelGraph *graph);
void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr);
size_t GetGraphMemSize(const session::KernelGraph *graph) const;
private:
std::unordered_map<const session::KernelGraph *, size_t> graph_mem_size_;
}; };
} // namespace cpu } // namespace cpu
} // namespace device } // namespace device
......
...@@ -270,7 +270,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { ...@@ -270,7 +270,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
} }
return target; return target;
} }
if (IsPrimitive(node, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return GetMaketupleNodeTarget(cnode); return GetMaketupleNodeTarget(cnode);
} }
return default_target; return default_target;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册