提交 e7936ded 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1480 gpu iiterator weak ref support

Merge pull request !1480 from panfengfeng/iterator_gpu_weak_ref
...@@ -26,10 +26,6 @@ ...@@ -26,10 +26,6 @@
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h" #include "dataset/engine/opt/pass.h"
#ifdef ENABLE_TDTQUE
#include "tdt/tsd_client.h"
#endif
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
...@@ -167,9 +163,15 @@ Status DeviceQueueOp::SendDataToGPU() { ...@@ -167,9 +163,15 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop = true; is_break_loop = true;
} }
} }
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else
is_break_loop = true;
} }
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else
is_break_loop = true;
} }
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
...@@ -191,7 +193,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con ...@@ -191,7 +193,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
items.push_back(data_item); items.push_back(data_item);
} }
while (!GpuBufferMgr::GetInstance().IsClosed()) { while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row)); RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row));
auto ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); auto ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) { if (ret) {
......
...@@ -172,9 +172,7 @@ bool GpuBufferMgr::CloseNotify() { ...@@ -172,9 +172,7 @@ bool GpuBufferMgr::CloseNotify() {
{ {
std::lock_guard<std::mutex> lk(close_mutex_); std::lock_guard<std::mutex> lk(close_mutex_);
// set closed_ to be true, all the dataset retry can be jumped out of the while // set closed_ to be true, all the dataset retry can be jumped out of the while
closed_ = true; // set closed_ to be true, all the dataset retry can be jumped out of the while closed_ = true;
// notify all the waiting dataset threads
close_confirm_cond_.notify_all(); // notify all the waiting dataset threads
} }
// wati for the dataset threads' ack // wati for the dataset threads' ack
...@@ -188,16 +186,6 @@ bool GpuBufferMgr::CloseNotify() { ...@@ -188,16 +186,6 @@ bool GpuBufferMgr::CloseNotify() {
return result; return result;
} }
void GpuBufferMgr::CloseConfirm() { void GpuBufferMgr::CloseConfirm() { sema.Signal(); }
// lock scope
{
std::unique_lock<std::mutex> lk(close_mutex_);
// dataset threads wait for the closed_ flag from false to true
close_confirm_cond_.wait(
lk, [this] { return closed_; }); // dataset threads wait for the closed_ flag from false to true
}
sema.Signal();
}
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
...@@ -119,7 +119,6 @@ class GpuBufferMgr { ...@@ -119,7 +119,6 @@ class GpuBufferMgr {
bool closed_; bool closed_;
std::mutex mutex_; std::mutex mutex_;
std::mutex close_mutex_; std::mutex close_mutex_;
std::condition_variable close_confirm_cond_;
// how many queues opened by dataset // how many queues opened by dataset
int open_by_dataset_; int open_by_dataset_;
Semaphore sema; Semaphore sema;
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
from abc import abstractmethod from abc import abstractmethod
import copy import copy
import weakref import weakref
from importlib import import_module
from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import DEPipeline
from mindspore._c_dataengine import OpName from mindspore._c_dataengine import OpName
...@@ -25,10 +24,6 @@ from mindspore._c_dataengine import OpName ...@@ -25,10 +24,6 @@ from mindspore._c_dataengine import OpName
from mindspore import log as logger from mindspore import log as logger
from . import datasets as de from . import datasets as de
try:
context = import_module("mindspore.context")
except ModuleNotFoundError:
context = None
ITERATORS_LIST = list() ITERATORS_LIST = list()
...@@ -36,18 +31,9 @@ ITERATORS_LIST = list() ...@@ -36,18 +31,9 @@ ITERATORS_LIST = list()
def _cleanup(): def _cleanup():
"""Release all the Iterator.""" """Release all the Iterator."""
for itr_ref in ITERATORS_LIST: for itr_ref in ITERATORS_LIST:
if context: itr = itr_ref()
device_type = context.get_context("device_target") if itr is not None:
if device_type == "GPU": itr.release()
itr_ref.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
def alter_tree(node): def alter_tree(node):
...@@ -101,14 +87,7 @@ class Iterator: ...@@ -101,14 +87,7 @@ class Iterator:
""" """
def __init__(self, dataset): def __init__(self, dataset):
if context: ITERATORS_LIST.append(weakref.ref(self))
device_type = context.get_context("device_target")
if device_type == "GPU":
ITERATORS_LIST.append(self)
else:
ITERATORS_LIST.append(weakref.ref(self))
else:
ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it. # create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset) self.dataset = copy.deepcopy(dataset)
self.dataset = alter_tree(self.dataset) self.dataset = alter_tree(self.dataset)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册