提交 a0d14b18 编写于 作者: S sneaxiy

Turn on keep_order=True for test, test=develop

上级 a4951843
......@@ -135,7 +135,6 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
for (auto& item : queues_) {
item->Close();
}
data_index_ = 0;
}
inline void Kill() {
......@@ -157,6 +156,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
data_index_ = 0;
}
inline void SetResetMethod(size_t idx,
......@@ -171,8 +171,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
private:
const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() {
EnforceIsInited();
return queues_[data_index_.fetch_add(1) % queues_.size()];
return queues_[(data_index_++) % queues_.size()];
}
private:
......@@ -183,7 +182,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
private:
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
mutable std::atomic<uint64_t> data_index_{0};
mutable uint64_t data_index_{0};
size_t dev_cnt_{0};
const size_t capacity_;
......
......@@ -89,7 +89,7 @@ class DataLoader(object):
return_list=False,
use_multiprocess=False,
drop_last=True,
keep_order=False):
keep_order=True):
"""
Create a DataLoader object for loading data from Python generator.
Data would be prefetched using Python thread and be pushed
......@@ -633,7 +633,7 @@ class GeneratorLoader(DataLoaderBase):
iterable=True,
return_list=False,
drop_last=True,
keep_order=False):
keep_order=True):
self._tensor_reader = None
self._places = None
self._thread = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册