提交 f82e63fe 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!671 Added testcase for sync_wait

Merge pull request !671 from EricZ/master
......@@ -65,8 +65,8 @@ Status BarrierOp::operator()() {
TaskManager::FindMe()->Post();
// create child iterator, right now this barrier is a pipeline operator
int32_t worker_id = 0;
int32_t child_idx = 0;
const int32_t worker_id = 0;
const int32_t child_idx = 0;
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
// Loop until eof is true
......
......@@ -922,7 +922,7 @@ class Dataset:
def sync_update(self, condition_name, num_batch=None, data=None):
"""
condition_name (str): The condition name that is used to toggle sending next row
step_size (int or None): The number of steps(rows) that are released
num_batch (int or None): The number of batches(rows) that are released
when pass_rows is None, will update the same number as sync_wait specified
data (dict or None): The data passed to the callback
"""
......
......@@ -107,6 +107,7 @@ def test_two_sync():
if count % 2 == 0:
dataset.sync_update(condition_name="every 2 batches")
def test_sync_epoch():
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
......@@ -130,6 +131,34 @@ def test_sync_epoch():
dataset.sync_update(condition_name="policy", data=data)
def test_multiple_iterators():
"""
Test sync wait with multiple iterators: will start multiple
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
# 2nd dataset
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
assert (item1["input"][0] == item2["input"][0])
data1 = {"loss": item1["input"][0]}
data2 = {"loss": item2["input"][0]}
dataset.sync_update(condition_name="policy", data=data1)
dataset2.sync_update(condition_name="policy", data=data2)
def test_sync_exception_01():
"""
Test sync: with shuffle in sync mode
......@@ -179,4 +208,5 @@ if __name__ == "__main__":
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_epoch()
\ No newline at end of file
test_sync_epoch()
test_multiple_iterators()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册