From d6afe16579577d7ef81fc7fd7fdcdc21ca042600 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Fri, 7 Jan 2022 04:56:45 +0000 Subject: [PATCH] add seed --- configs/vqa/re/layoutxlm.yml | 3 ++- configs/vqa/ser/layoutlm.yml | 1 + configs/vqa/ser/layoutxlm.yml | 1 + ppocr/utils/utility.py | 9 +++++++++ tools/train.py | 5 +++-- 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml index 29005bfb..ca6b0d29 100644 --- a/configs/vqa/re/layoutxlm.yml +++ b/configs/vqa/re/layoutxlm.yml @@ -1,6 +1,6 @@ Global: use_gpu: True - epoch_num: 200 + epoch_num: &epoch_num 200 log_smooth_window: 10 print_batch_step: 10 save_model_dir: ./output/re_layoutxlm/ @@ -10,6 +10,7 @@ Global: cal_metric_during_train: False save_inference_dir: use_visualdl: False + seed: 2022 infer_img: doc/vqa/input/zh_val_21.jpg save_res_path: ./output/re/ diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml index 805a3993..29cb4688 100644 --- a/configs/vqa/ser/layoutlm.yml +++ b/configs/vqa/ser/layoutlm.yml @@ -10,6 +10,7 @@ Global: cal_metric_during_train: False save_inference_dir: use_visualdl: False + seed: 2022 infer_img: doc/vqa/input/zh_val_0.jpg save_res_path: ./output/ser/predicts_layoutlm.txt diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml index 54b1899c..eb1cca5a 100644 --- a/configs/vqa/ser/layoutxlm.yml +++ b/configs/vqa/ser/layoutxlm.yml @@ -10,6 +10,7 @@ Global: cal_metric_during_train: False save_inference_dir: use_visualdl: False + seed: 2022 infer_img: doc/vqa/input/zh_val_42.jpg save_res_path: ./output/ser diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index a560deb6..76484dfd 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -16,6 +16,9 @@ import logging import os import imghdr import cv2 +import random +import numpy as np +import paddle def print_dict(d, logger, delimiter=0): @@ -96,3 +99,9 @@ def load_vqa_bio_label_maps(label_map_path): label2id_map = {label: idx for idx, label in enumerate(labels)} id2label_map = {idx: label for idx, label in enumerate(labels)} return label2id_map, id2label_map + + +def set_seed(seed=1024): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) diff --git a/tools/train.py b/tools/train.py index c96298dd..506e0f7f 100755 --- a/tools/train.py +++ b/tools/train.py @@ -27,8 +27,6 @@ import yaml import paddle import paddle.distributed as dist -paddle.seed(2) - from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss @@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model +from ppocr.utils.utility import set_seed import tools.program as program dist.get_world_size() @@ -146,5 +145,7 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) + seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024 + set_seed(seed) main(config, device, logger, vdl_writer) # test_reader(config, device, logger) -- GitLab