From fb6eed23aeefbec5a8885498fb626e2d972158ca Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Wed, 1 Apr 2020 14:16:52 +0800 Subject: [PATCH] refining strategy-checking for resnet50 --- .../ut/python/parallel/test_auto_parallel_resnet.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index 667e3873a..1e0e3570b 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -295,11 +295,11 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576 model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_strategy(model._train_network) for (k, v) in strategies.items(): - if re.match(k, 'Conv2D-op') is not None: + if re.search('Conv2D-op', k) is not None: assert v[0][0] == dev_num - elif re.match(k, 'MatMul-op') is not None: + elif re.search('MatMul-op', k) is not None: assert v == [[dev_num, 1], [1, 1]] - elif re.match(k, 'ReduceSum-op') is not None: + elif re.search('ReduceSum-op', k) is not None: assert v == [[dev_num, 1]] allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network) @@ -490,9 +490,9 @@ def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_strategy(model._train_network) for (k, v) in strategies.items(): - if re.match(k, 'Conv2D-op') is not None: + if re.search('Conv2D-op', k ) is not None: assert v[0][0] == dev_num - elif re.match(k, 'MatMul-op') is not None: + elif re.search('MatMul-op', k) is not None: assert v == [[1, 1], [dev_num, 1]] - elif re.match(k, 'ReduceSum-op') is not None: + elif re.search('ReduceSum-op', k) is not None: assert v == [[1, dev_num]] -- GitLab