提交 b2b1adff 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!191 format memory manager interface

Merge pull request !191 from kisnwang/format-memory-manager-interface
...@@ -262,8 +262,8 @@ AscendDeviceAddress::~AscendDeviceAddress() { ...@@ -262,8 +262,8 @@ AscendDeviceAddress::~AscendDeviceAddress() {
if (ptr_ == nullptr) { if (ptr_ == nullptr) {
return; return;
} }
if (mem_dynamic_alloc_) { if (from_mem_pool_) {
AscendMemoryAllocator::GetInstance().FreeTensorMem(ptr_); AscendMemoryPool::GetInstance().FreeTensorMem(ptr_);
ptr_ = nullptr; ptr_ = nullptr;
} }
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "device/device_address.h" #include "device/device_address.h"
#include "device/ascend/ascend_memory_allocator.h" #include "device/ascend/ascend_memory_pool.h"
#include "ir/dtype.h" #include "ir/dtype.h"
namespace mindspore { namespace mindspore {
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "hccl/hcom.h" #include "hccl/hcom.h"
#include "runtime/context.h" #include "runtime/context.h"
#include "device/ascend/ascend_stream_assign.h" #include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_memory_allocator.h" #include "device/ascend/ascend_memory_pool.h"
#include "framework/ge_runtime/model_runner.h" #include "framework/ge_runtime/model_runner.h"
#include "device/ascend/tasksink/task_generator.h" #include "device/ascend/tasksink/task_generator.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
......
...@@ -15,29 +15,31 @@ ...@@ -15,29 +15,31 @@
*/ */
#include "device/ascend/ascend_memory_manager.h" #include "device/ascend/ascend_memory_manager.h"
#include "device/ascend/ascend_memory_allocator.h" #include "device/ascend/ascend_memory_pool.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "runtime/mem.h" #include "runtime/mem.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
static const uint64_t ASCEND_MEM_SIZE = 20; const uint64_t kAscendDeviceMemGB = 20;
static const uint64_t ASCEND_MEM_SIZE_BYTE = (ASCEND_MEM_SIZE << 30); const uint64_t kAscendMemPoolGB = 5;
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30);
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30);
void AscendMemoryManager::MallocDeviceMemory() { void AscendMemoryManager::MallocDeviceMemory() {
device_mem_size_ = ASCEND_MEM_SIZE_BYTE; device_mem_size_ = kAscendDeviceMemSize;
static_mem_offset_ = FloatToSize(device_mem_size_ * GRAPH_INIT_ASCEND_MEM_RATIO); static_mem_offset_ = device_mem_size_;
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM); auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]";
} }
device_mem_pool_size_ = FloatToSize(device_mem_size_ * (1 - GRAPH_INIT_ASCEND_MEM_RATIO)); device_mem_pool_size_ = kAscendMemPoolSize;
ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
} }
AscendMemoryAllocator::GetInstance().set_device_mem_pool_base(device_mem_pool_base_); AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
AscendMemoryAllocator::GetInstance().set_device_mem_pool_size(device_mem_pool_size_); AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
} }
void AscendMemoryManager::FreeDeviceMemory() { void AscendMemoryManager::FreeDeviceMemory() {
...@@ -57,8 +59,8 @@ void AscendMemoryManager::FreeDeviceMemory() { ...@@ -57,8 +59,8 @@ void AscendMemoryManager::FreeDeviceMemory() {
} }
} }
void *AscendMemoryManager::AllocTensorMemDynamic(size_t size) { void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
return AscendMemoryAllocator::GetInstance().AllocTensorMem(size); return AscendMemoryPool::GetInstance().AllocTensorMem(size);
} }
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -27,7 +27,11 @@ class AscendMemoryManager : public MemoryManager { ...@@ -27,7 +27,11 @@ class AscendMemoryManager : public MemoryManager {
void MallocDeviceMemory() override; void MallocDeviceMemory() override;
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
void *AllocTensorMemDynamic(size_t size) override; void *MallocMemFromMemPool(size_t size) override;
private:
uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0};
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -14,24 +14,15 @@ ...@@ -14,24 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
#include "device/ascend/ascend_memory_allocator.h" #include "device/ascend/ascend_memory_pool.h"
#include "device/ascend/ascend_kernel_runtime.h" #include "device/ascend/ascend_kernel_runtime.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
const uint64_t MEM_SIZE = 20; size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
const uint64_t MEM_SIZE_BYTE = (MEM_SIZE << 30); if (has_malloc_) {
AscendMemoryAllocator::AscendMemoryAllocator() {
hasMalloc_ = false;
free_mem_size_ = FloatToSize(MEM_SIZE_BYTE * (1 - GRAPH_INIT_ASCEND_MEM_RATIO));
total_mem_size_ = free_mem_size_;
}
size_t AscendMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
if (hasMalloc_) {
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
} }
if (size == 0 || size > free_mem_size_) { if (size == 0 || size > free_mem_size_) {
...@@ -41,35 +32,35 @@ size_t AscendMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { ...@@ -41,35 +32,35 @@ size_t AscendMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
if (*addr == nullptr) { if (*addr == nullptr) {
MS_LOG(EXCEPTION) << "Device memory pool base is nullptr, failed to alloc memory pool memory!"; MS_LOG(EXCEPTION) << "Device memory pool base is nullptr, failed to alloc memory pool memory!";
} }
hasMalloc_ = true; has_malloc_ = true;
free_mem_size_ -= size; free_mem_size_ -= size;
return size; return size;
} }
bool AscendMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) {
MS_EXCEPTION_IF_NULL(addr); MS_EXCEPTION_IF_NULL(addr);
hasMalloc_ = false; has_malloc_ = false;
free_mem_size_ = total_mem_size_; free_mem_size_ = total_mem_size_;
return true; return true;
} }
size_t AscendMemoryAllocator::AlignMemorySize(size_t size) const { size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
if (size == 0) { if (size == 0) {
return DYNAMIC_MEM_ALIGN_SIZE; return DYNAMIC_MEM_ALIGN_SIZE;
} }
return ((size + DYNAMIC_MEM_ALIGN_SIZE + 31) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; return ((size + DYNAMIC_MEM_ALIGN_SIZE + 31) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
} }
size_t AscendMemoryAllocator::mem_alloc_unit_size() const { return free_mem_size_ - 512; } size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
void AscendMemoryAllocator::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base); MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base; device_mem_pool_base_ = device_mem_pool_base;
} }
size_t AscendMemoryAllocator::free_mem_size() { return free_mem_size_; } size_t AscendMemoryPool::free_mem_size() { return free_mem_size_; }
size_t AscendMemoryAllocator::total_mem_size() { return total_mem_size_; } size_t AscendMemoryPool::total_mem_size() { return total_mem_size_; }
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_ALLOCATOR_H_ #ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_
#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_ALLOCATOR_H_ #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_
#include <memory> #include <memory>
#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" #include "pre_activate/mem_reuse/mem_dynamic_allocator.h"
...@@ -23,22 +23,23 @@ ...@@ -23,22 +23,23 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
// The fraction of total ascend memory used to compute the graph. class AscendMemoryPool : public DynamicMemPoolBestFit {
static const float GRAPH_INIT_ASCEND_MEM_RATIO = 0.8;
class AscendMemoryAllocator : public DynamicMemPoolBestFit {
public: public:
~AscendMemoryAllocator() override = default; ~AscendMemoryPool() override = default;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override; bool FreeDeviceMem(const DeviceMemPtr& addr) override;
void set_device_mem_pool_base(uint8_t* device_mem_pool_base); void set_device_mem_pool_base(uint8_t* device_mem_pool_base);
void set_device_mem_pool_size(uint64_t device_mem_pool_size) { device_mem_pool_size_ = device_mem_pool_size; } void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
device_mem_pool_size_ = device_mem_pool_size;
free_mem_size_ = device_mem_pool_size_;
total_mem_size_ = free_mem_size_;
}
size_t free_mem_size() override; size_t free_mem_size() override;
size_t total_mem_size() override; size_t total_mem_size() override;
static AscendMemoryAllocator& GetInstance() { static AscendMemoryPool& GetInstance() {
static AscendMemoryAllocator instance; static AscendMemoryPool instance;
return instance; return instance;
} }
...@@ -49,10 +50,10 @@ class AscendMemoryAllocator : public DynamicMemPoolBestFit { ...@@ -49,10 +50,10 @@ class AscendMemoryAllocator : public DynamicMemPoolBestFit {
size_t mem_alloc_unit_size() const override; size_t mem_alloc_unit_size() const override;
private: private:
AscendMemoryAllocator(); AscendMemoryPool() = default;
AscendMemoryAllocator(const AscendMemoryAllocator&) = delete; AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryAllocator& operator=(const AscendMemoryAllocator&) = delete; AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
bool hasMalloc_; bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr}; uint8_t* device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_size_{0};
size_t free_mem_size_; size_t free_mem_size_;
...@@ -62,4 +63,4 @@ class AscendMemoryAllocator : public DynamicMemPoolBestFit { ...@@ -62,4 +63,4 @@ class AscendMemoryAllocator : public DynamicMemPoolBestFit {
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_ALLOCATOR_H_ #endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_
...@@ -70,7 +70,7 @@ class DeviceAddress { ...@@ -70,7 +70,7 @@ class DeviceAddress {
size_t ref_count_{0}; size_t ref_count_{0};
string format_{"DefaultFormat"}; string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16}; TypeId type_id_{kNumberTypeFloat16};
bool mem_dynamic_alloc_{false}; bool from_mem_pool_{false};
friend class KernelRuntime; friend class KernelRuntime;
friend class MemoryManager; friend class MemoryManager;
friend class mindspore::device::ascend::tasksink::TaskGenerator; friend class mindspore::device::ascend::tasksink::TaskGenerator;
......
...@@ -46,7 +46,7 @@ GPUDeviceAddress::~GPUDeviceAddress() { ...@@ -46,7 +46,7 @@ GPUDeviceAddress::~GPUDeviceAddress() {
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (mem_dynamic_alloc_) { if (from_mem_pool_) {
GPUMemoryAllocator::GetInstance().FreeTensorMem(ptr_); GPUMemoryAllocator::GetInstance().FreeTensorMem(ptr_);
ptr_ = nullptr; ptr_ = nullptr;
} }
......
...@@ -227,7 +227,7 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod ...@@ -227,7 +227,7 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
auto device_ptr = device_address->ptr_; auto device_ptr = device_address->ptr_;
if (device_ptr == nullptr) { if (device_ptr == nullptr) {
device_ptr = mem_manager_->AllocTensorMemDynamic(output_sizes[i]); device_ptr = mem_manager_->MallocMemFromMemPool(output_sizes[i]);
MS_EXCEPTION_IF_NULL(device_ptr); MS_EXCEPTION_IF_NULL(device_ptr);
device_address->ptr_ = device_ptr; device_address->ptr_ = device_ptr;
} }
...@@ -244,7 +244,7 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod ...@@ -244,7 +244,7 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod
kernel_workspaces->emplace_back(nullptr); kernel_workspaces->emplace_back(nullptr);
continue; continue;
} }
auto device_ptr = mem_manager_->AllocTensorMemDynamic(workspace_sizes[i]); auto device_ptr = mem_manager_->MallocMemFromMemPool(workspace_sizes[i]);
MS_EXCEPTION_IF_NULL(device_ptr); MS_EXCEPTION_IF_NULL(device_ptr);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace); MS_EXCEPTION_IF_NULL(workspace);
...@@ -292,7 +292,7 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN ...@@ -292,7 +292,7 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN
addr_size.emplace_back(device_address.get(), output_size); addr_size.emplace_back(device_address.get(), output_size);
} }
auto device_mem_ptr = mem_manager_->AllocTensorMemDynamic(total); auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total);
MS_EXCEPTION_IF_NULL(device_mem_ptr); MS_EXCEPTION_IF_NULL(device_mem_ptr);
for (const auto &iter : addr_size) { for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first); MS_EXCEPTION_IF_NULL(iter.first);
...@@ -328,7 +328,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf ...@@ -328,7 +328,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf
addr_size.emplace_back(device_address.get(), output_sizes[i]); addr_size.emplace_back(device_address.get(), output_sizes[i]);
} }
auto device_mem_ptr = mem_manager_->AllocTensorMemDynamic(total); auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total);
MS_EXCEPTION_IF_NULL(device_mem_ptr); MS_EXCEPTION_IF_NULL(device_mem_ptr);
for (const auto &iter : addr_size) { for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first); MS_EXCEPTION_IF_NULL(iter.first);
...@@ -361,7 +361,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -361,7 +361,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_); MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeTensorMemDynamic(device_address->ptr_); mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr; device_address->ptr_ = nullptr;
} }
} }
...@@ -372,7 +372,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -372,7 +372,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
auto workspace = kernel_workspaces[i]; auto workspace = kernel_workspaces[i];
if (workspace != nullptr) { if (workspace != nullptr) {
MS_EXCEPTION_IF_NULL(workspace->addr); MS_EXCEPTION_IF_NULL(workspace->addr);
mem_manager_->FreeTensorMemDynamic(workspace->addr); mem_manager_->FreeMemFromMemPool(workspace->addr);
workspace->addr = nullptr; workspace->addr = nullptr;
} }
} }
...@@ -389,7 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr ...@@ -389,7 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_); MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeTensorMemDynamic(device_address->ptr_); mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr; device_address->ptr_ = nullptr;
} }
*is_communication_op = true; *is_communication_op = true;
...@@ -411,7 +411,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr ...@@ -411,7 +411,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_); MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeTensorMemDynamic(device_address->ptr_); mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr; device_address->ptr_ = nullptr;
} }
*is_communication_op = true; *is_communication_op = true;
......
...@@ -21,11 +21,11 @@ ...@@ -21,11 +21,11 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
void *GPUMemoryManager::AllocTensorMemDynamic(size_t size) { void *GPUMemoryManager::MallocMemFromMemPool(size_t size) {
return GPUMemoryAllocator::GetInstance().AllocTensorMem(size); return GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
} }
void GPUMemoryManager::FreeTensorMemDynamic(void *device_ptr) { void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) {
GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr);
} }
...@@ -34,7 +34,7 @@ void GPUMemoryManager::MallocDeviceMemory() { ...@@ -34,7 +34,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
// If use the dynamic memory pool, then alloc the first memory block to init. // If use the dynamic memory pool, then alloc the first memory block to init.
if (context_ptr->enable_dynamic_mem_pool()) { if (context_ptr->enable_dynamic_mem_pool()) {
auto device_addr = AllocTensorMemDynamic(1); auto device_addr = MallocMemFromMemPool(1);
if (!device_addr) { if (!device_addr) {
MS_LOG(ERROR) << "Dynamic memory pool init error."; MS_LOG(ERROR) << "Dynamic memory pool init error.";
} }
...@@ -62,7 +62,7 @@ uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { ...@@ -62,7 +62,7 @@ uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_dynamic_mem_pool()) { if (context_ptr->enable_dynamic_mem_pool()) {
auto device_ptr = AllocTensorMemDynamic(size); auto device_ptr = MallocMemFromMemPool(size);
MS_EXCEPTION_IF_NULL(device_ptr); MS_EXCEPTION_IF_NULL(device_ptr);
return AddressOffset(device_ptr, 0); return AddressOffset(device_ptr, 0);
} }
......
...@@ -28,11 +28,11 @@ class GPUMemoryManager : public MemoryManager { ...@@ -28,11 +28,11 @@ class GPUMemoryManager : public MemoryManager {
void MallocDeviceMemory() override; void MallocDeviceMemory() override;
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
void *AllocTensorMemDynamic(size_t size) override; void *MallocMemFromMemPool(size_t size) override;
void FreeTensorMemDynamic(void *device_ptr) override; void FreeMemFromMemPool(void *device_ptr) override;
protected: protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem); uint8_t *MallocStaticMem(size_t size, bool communication_mem) override;
}; };
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
......
...@@ -169,7 +169,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> ...@@ -169,7 +169,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
auto device_address = auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
mem_manager_->MallocOpMemory(device_address, tensor_size); mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
AnfAlgo::SetOutputAddr(device_address, index, item.get()); AnfAlgo::SetOutputAddr(device_address, index, item.get());
} }
} }
...@@ -198,7 +198,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { ...@@ -198,7 +198,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
mem_manager_->MallocOpMemory(device_address, output_sizes[i]); mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
} }
} }
...@@ -213,7 +213,7 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { ...@@ -213,7 +213,7 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
for (size_t i = 0; i < workspace_lists.size(); ++i) { for (size_t i = 0; i < workspace_lists.size(); ++i) {
auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown); auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
mem_manager_->MallocOpMemory(device_address, workspace_lists[i]); mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
} }
} }
...@@ -457,7 +457,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { ...@@ -457,7 +457,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
auto mem_flag = kDynamicMem; auto mem_flag = kDynamicMem;
if (is_enable_mem_reuse) { if (is_enable_mem_reuse) {
mem_manager_->InitReuseDynamicMemory(graph); mem_manager_->MallocReusedDynamicMem(graph);
mem_flag = kReuseDynamicMem; mem_flag = kReuseDynamicMem;
} }
auto &kernels = graph->execution_order(); auto &kernels = graph->execution_order();
......
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "device/memory_manager.h" #include "device/memory_manager.h"
// using mindspore::session::KernelGraph;
using mindspore::tensor::Tensor; using mindspore::tensor::Tensor;
using TensorPtr = std::shared_ptr<Tensor>; using TensorPtr = std::shared_ptr<Tensor>;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
......
...@@ -21,12 +21,6 @@ using mindspore::memreuse::BestFitMemReuse; ...@@ -21,12 +21,6 @@ using mindspore::memreuse::BestFitMemReuse;
using mindspore::memreuse::MemReuseUtilPtr; using mindspore::memreuse::MemReuseUtilPtr;
namespace mindspore { namespace mindspore {
namespace device { namespace device {
MemoryManager::~MemoryManager() {
device_mem_base_ = nullptr;
device_mem_pool_base_ = nullptr;
mem_reuse_util_ptr_ = nullptr;
}
size_t MemoryManager::GetCommonAlignSize(size_t input_size) const { size_t MemoryManager::GetCommonAlignSize(size_t input_size) const {
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
} }
...@@ -35,7 +29,7 @@ size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { ...@@ -35,7 +29,7 @@ size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const {
return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize;
} }
void MemoryManager::InitReuseDynamicMemory(session::KernelGraph *graph) { void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>(); MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>();
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
...@@ -147,23 +141,23 @@ uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { ...@@ -147,23 +141,23 @@ uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
} }
} }
void MemoryManager::MallocOpMemory(const DeviceAddressPtr address, size_t size) { void MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) {
auto device_ptr = AllocTensorMemDynamic(size); auto device_ptr = MallocMemFromMemPool(size);
MS_EXCEPTION_IF_NULL(device_ptr); MS_EXCEPTION_IF_NULL(device_ptr);
address->ptr_ = device_ptr; address->ptr_ = device_ptr;
address->mem_dynamic_alloc_ = true; address->from_mem_pool_ = true;
} }
void *MemoryManager::AllocTensorMemDynamic(size_t size) { void *MemoryManager::MallocMemFromMemPool(size_t size) {
if (size == 0) { if (size == 0) {
MS_LOG(ERROR) << "AllocTensorMemDynamic size is 0."; MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
} }
return nullptr; return nullptr;
} }
void MemoryManager::FreeTensorMemDynamic(void *device_ptr) { void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
if (device_ptr == nullptr) { if (device_ptr == nullptr) {
MS_LOG(ERROR) << "FreeTensorMemDynamic device_ptr is null."; MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
} }
} }
} // namespace device } // namespace device
......
...@@ -31,7 +31,7 @@ using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; ...@@ -31,7 +31,7 @@ using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr;
class MemoryManager { class MemoryManager {
public: public:
MemoryManager() = default; MemoryManager() = default;
virtual ~MemoryManager(); virtual ~MemoryManager() = default;
virtual void MallocDeviceMemory() = 0; virtual void MallocDeviceMemory() = 0;
virtual void FreeDeviceMemory() = 0; virtual void FreeDeviceMemory() = 0;
...@@ -40,16 +40,15 @@ class MemoryManager { ...@@ -40,16 +40,15 @@ class MemoryManager {
dynamic_mem_offset_ = 0; dynamic_mem_offset_ = 0;
} }
void InitReuseDynamicMemory(session::KernelGraph *graph); void MallocReusedDynamicMem(session::KernelGraph *graph);
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size);
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size);
virtual uint8_t *MallocMem(int flag, size_t size); virtual uint8_t *MallocMem(int flag, size_t size);
// Alloc memory use the dynamic memory pool. virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
virtual void *AllocTensorMemDynamic(size_t size); virtual void *MallocMemFromMemPool(size_t size);
// Free memory use the dynamic memory pool. virtual void FreeMemFromMemPool(void *device_ptr);
virtual void FreeTensorMemDynamic(void *device_ptr);
virtual void MallocOpMemory(const DeviceAddressPtr address, size_t size);
size_t GetCommonAlignSize(size_t input_size) const; size_t GetCommonAlignSize(size_t input_size) const;
size_t GetCommunicationAlignSize(size_t input_size) const; size_t GetCommunicationAlignSize(size_t input_size) const;
...@@ -57,9 +56,7 @@ class MemoryManager { ...@@ -57,9 +56,7 @@ class MemoryManager {
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem); virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem);
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
uint8_t *device_mem_base_{nullptr}; uint8_t *device_mem_base_{nullptr};
uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_size_{0}; uint64_t device_mem_size_{0};
uint64_t device_mem_pool_size_{0};
uint64_t dynamic_mem_offset_{0}; uint64_t dynamic_mem_offset_{0};
uint64_t static_mem_offset_{0}; uint64_t static_mem_offset_{0};
size_t total_static_size_ = 0; size_t total_static_size_ = 0;
......
...@@ -95,7 +95,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ...@@ -95,7 +95,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc" "../../../mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc"
"../../../mindspore/ccsrc/device/ascend/ascend_memory_manager.cc" "../../../mindspore/ccsrc/device/ascend/ascend_memory_manager.cc"
"../../../mindspore/ccsrc/device/ascend/ascend_device_address.cc" "../../../mindspore/ccsrc/device/ascend/ascend_device_address.cc"
"../../../mindspore/ccsrc/device/ascend/ascend_memory_allocator.cc" "../../../mindspore/ccsrc/device/ascend/ascend_memory_pool.cc"
"../../../mindspore/ccsrc/predict/generator/utils/ir_model_util.cc" "../../../mindspore/ccsrc/predict/generator/utils/ir_model_util.cc"
"../../../mindspore/ccsrc/predict/predict.cc" "../../../mindspore/ccsrc/predict/predict.cc"
"../../../mindspore/ccsrc/predict/converter/*.cc" "../../../mindspore/ccsrc/predict/converter/*.cc"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册