提交 31a51c17 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5227 dataset_return_single_value

Merge pull request !5227 from lijiaqi/dataset_return_single_value
......@@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper
from . import amp
def _transfer_tensor_to_tuple(inputs):
"""
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
"""
if isinstance(inputs, Tensor):
return (inputs,)
return inputs
class Model:
"""
High-Level API for Training or Testing.
......@@ -476,6 +486,7 @@ class Model:
for next_element in dataset_helper:
len_element = len(next_element)
next_element = _transfer_tensor_to_tuple(next_element)
if self._loss_fn and len_element != 2:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
......@@ -630,6 +641,7 @@ class Model:
for next_element in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
next_element = _transfer_tensor_to_tuple(next_element)
outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册