diff --git a/mindspore/parallel/__init__.py b/mindspore/parallel/__init__.py index 79d8e67a8dcd2e4f5f1a9ea0c8d47bcd7651c8fc..170418fc929e85acd13b8495058afa8c743f0555 100644 --- a/mindspore/parallel/__init__.py +++ b/mindspore/parallel/__init__.py @@ -17,5 +17,7 @@ This interface is ONLY used in Auto-parallel procedure. """ from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ set_algo_parameters +from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs -__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] +__all__ = ["set_multi_subgraphs", "get_multi_subgraphs", + "get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index dda68e4f2d99442e1feb9e2101a6054e297f18c2..3b278caaffbc448c9e03ea6784b7614ec798f67f 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -479,7 +479,6 @@ set_cost_model_context_func_map = { "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold, "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, - "multi_subgraphs": cost_model_context().set_multi_subgraphs, "run_phase": cost_model_context().set_run_phase, "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, @@ -501,7 +500,6 @@ get_cost_model_context_func_map = { "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, - "multi_subgraphs": cost_model_context().get_multi_subgraphs, "run_phase": cost_model_context().get_run_phase, "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, @@ -538,7 +536,6 @@ def set_cost_model_context(**kwargs): costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. - multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. 0: bypass allreduce fusion; @@ -591,3 +588,18 @@ def get_cost_model_context(attr_key): def reset_cost_model_context(): """Reset cost model context attributes.""" cost_model_context().reset_cost_model() + +def set_multi_subgraphs(multi_subgraph=True): + """ + Set the flag of ANF graph containing multiple subgraphs. + + Args: + multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag. + """ + cost_model_context().set_multi_subgraphs(multi_subgraph) + +def get_multi_subgraphs(): + """ + Get the flag of ANF graph containing multiple subgraphs. + """ + cost_model_context().get_multi_subgraphs() diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 8010a843e7cdcff19bbdf16af86840e48dfa64f0..8bef1821b3edefdbbd28966451dece2937cb841e 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -14,7 +14,7 @@ # ============================================================================ """wide and deep model""" import numpy as np -from mindspore import nn +from mindspore import nn, context from mindspore import Parameter, ParameterTuple import mindspore.common.dtype as mstype from mindspore.ops import functional as F @@ -22,10 +22,7 @@ from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.nn import Dropout from mindspore.nn.optim import Adam, FTRL, LazyAdam -# from mindspore.nn.metrics import Metric from mindspore.common.initializer import Uniform, initializer -# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.communication.management import get_group_size @@ -142,7 +139,7 @@ class WideDeepModel(nn.Cell): self.batch_size = config.batch_size host_device_mix = bool(config.host_device_mix) parameter_server = bool(config.parameter_server) - parallel_mode = _get_parallel_mode() + parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if is_auto_parallel: self.batch_size = self.batch_size * get_group_size() @@ -259,7 +256,7 @@ class NetWithLossClass(nn.Cell): super(NetWithLossClass, self).__init__(auto_prefix=False) host_device_mix = bool(config.host_device_mix) parameter_server = bool(config.parameter_server) - parallel_mode = _get_parallel_mode() + parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server) self.network = network @@ -312,7 +309,7 @@ class TrainStepWrap(nn.Cell): def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False): super(TrainStepWrap, self).__init__() - parallel_mode = _get_parallel_mode() + parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.network = network self.network.set_train() @@ -351,12 +348,11 @@ class TrainStepWrap(nn.Cell): self.reducer_flag = False self.grad_reducer_w = None self.grad_reducer_d = None - parallel_mode = _get_parallel_mode() self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) if self.reducer_flag: - mean = _get_mirror_mean() - degree = _get_device_num() + mean = context.get_auto_parallel_context("mirror_mean") + degree = context.get_auto_parallel_context("device_num") self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree) self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py index a47b9e040e11337fc5aa73a1b79b3d78f6536b18..4af6366e42d761136d3b151bbc2d1208c1577bad 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -22,7 +22,7 @@ from mindspore import Model, context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train import ParallelMode from mindspore.communication.management import get_rank, get_group_size, init -from mindspore.parallel import _cost_model_context as cost_model_context +from mindspore.parallel import set_multi_subgraphs from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel @@ -127,7 +127,7 @@ if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) context.set_context(variable_memory_max_size="24GB") context.set_context(enable_sparse=True) - cost_model_context.set_cost_model_context(multi_subgraphs=True) + set_multi_subgraphs() if wide_deep_config.device_target == "Ascend": init("hccl") elif wide_deep_config.device_target == "GPU": diff --git a/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py index 189faadbc56e731aae6e34c49a32b6312d1a9cf9..ba358dd723812c48bbfa7286f41100c492251b78 100644 --- a/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py @@ -16,7 +16,7 @@ import numpy as np import mindspore.common.dtype as mstype -from mindspore import nn +from mindspore import nn, context from mindspore import Tensor, Parameter, ParameterTuple from mindspore.ops import functional as F from mindspore.ops import composite as C @@ -24,7 +24,6 @@ from mindspore.ops import operations as P from mindspore.nn import Dropout, Flatten from mindspore.nn.optim import Adam, FTRL from mindspore.common.initializer import Uniform, initializer -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.grad_reducer import DistributedGradReducer @@ -552,13 +551,13 @@ class TrainStepWrap(nn.Cell): self.reducer_flag = False self.grad_reducer_w = None self.grad_reducer_d = None - parallel_mode = _get_parallel_mode() + parallel_mode = context.get_auto_parallel_context("parallel_mode") if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: - mean = _get_mirror_mean() - degree = _get_device_num() + mean = context.get_auto_parallel_context("mirror_mean") + degree = context.get_auto_parallel_context("device_num") self.grad_reducer_w = DistributedGradReducer( self.optimizer_w.parameters, mean, degree) self.grad_reducer_d = DistributedGradReducer( diff --git a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py index 930d7d6aaa7c7a35ffce6a32df876d002ad4576e..073c23423e88e3f9b6865ddbfa509d54e4c3b4aa 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py +++ b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py @@ -21,7 +21,7 @@ from mindspore import Model, context from mindspore.train.callback import TimeMonitor from mindspore.train import ParallelMode from mindspore.communication.management import get_rank, get_group_size, init -from mindspore.parallel import _cost_model_context as cost_model_context +from mindspore.parallel import set_multi_subgraphs from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel @@ -33,7 +33,7 @@ from src.config import WideDeepConfig sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) -cost_model_context.set_cost_model_context(multi_subgraphs=True) +set_multi_subgraphs() init() diff --git a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py index 6003c0cc28ab9f0b0da8ae35bdebdff0206e6374..eb9c397abc618255cd866fbbedb9d261fa8278dc 100644 --- a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py +++ b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py @@ -23,7 +23,7 @@ from mindspore.nn.optim import Adam, FTRL from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P -from mindspore.parallel import _cost_model_context as cost_model_context +from mindspore.parallel import set_multi_subgraphs from mindspore.parallel._utils import _reset_op_id as reset_op_id @@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell): def test_double_subgraphs(): - cost_model_context.set_cost_model_context(multi_subgraphs=True) + set_multi_subgraphs() context.set_context(save_graphs=True) context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="auto_parallel")