提交 e7936ded 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1480 gpu iiterator weak ref support

Merge pull request !1480 from panfengfeng/iterator_gpu_weak_ref
......@@ -26,10 +26,6 @@
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#ifdef ENABLE_TDTQUE
#include "tdt/tsd_client.h"
#endif
namespace mindspore {
namespace dataset {
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
......@@ -167,9 +163,15 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop = true;
}
}
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else
is_break_loop = true;
}
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else
is_break_loop = true;
}
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
......@@ -191,7 +193,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
items.push_back(data_item);
}
while (!GpuBufferMgr::GetInstance().IsClosed()) {
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row));
auto ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) {
......
......@@ -172,9 +172,7 @@ bool GpuBufferMgr::CloseNotify() {
{
std::lock_guard<std::mutex> lk(close_mutex_);
// set closed_ to be true, all the dataset retry can be jumped out of the while
closed_ = true; // set closed_ to be true, all the dataset retry can be jumped out of the while
// notify all the waiting dataset threads
close_confirm_cond_.notify_all(); // notify all the waiting dataset threads
closed_ = true;
}
// wati for the dataset threads' ack
......@@ -188,16 +186,6 @@ bool GpuBufferMgr::CloseNotify() {
return result;
}
void GpuBufferMgr::CloseConfirm() {
// lock scope
{
std::unique_lock<std::mutex> lk(close_mutex_);
// dataset threads wait for the closed_ flag from false to true
close_confirm_cond_.wait(
lk, [this] { return closed_; }); // dataset threads wait for the closed_ flag from false to true
}
sema.Signal();
}
void GpuBufferMgr::CloseConfirm() { sema.Signal(); }
} // namespace device
} // namespace mindspore
......@@ -119,7 +119,6 @@ class GpuBufferMgr {
bool closed_;
std::mutex mutex_;
std::mutex close_mutex_;
std::condition_variable close_confirm_cond_;
// how many queues opened by dataset
int open_by_dataset_;
Semaphore sema;
......
......@@ -17,7 +17,6 @@
from abc import abstractmethod
import copy
import weakref
from importlib import import_module
from mindspore._c_dataengine import DEPipeline
from mindspore._c_dataengine import OpName
......@@ -25,10 +24,6 @@ from mindspore._c_dataengine import OpName
from mindspore import log as logger
from . import datasets as de
try:
context = import_module("mindspore.context")
except ModuleNotFoundError:
context = None
ITERATORS_LIST = list()
......@@ -36,18 +31,9 @@ ITERATORS_LIST = list()
def _cleanup():
"""Release all the Iterator."""
for itr_ref in ITERATORS_LIST:
if context:
device_type = context.get_context("device_target")
if device_type == "GPU":
itr_ref.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
itr = itr_ref()
if itr is not None:
itr.release()
def alter_tree(node):
......@@ -101,14 +87,7 @@ class Iterator:
"""
def __init__(self, dataset):
if context:
device_type = context.get_context("device_target")
if device_type == "GPU":
ITERATORS_LIST.append(self)
else:
ITERATORS_LIST.append(weakref.ref(self))
else:
ITERATORS_LIST.append(weakref.ref(self))
ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.dataset = alter_tree(self.dataset)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册