From b6a9f5d2daea025a96084e9f100684065c7185ae Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 27 Sep 2021 19:43:36 +0800 Subject: [PATCH] add pse to windows_not_support_list --- ppocr/postprocess/__init__.py | 7 ++++++- tools/program.py | 29 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 3eb5e28d..80d926a2 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -18,8 +18,10 @@ from __future__ import print_function from __future__ import unicode_literals import copy +import platform __all__ = ['build_post_process'] +from ppocr.utils.logging import get_logger from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess @@ -28,7 +30,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di TableLabelDecode, SARLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess -from .pse_postprocess import PSEPostProcess + +if platform.system() != "Windows": + # pse is not support in Windows + from .pse_postprocess import PSEPostProcess def build_post_process(config, global_config=None): diff --git a/tools/program.py b/tools/program.py index f484cf4a..10eb246a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -395,6 +395,18 @@ def preprocess(is_train=False): config = load_config(FLAGS.config) merge_config(FLAGS.opt) + if is_train: + # save_config + save_model_dir = config['Global']['save_model_dir'] + os.makedirs(save_model_dir, exist_ok=True) + with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: + yaml.dump( + dict(config), f, default_flow_style=False, sort_keys=False) + log_file = '{}/train.log'.format(save_model_dir) + else: + log_file = None + logger = get_logger(name='root', log_file=log_file) + # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] check_gpu(use_gpu) @@ -404,22 +416,17 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE' ] + windows_not_support_list = ['PSE'] + if platform.system() == "Windows" and alg in windows_not_support_list: + logger.warning('{} is not support in Windows now'.format( + windows_not_support_list)) + sys.exit() device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 - if is_train: - # save_config - save_model_dir = config['Global']['save_model_dir'] - os.makedirs(save_model_dir, exist_ok=True) - with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: - yaml.dump( - dict(config), f, default_flow_style=False, sort_keys=False) - log_file = '{}/train.log'.format(save_model_dir) - else: - log_file = None - logger = get_logger(name='root', log_file=log_file) + if config['Global']['use_visualdl']: from visualdl import LogWriter save_model_dir = config['Global']['save_model_dir'] -- GitLab