提交 d6afe165 编写于 作者: 文幕地方's avatar 文幕地方

add seed

上级 524303c3
Global: Global:
use_gpu: True use_gpu: True
epoch_num: 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/re_layoutxlm/ save_model_dir: ./output/re_layoutxlm/
...@@ -10,6 +10,7 @@ Global: ...@@ -10,6 +10,7 @@ Global:
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re/
......
...@@ -10,6 +10,7 @@ Global: ...@@ -10,6 +10,7 @@ Global:
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/predicts_layoutlm.txt save_res_path: ./output/ser/predicts_layoutlm.txt
......
...@@ -10,6 +10,7 @@ Global: ...@@ -10,6 +10,7 @@ Global:
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_42.jpg infer_img: doc/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser save_res_path: ./output/ser
......
...@@ -16,6 +16,9 @@ import logging ...@@ -16,6 +16,9 @@ import logging
import os import os
import imghdr import imghdr
import cv2 import cv2
import random
import numpy as np
import paddle
def print_dict(d, logger, delimiter=0): def print_dict(d, logger, delimiter=0):
...@@ -96,3 +99,9 @@ def load_vqa_bio_label_maps(label_map_path): ...@@ -96,3 +99,9 @@ def load_vqa_bio_label_maps(label_map_path):
label2id_map = {label: idx for idx, label in enumerate(labels)} label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)} id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map return label2id_map, id2label_map
def set_seed(seed=1024):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
...@@ -27,8 +27,6 @@ import yaml ...@@ -27,8 +27,6 @@ import yaml
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
...@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer ...@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -146,5 +145,7 @@ def test_reader(config, device, logger): ...@@ -146,5 +145,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True) config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger) # test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册