diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 324b09b453a77991a2b6de68ca701c51b473627f..7e28eb97048910a4b3e53a91432cbaed4ca4254a 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper from . import amp +def _transfer_tensor_to_tuple(inputs): + """ + If the input is a tensor, convert it to a tuple. If not, the output is unchanged. + """ + if isinstance(inputs, Tensor): + return (inputs,) + + return inputs + + class Model: """ High-Level API for Training or Testing. @@ -476,6 +486,7 @@ class Model: for next_element in dataset_helper: len_element = len(next_element) + next_element = _transfer_tensor_to_tuple(next_element) if self._loss_fn and len_element != 2: raise ValueError("when loss_fn is not None, train_dataset should" "return two elements, but got {}".format(len_element)) @@ -630,6 +641,7 @@ class Model: for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) + next_element = _transfer_tensor_to_tuple(next_element) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs list_callback.step_end(run_context)