未验证 提交 af6dd632 编写于 作者: H haoyuying 提交者: GitHub

Raise an exception when the specified module does not support evaluation.

上级 30aace46
......@@ -68,6 +68,7 @@ class Trainer(object):
if not isinstance(self.model, paddle.nn.Layer):
raise TypeError('The model {} is not a `paddle.nn.Layer` object.'.format(self.model.__name__))
if self.local_rank == 0 and not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
......@@ -178,6 +179,9 @@ class Trainer(object):
collate_fn(callable): function to generate mini-batch data by merging the sample list.
None for only stack each fields of sample in axis 0(same as :attr::`np.stack(..., axis=0)`). Default None
'''
if eval_dataset is not None and not hasattr(self.model, 'validation_step'):
raise NotImplementedError('The specified finetuning model does not support evaluation.')
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader(
......@@ -298,6 +302,7 @@ class Trainer(object):
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
......
......@@ -643,21 +643,8 @@ class ImageSegmentationModule(ImageServing, RunModule):
Returns:
results(dict): The model outputs, such as loss.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int) -> dict:
"""
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
"""
'''
label = batch[1].astype('int64')
criterionCE = nn.loss.CrossEntropyLoss()
......@@ -666,10 +653,12 @@ class ImageSegmentationModule(ImageServing, RunModule):
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
logit = F.interpolate(logit, label.shape[-2:], mode='bilinear')
logit = logit.transpose([0,2,3,1])
loss_ce = criterionCE(logit, label)
loss += loss_ce / len(logits)
return {"loss": loss}
def predict(self, images: Union[str, np.ndarray], batch_size: int = 1, visualization: bool = True, save_path: str = 'seg_result') -> List[np.ndarray]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册