diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 2bed1031455c274c71c8df25867639bb5acc7c67..00f5014e440824b434b3203f265d744ac887f760 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -20,6 +20,8 @@ import paddle import paddle.distributed as dist from visualdl import LogWriter from paddle import nn +import numpy as np +import random from ppcls.utils.check import check_gpu from ppcls.utils.misc import AverageMeter @@ -57,6 +59,14 @@ class Engine(object): else: self.is_rec = False + # set seed + seed = self.config["Global"].get("seed", False) + if seed: + assert isinstance(seed, int), "The 'seed' must be a integer!" + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + # init logger self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],