提交 f2054c8d 编写于 作者: W wuzewu

Update exception notes

上级 e57927b9
...@@ -124,7 +124,7 @@ class BaseTask(object): ...@@ -124,7 +124,7 @@ class BaseTask(object):
def enter_phase(self, phase: str): def enter_phase(self, phase: str):
if phase not in ['train', 'val', 'dev', 'test', 'predict', 'inference']: if phase not in ['train', 'val', 'dev', 'test', 'predict', 'inference']:
raise RuntimeError() raise RuntimeError('Unknown phase {}.'.format(phase))
if phase in ['val', 'dev']: if phase in ['val', 'dev']:
phase = 'dev' phase = 'dev'
elif phase in ['predict', 'inference']: elif phase in ['predict', 'inference']:
...@@ -289,7 +289,7 @@ class BaseTask(object): ...@@ -289,7 +289,7 @@ class BaseTask(object):
@property @property
def loss(self) -> paddle.static.Variable: def loss(self) -> paddle.static.Variable:
if self.is_predict_phase: if self.is_predict_phase:
raise RuntimeError() raise RuntimeError('Loss cannot be obtained in the prediction phase.')
if not self.env.is_inititalized: if not self.env.is_inititalized:
self._build_env() self._build_env()
...@@ -298,7 +298,7 @@ class BaseTask(object): ...@@ -298,7 +298,7 @@ class BaseTask(object):
@property @property
def labels(self) -> List[paddle.static.Variable]: def labels(self) -> List[paddle.static.Variable]:
if self.is_predict_phase: if self.is_predict_phase:
raise RuntimeError() raise RuntimeError('Labels cannot be obtained in the prediction phase.')
if not self.env.is_inititalized: if not self.env.is_inititalized:
self._build_env() self._build_env()
...@@ -313,7 +313,7 @@ class BaseTask(object): ...@@ -313,7 +313,7 @@ class BaseTask(object):
@property @property
def metrics(self) -> List[str]: def metrics(self) -> List[str]:
if self.is_predict_phase: if self.is_predict_phase:
raise RuntimeError() raise RuntimeError('Metrics cannot be obtained in the prediction phase.')
if not self.env.is_inititalized: if not self.env.is_inititalized:
self._build_env() self._build_env()
......
...@@ -314,11 +314,12 @@ class Trainer(object): ...@@ -314,11 +314,12 @@ class Trainer(object):
# process result # process result
if not isinstance(result, dict): if not isinstance(result, dict):
raise RuntimeError() raise RuntimeError('The return value of `trainning_step` in {} is not a dict'.format(self.model.__class__))
loss = result.get('loss', None) loss = result.get('loss', None)
if not loss: if not loss:
raise RuntimeError() raise RuntimeError('Cannot find loss attribute in the return value of `trainning_step` of {}'.format(
self.model.__class__))
metrics = result.get('metrics', {}) metrics = result.get('metrics', {})
......
...@@ -34,7 +34,7 @@ class HubServer(object): ...@@ -34,7 +34,7 @@ class HubServer(object):
elif source_type == 'git': elif source_type == 'git':
source = GitSource(url) source = GitSource(url)
else: else:
raise RuntimeError() raise RuntimeError('Unknown source type {}.'.format(source_type))
return source return source
def _get_source_key(self, url: str): def _get_source_key(self, url: str):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册