diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml index 29005bfbd4c085cc2514cb9d5eca1825c6dc48a4..ca6b0d29db534eb1189e305d1f033ece24c368b9 100644 --- a/configs/vqa/re/layoutxlm.yml +++ b/configs/vqa/re/layoutxlm.yml @@ -1,6 +1,6 @@ Global: use_gpu: True - epoch_num: 200 + epoch_num: &epoch_num 200 log_smooth_window: 10 print_batch_step: 10 save_model_dir: ./output/re_layoutxlm/ @@ -10,6 +10,7 @@ Global: cal_metric_during_train: False save_inference_dir: use_visualdl: False + seed: 2022 infer_img: doc/vqa/input/zh_val_21.jpg save_res_path: ./output/re/ diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml index 805a3993ffd9f76716755cf7fa2cfc5d440462e5..29cb46885799c061d715def5bcacb068775930d0 100644 --- a/configs/vqa/ser/layoutlm.yml +++ b/configs/vqa/ser/layoutlm.yml @@ -10,6 +10,7 @@ Global: cal_metric_during_train: False save_inference_dir: use_visualdl: False + seed: 2022 infer_img: doc/vqa/input/zh_val_0.jpg save_res_path: ./output/ser/predicts_layoutlm.txt diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml index 54b1899c68a7e7b07fd13d69d49ece302662d00c..eb1cca5a215dd65ef9c302441d05b482f2622a79 100644 --- a/configs/vqa/ser/layoutxlm.yml +++ b/configs/vqa/ser/layoutxlm.yml @@ -10,6 +10,7 @@ Global: cal_metric_during_train: False save_inference_dir: use_visualdl: False + seed: 2022 infer_img: doc/vqa/input/zh_val_42.jpg save_res_path: ./output/ser diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index a560deb60c78548fe402833232a5632953075595..76484dfd3d3caaa03731368cf4eace1715121874 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -16,6 +16,9 @@ import logging import os import imghdr import cv2 +import random +import numpy as np +import paddle def print_dict(d, logger, delimiter=0): @@ -96,3 +99,9 @@ def load_vqa_bio_label_maps(label_map_path): label2id_map = {label: idx for idx, label in enumerate(labels)} id2label_map = {idx: label for idx, label in enumerate(labels)} return label2id_map, id2label_map + + +def set_seed(seed=1024): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) diff --git a/tools/train.py b/tools/train.py index c96298dd6645b4fc95afc64b370990e1417dbd46..506e0f7fa87fe8afc82cbb12d553a8da4ba298e2 100755 --- a/tools/train.py +++ b/tools/train.py @@ -27,8 +27,6 @@ import yaml import paddle import paddle.distributed as dist -paddle.seed(2) - from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss @@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model +from ppocr.utils.utility import set_seed import tools.program as program dist.get_world_size() @@ -146,5 +145,7 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) + seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024 + set_seed(seed) main(config, device, logger, vdl_writer) # test_reader(config, device, logger)