From 7b99a1cb2a8cffa4ee7dd4f0c6c42551b38a3929 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Mon, 20 Apr 2020 20:10:21 +0800 Subject: [PATCH] fix bug in model predict and eval --- mindspore/train/model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 46e4f421f..3391cc7f3 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -108,6 +108,7 @@ class Model: self._train_network = self._build_train_network() self._build_eval_network(metrics, eval_network, eval_indexes) + self._build_predict_network() def _check_kwargs(self, kwargs): for arg in kwargs: @@ -153,6 +154,12 @@ class Model: self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) self._eval_indexes = [0, 1, 2] + def _build_predict_network(self): + """Build the network for prediction.""" + self._predict_network = self._network + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._predict_network = _VirtualDatasetCell(self._network) + def _clear_metrics(self): """Clear metrics local values.""" for metric in self._metric_fns.values(): @@ -466,6 +473,7 @@ class Model: dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) for next_element in dataset_helper: + cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs @@ -543,12 +551,9 @@ class Model: >>> model = Model(Net()) >>> model.predict(input_data) """ - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - self._network = _VirtualDatasetCell(self._network) - - self._network.set_train(False) + self._predict_network.set_train(False) check_input_data(*predict_data, data_class=Tensor) - result = self._network(*predict_data) + result = self._predict_network(*predict_data) check_output_data(result) return result -- GitLab