提交 a0d14b18 编写于 作者: S sneaxiy

Turn on keep_order=True for test, test=develop

上级 a4951843
...@@ -135,7 +135,6 @@ class OrderedMultiDeviceLoDTensorBlockingQueue { ...@@ -135,7 +135,6 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
for (auto& item : queues_) { for (auto& item : queues_) {
item->Close(); item->Close();
} }
data_index_ = 0;
} }
inline void Kill() { inline void Kill() {
...@@ -157,6 +156,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue { ...@@ -157,6 +156,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt; auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_)); item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
} }
data_index_ = 0;
} }
inline void SetResetMethod(size_t idx, inline void SetResetMethod(size_t idx,
...@@ -171,8 +171,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue { ...@@ -171,8 +171,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
private: private:
const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() { const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() {
EnforceIsInited(); return queues_[(data_index_++) % queues_.size()];
return queues_[data_index_.fetch_add(1) % queues_.size()];
} }
private: private:
...@@ -183,7 +182,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue { ...@@ -183,7 +182,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
private: private:
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_; std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
mutable std::atomic<uint64_t> data_index_{0}; mutable uint64_t data_index_{0};
size_t dev_cnt_{0}; size_t dev_cnt_{0};
const size_t capacity_; const size_t capacity_;
......
...@@ -89,7 +89,7 @@ class DataLoader(object): ...@@ -89,7 +89,7 @@ class DataLoader(object):
return_list=False, return_list=False,
use_multiprocess=False, use_multiprocess=False,
drop_last=True, drop_last=True,
keep_order=False): keep_order=True):
""" """
Create a DataLoader object for loading data from Python generator. Create a DataLoader object for loading data from Python generator.
Data would be prefetched using Python thread and be pushed Data would be prefetched using Python thread and be pushed
...@@ -633,7 +633,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -633,7 +633,7 @@ class GeneratorLoader(DataLoaderBase):
iterable=True, iterable=True,
return_list=False, return_list=False,
drop_last=True, drop_last=True,
keep_order=False): keep_order=True):
self._tensor_reader = None self._tensor_reader = None
self._places = None self._places = None
self._thread = None self._thread = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册