未验证 提交 7a36ec59 编写于 作者: L LiuChiachi 提交者: GitHub

Fix random seed for language model in static mode (#4836)

上级 bc07a010
...@@ -119,7 +119,7 @@ def main(): ...@@ -119,7 +119,7 @@ def main():
main_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
if args.enable_ce: if args.enable_ce:
startup_program.random_seed = SEED startup_program.random_seed, main_program.random_seed = SEED, SEED
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
res_vars = lm_model.lm_model( res_vars = lm_model.lm_model(
...@@ -154,6 +154,7 @@ def main(): ...@@ -154,6 +154,7 @@ def main():
# define inference program # define inference program
inference_program = fluid.Program() inference_program = fluid.Program()
inference_startup_program = fluid.Program() inference_startup_program = fluid.Program()
inference_program.random_seed, inference_startup_program.radom_seed = SEED, SEED
with fluid.program_guard(inference_program, inference_startup_program): with fluid.program_guard(inference_program, inference_startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
lm_model.lm_model( lm_model.lm_model(
...@@ -207,6 +208,7 @@ def main(): ...@@ -207,6 +208,7 @@ def main():
else: else:
train_program = fluid.compiler.CompiledProgram(main_program) train_program = fluid.compiler.CompiledProgram(main_program)
train_program.random_seed = SEED
data_path = args.data_path data_path = args.data_path
print("begin to load data") print("begin to load data")
ptb_data = reader.get_ptb_data(data_path) ptb_data = reader.get_ptb_data(data_path)
...@@ -483,3 +485,4 @@ def main(): ...@@ -483,3 +485,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册