diff --git a/tools/eval.py b/tools/eval.py index 07181ee75d44e3f5f0676f55f89330b215ef5220..16cfe532aae49ce98bc9503ca73e009bf206caa7 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) -import paddle -# paddle.manual_seed(2) - -from ppocr.utils.logging import get_logger from ppocr.data import build_dataloader -from ppocr.modeling import build_model +from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import init_model @@ -39,8 +35,7 @@ import tools.program as program def main(): global_config = config['Global'] # build dataloader - eval_loader, _ = build_dataloader(config['EVAL'], device, False, - global_config) + valid_dataloader = build_dataloader(config, 'Eval', device, logger) # build post process post_process_class = build_post_process(config['PostProcess'], @@ -63,16 +58,13 @@ def main(): eval_class = build_metric(config['Metric']) # start eval - metirc = program.eval(model, eval_loader, post_process_class, eval_class) + metirc = program.eval(model, valid_dataloader, post_process_class, + eval_class) logger.info('metric eval ***************') for k, v in metirc.items(): logger.info('{}:{}'.format(k, v)) if __name__ == '__main__': - device, config = program.preprocess() - paddle.disable_static(device) - - logger = get_logger() - print_dict(config, logger) + config, device, logger, vdl_writer = program.preprocess() main()