提交 b8a65d43 编写于 作者: L LDOUBLEV

fix eval bug

上级 0742f5c5
...@@ -55,9 +55,9 @@ class DetMetric(object): ...@@ -55,9 +55,9 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list) result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result) self.results.append(result)
metircs = self.evaluator.combine_results(self.results) # metircs = self.evaluator.combine_results(self.results)
self.reset() # self.reset()
return metircs # return metircs
def get_metric(self): def get_metric(self):
""" """
......
...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric ...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class DistillationMetric(object): class DistillationMetric(object):
def __init__(self, def __init__(self,
key=None, key=None,
base_metric_name="RecMetric", base_metric_name=None,
main_indicator='acc', main_indicator=None,
**kwargs): **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.key = key self.key = key
...@@ -42,16 +42,13 @@ class DistillationMetric(object): ...@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator=self.main_indicator, **self.kwargs) main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset() self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs): def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict) assert isinstance(preds, dict)
if self.metrics is None: if self.metrics is None:
self._init_metrcis(preds) self._init_metrcis(preds)
output = dict() output = dict()
for key in preds: for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs) self.metrics[key].__call__(preds[key], batch, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def get_metric(self): def get_metric(self):
""" """
......
...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): ...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained") pretrained = model_config.pop("pretrained")
model = BaseModel(model_config) model = BaseModel(model_config)
if pretrained is not None: if pretrained is not None:
load_pretrained_params(model, pretrained) model = load_pretrained_params(model, pretrained)
if freeze_params: if freeze_params:
for param in model.parameters(): for param in model.parameters():
param.trainable = False param.trainable = False
......
...@@ -189,29 +189,27 @@ class DBPostProcess(object): ...@@ -189,29 +189,27 @@ class DBPostProcess(object):
return boxes_batch return boxes_batch
class DistillationDBPostProcess(DBPostProcess): class DistillationDBPostProcess(object):
def __init__(self, def __init__(self, model_name=["student"],
model_name=["student"],
key=None, key=None,
thresh=0.3, thresh=0.3,
box_thresh=0.7, box_thresh=0.6,
max_candidates=1000, max_candidates=1000,
unclip_ratio=2.0, unclip_ratio=1.5,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
**kwargs): **kwargs):
super().__init__()
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
self.post_process = DBPostProcess(thresh=thresh,
box_thresh=box_thresh,
max_candidates=max_candidates,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
for name in self.model_name: for k in self.model_name:
pred = predicts[name] results[k] = self.post_process(predicts[k], shape_list=shape_list)
if self.key is not None:
pred = pred[self.key]
results[name] = super().__call__(pred, shape_list=shape_list)
return results return results
...@@ -136,7 +136,7 @@ def load_pretrained_params(model, path): ...@@ -136,7 +136,7 @@ def load_pretrained_params(model, path):
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}") print(f"load pretrain successful from {path}")
return True return model
def save_model(model, def save_model(model,
optimizer, optimizer,
......
...@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader ...@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
...@@ -59,7 +59,8 @@ def main(): ...@@ -59,7 +59,8 @@ def main():
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
else: else:
model_type = None model_type = None
best_model_dict = init_model(config, model)
best_model_dict = init_model(config, model, model_type)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -374,6 +374,7 @@ def eval(model, ...@@ -374,6 +374,7 @@ def eval(model,
eval_class(preds, batch) eval_class(preds, batch)
else: else:
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
# post_result = post_result_["Student"]
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
......
...@@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer): ...@@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) #pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
pre_best_model_dict = {}
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册