未验证 提交 2d44a71b 编写于 作者: L Lin Manhui 提交者: GitHub

Toward Devkit Consistency (#10150)

* Accommodate UAPI

* Fix signal handler

* Save model.pdopt

* Change variable name

* Update vdl dir
上级 15abbcc4
......@@ -25,7 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import paddle
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
......@@ -39,6 +39,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
......
......@@ -26,7 +26,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import paddle
import paddle.distributed as dist
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
......@@ -57,6 +57,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger)
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
......
......@@ -34,7 +34,7 @@ from tools.program import load_config, merge_config, ArgsParser
from ppocr.metrics import build_metric
import tools.program as program
from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from tools.export_model import export_single_model
......@@ -134,6 +134,7 @@ def main():
eval_class = build_metric(config['Metric'])
# build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN"
......
......@@ -31,7 +31,7 @@ import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
......@@ -95,6 +95,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger)
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
......
......@@ -31,7 +31,7 @@ import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
......@@ -117,6 +117,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build dataloader
set_signal_handlers()
config['Train']['loader']['num_workers'] = 0
is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
train_dataloader = build_dataloader(config, 'Train', device, logger)
......
......@@ -39,7 +39,7 @@ from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.multi_scale_sampler import MultiScaleSampler
__all__ = ['build_dataloader', 'transform', 'create_operators']
__all__ = ['build_dataloader', 'transform', 'create_operators', 'set_signal_handlers']
def term_mp(sig_num, frame):
......@@ -51,6 +51,21 @@ def term_mp(sig_num, frame):
os.killpg(pgid, signal.SIGKILL)
def set_signal_handlers():
pid = os.getpid()
pgid = os.getpgid(os.getpid())
# XXX: `term_mp` kills all processes in the process group, which in
# some cases includes the parent process of current process and may
# cause unexpected results. To solve this problem, we set signal
# handlers only when current process is the group leader. In the
# future, it would be better to consider killing only descendants of
# the current process.
if pid == pgid:
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
......@@ -109,8 +124,4 @@ def build_dataloader(config, mode, device, logger, seed=None):
use_shared_memory=use_shared_memory,
collate_fn=collate_fn)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
return data_loader
......@@ -197,13 +197,26 @@ def save_model(model,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
if prefix == 'best_accuracy':
best_model_path = os.path.join(model_path, 'best_model')
_mkdir_if_not_exist(best_model_path, logger)
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if prefix == 'best_accuracy':
paddle.save(optimizer.state_dict(),
os.path.join(best_model_path, 'model.pdopt'))
is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
"Architecture"]["algorithm"] not in ["SDMGR"]
if is_nlp_model is not True:
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
if prefix == 'best_accuracy':
paddle.save(model.state_dict(),
os.path.join(best_model_path, 'model.pdparams'))
else: # for kie system, we follow the save/load rules in NLP
if config['Global']['distributed']:
arch = model._layers
......@@ -213,6 +226,10 @@ def save_model(model,
arch = arch.Student
arch.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
if prefix == 'best_accuracy':
arch.backbone.model.save_pretrained(best_model_path)
# save metric and config
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
......
......@@ -24,7 +24,7 @@ sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import paddle
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
......@@ -35,6 +35,7 @@ import tools.program as program
def main():
global_config = config['Global']
# build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
......
......@@ -24,7 +24,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
......@@ -40,6 +40,7 @@ def main():
'data_dir']
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
'label_file_list']
set_signal_handlers()
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
......
......@@ -40,17 +40,16 @@ import tools.program as program
def draw_det_res(dt_boxes, config, img, img_name, save_path):
if len(dt_boxes) > 0:
import cv2
src_im = img
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = os.path.join(save_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
import cv2
src_im = img
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = os.path.join(save_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
@paddle.no_grad()
......
......@@ -683,7 +683,7 @@ def preprocess(is_train=False):
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
vdl_writer_path = save_model_dir
log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and
......
......@@ -27,7 +27,7 @@ import yaml
import paddle
import paddle.distributed as dist
from ppocr.data import build_dataloader
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
......@@ -49,6 +49,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0:
logger.error(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册