diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index 4efb77055f484a2940993570ea399b4b9052c244..fde951454003e0661ff1bfbfd783cfbce204c620 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -34,8 +34,8 @@ namespace parallel { #define OPERATOR_TO_OPERATOR_CONNECTOR "-" #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) #define DEFAULT_COST_MODEL_ALPHA 1.0 -#define DEFAULT_COST_MODEL_BETA 65.0 -#define DEFAULT_COST_MODEL_GAMMA 0.02 +#define DEFAULT_COST_MODEL_BETA 260.0 +#define DEFAULT_COST_MODEL_GAMMA 0.001 #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 #define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 927acea7055e349a744d77ee1d643e0780059868..21da11ec219d678bb7c2a92a95acd370fa63013c 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -375,6 +375,10 @@ bool IsParallelCareNode(const CNodePtr& cnode) { MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); return false; } + // get_next is not in the forward graph, we need mark the get_next as the forward node + if (prim->name() == GET_NEXT) { + return true; + } if ((prim->name() == CAST)) { if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) { return false; diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index de26d72108f760ac74c77b67783d6bd08e5c91dc..6252116efeeb4c1561663861ea3220019a7e55fe 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -88,7 +88,7 @@ class _DatasetIter: # times the batch dimension of tensors for run if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): device_num = _get_device_num() - dataset_shapes = _to_full_shapes(dataset_shapes, device_num) + self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num) def __iter__(self): self.ind = 0 diff --git a/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py b/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py index 67b8f98fafae2e9e31cdf5acd09a7555abb621ec..b7a3255f7c5eed275dde79f16a67c5c92e7470f5 100644 --- a/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py +++ b/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py @@ -80,9 +80,9 @@ def test_common_parameter(): _executor.compile(net, x, y, z, w, phase='train') strategies = _executor._get_strategy(net) - expected_strategies = {'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]], - 'Default/network-Net/MatMul-op9': [[1, 1], [1, 8]], - 'Default/network-Net/Cast-op10': [[1, 8]], - 'Default/network-Net/MatMul-op0': [[1, 1], [1, 8]], - 'Default/network-Net/Cast-op11': [[1, 8]]} - assert strategies == expected_strategies \ No newline at end of file + expected_strategies = {'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]], + 'Default/network-Net/MatMul-op8': [[8, 1], [1, 1]], + 'Default/network-Net/Cast-op7': [[1, 1]], + 'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]], + 'Default/network-Net/Cast-op9': [[1, 1]]} + assert strategies == expected_strategies diff --git a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py index 5155db41f64b6134b4d1fedc2ec876945eb58732..e7beed384e0bd9d0dff1e12a164d18341be031b8 100644 --- a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py +++ b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py @@ -86,9 +86,9 @@ def test_two_matmul(): costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") - assert costmodel_beta == 65.0 + assert costmodel_beta == 260.0 costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") - assert costmodel_gamma == 0.02 + assert costmodel_gamma == 0.001 costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") assert costmodel_communi_threshold == 2048.0 costmodel_communi_const = cost_model_context.get_cost_model_context("costmodel_communi_const") @@ -137,4 +137,5 @@ def test_two_matmul(): strategies = _executor._get_strategy(net) expected_strategies = {'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]], 'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]} - assert strategies == expected_strategies \ No newline at end of file + assert strategies == expected_strategies +