提交 731006f1 编写于 作者: K kangguangli 提交者: cuicheng01

set seed by configs

上级 ee36c40d
......@@ -23,6 +23,9 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import numpy as np
import random
import paddle
from paddle.distributed import fleet
from visualdl import LogWriter
......@@ -65,6 +68,15 @@ def main(args):
all the config of training paradigm should be in config["Global"]
"""
config = get_config(args.config, overrides=args.override, show=False)
# set seed
seed = config["Global"].get("seed", False)
if seed or seed == 0:
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
global_config = config["Global"]
mode = "train"
......@@ -207,6 +219,5 @@ def main(args):
if __name__ == '__main__':
paddle.enable_static()
paddle.seed(0)
args = parse_args()
main(args)
......@@ -47,7 +47,7 @@ function _train(){
log_file=${profiling_log_file}
fi
train_cmd="${config_file} -o DataLoader.Train.sampler.batch_size=${base_batch_size} -o Global.epochs=${max_epochs} -o DataLoader.Train.loader.num_workers=${num_workers} ${profiling_config} -o Global.eval_during_train=False -o fuse_elewise_add_act_ops=True -o enable_addto=True"
train_cmd="${config_file} -o DataLoader.Train.sampler.batch_size=${base_batch_size} -o Global.seed=1234 -o Global.epochs=${max_epochs} -o DataLoader.Train.loader.num_workers=${num_workers} ${profiling_config} -o Global.eval_during_train=False -o fuse_elewise_add_act_ops=True -o enable_addto=True"
# 以下为通用执行命令,无特殊可不用修改
case ${run_mode} in
DP) if [[ ${device_num} = "N1C1" ]];then
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册