提交 c67e5623 编写于 作者: L lizhenyu

refine GPU memory swap performance

上级 28c8a5cc
......@@ -46,7 +46,7 @@ struct KernelExecutionInfo {
size_t swap_in_task_num_{0};
// Key: output index, value: topo orders of node users
std::map<size_t, std::vector<size_t>> node_users_map_;
// Key: output idx, value: (host addr, dirty or not)
// Key: output index, value: pair (host addr, dirty or not)
std::map<size_t, std::pair<HostAddress, bool>> host_addrs_;
KernelExecutionInfo() {}
......@@ -105,7 +105,12 @@ class MemCopyManager {
virtual void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {}
virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {}
virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr, bool profiling,
float *cost_time) {}
virtual void AddMemSwapOutTaskMock(const DeviceAddressPtr &device_address) {}
virtual void AddMemSwapInTaskMock(const DeviceAddressPtr &device_address) {}
virtual bool SyncMemCopyStream(SwapKind swap_kind) { return true; }
......@@ -113,11 +118,17 @@ class MemCopyManager {
virtual DeviceAddressPtr UpdateSwapInQueue() { return nullptr; }
virtual DeviceAddressPtr UpdateSwapOutQueueMock() { return nullptr; }
virtual DeviceAddressPtr UpdateSwapInQueueMock() { return nullptr; }
virtual bool AllocHostPinnedMem(size_t size, void **addr) const { return true; }
virtual void FreeHostPinnedMem(void *addr) const {}
virtual void ClearSwapQueue() {}
virtual void ClearSwapQueueMock() {}
};
using MemCopyManagerPtr = std::shared_ptr<MemCopyManager>;
using MemSwapInfoSet = std::set<MemSwapInfo, SwapInfoComp>;
......
......@@ -147,6 +147,30 @@ bool MemSwapManager::CheckDistanceBetweenKernels(const TensorInfo &tensor_info)
return false;
}
std::vector<std::pair<size_t, size_t>> MemSwapManager::CheckDistanceBetweenKernelsWithIdx(
const TensorInfo &tensor_info) const {
const AnfNodePtr &kernel = tensor_info.kernel_;
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
auto &node_users_map = kernel_exec_info.node_users_map_;
std::vector<std::pair<size_t, size_t>> need_swap_topo_pair_list;
auto iter = node_users_map.find(tensor_info.output_idx_);
if (iter == node_users_map.end()) {
return need_swap_topo_pair_list;
}
auto &node_users = iter->second;
if (node_users.front() - kernel_exec_info.topo_order_ > distance_threshold_) {
need_swap_topo_pair_list.emplace_back(kernel_exec_info.topo_order_, node_users.front());
}
for (size_t i = 1; i < node_users.size(); ++i) {
if (node_users[i] - node_users[i - 1] > distance_threshold_) {
need_swap_topo_pair_list.emplace_back(node_users[i - 1], node_users[i]);
}
}
return need_swap_topo_pair_list;
}
bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::IsCommunicationOp(kernel)) {
......@@ -201,56 +225,55 @@ void MemSwapManager::AddSwapInfo() {
break;
}
size_t output_idx = tensor.output_idx_;
const AnfNodePtr &kernel = tensor.kernel_;
if (IsCommunicationRelevantOp(kernel)) {
continue;
}
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
auto &node_users_map = kernel_exec_info.node_users_map_;
auto iter = node_users_map.find(output_idx);
if (iter == node_users_map.end()) {
continue;
}
auto &node_users = iter->second;
bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) ||
(node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_);
if (!need_swap) {
auto need_swap_topo_pair_list = CheckDistanceBetweenKernelsWithIdx(tensor);
if (need_swap_topo_pair_list.empty()) {
continue;
}
HostAddress host_addr;
host_addr.size = tensor_size;
auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast<void **>(&host_addr.addr));
if (!ret) {
MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed.";
}
host_addr.addr = nullptr;
size_t output_idx = tensor.output_idx_;
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
kernel_exec_info.host_addrs_[output_idx] = std::make_pair(host_addr, true);
MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel_exec_info.topo_order_, output_idx, 0};
if (node_users.size() > 1) {
AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info);
} else {
AddKernelMemSwapInfo(kernel, mem_swap_out_info);
}
size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1;
if (swap_in_order <= kernel_exec_info.topo_order_) {
MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
for (auto &swap_topo_pair : need_swap_topo_pair_list) {
size_t swap_out_order = swap_topo_pair.first;
MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel_exec_info.topo_order_, output_idx,
swap_out_order};
AddKernelMemSwapInfo(execution_order_[swap_out_order], mem_swap_out_info);
size_t swap_in_order = swap_topo_pair.second - 1;
MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel_exec_info.topo_order_, output_idx,
swap_out_order};
if (swap_in_order <= swap_out_order) {
MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
AddKernelMemSwapInfo(execution_order_[swap_in_order], mem_swap_in_info);
}
auto swap_in_kernel = execution_order_[swap_in_order];
MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel_exec_info.topo_order_, output_idx, 0};
AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info);
host_addrs_list_.push_back(host_addr);
}
}
void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address,
const HostAddress &host_address) const {
const HostAddress &host_address, bool mock, bool profiling,
float *cost_time) const {
if (!mock) {
if (swap_kind == SwapKind::kDeviceToHost) {
mem_copy_manager_->AddMemSwapOutTask(device_address, host_address);
} else if (swap_kind == SwapKind::kHostToDevice) {
mem_copy_manager_->AddMemSwapInTask(device_address, host_address, profiling, cost_time);
}
}
if (swap_kind == SwapKind::kDeviceToHost) {
mem_copy_manager_->AddMemSwapOutTask(device_address, host_address);
mem_copy_manager_->AddMemSwapOutTaskMock(device_address);
} else if (swap_kind == SwapKind::kHostToDevice) {
mem_copy_manager_->AddMemSwapInTask(device_address, host_address);
mem_copy_manager_->AddMemSwapInTaskMock(device_address);
}
}
......@@ -258,11 +281,19 @@ bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) const {
return mem_copy_manager_->SyncMemCopyStream(swap_kind);
}
DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const {
DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind, bool mock) const {
if (!mock) {
if (swap_kind == SwapKind::kDeviceToHost) {
return mem_copy_manager_->UpdateSwapOutQueue();
} else {
return mem_copy_manager_->UpdateSwapInQueue();
}
}
if (swap_kind == SwapKind::kDeviceToHost) {
return mem_copy_manager_->UpdateSwapOutQueue();
return mem_copy_manager_->UpdateSwapOutQueueMock();
} else {
return mem_copy_manager_->UpdateSwapInQueue();
return mem_copy_manager_->UpdateSwapInQueueMock();
}
}
......@@ -273,19 +304,7 @@ bool MemSwapManager::RetreatSwapInfo() {
}
if (swap_info_already_set_) {
ResetSwapInfo();
if (distance_threshold_ >= kDistanceLowerBound) {
auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_;
distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1);
}
while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) {
++tensor_size_threshold_idx_;
if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_;
break;
}
}
RetreatSwapThreshold();
if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) {
MS_LOG(ERROR) << "Retreat swap info failed";
return false;
......@@ -373,7 +392,7 @@ bool MemSwapManager::QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t inde
}
size_t MemSwapManager::BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const {
auto need_swap_kernel = QueryKerneByTopoOrder(mem_swap_info.topo_order_);
auto need_swap_kernel = QueryKernelByTopoOrder(mem_swap_info.topo_order_);
const PerformPair &perform_pair = QueryKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_);
float swap_in_cost_time = perform_pair.second;
size_t swap_out_pos = mem_swap_info.swap_out_pos_;
......@@ -383,11 +402,11 @@ size_t MemSwapManager::BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, co
size_t pos = trigger_kernel_pos;
for (; pos > swap_out_pos + 1; pos--) {
auto kernel = QueryKerneByTopoOrder(pos - 1);
auto kernel = QueryKernelByTopoOrder(pos - 1);
if (QueryKernelTriggerSwapIn(kernel)) {
return pos;
}
kernel_execution_time += QueryKernelExecutionPerform(QueryKerneByTopoOrder(pos));
kernel_execution_time += QueryKernelExecutionPerform(QueryKernelByTopoOrder(pos));
if (kernel_execution_time >= swap_in_cost_time) {
return pos - 1;
}
......@@ -399,8 +418,8 @@ void MemSwapManager::MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSw
if (des_pos == src_pos) {
MS_LOG(EXCEPTION) << "destination pos can not equal source pos";
}
auto des_kernel = QueryKerneByTopoOrder(des_pos);
auto src_kernel = QueryKerneByTopoOrder(src_pos);
auto des_kernel = QueryKernelByTopoOrder(des_pos);
auto src_kernel = QueryKernelByTopoOrder(src_pos);
AddKernelMemSwapInfo(des_kernel, mem_swap_info);
RemoveKernelMemSwapInfo(src_kernel, mem_swap_info);
}
......@@ -422,7 +441,10 @@ void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float p
void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx,
const std::pair<float, float> &perform) {
MS_EXCEPTION_IF_NULL(kernel);
kernel_swap_perform_[kernel.get()][output_idx] = perform;
auto iter = kernel_swap_perform_.find(kernel.get());
if (iter == kernel_swap_perform_.end()) {
kernel_swap_perform_[kernel.get()][output_idx] = perform;
}
}
void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) {
......@@ -485,13 +507,18 @@ size_t MemSwapManager::QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel)
return kernel_exec_info.swap_in_task_num_;
}
const AnfNodePtr MemSwapManager::QueryKerneByTopoOrder(size_t index) const {
const AnfNodePtr MemSwapManager::QueryKernelByTopoOrder(size_t index) const {
if (index >= execution_order_.size()) {
MS_LOG(EXCEPTION) << "Index [" << index << "] out of range";
}
return execution_order_[index];
}
size_t MemSwapManager::QueryKernelTopoOrder(const AnfNodePtr &kernel) const {
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
return kernel_exec_info.topo_order_;
}
const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const {
MS_EXCEPTION_IF_NULL(kernel);
auto iter_kernel = kernel_swap_perform_.find(kernel.get());
......@@ -572,13 +599,6 @@ void MemSwapManager::ResetHostAddrIsDirty() {
}
}
void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); }
bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const {
auto iter = swap_in_blacklist_.find(device_ptr);
return iter != swap_in_blacklist_.end();
}
bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const {
return mem_copy_manager_->AllocHostPinnedMem(size, addr);
}
......@@ -592,10 +612,16 @@ void MemSwapManager::ReleaseHostPinnedMem() {
host_addrs_list_.clear();
}
void MemSwapManager::ClearSwapQueue() const { mem_copy_manager_->ClearSwapQueue(); }
void MemSwapManager::ClearSwapQueue(bool mock) const {
if (!mock) {
mem_copy_manager_->ClearSwapQueue();
} else {
mem_copy_manager_->ClearSwapQueueMock();
}
}
void MemSwapManager::ResetSwapInfo() {
ClearSwapQueue();
ClearSwapQueue(true);
for (auto &kernel_exec_info_pair : kernel_execution_info_) {
auto &kernel_exec_info = kernel_exec_info_pair.second;
kernel_exec_info.trigger_swap_out_ = false;
......@@ -603,10 +629,53 @@ void MemSwapManager::ResetSwapInfo() {
kernel_exec_info.swap_in_task_num_ = 0;
kernel_exec_info.host_addrs_.clear();
}
ReleaseHostPinnedMem();
swap_in_blacklist_.clear();
mem_swap_info_map_.clear();
}
void MemSwapManager::DumpSwapInfo() const {
for (auto &kernel : execution_order_) {
if (!QueryKernelTriggerSwap(kernel)) {
continue;
}
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
MS_LOG(WARNING) << "Trigger kernel topo order[" << kernel_exec_info.topo_order_ << "] , op name["
<< AnfAlgo::GetCNodeName(kernel) << "]";
const MemSwapInfoSet &mem_swap_info_set = QueryKernelMemSwapInfo(kernel);
for (auto &mem_swap_info : mem_swap_info_set) {
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
MS_LOG(WARNING) << " Swap Out Task: swapped kernel topo order[" << mem_swap_info.topo_order_ << "], op name["
<< AnfAlgo::GetCNodeName(QueryKernelByTopoOrder(mem_swap_info.topo_order_)) << "], output idx["
<< mem_swap_info.output_idx_ << "]";
} else {
MS_LOG(WARNING) << " Swap In Task: swapped kernel topo order[" << mem_swap_info.topo_order_ << "], op name["
<< AnfAlgo::GetCNodeName(QueryKernelByTopoOrder(mem_swap_info.topo_order_)) << "], output idx["
<< mem_swap_info.output_idx_ << "]";
}
}
}
}
void MemSwapManager::DumpUserNodes() const {
for (auto &kernel : execution_order_) {
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
const auto &node_users_map = kernel_exec_info.node_users_map_;
MS_LOG(WARNING) << "Kernel topo order[" << kernel_exec_info.topo_order_ << "], op name["
<< AnfAlgo::GetCNodeName(kernel) << "]";
if (node_users_map.empty()) {
MS_LOG(WARNING) << " Kernel does not own any user node";
}
for (auto &item : node_users_map) {
size_t output_idx = item.first;
auto &node_users = item.second;
for (auto &order : node_users) {
MS_LOG(WARNING) << " Output index[" << output_idx << "] tensor is used by kernel["
<< AnfAlgo::GetCNodeName(QueryKernelByTopoOrder(order)) << "], topo order[" << order << "]";
}
}
}
}
} // namespace memswap
} // namespace device
} // namespace mindspore
......@@ -48,12 +48,12 @@ class MemSwapManager {
bool Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size = 0);
void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address,
const HostAddress &host_address) const;
void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, const HostAddress &host_address,
bool mock, bool profiling = false, float *cost_time = nullptr) const;
bool SyncMemCopyStream(SwapKind swap_kind) const;
DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const;
DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind, bool mock) const;
bool RetreatSwapInfo();
......@@ -63,8 +63,6 @@ class MemSwapManager {
bool mem_swap_init() const { return mem_swap_initialized_; }
KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const;
void AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform);
float QueryKernelExecutionPerform(const AnfNodePtr &kernel) const;
......@@ -79,7 +77,9 @@ class MemSwapManager {
size_t QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const;
const AnfNodePtr QueryKerneByTopoOrder(size_t index) const;
const AnfNodePtr QueryKernelByTopoOrder(size_t index) const;
size_t QueryKernelTopoOrder(const AnfNodePtr &kernel) const;
const MemSwapInfoSet &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const;
......@@ -93,17 +93,19 @@ class MemSwapManager {
void ResetHostAddrIsDirty();
void InsertSwapInBlackList(const void *device_ptr);
bool FindInSwapInBlackList(const void *device_ptr) const;
bool AllocHostPinnedMem(size_t size, void **addr) const;
void ReleaseHostPinnedMem();
void ClearSwapQueue() const;
void ClearSwapQueue(bool mock) const;
void DumpSwapInfo() const;
void DumpUserNodes() const;
private:
KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const;
void AddSwapInfo();
void ResetSwapInfo();
......@@ -130,6 +132,8 @@ class MemSwapManager {
bool CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const;
std::vector<std::pair<size_t, size_t>> CheckDistanceBetweenKernelsWithIdx(const TensorInfo &tensor_info) const;
bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;
std::vector<CNodePtr> execution_order_;
......@@ -139,7 +143,6 @@ class MemSwapManager {
// Key: trigger swap kernel, value: MemSwapInfoSet of kernel need to be swapped
std::unordered_map<void *, MemSwapInfoSet> mem_swap_info_map_;
std::vector<HostAddress> host_addrs_list_;
std::unordered_set<const void *> swap_in_blacklist_;
// Key: cache kernel address, value: lists of first time move pos or not
std::map<void *, std::vector<bool>> kernel_first_move_cache_map_;
......
......@@ -112,7 +112,7 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
auto &mem_swap_manager = item.second;
MS_EXCEPTION_IF_NULL(mem_swap_manager);
if (mem_swap_manager->trigger_swap()) {
mem_swap_manager->ClearSwapQueue();
mem_swap_manager->ClearSwapQueue(false);
mem_swap_manager->ReleaseHostPinnedMem();
}
}
......@@ -141,6 +141,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
InitMemorySwapInfo(graph);
InitKernelOutputAddress(graph);
InitKernelWorkspaceAddress(graph);
SaveGraphOutputNode(graph);
} else {
AssignDynamicMemory(graph);
}
......@@ -168,12 +169,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
}
mem_reuse_util_ = mem_reuse_iter->second;
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
while (!LaunchKernelDynamic(graph)) {
MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment.";
if (!UpdateMemorySwapInfo(graph)) {
return false;
}
}
ret = RunOneStep(graph);
} else {
ret = LaunchKernel(graph);
}
......@@ -185,7 +182,29 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
return ret;
}
bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) {
bool ret = true;
auto graph_id = graph->graph_id();
if (!is_first_step_map_[graph_id]) {
// Normally run graph
ret = LaunchKernelDynamic(graph);
} else {
// Mock run first step
ret = LaunchKernelDynamic(graph, true, false);
if (ret) {
// Normally run graph
ret = LaunchKernelDynamic(graph);
} else {
// Trigger memory swap
ret = SearchMemSwapScheme(graph);
}
is_first_step_map_[graph_id] = false;
}
return ret;
}
bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) {
MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment.";
bool ret = false;
ClearKernelOldOutputAndWorkspace(graph);
if (!mem_swap_manager_->mem_swap_init()) {
......@@ -214,6 +233,7 @@ bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) {
}
bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) {
MS_LOG(WARNING) << "Refine memory swap scheme, it may take some time, please wait a moment.";
auto &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) {
......@@ -228,6 +248,7 @@ bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) {
ret = LaunchKernelDynamic(graph, true, false);
if (!ret) {
ClearKernelOldOutputAndWorkspace(graph);
ClearSwapInfo(true);
}
}
}
......@@ -297,6 +318,26 @@ void GPUKernelRuntime::InitKernelWorkspaceAddress(const session::KernelGraph *gr
}
}
void GPUKernelRuntime::SaveGraphOutputNode(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto graph_id = graph->graph_id();
const auto &output_nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
for (const auto &node : output_nodes) {
graph_output_map_[graph_id].insert(node);
}
}
bool GPUKernelRuntime::IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(graph);
auto graph_id = graph->graph_id();
auto iter = graph_output_map_.find(graph_id);
if (iter == graph_output_map_.end()) {
MS_LOG(EXCEPTION) << "Find graph output info failed.";
}
auto &graph_output_set = iter->second;
return (graph_output_set.find(kernel) != graph_output_set.end());
}
void GPUKernelRuntime::ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph) {
ClearKernelOutputAddress(graph);
ClearKernelWorkspaceAddress(graph);
......@@ -306,6 +347,9 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap
MS_EXCEPTION_IF_NULL(graph);
auto &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
if (IsGraphOutput(graph, kernel)) {
continue;
}
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();
......@@ -354,18 +398,27 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs, mock);
if (!ret) {
return false;
}
if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) {
MS_LOG(EXCEPTION) << "Launch kernel failed.";
if (!mock) {
if (!profiling) {
CHECK_OP_RET_WITH_EXCEPT(kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_),
"Launch kernel failed.");
} else {
LaunchKernelWithTimeProfiling(kernel, kernel_inputs, kernel_workspaces, kernel_outputs);
}
}
FreeKernelDynamicRes(kernel);
UpdateMemorySwapTask(kernel);
if (!UpdateMemorySwapTask(kernel, mock, profiling)) {
return false;
}
}
if (!mock) {
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
}
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
ClearSwapQueue();
ClearSwapInfo(mock);
return true;
}
......@@ -393,29 +446,38 @@ void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, c
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(end), "Failed to destroy event.");
}
bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) {
bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
const MemSwapInfoSet &mem_swap_info_set = mem_swap_manager_->QueryKernelMemSwapInfo(kernel);
for (auto &mem_swap_info : mem_swap_info_set) {
auto need_swap_kernel = mem_swap_manager_->QueryKerneByTopoOrder(mem_swap_info.topo_order_);
auto need_swap_kernel = mem_swap_manager_->QueryKernelByTopoOrder(mem_swap_info.topo_order_);
MS_EXCEPTION_IF_NULL(need_swap_kernel);
const HostAddress &host_address =
mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_);
auto device_address = AnfAlgo::GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false);
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address);
if (mem_swap_manager_->QueryKernelHostAddrIsDirty(need_swap_kernel, mem_swap_info.output_idx_)) {
mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address, mock);
mem_swap_manager_->AddKernelHostAddrIsDirty(need_swap_kernel, mem_swap_info.output_idx_, false);
} else {
mem_manager_->FreeMemFromMemPool(device_address);
device_address->set_status(DeviceAddressStatus::kInHost);
}
} else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) {
auto status = device_address->status();
if (status == DeviceAddressStatus::kInDeviceToHost) {
mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
device_address->set_status(DeviceAddressStatus::kInDevice);
} else if (status == DeviceAddressStatus::kInHost) {
if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) {
if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_, mock)) {
return false;
}
if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) {
mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address);
float cost_time = 0;
mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address, mock, profiling,
&cost_time);
if (profiling) {
mem_swap_manager_->AddKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_,
std::make_pair(0, cost_time));
}
}
}
......@@ -423,87 +485,81 @@ bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) {
return true;
}
bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
ClearKernelOldOutputAndWorkspace(graph);
if (!mem_swap_manager_->mem_swap_init()) {
if (!mem_swap_manager_->Init(graph)) {
return false;
}
}
return mem_swap_manager_->RetreatSwapInfo();
}
bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) {
bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return true;
}
if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) {
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
if (!AddMemorySwapTask(kernel)) {
if (!mock) {
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
}
if (!AddMemorySwapTask(kernel, mock, profiling)) {
return false;
}
if (!mock) {
CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed.");
}
}
CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed.");
return true;
}
void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) {
void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return;
}
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice, mock)) {
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
}
auto status = device_address->status();
switch (status) {
case DeviceAddressStatus::kInDevice:
break;
case DeviceAddressStatus::kInDeviceToHost: {
mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
device_address->set_status(DeviceAddressStatus::kInDevice);
break;
}
case DeviceAddressStatus::kInHostToDevice: {
while (device_address->status() != DeviceAddressStatus::kInDevice) {
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice, mock)) {
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
}
}
break;
}
case DeviceAddressStatus::kInHost:
MS_LOG(ERROR) << "Invaild device address status:" << status;
MS_LOG(WARNING) << "Unexpected device address status: " << status;
break;
default:
MS_LOG(EXCEPTION) << "Invaild device address status:" << status;
MS_LOG(EXCEPTION) << "Invaild device address status: " << status;
}
}
void GPUKernelRuntime::UpdateDeviceSwapQueue() {
void GPUKernelRuntime::UpdateHostSwapOutQueue(bool mock) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return;
}
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost, mock)) {
if (device_address_swap_out->status() == DeviceAddressStatus::kInDeviceToHost && device_address_swap_out->ptr_) {
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
}
}
}
void GPUKernelRuntime::ClearSwapQueue() {
void GPUKernelRuntime::ClearSwapInfo(bool mock) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return;
}
mem_swap_manager_->ClearSwapQueue();
mem_swap_manager_->ClearSwapQueue(mock);
mem_swap_manager_->ResetHostAddrIsDirty();
}
bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) {
bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock) {
MS_EXCEPTION_IF_NULL(mem_manager_);
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, size);
......@@ -511,13 +567,11 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address,
if (!mem_swap_manager_->trigger_swap()) {
return false;
}
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
}
if (!mock) {
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
}
UpdateHostSwapOutQueue(mock);
ret = mem_manager_->MallocMemFromMemPool(device_address, size);
if (!ret) {
return false;
......@@ -528,20 +582,22 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address,
bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) {
if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) {
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs,
bool mock) {
if (!AllocKernelInputDynamicRes(kernel, kernel_inputs, mock)) {
return false;
}
if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) {
if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs, mock)) {
return false;
}
if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) {
if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces, mock)) {
return false;
}
return true;
}
bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) {
bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
bool mock) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_inputs);
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
......@@ -555,7 +611,7 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true);
}
MS_EXCEPTION_IF_NULL(device_address);
UpdateHostSwapQueue(device_address);
UpdateHostSwapInQueue(device_address, mock);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
......@@ -567,16 +623,16 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k
}
bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_outputs) {
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_outputs,
bool mock) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_outputs);
UpdateDeviceSwapQueue();
UpdateHostSwapOutQueue(mock);
auto output_sizes = kernel_mod.GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
MS_EXCEPTION_IF_NULL(device_address);
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) {
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i], mock)) {
return false;
}
kernel::AddressPtr output = std::make_shared<kernel::Address>();
......@@ -590,7 +646,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern
bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_workspaces) {
AddressPtrList *kernel_workspaces, bool mock) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_workspaces);
auto workspace_sizes = kernel_mod.GetWorkspaceSizeList();
......@@ -600,7 +656,7 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K
continue;
}
auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i);
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, workspace_sizes[i])) {
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, workspace_sizes[i], mock)) {
return false;
}
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
......
......@@ -20,6 +20,7 @@
#include <string>
#include <memory>
#include <vector>
#include <set>
#include <utility>
#include <unordered_map>
#include "runtime/device/kernel_runtime.h"
......@@ -55,23 +56,27 @@ class GPUKernelRuntime : public KernelRuntime {
void InitKernelOutputAddress(const session::KernelGraph *graph);
void InitKernelWorkspaceAddress(const session::KernelGraph *graph);
void InitMemorySwapInfo(const session::KernelGraph *graph);
void SaveGraphOutputNode(const session::KernelGraph *graph);
bool IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const;
void ClearKernelOutputAddress(const session::KernelGraph *graph);
void ClearKernelWorkspaceAddress(const session::KernelGraph *graph);
void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph);
bool RunOneStep(const session::KernelGraph *graph);
bool SearchMemSwapScheme(const session::KernelGraph *graph);
bool RefineMemSwapScheme(const session::KernelGraph *graph);
bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false);
void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
const AddressPtrList &workspace, const AddressPtrList &outputs);
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size);
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock);
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs);
bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs);
AddressPtrList *kernel_outputs, bool mock);
bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, bool mock);
bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_outputs);
AddressPtrList *kernel_outputs, bool mock);
bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces);
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces,
bool mock);
void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph);
void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel);
void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel);
......@@ -79,15 +84,16 @@ class GPUKernelRuntime : public KernelRuntime {
const DeviceAddressPtrList addr_list, size_t total_size,
std::vector<size_t> size_list);
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel);
bool AddMemorySwapTask(const AnfNodePtr &kernel);
bool UpdateMemorySwapInfo(const session::KernelGraph *graph);
bool UpdateMemorySwapTask(const AnfNodePtr &kernel);
void UpdateHostSwapQueue(const DeviceAddressPtr device_address);
void UpdateDeviceSwapQueue();
void ClearSwapQueue();
bool UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling);
bool AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling);
void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock);
void UpdateHostSwapOutQueue(bool mock);
void ClearSwapInfo(bool mock);
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, bool> is_first_step_map_;
std::unordered_map<uint32_t, std::set<AnfNodePtr>> graph_output_map_;
MemReuseUtilPtr mem_reuse_util_{nullptr};
MemSwapManagerPtr mem_swap_manager_{nullptr};
};
......
......@@ -47,11 +47,20 @@ void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address
swap_out_queue_.emplace(device_address, event);
}
void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {
void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr,
bool profiling, float *cost_time) {
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(host_addr.addr);
DeviceEvent event = nullptr;
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event.");
DeviceEvent start = nullptr;
DeviceEvent end = nullptr;
if (profiling) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create CUDA event.");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create CUDA event.");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, swap_in_stream_),
"Failed to record CUDA event to swap in stream.");
} else {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end, cudaEventDisableTiming), "Failed to create CUDA event.");
}
DeviceMemPtr device_ptr = const_cast<DeviceMemPtr>(device_address->GetPtr());
MS_EXCEPTION_IF_NULL(device_ptr);
device_address->set_status(DeviceAddressStatus::kInHostToDevice);
......@@ -59,9 +68,27 @@ void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address,
CHECK_OP_RET_WITH_EXCEPT(
CudaDriver::CopyHostMemToDeviceAsync(device_ptr, host_addr.addr, host_addr.size, swap_in_stream_),
"Failed to copy host memory to device.");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_in_stream_),
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(end, swap_in_stream_),
"Failed to record CUDA event to swap in stream.");
swap_in_queue_.emplace(device_address, event);
if (profiling) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(start), "Failed to sync event.");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(end), "Failed to sync event.");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ElapsedTime(cost_time, start, end), "Failed to record elapsed time.");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(start), "Failed to destroy event.");
}
swap_in_queue_.emplace(device_address, end);
}
void GPUMemCopyManager::AddMemSwapOutTaskMock(const DeviceAddressPtr &device_address) {
MS_EXCEPTION_IF_NULL(device_address);
device_address->set_status(DeviceAddressStatus::kInDeviceToHost);
swap_out_queue_mock_.emplace(device_address);
}
void GPUMemCopyManager::AddMemSwapInTaskMock(const DeviceAddressPtr &device_address) {
MS_EXCEPTION_IF_NULL(device_address);
device_address->set_status(DeviceAddressStatus::kInHostToDevice);
swap_in_queue_mock_.emplace(device_address);
}
bool GPUMemCopyManager::SyncMemCopyStream(SwapKind swap_kind) {
......@@ -104,6 +131,24 @@ DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueue() {
return device_address;
}
DeviceAddressPtr GPUMemCopyManager::UpdateSwapOutQueueMock() {
if (swap_out_queue_mock_.empty()) {
return nullptr;
}
auto device_address = swap_out_queue_mock_.front();
swap_out_queue_mock_.pop();
return device_address;
}
DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueueMock() {
if (swap_in_queue_mock_.empty()) {
return nullptr;
}
auto device_address = swap_in_queue_mock_.front();
swap_in_queue_mock_.pop();
return device_address;
}
bool GPUMemCopyManager::AllocHostPinnedMem(size_t size, void **addr) const {
auto alloc_size = CudaDriver::AllocHostPinnedMem(size, addr);
return alloc_size == size;
......@@ -126,6 +171,15 @@ void GPUMemCopyManager::ClearSwapQueue() {
swap_in_queue_.pop();
}
}
void GPUMemCopyManager::ClearSwapQueueMock() {
while (!swap_out_queue_mock_.empty()) {
swap_out_queue_mock_.pop();
}
while (!swap_in_queue_mock_.empty()) {
swap_in_queue_mock_.pop();
}
}
} // namespace gpu
} // namespace device
} // namespace mindspore
......@@ -40,7 +40,12 @@ class GPUMemCopyManager : public MemCopyManager {
void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override;
void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override;
void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr, bool profiling,
float *cost_time) override;
void AddMemSwapOutTaskMock(const DeviceAddressPtr &device_address) override;
void AddMemSwapInTaskMock(const DeviceAddressPtr &device_address) override;
bool SyncMemCopyStream(SwapKind swap_kind) override;
......@@ -48,17 +53,25 @@ class GPUMemCopyManager : public MemCopyManager {
DeviceAddressPtr UpdateSwapInQueue() override;
DeviceAddressPtr UpdateSwapOutQueueMock() override;
DeviceAddressPtr UpdateSwapInQueueMock() override;
bool AllocHostPinnedMem(size_t size, void **addr) const override;
void FreeHostPinnedMem(void *addr) const override;
void ClearSwapQueue() override;
void ClearSwapQueueMock() override;
private:
DeviceStream swap_out_stream_{nullptr};
DeviceStream swap_in_stream_{nullptr};
std::queue<std::pair<DeviceAddressPtr, DeviceEvent>> swap_out_queue_;
std::queue<std::pair<DeviceAddressPtr, DeviceEvent>> swap_in_queue_;
std::queue<DeviceAddressPtr> swap_out_queue_mock_;
std::queue<DeviceAddressPtr> swap_in_queue_mock_;
};
using GPUMemCopyManagerPtr = std::shared_ptr<GPUMemCopyManager>;
} // namespace gpu
......
......@@ -355,7 +355,7 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=170):
def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=338):
net = resnet50(num_classes)
lr = 0.1
momentum = 0.9
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册