提交 b8a65d43 编写于 作者: L LDOUBLEV

fix eval bug

上级 0742f5c5
......@@ -55,9 +55,9 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result)
metircs = self.evaluator.combine_results(self.results)
self.reset()
return metircs
# metircs = self.evaluator.combine_results(self.results)
# self.reset()
# return metircs
def get_metric(self):
"""
......
......@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class DistillationMetric(object):
def __init__(self,
key=None,
base_metric_name="RecMetric",
main_indicator='acc',
base_metric_name=None,
main_indicator=None,
**kwargs):
self.main_indicator = main_indicator
self.key = key
......@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs):
def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict)
if self.metrics is None:
self._init_metrcis(preds)
output = dict()
for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
self.metrics[key].__call__(preds[key], batch, **kwargs)
def get_metric(self):
"""
......
......@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
load_pretrained_params(model, pretrained)
model = load_pretrained_params(model, pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
......
......@@ -189,29 +189,27 @@ class DBPostProcess(object):
return boxes_batch
class DistillationDBPostProcess(DBPostProcess):
def __init__(self,
model_name=["student"],
class DistillationDBPostProcess(object):
def __init__(self, model_name=["student"],
key=None,
thresh=0.3,
box_thresh=0.7,
box_thresh=0.6,
max_candidates=1000,
unclip_ratio=2.0,
unclip_ratio=1.5,
use_dilation=False,
score_mode="fast",
**kwargs):
super().__init__()
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
self.post_process = DBPostProcess(thresh=thresh,
box_thresh=box_thresh,
max_candidates=max_candidates,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode)
def __call__(self, predicts, shape_list):
results = {}
for name in self.model_name:
pred = predicts[name]
if self.key is not None:
pred = pred[self.key]
results[name] = super().__call__(pred, shape_list=shape_list)
for k in self.model_name:
results[k] = self.post_process(predicts[k], shape_list=shape_list)
return results
......@@ -136,7 +136,7 @@ def load_pretrained_params(model, path):
)
model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}")
return True
return model
def save_model(model,
optimizer,
......
......@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.utility import print_dict
import tools.program as program
......@@ -59,7 +59,8 @@ def main():
model_type = config['Architecture']['model_type']
else:
model_type = None
best_model_dict = init_model(config, model)
best_model_dict = init_model(config, model, model_type)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
......@@ -70,7 +71,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
eval_class, model_type, use_srn)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
......
......@@ -374,6 +374,7 @@ def eval(model,
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
# post_result = post_result_["Student"]
eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
......
......@@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
#pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
pre_best_model_dict = {}
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册