提交 e4b404e8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!32 auto-enable-dynamic-mem-pool

Merge pull request !32 from JoyLvliang/master
......@@ -239,22 +239,11 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
}
void AscendKernelRuntime::MallocOpMemory(const DeviceAddressPtr address, size_t size, int flag) {
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->enable_dynamic_mem_pool()) {
auto device_ptr = AscendMemoryAllocator::GetInstance().AllocTensorMem(size);
MS_EXCEPTION_IF_NULL(device_ptr);
address->ptr_ = device_ptr;
address->mem_dynamic_alloc_ = true;
return;
}
if (flag == kStaticMem) {
address->ptr_ = MallocStaticMem(size, false);
} else if (flag == kDynamicMem) {
address->ptr_ = MallocDynamicMem(size, false);
} else {
MS_LOG(EXCEPTION) << "Unknown memory type!";
}
void AscendKernelRuntime::MallocOpMemory(const DeviceAddressPtr address, size_t size, int) {
auto device_ptr = AscendMemoryAllocator::GetInstance().AllocTensorMem(size);
MS_EXCEPTION_IF_NULL(device_ptr);
address->ptr_ = device_ptr;
address->mem_dynamic_alloc_ = true;
}
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
......@@ -488,23 +477,18 @@ bool AscendKernelRuntime::DestroyHccl() {
bool AscendKernelRuntime::MallocDeviceMemory() {
device_mem_size_ = ASCEND_MEM_SIZE_BYTE;
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->enable_dynamic_mem_pool()) {
static_mem_offset_ = FloatToSize(device_mem_size_ * GRAPH_INIT_DAVINCI_MEM_RATIO);
device_mem_pool_size_ = FloatToSize(device_mem_size_ * (1 - GRAPH_INIT_DAVINCI_MEM_RATIO));
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
}
AscendMemoryAllocator::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
AscendMemoryAllocator::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
} else {
static_mem_offset_ = device_mem_size_;
static_mem_offset_ = FloatToSize(device_mem_size_ * GRAPH_INIT_ASCEND_MEM_RATIO);
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]";
}
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM);
device_mem_pool_size_ = FloatToSize(device_mem_size_ * (1 - GRAPH_INIT_ASCEND_MEM_RATIO));
ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
}
AscendMemoryAllocator::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
AscendMemoryAllocator::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
return true;
}
......
......@@ -26,7 +26,7 @@ const uint64_t MEM_SIZE_BYTE = (MEM_SIZE << 30);
AscendMemoryAllocator::AscendMemoryAllocator() {
hasMalloc_ = false;
free_mem_size_ = FloatToSize(MEM_SIZE_BYTE * (1 - GRAPH_INIT_DAVINCI_MEM_RATIO));
free_mem_size_ = FloatToSize(MEM_SIZE_BYTE * (1 - GRAPH_INIT_ASCEND_MEM_RATIO));
total_mem_size_ = free_mem_size_;
}
......
......@@ -24,7 +24,7 @@ namespace mindspore {
namespace device {
namespace ascend {
// The fraction of total ascend memory used to compute the graph.
static const float GRAPH_INIT_DAVINCI_MEM_RATIO = 0.8;
static const float GRAPH_INIT_ASCEND_MEM_RATIO = 0.8;
class AscendMemoryAllocator : public DynamicMemPoolBestFit {
public:
......
......@@ -497,7 +497,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
need_sync = true;
}
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册