未验证 提交 2d44a71b 编写于 作者: L Lin Manhui 提交者: GitHub

Toward Devkit Consistency (#10150)

* Accommodate UAPI

* Fix signal handler

* Save model.pdopt

* Change variable name

* Update vdl dir
上级 15abbcc4
...@@ -25,7 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..')) ...@@ -25,7 +25,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import paddle import paddle
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
...@@ -39,6 +39,7 @@ def main(config, device, logger, vdl_writer): ...@@ -39,6 +39,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process # build post process
......
...@@ -26,7 +26,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) ...@@ -26,7 +26,7 @@ sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
...@@ -57,6 +57,7 @@ def main(config, device, logger, vdl_writer): ...@@ -57,6 +57,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if config['Eval']: if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
......
...@@ -34,7 +34,7 @@ from tools.program import load_config, merge_config, ArgsParser ...@@ -34,7 +34,7 @@ from tools.program import load_config, merge_config, ArgsParser
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
import tools.program as program import tools.program as program
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from tools.export_model import export_single_model from tools.export_model import export_single_model
...@@ -134,6 +134,7 @@ def main(): ...@@ -134,6 +134,7 @@ def main():
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# build dataloader # build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
......
...@@ -31,7 +31,7 @@ import paddle.distributed as dist ...@@ -31,7 +31,7 @@ import paddle.distributed as dist
paddle.seed(2) paddle.seed(2)
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
...@@ -95,6 +95,7 @@ def main(config, device, logger, vdl_writer): ...@@ -95,6 +95,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if config['Eval']: if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
......
...@@ -31,7 +31,7 @@ import paddle.distributed as dist ...@@ -31,7 +31,7 @@ import paddle.distributed as dist
paddle.seed(2) paddle.seed(2)
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
...@@ -117,6 +117,7 @@ def main(config, device, logger, vdl_writer): ...@@ -117,6 +117,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
set_signal_handlers()
config['Train']['loader']['num_workers'] = 0 config['Train']['loader']['num_workers'] = 0
is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer' is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
......
...@@ -39,7 +39,7 @@ from ppocr.data.pgnet_dataset import PGDataSet ...@@ -39,7 +39,7 @@ from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.multi_scale_sampler import MultiScaleSampler from ppocr.data.multi_scale_sampler import MultiScaleSampler
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators', 'set_signal_handlers']
def term_mp(sig_num, frame): def term_mp(sig_num, frame):
...@@ -51,6 +51,21 @@ def term_mp(sig_num, frame): ...@@ -51,6 +51,21 @@ def term_mp(sig_num, frame):
os.killpg(pgid, signal.SIGKILL) os.killpg(pgid, signal.SIGKILL)
def set_signal_handlers():
pid = os.getpid()
pgid = os.getpgid(os.getpid())
# XXX: `term_mp` kills all processes in the process group, which in
# some cases includes the parent process of current process and may
# cause unexpected results. To solve this problem, we set signal
# handlers only when current process is the group leader. In the
# future, it would be better to consider killing only descendants of
# the current process.
if pid == pgid:
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
...@@ -109,8 +124,4 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -109,8 +124,4 @@ def build_dataloader(config, mode, device, logger, seed=None):
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
collate_fn=collate_fn) collate_fn=collate_fn)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
return data_loader return data_loader
...@@ -197,13 +197,26 @@ def save_model(model, ...@@ -197,13 +197,26 @@ def save_model(model,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
if prefix == 'best_accuracy':
best_model_path = os.path.join(model_path, 'best_model')
_mkdir_if_not_exist(best_model_path, logger)
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if prefix == 'best_accuracy':
paddle.save(optimizer.state_dict(),
os.path.join(best_model_path, 'model.pdopt'))
is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[ is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
"Architecture"]["algorithm"] not in ["SDMGR"] "Architecture"]["algorithm"] not in ["SDMGR"]
if is_nlp_model is not True: if is_nlp_model is not True:
paddle.save(model.state_dict(), model_prefix + '.pdparams') paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix metric_prefix = model_prefix
if prefix == 'best_accuracy':
paddle.save(model.state_dict(),
os.path.join(best_model_path, 'model.pdparams'))
else: # for kie system, we follow the save/load rules in NLP else: # for kie system, we follow the save/load rules in NLP
if config['Global']['distributed']: if config['Global']['distributed']:
arch = model._layers arch = model._layers
...@@ -213,6 +226,10 @@ def save_model(model, ...@@ -213,6 +226,10 @@ def save_model(model,
arch = arch.Student arch = arch.Student
arch.backbone.model.save_pretrained(model_prefix) arch.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric') metric_prefix = os.path.join(model_prefix, 'metric')
if prefix == 'best_accuracy':
arch.backbone.model.save_pretrained(best_model_path)
# save metric and config # save metric and config
with open(metric_prefix + '.states', 'wb') as f: with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2) pickle.dump(kwargs, f, protocol=2)
......
...@@ -24,7 +24,7 @@ sys.path.insert(0, __dir__) ...@@ -24,7 +24,7 @@ sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import paddle import paddle
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
...@@ -35,6 +35,7 @@ import tools.program as program ...@@ -35,6 +35,7 @@ import tools.program as program
def main(): def main():
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process # build post process
......
...@@ -24,7 +24,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -24,7 +24,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
...@@ -40,6 +40,7 @@ def main(): ...@@ -40,6 +40,7 @@ def main():
'data_dir'] 'data_dir']
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][ config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
'label_file_list'] 'label_file_list']
set_signal_handlers()
eval_dataloader = build_dataloader(config, 'Eval', device, logger) eval_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process # build post process
......
...@@ -40,17 +40,16 @@ import tools.program as program ...@@ -40,17 +40,16 @@ import tools.program as program
def draw_det_res(dt_boxes, config, img, img_name, save_path): def draw_det_res(dt_boxes, config, img, img_name, save_path):
if len(dt_boxes) > 0: import cv2
import cv2 src_im = img
src_im = img for box in dt_boxes:
for box in dt_boxes: box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
box = np.array(box).astype(np.int32).reshape((-1, 1, 2)) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) if not os.path.exists(save_path):
if not os.path.exists(save_path): os.makedirs(save_path)
os.makedirs(save_path) save_path = os.path.join(save_path, os.path.basename(img_name))
save_path = os.path.join(save_path, os.path.basename(img_name)) cv2.imwrite(save_path, src_im)
cv2.imwrite(save_path, src_im) logger.info("The detected Image saved in {}".format(save_path))
logger.info("The detected Image saved in {}".format(save_path))
@paddle.no_grad() @paddle.no_grad()
......
...@@ -683,7 +683,7 @@ def preprocess(is_train=False): ...@@ -683,7 +683,7 @@ def preprocess(is_train=False):
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']: if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = save_model_dir
log_writer = VDLLogger(vdl_writer_path) log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer) loggers.append(log_writer)
if ('use_wandb' in config['Global'] and if ('use_wandb' in config['Global'] and
......
...@@ -27,7 +27,7 @@ import yaml ...@@ -27,7 +27,7 @@ import yaml
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from ppocr.data import build_dataloader from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
...@@ -49,6 +49,7 @@ def main(config, device, logger, vdl_writer): ...@@ -49,6 +49,7 @@ def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0: if len(train_dataloader) == 0:
logger.error( logger.error(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册