未验证 提交 c3e77c50 编写于 作者: W whs 提交者: GitHub

Fix random seed. (#1267)

上级 e56ab89e
......@@ -17,6 +17,7 @@ from paddle.fluid.initializer import init_on_cpu
if 'ce_mode' in os.environ:
np.random.seed(10)
fluid.default_startup_program().random_seed = 90
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -91,9 +92,6 @@ def train(args):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
if 'ce_mode' in os.environ:
fluid.default_startup_program().random_seed = 90
exe.run(fluid.default_startup_program())
if args.init_model is not None:
......@@ -126,9 +124,10 @@ def train(args):
sub124_loss += results[3]
# training log
if iter_id % LOG_PERIOD == 0:
print("Iter[%d]; train loss: %.3f; sub4_loss: %.3f; sub24_loss: %.3f; sub124_loss: %.3f" % (
iter_id, t_loss / LOG_PERIOD, sub4_loss / LOG_PERIOD,
sub24_loss / LOG_PERIOD, sub124_loss / LOG_PERIOD))
print(
"Iter[%d]; train loss: %.3f; sub4_loss: %.3f; sub24_loss: %.3f; sub124_loss: %.3f"
% (iter_id, t_loss / LOG_PERIOD, sub4_loss / LOG_PERIOD,
sub24_loss / LOG_PERIOD, sub124_loss / LOG_PERIOD))
print("kpis train_cost %f" % (t_loss / LOG_PERIOD))
t_loss = 0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册