提交 394be434 编写于 作者: Y Yi Huaijie

raise RuntimeError when set different mode after Initializer created

上级 b7c92aa8
......@@ -81,8 +81,6 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
void ParallelContext::set_has_initializer(bool has_initializer) { has_initializer_ = has_initializer; }
void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; }
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
......
......@@ -58,9 +58,6 @@ class ParallelContext {
void set_full_batch(bool full_batch);
bool full_batch() const { return full_batch_; }
void set_has_initializer(bool has_initializer);
bool has_initializer() const { return has_initializer_; }
void set_cast_before_mirror(bool cast_before_mirror);
bool cast_before_mirror() const { return cast_before_mirror_; }
......@@ -115,7 +112,6 @@ class ParallelContext {
static std::shared_ptr<ParallelContext> inst_context_;
bool mirror_mean_;
bool full_batch_;
bool has_initializer_ = false;
bool cast_before_mirror_;
bool loss_repeated_mean_;
int32_t device_num_;
......
......@@ -198,8 +198,6 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_has_initializer", &ParallelContext::set_has_initializer, "Set whether any Initializer has been created.")
.def("get_has_initializer", &ParallelContext::has_initializer, "Get whether any Initializer has been created.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
......
......@@ -24,7 +24,7 @@ from mindspore import log as logger
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor, _set_has_initializer
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
......@@ -383,7 +383,6 @@ class _Executor:
Str, the full phase of the cell.
Bool, if the graph has been compiled before, return False, else return True.
"""
_set_has_initializer(False)
obj.check_names()
args_names, args_list = _generate_pip_args(obj, *args)
dic = dict(zip(args_names, args_list))
......
......@@ -24,7 +24,6 @@ from mindspore import log as logger
from . import dtype as mstype
from .tensor import Tensor
from .._c_expression import random_normal
from ..parallel._utils import _set_has_initializer
_INITIALIZER_ALIAS = dict()
......@@ -43,7 +42,6 @@ class Initializer:
self._kwargs = kwargs
self.shape = None
self.dtype = None
_set_has_initializer(True)
def _initialize(self, *kwargs):
raise NotImplementedError('Must be overridden!')
......
......@@ -90,6 +90,9 @@ class Parameter(MetaTensor):
input_class.__init__(obj, *class_init_args)
# it's better to make the Initializer a kind of metatensor.
obj.init_mode = None
obj.is_default_input_initializer = False
if isinstance(default_input, Initializer):
obj.is_default_input_initializer = True
if not isinstance(obj, Tensor):
obj.init_mode = default_input
return obj
......@@ -118,6 +121,7 @@ class Parameter(MetaTensor):
self.is_param_ps = False
self._cast_type = None
self.init_in_server = False
self.is_in_parallel = _is_in_parallel_mode()
@staticmethod
def _get_base_class(input_class):
......@@ -372,10 +376,17 @@ class Parameter(MetaTensor):
set_sliced (bool): True if the parameter is set sliced after initializing the data.
Default: False.
Raises:
RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
Returns:
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
returns the same initialized `Parameter`.
"""
if self.is_default_input_initializer:
is_current_in_parallel = _is_in_parallel_mode()
if self.is_in_parallel != is_current_in_parallel:
raise RuntimeError("Must set or change parallel mode before any Initializer created.")
if self.init_mode is None:
return self
if layout is not None:
......
......@@ -449,7 +449,7 @@ def set_auto_parallel_context(**kwargs):
next task, interface mindspore.context.reset_auto_parallel_context() needs to be called to reset
the configuration.
Setting or changing parallel modes must be called before any Initializer created, or RuntimeError
will be raised.
may be raised when compile network.
Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
......@@ -491,7 +491,6 @@ def set_auto_parallel_context(**kwargs):
Raises:
ValueError: If input key is not attribute in auto parallel context.
RuntimeError: If there is any Initializer created before setting or changing parallel_mode.
Examples:
>>> context.set_auto_parallel_context(device_num=8)
......
......@@ -176,12 +176,8 @@ class _AutoParallelContext:
Raises:
ValueError: If parallel mode is not supported.
RuntimeError: If there is any Initializer created before setting or changing parallel_mode.
"""
self.check_context_handle()
if self.get_has_initializer():
self.set_has_initializer(False)
raise RuntimeError("Must set or change parallel mode before any Initializer created.")
ret = self._context_handle.set_parallel_mode(parallel_mode)
if ret is False:
raise ValueError("Parallel mode does not support {}".format(parallel_mode))
......@@ -253,21 +249,6 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_full_batch()
def set_has_initializer(self, has_initializer):
"""
Set whether any Initializer has been created.
Args:
has_initializer (bool): True if a Initializer created.
"""
self.check_context_handle()
self._context_handle.set_has_initializer(has_initializer)
def get_has_initializer(self):
"""Get whether any Initializer has been created."""
self.check_context_handle()
return self._context_handle.get_has_initializer()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
"""
Set strategy checkpoint save path.
......@@ -562,7 +543,6 @@ def _set_auto_parallel_context(**kwargs):
Raises:
ValueError: If input key is not attribute in auto parallel context.
RuntimeError: If there is any Initializer created before setting or changing parallel_mode.
"""
for key, value in kwargs.items():
if key not in _set_auto_parallel_context_func_map:
......
......@@ -32,19 +32,6 @@ def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()
def _get_has_initializer():
"""Get whether any Initializer has been created."""
return auto_parallel_context().get_has_initializer()
def _set_has_initializer(has_initializer):
"""
Set whether any Initializer has been created.
Args:
has_initializer (bool): True if a Initializer created.
"""
auto_parallel_context().set_has_initializer(has_initializer)
def _need_to_full():
"""Check whether to convert input to full shape or tensor."""
......
......@@ -24,7 +24,6 @@ import mindspore.nn as nn
from mindspore import Tensor, Model, ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._utils import _set_has_initializer
_current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../test_data"
......@@ -90,4 +89,3 @@ def test_lenet5_train_step_training_pynative():
Model(network=network, loss_fn=loss_fn, optimizer=optimizer)
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
_set_has_initializer(False)
......@@ -21,7 +21,6 @@ from mindspore import context, Tensor, Parameter, ParameterTuple
from mindspore._checkparam import _check_str_by_regular
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.parallel._utils import _set_has_initializer
def test_parameter_init():
dat = np.array([[1, 2, 3], [2, 3, 4]])
......@@ -191,7 +190,6 @@ def test_scalar_parameter_update():
def test_parameter_lazy_init():
_set_has_initializer(False)
# support lazy init in SEMI_AUTO_PARALLEL mode
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
......
......@@ -20,7 +20,6 @@ from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.parallel._utils import _set_has_initializer
from tests.ut.python.ops.test_math_ops import VirtualLoss
......@@ -61,7 +60,6 @@ def compile_net(net, x, y):
def test_add_relu_stride_slice():
_set_has_initializer(False)
context.set_auto_parallel_context(device_num=8, global_rank=7)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......@@ -75,7 +73,6 @@ def test_add_relu_stride_slice():
def test_add_relu_all_gather():
_set_has_initializer(False)
context.set_auto_parallel_context(device_num=8, global_rank=7)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......
......@@ -23,7 +23,6 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train import Model, ParallelMode
from mindspore.parallel._utils import _set_has_initializer
from tests.dataset_mock import MindData
......@@ -182,7 +181,6 @@ def test_allreduce_fusion_parameters():
def test_allreduce_fusion1():
_set_has_initializer(False)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
......
......@@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id, _set_has_initializer
from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode
from tests.dataset_mock import MindData
......@@ -90,7 +90,6 @@ def all_to_all_common(strategy1):
def test_all_to_all():
_set_has_initializer(False)
strategy1 = ((8, 1),)
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
_reset_op_id()
......
......@@ -20,7 +20,6 @@ from mindspore import Parameter, Tensor, context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.parallel._utils import _set_has_initializer
from tests.ut.python.ops.test_math_ops import VirtualLoss
......@@ -61,7 +60,6 @@ def test_matmul_sub():
out = self.sub(out, b)
return out
_set_has_initializer(False)
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
......
......@@ -84,8 +84,20 @@ def test_wrong_order_set_parallel_mode_with_initializer():
net = Net(strategy1, strategy2, weight)
exe = me._executor
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
net.set_auto_parallel()
with pytest.raises(RuntimeError):
exe.compile(net, x, auto_parallel_mode=True, phase='train')
def test_wrong_order_set_same_parallel_mode_with_initializer():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
weight = initializer("Normal", [64, 32], ms.float32)
strategy1 = ((2, 1), (4, 1))
strategy2 = ((2, 4),)
net = Net(strategy1, strategy2, weight)
exe = me._executor
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net.set_auto_parallel()
exe.compile(net, x, auto_parallel_mode=True, phase='train')
......
......@@ -18,7 +18,6 @@ from numpy import allclose
import mindspore.common.initializer as init
import mindspore.nn as nn
from mindspore import Parameter
from mindspore.parallel._utils import _set_has_initializer
parameter_shape = [16, 4]
......@@ -47,7 +46,6 @@ def test_using_same_seed_for_initializer():
np.random.seed(0)
net2 = ParameterNet()
net2.init_parameters_data()
_set_has_initializer(False)
for key in net1.parameters_dict():
if key not in net2.parameters_dict():
assert False
......@@ -62,7 +60,6 @@ def test_using_diffserent_seed_for_initializer():
np.random.seed(1)
net2 = ParameterNet()
net2.init_parameters_data()
_set_has_initializer(False)
for key in net1.parameters_dict():
if key not in net2.parameters_dict():
assert False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部