提交 d4b52ac5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3489 use kernelruntime::mem_manager to reduce rtMalloc and rtFree time in trans data format

Merge pull request !3489 from lvchangquan/master
...@@ -153,6 +153,16 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s ...@@ -153,6 +153,16 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
return true; return true;
} }
DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type);
return address_ptr;
}
size_t GetCommonAlignSize(size_t input_size) { size_t GetCommonAlignSize(size_t input_size) {
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
} }
...@@ -325,18 +335,15 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v ...@@ -325,18 +335,15 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
AddressPtrList kernel_inputs = {input_address}; AddressPtrList kernel_inputs = {input_address};
AddressPtrList kernel_outputs = {output_address}; AddressPtrList kernel_outputs = {output_address};
AddressPtrList kernel_workspaces; AddressPtrList kernel_workspaces;
std::vector<void *> workspaces_address_ptr(workspace_size_list.size(), nullptr);
if (!workspace_size_list.empty()) { if (!workspace_size_list.empty()) {
for (size_t i = 0; i < workspace_size_list.size(); ++i) { for (size_t i = 0; i < workspace_size_list.size(); ++i) {
auto workspace_size = GetCommonAlignSize(workspace_size_list[i]); auto workspace_size = GetCommonAlignSize(workspace_size_list[i]);
auto ret_malloc = rtMalloc(&workspaces_address_ptr[i], workspace_size, RT_MEMORY_HBM); auto workspace_address_ptr = AssignLaunchMemory(workspace_size, "", kTypeUnknown);
if (ret_malloc != RT_ERROR_NONE) { MS_EXCEPTION_IF_NULL(workspace_address_ptr);
MS_LOG(ERROR) << "Failed to rtMalloc memory";
}
auto workspace_address = std::make_shared<kernel::Address>(); auto workspace_address = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace_address); MS_EXCEPTION_IF_NULL(workspace_address);
workspace_address->addr = workspaces_address_ptr[i]; workspace_address->addr = workspace_address_ptr->GetMutablePtr();
workspace_address->size = workspace_size; workspace_address->size = workspace_address_ptr->GetSize();
kernel_workspaces.push_back(workspace_address); kernel_workspaces.push_back(workspace_address);
} }
} }
...@@ -350,15 +357,6 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v ...@@ -350,15 +357,6 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed."; MS_LOG(ERROR) << "Launch kernel failed.";
} }
SyncStream();
if (!workspace_size_list.empty()) {
for (size_t i = 0; i < workspace_size_list.size(); ++i) {
auto ret_free = rtFree(workspaces_address_ptr[i]);
if (ret_free != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtFree memory";
}
}
}
} }
kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const { kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const {
...@@ -418,19 +416,17 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const ...@@ -418,19 +416,17 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
size = device_dtype_size * shape_size; size = device_dtype_size * shape_size;
} }
size = GetCommonAlignSize(size); size = GetCommonAlignSize(size);
void *output_address_ptr = nullptr; auto output_address = AssignLaunchMemory(size, kOpFormat_NCHW, type_id_);
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM); MS_EXCEPTION_IF_NULL(output_address);
if (ret_malloc != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtMalloc memory";
}
auto workspace_size_list = GetWorkspaceSizeList(kernel_json); auto workspace_size_list = GetWorkspaceSizeList(kernel_json);
// launch // launch
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list); LaunchTransData(kernel_mod_ptr, output_address->GetMutablePtr(), output_address->GetSize(), workspace_size_list);
SyncStream();
if (type_id_ == type) { if (type_id_ == type) {
SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST); SyncMemory(host_ptr, output_address->GetPtr(), host_size, RT_MEMCPY_DEVICE_TO_HOST);
} else { } else {
auto host = std::vector<uint8_t>(size); auto host = std::vector<uint8_t>(size);
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST);
auto shape_size = trans::ShapeSize(host_shape); auto shape_size = trans::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size};
sync_ok = trans::TransDataType(type_args, host_ptr); sync_ok = trans::TransDataType(type_args, host_ptr);
...@@ -439,10 +435,6 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const ...@@ -439,10 +435,6 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
return false; return false;
} }
} }
auto ret_free = rtFree(output_address_ptr);
if (ret_free != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtFree memory";
}
return sync_ok; return sync_ok;
} }
......
...@@ -842,9 +842,10 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { ...@@ -842,9 +842,10 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
} }
bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs, bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr,
AddressPtrList kernel_outputs, const AddressPtrList &kernel_inputs,
AddressPtrList kernel_workspaces) const { const AddressPtrList &kernel_outputs,
const AddressPtrList &kernel_workspaces) const {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr); MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) { if (!ret) {
...@@ -854,6 +855,15 @@ bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mo ...@@ -854,6 +855,15 @@ bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mo
return true; return true;
} }
DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type) {
auto device_address = CreateDeviceAddress(nullptr, size, format, type);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_);
auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size);
device_address->set_ptr(base_ptr);
return device_address;
}
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
bool KernelRuntime::SetDumpConf() { bool KernelRuntime::SetDumpConf() {
dump_conf_ptr_ = std::make_shared<Dump>(); dump_conf_ptr_ = std::make_shared<Dump>();
......
...@@ -65,8 +65,9 @@ class KernelRuntime { ...@@ -65,8 +65,9 @@ class KernelRuntime {
virtual bool RunTask(const session::KernelGraph *graph); virtual bool RunTask(const session::KernelGraph *graph);
virtual bool GenTask(const session::KernelGraph *graph); virtual bool GenTask(const session::KernelGraph *graph);
bool LaunchKernel(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs, bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
AddressPtrList kernel_outputs, AddressPtrList kernel_workspaces) const; const AddressPtrList &kernel_outputs,
const AddressPtrList &kernel_workspaces) const;
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id); virtual void ClearGraphRuntimeResource(uint32_t graph_id);
...@@ -79,6 +80,7 @@ class KernelRuntime { ...@@ -79,6 +80,7 @@ class KernelRuntime {
// for GPU and D to impl // for GPU and D to impl
virtual void ReleaseDeviceRes() {} virtual void ReleaseDeviceRes() {}
void set_device_id(uint32_t device_id) { device_id_ = device_id; } void set_device_id(uint32_t device_id) { device_id_ = device_id; }
DeviceAddressPtr AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type);
protected: protected:
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册