提交 7b99a1cb 编写于 作者: W wangnan39@huawei.com

fix bug in model predict and eval

上级 2d31ae97
......@@ -108,6 +108,7 @@ class Model:
self._train_network = self._build_train_network()
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network()
def _check_kwargs(self, kwargs):
for arg in kwargs:
......@@ -153,6 +154,12 @@ class Model:
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn)
self._eval_indexes = [0, 1, 2]
def _build_predict_network(self):
"""Build the network for prediction."""
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
def _clear_metrics(self):
"""Clear metrics local values."""
for metric in self._metric_fns.values():
......@@ -466,6 +473,7 @@ class Model:
dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False)
for next_element in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs
......@@ -543,12 +551,9 @@ class Model:
>>> model = Model(Net())
>>> model.predict(input_data)
"""
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._network = _VirtualDatasetCell(self._network)
self._network.set_train(False)
self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor)
result = self._network(*predict_data)
result = self._predict_network(*predict_data)
check_output_data(result)
return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册