From 1ebf98b795366b0a3ac1d6110431f220ab8914f6 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Tue, 19 May 2020 11:04:50 +0800 Subject: [PATCH] add model init api to compile df graph before exec --- mindspore/common/api.py | 4 + mindspore/nn/cell.py | 25 ++++- mindspore/train/model.py | 148 +++++++++++++++++++------ tests/ut/python/train/test_training.py | 25 ++++- 4 files changed, 162 insertions(+), 40 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 16df9a00e..b16db3156 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -383,6 +383,10 @@ class _Executor: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) obj.load_parameter_slice(params) + # set parallel inputs in sink mode + if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag): + obj.set_parallel_input_with_inputs(*args) + # the following GE init process is not needed when use vm or ms backend if enable_ge: # decide whether to sink based on whether the inputs is virtual or not diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 13dac375e..d5f697744 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -288,6 +288,15 @@ class Cell: parallel_inputs_run.append(new_tensor) return tuple(parallel_inputs_run) + def set_parallel_input_with_inputs(self, *inputs): + """ + Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run` + + Args: + inputs (tuple): inputs of construct method. + """ + self._parallel_inputs_run = self._load_inputs(*inputs) + def _get_construct_inputs_number_and_name(self): """Compute self._construct_inputs_names and self._construct_inputs_num""" import inspect @@ -304,6 +313,15 @@ class Cell: self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num] self._construct_inputs_num = self._construct_inputs_num - 1 + def compile(self, *inputs): + """ + Compiles cell. + + Args: + inputs (tuple): Input parameters. + """ + _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) + def compile_and_run(self, *inputs): """ Compiles and runs cell. @@ -314,13 +332,14 @@ class Cell: Returns: Object, the result of executing. """ - _, compile_flag = _executor.compile(self, *inputs, phase=self.phase, - auto_parallel_mode=self._auto_parallel_mode) + _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) if self._auto_parallel_mode: - if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag): + if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: + # get parallel inputs in sink mode, parallel inputs set in _executor.compile parallel_inputs_run = self._parallel_inputs_run else: + # set parallel inputs in normal mode self._parallel_inputs_run = self._load_inputs(*inputs) parallel_inputs_run = self._parallel_inputs_run return _executor(self, *parallel_inputs_run, phase=self.phase) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index b4faecbe4..427e5a29c 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -217,6 +217,94 @@ class Model: scaling_sens /= self._device_number return scaling_sens + def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode): + """Initializes dataset.""" + need_wrap = False + if dataset_sink_mode: + # remove later to deal with loop sink + if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ + and not context.get_context("enable_ge"): + need_wrap = True + + if not is_train: + dataset.__loop_size__ = 1 + + dataset_helper = DatasetHelper(dataset, dataset_sink_mode) + + # remove later to deal with loop sink + if need_wrap: + network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) + network.set_train(is_train) + network.phase = phase + + return dataset_helper, network + + def init(self, train_dataset=None, valid_dataset=None): + """ + Initializes compute graphs and data graphs with sink mode. + + Note: + Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. + + Args: + train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be + initialized. Default: None. + valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will + be initialized, and `metrics` in `Model` can not be None. Default: None. + + Examples: + >>> train_dataset = get_train_dataset() + >>> valid_dataset = get_valid_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) + >>> model.init(train_dataset, valid_dataset) + >>> model.train(2, train_dataset) + >>> model.eval(valid_dataset) + """ + if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": + raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') + + if not train_dataset and not valid_dataset: + raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') + + _device_number_check(self._parallel_mode, self._device_number) + + if train_dataset: + _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) + self._train_network.set_train() + self._train_network.phase = 'train' + + if self._parameter_broadcast: + self._train_network.set_broadcast_flag() + + train_dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True) + self._train_network = train_network + for inputs in train_dataset_helper: + self._train_network.compile(*inputs) + break + + if valid_dataset: + if not self._metric_fns: + raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') + + self._eval_network.set_train(False) + self._eval_network.phase = 'eval' + valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + for inputs in valid_dataset_helper: + self._eval_network.compile(*inputs) + break + def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): """ Training. @@ -277,21 +365,15 @@ class Model: list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ - # remove later to deal with loop sink - need_wrap = False - if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ - and not context.get_context("enable_ge"): - need_wrap = True - - dataset_helper = DatasetHelper(train_dataset) - # remove later to deal with loop sink - if need_wrap: - self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), - train_dataset.__ME_INITED__) - cb_params.train_network = self._train_network - self._train_network.set_train() - + dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True) + self._train_network = train_network + cb_params.train_network = self._train_network cb_params.cur_step_num = 0 + loop_size = dataset_helper.loop_size() run_context = RunContext(cb_params) list_callback.begin(run_context) @@ -331,7 +413,11 @@ class Model: list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ - dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) + dataset_helper, _ = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=False) cb_params.cur_step_num = 0 run_context = RunContext(cb_params) list_callback.begin(run_context) @@ -437,26 +523,15 @@ class Model: Returns: Dict, returns the loss value & metrics values for the model in test mode. """ - _device_number_check(self._parallel_mode, self._device_number) - run_context = RunContext(cb_params) - # remove later to deal with loop sink - need_wrap = False - if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ - and not context.get_context("enable_ge"): - need_wrap = True - - valid_dataset.__loop_size__ = 1 - dataset_helper = DatasetHelper(valid_dataset) - - # remove later to deal with loop sink - if need_wrap: - self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), - valid_dataset.__ME_INITED__) - self._eval_network.set_train(mode=False) - self._eval_network.phase = 'eval' - + dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + cb_params.eval_network = self._eval_network list_callback.begin(run_context) for inputs in dataset_helper: @@ -490,7 +565,11 @@ class Model: run_context = RunContext(cb_params) list_callback.begin(run_context) - dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) + dataset_helper, _ = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=False) for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) @@ -532,6 +611,7 @@ class Model: >>> model.eval(dataset) """ check_bool(dataset_sink_mode) + _device_number_check(self._parallel_mode, self._device_number) if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") diff --git a/tests/ut/python/train/test_training.py b/tests/ut/python/train/test_training.py index b56a0868c..5d7c90043 100644 --- a/tests/ut/python/train/test_training.py +++ b/tests/ut/python/train/test_training.py @@ -68,12 +68,12 @@ class LossNet(nn.Cell): return out -def get_model(): +def get_model(metrics=None): """ get_model """ net = Net() loss = nn.SoftmaxCrossEntropyWithLogits() optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) - model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + model = Model(net, loss_fn=loss, optimizer=optim, metrics=metrics) return model @@ -215,8 +215,27 @@ def test_model_build_abnormal_string(): assert err -def test_model_init_error(): +def test_model_init(): """ test_model_init_error """ + train_dataset = get_dataset() + eval_dataset = get_dataset() + + with pytest.raises(RuntimeError): + context.set_context(mode=context.PYNATIVE_MODE) + get_model().init(train_dataset) + + context.set_context(mode=context.GRAPH_MODE) + get_model().init(train_dataset) + get_model(metrics={'acc'}).init(eval_dataset) + + with pytest.raises(RuntimeError): + get_model().init(train_dataset, eval_dataset) + with pytest.raises(ValueError): + get_model().init() + + +def test_init_model_error(): + """ test_init_model_error """ net = nn.ReLU() loss = nn.SoftmaxCrossEntropyWithLogits() with pytest.raises(KeyError): -- GitLab