未验证 提交 ae43bf8e 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1225 from RainFrost1/seed

add seed
...@@ -20,6 +20,8 @@ import paddle ...@@ -20,6 +20,8 @@ import paddle
import paddle.distributed as dist import paddle.distributed as dist
from visualdl import LogWriter from visualdl import LogWriter
from paddle import nn from paddle import nn
import numpy as np
import random
from ppcls.utils.check import check_gpu from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
...@@ -57,6 +59,14 @@ class Engine(object): ...@@ -57,6 +59,14 @@ class Engine(object):
else: else:
self.is_rec = False self.is_rec = False
# set seed
seed = self.config["Global"].get("seed", False)
if seed:
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
# init logger # init logger
self.output_dir = self.config['Global']['output_dir'] self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册