diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 0ba918575503e0df45a2ac3c4039ff6d6a58d529..20445f9c6a75e970d023f7b5e2261b0c5bbddecf 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -81,8 +81,6 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_ void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } -void ParallelContext::set_has_initializer(bool has_initializer) { has_initializer_ = has_initializer; } - void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index e32ef855e3e3c1765833c7182edb6e1c9c485328..34363726411dd9b154e6e03522679eccaeaa6659 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -58,9 +58,6 @@ class ParallelContext { void set_full_batch(bool full_batch); bool full_batch() const { return full_batch_; } - void set_has_initializer(bool has_initializer); - bool has_initializer() const { return has_initializer_; } - void set_cast_before_mirror(bool cast_before_mirror); bool cast_before_mirror() const { return cast_before_mirror_; } @@ -115,7 +112,6 @@ class ParallelContext { static std::shared_ptr inst_context_; bool mirror_mean_; bool full_batch_; - bool has_initializer_ = false; bool cast_before_mirror_; bool loss_repeated_mean_; int32_t device_num_; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 3919fd10614c4fd4a2132ae568eed3d2483c06c7..d89dc444481a88d1e06fa61cfe3c68bce28792f7 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -198,8 +198,6 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") - .def("set_has_initializer", &ParallelContext::set_has_initializer, "Set whether any Initializer has been created.") - .def("get_has_initializer", &ParallelContext::has_initializer, "Get whether any Initializer has been created.") .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, "Set enable/disable parallel optimizer.") .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, diff --git a/mindspore/common/api.py b/mindspore/common/api.py index b30d7ed0d12258935af1eb98fe952ee0dc5b5faf..b827ffe3455e92584052db9399abbad962b3b502 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -24,7 +24,7 @@ from mindspore import log as logger from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend from .tensor import Tensor as MsTensor -from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor, _set_has_initializer +from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor # store ms_function class compiled pipeline cache ms_compile_cache = {} @@ -383,7 +383,6 @@ class _Executor: Str, the full phase of the cell. Bool, if the graph has been compiled before, return False, else return True. """ - _set_has_initializer(False) obj.check_names() args_names, args_list = _generate_pip_args(obj, *args) dic = dict(zip(args_names, args_list)) diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 4982243f044d20a78ebf90b52793e04fb72ff8a5..546d1e99b150e6155c3a38c85ff2e4e869ca4af4 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -24,7 +24,6 @@ from mindspore import log as logger from . import dtype as mstype from .tensor import Tensor from .._c_expression import random_normal -from ..parallel._utils import _set_has_initializer _INITIALIZER_ALIAS = dict() @@ -43,7 +42,6 @@ class Initializer: self._kwargs = kwargs self.shape = None self.dtype = None - _set_has_initializer(True) def _initialize(self, *kwargs): raise NotImplementedError('Must be overridden!') diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index e5dfbc23bdff73acf099b2a9f35252a19f40769a..ee41fdd1656e474480d78021adfdc06d2caf0a45 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -90,6 +90,9 @@ class Parameter(MetaTensor): input_class.__init__(obj, *class_init_args) # it's better to make the Initializer a kind of metatensor. obj.init_mode = None + obj.is_default_input_initializer = False + if isinstance(default_input, Initializer): + obj.is_default_input_initializer = True if not isinstance(obj, Tensor): obj.init_mode = default_input return obj @@ -118,6 +121,7 @@ class Parameter(MetaTensor): self.is_param_ps = False self._cast_type = None self.init_in_server = False + self.is_in_parallel = _is_in_parallel_mode() @staticmethod def _get_base_class(input_class): @@ -372,10 +376,17 @@ class Parameter(MetaTensor): set_sliced (bool): True if the parameter is set sliced after initializing the data. Default: False. + Raises: + RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created. + Returns: Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, returns the same initialized `Parameter`. """ + if self.is_default_input_initializer: + is_current_in_parallel = _is_in_parallel_mode() + if self.is_in_parallel != is_current_in_parallel: + raise RuntimeError("Must set or change parallel mode before any Initializer created.") if self.init_mode is None: return self if layout is not None: diff --git a/mindspore/context.py b/mindspore/context.py index c92a9985142293cd0975797fb5712e1595ce2169..985270d1facba5de4d907537989af8033b5f377b 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -449,7 +449,7 @@ def set_auto_parallel_context(**kwargs): next task, interface mindspore.context.reset_auto_parallel_context() needs to be called to reset the configuration. Setting or changing parallel modes must be called before any Initializer created, or RuntimeError - will be raised. + may be raised when compile network. Args: device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. @@ -491,7 +491,6 @@ def set_auto_parallel_context(**kwargs): Raises: ValueError: If input key is not attribute in auto parallel context. - RuntimeError: If there is any Initializer created before setting or changing parallel_mode. Examples: >>> context.set_auto_parallel_context(device_num=8) diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 7da25b8fd3724baa313151efabc63d3a681af60c..e2369c4aa69ddeddfc31206b42a6fef0a1001311 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -176,12 +176,8 @@ class _AutoParallelContext: Raises: ValueError: If parallel mode is not supported. - RuntimeError: If there is any Initializer created before setting or changing parallel_mode. """ self.check_context_handle() - if self.get_has_initializer(): - self.set_has_initializer(False) - raise RuntimeError("Must set or change parallel mode before any Initializer created.") ret = self._context_handle.set_parallel_mode(parallel_mode) if ret is False: raise ValueError("Parallel mode does not support {}".format(parallel_mode)) @@ -253,21 +249,6 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_full_batch() - def set_has_initializer(self, has_initializer): - """ - Set whether any Initializer has been created. - - Args: - has_initializer (bool): True if a Initializer created. - """ - self.check_context_handle() - self._context_handle.set_has_initializer(has_initializer) - - def get_has_initializer(self): - """Get whether any Initializer has been created.""" - self.check_context_handle() - return self._context_handle.get_has_initializer() - def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): """ Set strategy checkpoint save path. @@ -562,7 +543,6 @@ def _set_auto_parallel_context(**kwargs): Raises: ValueError: If input key is not attribute in auto parallel context. - RuntimeError: If there is any Initializer created before setting or changing parallel_mode. """ for key, value in kwargs.items(): if key not in _set_auto_parallel_context_func_map: diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 3ed2416b359a30fabe4d5e3b26925be2f2ef3db6..ff1fbcc6c2bb96a62c298d749d090839c5ae9f7a 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -32,19 +32,6 @@ def _get_full_batch(): """Get whether to use full_batch.""" return auto_parallel_context().get_full_batch() -def _get_has_initializer(): - """Get whether any Initializer has been created.""" - return auto_parallel_context().get_has_initializer() - -def _set_has_initializer(has_initializer): - """ - Set whether any Initializer has been created. - - Args: - has_initializer (bool): True if a Initializer created. - """ - auto_parallel_context().set_has_initializer(has_initializer) - def _need_to_full(): """Check whether to convert input to full shape or tensor.""" diff --git a/tests/ut/python/communication/test_data_parallel_lenet.py b/tests/ut/python/communication/test_data_parallel_lenet.py index 42fc122adec3f179ce06dca868b515124d60a0d1..7a5062b9410cf0b39eb3aa77a69f83ee5ec5b6c1 100755 --- a/tests/ut/python/communication/test_data_parallel_lenet.py +++ b/tests/ut/python/communication/test_data_parallel_lenet.py @@ -24,7 +24,6 @@ import mindspore.nn as nn from mindspore import Tensor, Model, ParallelMode from mindspore.nn.optim import Momentum from mindspore.ops import operations as P -from mindspore.parallel._utils import _set_has_initializer _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../test_data" @@ -90,4 +89,3 @@ def test_lenet5_train_step_training_pynative(): Model(network=network, loss_fn=loss_fn, optimizer=optimizer) context.set_context(mode=context.GRAPH_MODE) context.reset_auto_parallel_context() - _set_has_initializer(False) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 2447112bd89e7f96228c8e89f2d76f55545119c6..ec0d771075844ad8f3a34979bd0b58b55ea579c4 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -21,7 +21,6 @@ from mindspore import context, Tensor, Parameter, ParameterTuple from mindspore._checkparam import _check_str_by_regular from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer -from mindspore.parallel._utils import _set_has_initializer def test_parameter_init(): dat = np.array([[1, 2, 3], [2, 3, 4]]) @@ -191,7 +190,6 @@ def test_scalar_parameter_update(): def test_parameter_lazy_init(): - _set_has_initializer(False) # support lazy init in SEMI_AUTO_PARALLEL mode context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) diff --git a/tests/ut/python/parallel/test_add_relu_redistribution.py b/tests/ut/python/parallel/test_add_relu_redistribution.py index 9a8a43b9bf82ed9f8547509ac2104bab9222af8e..c5c5cc0220df2a3e1747d5cadfa84a0aaed87576 100644 --- a/tests/ut/python/parallel/test_add_relu_redistribution.py +++ b/tests/ut/python/parallel/test_add_relu_redistribution.py @@ -20,7 +20,6 @@ from mindspore import context from mindspore.common.api import _executor from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore.parallel._utils import _set_has_initializer from tests.ut.python.ops.test_math_ops import VirtualLoss @@ -61,7 +60,6 @@ def compile_net(net, x, y): def test_add_relu_stride_slice(): - _set_has_initializer(False) context.set_auto_parallel_context(device_num=8, global_rank=7) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") @@ -75,7 +73,6 @@ def test_add_relu_stride_slice(): def test_add_relu_all_gather(): - _set_has_initializer(False) context.set_auto_parallel_context(device_num=8, global_rank=7) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") diff --git a/tests/ut/python/parallel/test_allreduce_fusion.py b/tests/ut/python/parallel/test_allreduce_fusion.py index bd1f88f85f82e3f1b9a9786f6fc6318204c427fe..c93df7ffb1bcc3467a3af8873e03ad0af761b206 100644 --- a/tests/ut/python/parallel/test_allreduce_fusion.py +++ b/tests/ut/python/parallel/test_allreduce_fusion.py @@ -23,7 +23,6 @@ from mindspore.nn.optim.momentum import Momentum from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train import Model, ParallelMode -from mindspore.parallel._utils import _set_has_initializer from tests.dataset_mock import MindData @@ -182,7 +181,6 @@ def test_allreduce_fusion_parameters(): def test_allreduce_fusion1(): - _set_has_initializer(False) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) diff --git a/tests/ut/python/parallel/test_alltoall.py b/tests/ut/python/parallel/test_alltoall.py index 26f19d722bbd268dd7053e8c37cc46bec0e8cf31..96ff84350462b689b8216e3c15da6ee59863ec5a 100644 --- a/tests/ut/python/parallel/test_alltoall.py +++ b/tests/ut/python/parallel/test_alltoall.py @@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P -from mindspore.parallel._utils import _reset_op_id, _set_has_initializer +from mindspore.parallel._utils import _reset_op_id from mindspore.train import Model, ParallelMode from tests.dataset_mock import MindData @@ -90,7 +90,6 @@ def all_to_all_common(strategy1): def test_all_to_all(): - _set_has_initializer(False) strategy1 = ((8, 1),) context.set_context(mode=context.GRAPH_MODE, save_graphs=False) _reset_op_id() diff --git a/tests/ut/python/parallel/test_arithmetic.py b/tests/ut/python/parallel/test_arithmetic.py index 134685a620bba8db72d74c029fda092b35e625f8..311e1425ead7897abd2454dc75862cf9f145aa9f 100644 --- a/tests/ut/python/parallel/test_arithmetic.py +++ b/tests/ut/python/parallel/test_arithmetic.py @@ -20,7 +20,6 @@ from mindspore import Parameter, Tensor, context from mindspore.common.api import _executor from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore.parallel._utils import _set_has_initializer from tests.ut.python.ops.test_math_ops import VirtualLoss @@ -61,7 +60,6 @@ def test_matmul_sub(): out = self.sub(out, b) return out - _set_has_initializer(False) context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") strategy1 = ((2, 2), (2, 2)) diff --git a/tests/ut/python/parallel/test_initializer_weight_slice.py b/tests/ut/python/parallel/test_initializer_weight_slice.py index 7065087901da11b3edc23d4747b42f8cfd85a42a..098a55c5a2cda7868e088aecf8652fcbea6f7100 100644 --- a/tests/ut/python/parallel/test_initializer_weight_slice.py +++ b/tests/ut/python/parallel/test_initializer_weight_slice.py @@ -84,11 +84,23 @@ def test_wrong_order_set_parallel_mode_with_initializer(): net = Net(strategy1, strategy2, weight) exe = me._executor x = Tensor(np.ones([32, 32]), dtype=ms.float32) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + net.set_auto_parallel() with pytest.raises(RuntimeError): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - net.set_auto_parallel() exe.compile(net, x, auto_parallel_mode=True, phase='train') +def test_wrong_order_set_same_parallel_mode_with_initializer(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + weight = initializer("Normal", [64, 32], ms.float32) + strategy1 = ((2, 1), (4, 1)) + strategy2 = ((2, 4),) + net = Net(strategy1, strategy2, weight) + exe = me._executor + x = Tensor(np.ones([32, 32]), dtype=ms.float32) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net.set_auto_parallel() + exe.compile(net, x, auto_parallel_mode=True, phase='train') + def test_wrong_order_set_parallel_mode_without_initializer(): weight = Tensor(np.ones([64, 32]), ms.float32) strategy1 = ((2, 1), (4, 1)) diff --git a/tests/ut/python/parallel/test_using_seed_for_initializer.py b/tests/ut/python/parallel/test_using_seed_for_initializer.py index a8426ebf5874a8e0599a0735205aff599eaeb17c..9e601efccddcc45a4fe15eb2486ec95b65bc6a70 100644 --- a/tests/ut/python/parallel/test_using_seed_for_initializer.py +++ b/tests/ut/python/parallel/test_using_seed_for_initializer.py @@ -18,7 +18,6 @@ from numpy import allclose import mindspore.common.initializer as init import mindspore.nn as nn from mindspore import Parameter -from mindspore.parallel._utils import _set_has_initializer parameter_shape = [16, 4] @@ -47,7 +46,6 @@ def test_using_same_seed_for_initializer(): np.random.seed(0) net2 = ParameterNet() net2.init_parameters_data() - _set_has_initializer(False) for key in net1.parameters_dict(): if key not in net2.parameters_dict(): assert False @@ -62,7 +60,6 @@ def test_using_diffserent_seed_for_initializer(): np.random.seed(1) net2 = ParameterNet() net2.init_parameters_data() - _set_has_initializer(False) for key in net1.parameters_dict(): if key not in net2.parameters_dict(): assert False