diff --git a/mindinsight/datavisual/processors/train_task_manager.py b/mindinsight/datavisual/processors/train_task_manager.py index 03cf16fbde213a2e1ff1012662fff0dba32c8933..b1aeb876494d74ffea01083c61ca8552a73d74f7 100644 --- a/mindinsight/datavisual/processors/train_task_manager.py +++ b/mindinsight/datavisual/processors/train_task_manager.py @@ -14,6 +14,7 @@ # ============================================================================ """Train task manager.""" +from mindinsight.utils.exceptions import ParamTypeError from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum @@ -141,9 +142,20 @@ class TrainTaskManager(BaseProcessor): Returns: dict, indicates train job ID and its current cache status. + + Raises: + ParamTypeError, if the given train_ids parameter is not in valid type. """ + if not isinstance(train_ids, list): + logger.error("train_ids must be list.") + raise ParamTypeError('train_ids', list) + cache_result = [] for train_id in train_ids: + if not isinstance(train_id, str): + logger.error("train_id must be str.") + raise ParamTypeError('train_id', str) + try: train_job = self._data_manager.get_train_job(train_id) except exceptions.TrainJobNotExistError: