提交 b0095f63 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3264 Enable mem pool manage pynative and graph static mem

Merge pull request !3264 from JoyLvliang/enable-mem-pool-manage-pynative-and-graph-static-mem
...@@ -618,7 +618,12 @@ AscendDeviceAddress::~AscendDeviceAddress() { ...@@ -618,7 +618,12 @@ AscendDeviceAddress::~AscendDeviceAddress() {
return; return;
} }
if (from_mem_pool_) { if (from_mem_pool_) {
AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); if (communication_ptr_ != nullptr) {
AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_);
communication_ptr_ = nullptr;
} else {
AscendMemoryPool::GetInstance().FreeTensorMem(ptr_);
}
ptr_ = nullptr; ptr_ = nullptr;
} }
} }
......
...@@ -21,32 +21,23 @@ ...@@ -21,32 +21,23 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
constexpr uint64_t kAscendDeviceMemGB = 26; constexpr uint64_t kAscendDeviceMemGB = 30;
constexpr uint64_t kAscendMemPoolGB = 4;
constexpr uint64_t kMemSizeGB = 30; constexpr uint64_t kMemSizeGB = 30;
constexpr uint64_t kMaxMemSizeGB = 30;
constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB);
constexpr uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << kMemSizeGB); constexpr uint64_t kReservedMemorySize = 10 * 1024 * 1024;
void AscendMemoryManager::MallocDeviceMemory() { void AscendMemoryManager::MallocDeviceMemory() {
auto context_mem = GetDeviceMemSizeFromContext(); auto context_mem = GetDeviceMemSizeFromContext();
device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem;
static_mem_offset_ = device_mem_size_; auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM);
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
} }
if (context_mem == 0) { dynamic_mem_offset_ = device_mem_size_ - kReservedMemorySize;
device_mem_pool_size_ = kAscendMemPoolSize; AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_);
ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
}
AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
}
} }
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
...@@ -64,7 +55,7 @@ uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { ...@@ -64,7 +55,7 @@ uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
auto gb_str = variable_memory_max_size.substr(0, pos); auto gb_str = variable_memory_max_size.substr(0, pos);
auto gb_var = std::stoull(gb_str); auto gb_var = std::stoull(gb_str);
MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var;
if (gb_var > kMaxMemSizeGB || gb_var == 0) { if (gb_var > kAscendDeviceMemGB || gb_var == 0) {
MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB";
} }
return gb_var << kMemSizeGB; return gb_var << kMemSizeGB;
...@@ -87,8 +78,60 @@ void AscendMemoryManager::FreeDeviceMemory() { ...@@ -87,8 +78,60 @@ void AscendMemoryManager::FreeDeviceMemory() {
} }
} }
void AscendMemoryManager::ResetDynamicMemory() {
total_dynamic_size_ = 0;
dynamic_mem_offset_ = device_mem_size_ - kReservedMemorySize;
AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
}
void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
return AscendMemoryPool::GetInstance().AllocTensorMem(size); auto align_size = GetCommonAlignSize(size);
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
}
uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem) {
size_t align_size = 0;
if (communication_mem) {
align_size = GetCommunicationAlignSize(size);
} else {
align_size = GetCommonAlignSize(size);
}
if (communication_mem) {
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
return alloc_address + kMemAlignSize;
} else {
return reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
}
}
uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
size_t align_size = 0;
if (communication_mem) {
align_size = GetCommunicationAlignSize(size);
} else {
align_size = GetCommonAlignSize(size);
}
if (dynamic_mem_offset_ < align_size) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "]) malloc [" << align_size << "] failed!";
}
auto new_offset = dynamic_mem_offset_ - align_size;
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
if (new_offset <= device_mem_pool_offset) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])"
<< " malloc [" << align_size << "] failed!";
}
total_dynamic_size_ += align_size;
dynamic_mem_offset_ = new_offset;
AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
if (communication_mem) {
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
return device_mem_base_ + dynamic_mem_offset_ + kMemAlignSize;
} else {
return device_mem_base_ + dynamic_mem_offset_;
}
} }
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -27,8 +27,13 @@ class AscendMemoryManager : public MemoryManager { ...@@ -27,8 +27,13 @@ class AscendMemoryManager : public MemoryManager {
void MallocDeviceMemory() override; void MallocDeviceMemory() override;
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
void ResetDynamicMemory() override;
void *MallocMemFromMemPool(size_t size) override; void *MallocMemFromMemPool(size_t size) override;
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem) override;
uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override;
private: private:
uint8_t *device_mem_pool_base_{nullptr}; uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_size_{0};
......
...@@ -22,51 +22,56 @@ namespace mindspore { ...@@ -22,51 +22,56 @@ namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (has_malloc_) { if (size == 0) {
MS_LOG(EXCEPTION) << "Memory pool has been allocated memory resource!"; MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!";
} }
if (size == 0 || size > free_mem_size_) { if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) {
MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero or large than free mem size!"; MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ ["
<< device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_
<< "], need memory size [" << size << "]";
} }
*addr = device_mem_pool_base_; *addr = device_mem_pool_base_ + device_mem_pool_offset_;
device_mem_pool_offset_ += size;
if (*addr == nullptr) { if (*addr == nullptr) {
MS_LOG(EXCEPTION) << "Device memory pool base address is nullptr, failed to alloc memory pool resource!"; MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!";
} }
has_malloc_ = true;
free_mem_size_ -= size;
return size; return size;
} }
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
MS_EXCEPTION_IF_NULL(addr); MS_EXCEPTION_IF_NULL(addr);
has_malloc_ = false;
free_mem_size_ = total_mem_size_;
return true; return true;
} }
size_t AscendMemoryPool::AlignMemorySize(size_t size) const { size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
if (size == 0) { if (size == 0) {
return DYNAMIC_MEM_ALIGN_SIZE; MS_LOG(EXCEPTION) << "The align memory size is a zero !";
} }
return ((size + DYNAMIC_MEM_ALIGN_SIZE + 31) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; return size;
} }
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - DYNAMIC_MEM_ALIGN_SIZE; } size_t AscendMemoryPool::mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE / 2; }
void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base); MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base; device_mem_pool_base_ = device_mem_pool_base;
} }
void AscendMemoryPool::set_device_mem_pool_size(uint64_t device_mem_pool_size) { void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset) {
device_mem_pool_size_ = device_mem_pool_size; graph_dynamic_mem_offset_ = graph_dynamic_mem_offset;
free_mem_size_ = device_mem_pool_size_;
total_mem_size_ = free_mem_size_;
} }
size_t AscendMemoryPool::free_mem_size() { return free_mem_size_; } uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; }
size_t AscendMemoryPool::free_mem_size() {
if (graph_dynamic_mem_offset_ <= device_mem_pool_offset_) {
MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_
<< "] less than or equal to device mem pool offset [" << device_mem_pool_offset_ << "]!";
}
return graph_dynamic_mem_offset_ - device_mem_pool_offset_;
}
size_t AscendMemoryPool::total_mem_size() { return total_mem_size_; } size_t AscendMemoryPool::total_mem_size() { return graph_dynamic_mem_offset_ == 0 ? 0 : graph_dynamic_mem_offset_ - 1; }
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
...@@ -32,8 +32,9 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { ...@@ -32,8 +32,9 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr &addr) override; bool FreeDeviceMem(const DeviceMemPtr &addr) override;
void set_device_mem_pool_base(uint8_t *device_mem_pool_base); void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
void set_device_mem_pool_size(uint64_t device_mem_pool_size); void set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset);
uint64_t device_mem_pool_offset() const;
size_t free_mem_size() override; size_t free_mem_size() override;
size_t total_mem_size() override; size_t total_mem_size() override;
...@@ -50,11 +51,9 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { ...@@ -50,11 +51,9 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private: private:
AscendMemoryPool() = default; AscendMemoryPool() = default;
bool has_malloc_{false};
uint8_t *device_mem_pool_base_{nullptr}; uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_offset_{0};
size_t free_mem_size_{0}; uint64_t graph_dynamic_mem_offset_{0};
size_t total_mem_size_{0};
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -76,6 +76,7 @@ class DeviceAddress : public mindspore::DeviceSync { ...@@ -76,6 +76,7 @@ class DeviceAddress : public mindspore::DeviceSync {
string format_{"DefaultFormat"}; string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16}; TypeId type_id_{kNumberTypeFloat16};
bool from_mem_pool_{false}; bool from_mem_pool_{false};
uint8_t *communication_ptr_{nullptr};
std::vector<int> host_shape_{}; std::vector<int> host_shape_{};
friend class KernelRuntime; friend class KernelRuntime;
friend class MemoryManager; friend class MemoryManager;
......
...@@ -335,8 +335,10 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { ...@@ -335,8 +335,10 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
output_type_id = AnfAlgo::GetOutputInferDataType(item, index); output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
} }
auto tensor_size = CountNodeDeviceMemorySize(item, index); auto tensor_size = CountNodeDeviceMemorySize(item, index);
auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
}
AnfAlgo::SetOutputAddr(address, index, item.get()); AnfAlgo::SetOutputAddr(address, index, item.get());
} }
} }
...@@ -434,11 +436,18 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode ...@@ -434,11 +436,18 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
// reuse communication op's all outputs' memory // reuse communication op's all outputs' memory
type = kReuseDynamicCommMem; type = kReuseDynamicCommMem;
} }
uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); uint8_t *output_ptr = nullptr;
for (size_t j = 0; j < align_size_list.size(); ++j) { for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j); std::string output_format = AnfAlgo::GetOutputFormat(node, j);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
MS_EXCEPTION_IF_NULL(address);
if (output_ptr == nullptr) {
output_ptr = mem_manager_->MallocMem(address, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0));
MS_EXCEPTION_IF_NULL(output_ptr);
} else {
address->set_ptr(output_ptr);
}
AnfAlgo::SetOutputAddr(address, j, node.get()); AnfAlgo::SetOutputAddr(address, j, node.get());
output_ptr += align_size_list[j]; output_ptr += align_size_list[j];
} }
...@@ -464,7 +473,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP ...@@ -464,7 +473,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
size_t total_size = 0; size_t total_size = 0;
std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
auto input_node = input_node_with_index.first; auto input_node = input_node_with_index.first;
...@@ -477,9 +486,13 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP ...@@ -477,9 +486,13 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
total_size += mem_size; total_size += mem_size;
addr_size.emplace_back(address.get(), mem_size); addr_size.emplace_back(address, mem_size);
} }
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); if (addr_size.empty()) {
return;
}
uint8_t *input_ptr =
mem_manager_->MallocMem(addr_size[0].first, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0));
for (const auto &iter : addr_size) { for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first); MS_EXCEPTION_IF_NULL(iter.first);
iter.first->set_ptr(input_ptr); iter.first->set_ptr(input_ptr);
...@@ -509,15 +522,13 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in ...@@ -509,15 +522,13 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
MS_LOG(INFO) << "Already malloc index:" << i; MS_LOG(INFO) << "Already malloc index:" << i;
continue; continue;
} }
auto ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i]);
if (ptr == nullptr) {
// reused ptr, no need alloc, continue;
continue;
}
std::string output_format = AnfAlgo::GetOutputFormat(node, i); std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
uint8_t *ptr =
mem_manager_->MallocMem(device_address, type, output_sizes[i], std::pair<AnfNodePtr, size_t>(node, i));
MS_EXCEPTION_IF_NULL(ptr);
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
AnfAlgo::SetOutputAddr(device_address, i, node.get()); AnfAlgo::SetOutputAddr(device_address, i, node.get());
} }
...@@ -543,16 +554,12 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ...@@ -543,16 +554,12 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
} }
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
DeviceAddressPtr address = nullptr; DeviceAddressPtr address = nullptr;
if (ms_context->enable_pynative_infer()) { address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address); if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) {
if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; } else if (mem_manager_->MallocMem(address, kStaticMem, node_size) == nullptr) {
} MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
} else {
auto ptr = mem_manager_->MallocMem(kStaticMem, node_size);
address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
} }
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
...@@ -582,16 +589,12 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { ...@@ -582,16 +589,12 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
auto value = GetValue<std::string>(node_value); auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size(); size_t tensor_size = value.size();
DeviceAddressPtr address = nullptr; DeviceAddressPtr address = nullptr;
if (ms_context->enable_pynative_infer()) { address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address); if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; } else if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) {
} MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
} else {
auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size);
address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address);
} }
AnfAlgo::SetOutputAddr(address, 0, value_node.get()); AnfAlgo::SetOutputAddr(address, 0, value_node.get());
std::vector<int> shape = {1, SizeToInt(tensor_size)}; std::vector<int> shape = {1, SizeToInt(tensor_size)};
......
...@@ -95,6 +95,31 @@ uint8_t *MemoryManager::MallocMem(MemType type, size_t size) { ...@@ -95,6 +95,31 @@ uint8_t *MemoryManager::MallocMem(MemType type, size_t size) {
return ptr; return ptr;
} }
uint8_t *MemoryManager::MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size,
const session::KernelWithIndex &node_with_index) {
MS_EXCEPTION_IF_NULL(address);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
uint8_t *ptr = nullptr;
if (node_with_index.first != nullptr) {
ptr = MallocOutputMem(node_with_index.first, node_with_index.second, flag, size);
MS_EXCEPTION_IF_NULL(ptr);
if (AnfAlgo::IsCommunicationOp(node_with_index.first) && context_ptr->enable_hccl()) {
address->communication_ptr_ = ptr - kMemAlignSize;
}
} else {
ptr = MallocMem(flag, size);
MS_EXCEPTION_IF_NULL(ptr);
}
address->ptr_ = ptr;
if (flag == kStaticMem) {
address->from_mem_pool_ = true;
}
return ptr;
}
uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) {
size_t align_size = 0; size_t align_size = 0;
if (communication_mem) { if (communication_mem) {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <utility>
#include "backend/optimizer/mem_reuse/mem_reuse.h" #include "backend/optimizer/mem_reuse/mem_reuse.h"
#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h"
namespace mindspore { namespace mindspore {
...@@ -34,7 +35,7 @@ class MemoryManager { ...@@ -34,7 +35,7 @@ class MemoryManager {
virtual void MallocDeviceMemory() = 0; virtual void MallocDeviceMemory() = 0;
virtual void FreeDeviceMemory() = 0; virtual void FreeDeviceMemory() = 0;
void ResetDynamicMemory() { virtual void ResetDynamicMemory() {
total_dynamic_size_ = 0; total_dynamic_size_ = 0;
dynamic_mem_offset_ = 0; dynamic_mem_offset_ = 0;
} }
...@@ -42,6 +43,8 @@ class MemoryManager { ...@@ -42,6 +43,8 @@ class MemoryManager {
void MallocReusedDynamicMem(const session::KernelGraph *graph); void MallocReusedDynamicMem(const session::KernelGraph *graph);
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
uint8_t *MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size,
const session::KernelWithIndex &node_with_index = std::pair<AnfNodePtr, size_t>(nullptr, 0));
virtual uint8_t *MallocMem(MemType type, size_t size); virtual uint8_t *MallocMem(MemType type, size_t size);
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册