提交 0a795c24 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1245 add model init api to compile df graph before train

Merge pull request !1245 from wangnan39/add_model_init_api_to_compile_df_graph_before_train_and_eval
...@@ -390,6 +390,10 @@ class _Executor: ...@@ -390,6 +390,10 @@ class _Executor:
if auto_parallel_mode and "train" in phase: if auto_parallel_mode and "train" in phase:
obj.load_parameter_slice(params) obj.load_parameter_slice(params)
# set parallel inputs in sink mode
if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
obj.set_parallel_input_with_inputs(*args)
# the following GE init process is not needed when use vm or ms backend # the following GE init process is not needed when use vm or ms backend
if enable_ge: if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not # decide whether to sink based on whether the inputs is virtual or not
......
...@@ -294,6 +294,15 @@ class Cell: ...@@ -294,6 +294,15 @@ class Cell:
parallel_inputs_run.append(new_tensor) parallel_inputs_run.append(new_tensor)
return tuple(parallel_inputs_run) return tuple(parallel_inputs_run)
def set_parallel_input_with_inputs(self, *inputs):
"""
Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run`
Args:
inputs (tuple): inputs of construct method.
"""
self._parallel_inputs_run = self._load_inputs(*inputs)
def _get_construct_inputs_number_and_name(self): def _get_construct_inputs_number_and_name(self):
"""Compute self._construct_inputs_names and self._construct_inputs_num""" """Compute self._construct_inputs_names and self._construct_inputs_num"""
import inspect import inspect
...@@ -310,6 +319,15 @@ class Cell: ...@@ -310,6 +319,15 @@ class Cell:
self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num] self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
self._construct_inputs_num = self._construct_inputs_num - 1 self._construct_inputs_num = self._construct_inputs_num - 1
def compile(self, *inputs):
"""
Compiles cell.
Args:
inputs (tuple): Input parameters.
"""
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
def compile_and_run(self, *inputs): def compile_and_run(self, *inputs):
""" """
Compiles and runs cell. Compiles and runs cell.
...@@ -320,13 +338,14 @@ class Cell: ...@@ -320,13 +338,14 @@ class Cell:
Returns: Returns:
Object, the result of executing. Object, the result of executing.
""" """
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase, _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
auto_parallel_mode=self._auto_parallel_mode)
if self._auto_parallel_mode: if self._auto_parallel_mode:
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag): if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
# get parallel inputs in sink mode, parallel inputs set in _executor.compile
parallel_inputs_run = self._parallel_inputs_run parallel_inputs_run = self._parallel_inputs_run
else: else:
# set parallel inputs in normal mode
self._parallel_inputs_run = self._load_inputs(*inputs) self._parallel_inputs_run = self._load_inputs(*inputs)
parallel_inputs_run = self._parallel_inputs_run parallel_inputs_run = self._parallel_inputs_run
return _executor(self, *parallel_inputs_run, phase=self.phase) return _executor(self, *parallel_inputs_run, phase=self.phase)
......
...@@ -217,6 +217,94 @@ class Model: ...@@ -217,6 +217,94 @@ class Model:
scaling_sens /= self._device_number scaling_sens /= self._device_number
return scaling_sens return scaling_sens
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode):
"""Initializes dataset."""
need_wrap = False
if dataset_sink_mode:
# remove later to deal with loop sink
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
if not is_train:
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode)
# remove later to deal with loop sink
if need_wrap:
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
network.set_train(is_train)
network.phase = phase
return dataset_helper, network
def init(self, train_dataset=None, valid_dataset=None):
"""
Initializes compute graphs and data graphs with sink mode.
Note:
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
Args:
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
initialized. Default: None.
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
be initialized, and `metrics` in `Model` can not be None. Default: None.
Examples:
>>> train_dataset = get_train_dataset()
>>> valid_dataset = get_valid_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
>>> model.init(train_dataset, valid_dataset)
>>> model.train(2, train_dataset)
>>> model.eval(valid_dataset)
"""
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
if not train_dataset and not valid_dataset:
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
_device_number_check(self._parallel_mode, self._device_number)
if train_dataset:
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train_network.set_train()
self._train_network.phase = 'train'
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=True)
self._train_network = train_network
for inputs in train_dataset_helper:
self._train_network.compile(*inputs)
break
if valid_dataset:
if not self._metric_fns:
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
self._eval_network.set_train(False)
self._eval_network.phase = 'eval'
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
for inputs in valid_dataset_helper:
self._eval_network.compile(*inputs)
break
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
""" """
Training. Training.
...@@ -277,21 +365,15 @@ class Model: ...@@ -277,21 +365,15 @@ class Model:
list_callback (_ListCallback): Executor of callback list. Default: None. list_callback (_ListCallback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
""" """
# remove later to deal with loop sink dataset_helper, train_network = self._exec_preprocess(self._train_network,
need_wrap = False is_train=True,
if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ phase='train',
and not context.get_context("enable_ge"): dataset=train_dataset,
need_wrap = True dataset_sink_mode=True)
self._train_network = train_network
dataset_helper = DatasetHelper(train_dataset) cb_params.train_network = self._train_network
# remove later to deal with loop sink
if need_wrap:
self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
train_dataset.__ME_INITED__)
cb_params.train_network = self._train_network
self._train_network.set_train()
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size() loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
list_callback.begin(run_context) list_callback.begin(run_context)
...@@ -331,7 +413,11 @@ class Model: ...@@ -331,7 +413,11 @@ class Model:
list_callback (_ListCallback): Executor of callback list. Default: None. list_callback (_ListCallback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
""" """
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) dataset_helper, _ = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=False)
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
list_callback.begin(run_context) list_callback.begin(run_context)
...@@ -437,26 +523,15 @@ class Model: ...@@ -437,26 +523,15 @@ class Model:
Returns: Returns:
Dict, returns the loss value & metrics values for the model in test mode. Dict, returns the loss value & metrics values for the model in test mode.
""" """
_device_number_check(self._parallel_mode, self._device_number)
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
# remove later to deal with loop sink dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
need_wrap = False is_train=False,
if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ phase='eval',
and not context.get_context("enable_ge"): dataset=valid_dataset,
need_wrap = True dataset_sink_mode=True)
self._eval_network = eval_network
valid_dataset.__loop_size__ = 1 cb_params.eval_network = self._eval_network
dataset_helper = DatasetHelper(valid_dataset)
# remove later to deal with loop sink
if need_wrap:
self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
valid_dataset.__ME_INITED__)
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
list_callback.begin(run_context) list_callback.begin(run_context)
for inputs in dataset_helper: for inputs in dataset_helper:
...@@ -490,7 +565,11 @@ class Model: ...@@ -490,7 +565,11 @@ class Model:
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
list_callback.begin(run_context) list_callback.begin(run_context)
dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) dataset_helper, _ = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=False)
for next_element in dataset_helper: for next_element in dataset_helper:
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
...@@ -532,6 +611,7 @@ class Model: ...@@ -532,6 +611,7 @@ class Model:
>>> model.eval(dataset) >>> model.eval(dataset)
""" """
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns: if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.") raise ValueError("metric fn can not be None or empty.")
......
...@@ -68,12 +68,12 @@ class LossNet(nn.Cell): ...@@ -68,12 +68,12 @@ class LossNet(nn.Cell):
return out return out
def get_model(): def get_model(metrics=None):
""" get_model """ """ get_model """
net = Net() net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) model = Model(net, loss_fn=loss, optimizer=optim, metrics=metrics)
return model return model
...@@ -215,8 +215,27 @@ def test_model_build_abnormal_string(): ...@@ -215,8 +215,27 @@ def test_model_build_abnormal_string():
assert err assert err
def test_model_init_error(): def test_model_init():
""" test_model_init_error """ """ test_model_init_error """
train_dataset = get_dataset()
eval_dataset = get_dataset()
with pytest.raises(RuntimeError):
context.set_context(mode=context.PYNATIVE_MODE)
get_model().init(train_dataset)
context.set_context(mode=context.GRAPH_MODE)
get_model().init(train_dataset)
get_model(metrics={'acc'}).init(eval_dataset)
with pytest.raises(RuntimeError):
get_model().init(train_dataset, eval_dataset)
with pytest.raises(ValueError):
get_model().init()
def test_init_model_error():
""" test_init_model_error """
net = nn.ReLU() net = nn.ReLU()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
with pytest.raises(KeyError): with pytest.raises(KeyError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册