From 7aa9307742516d8ff88e2cf366ba34741388a5be Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Tue, 21 Jan 2020 16:52:13 +0800 Subject: [PATCH] for bad params init, stop train, test=develop (#4227) --- dygraph/ptb_lm/ptb_dy.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/dygraph/ptb_lm/ptb_dy.py b/dygraph/ptb_lm/ptb_dy.py index d96e4ed3..befa4ad6 100644 --- a/dygraph/ptb_lm/ptb_dy.py +++ b/dygraph/ptb_lm/ptb_dy.py @@ -326,7 +326,7 @@ def train_ptb_lm(): boundaries=bd, values=lr_arr), parameter_list=ptb_model.parameters()) def eval(model, data): - print("begion to eval") + print("begin to eval") total_loss = 0.0 iters = 0.0 init_hidden_data = np.zeros( @@ -404,10 +404,18 @@ def train_ptb_lm(): (epoch_id, batch_id, ppl[0], sgd._global_learning_rate().numpy(), out_loss)) - print("one ecpoh finished", epoch_id) + print("one epoch finished", epoch_id) print("time cost ", time.time() - start_time) ppl = np.exp(total_loss / iters) print("-- Epoch:[%d]; ppl: %.5f" % (epoch_id, ppl[0])) + + if batch_size <= 20 and epoch_id == 0 and ppl[0] > 1000: + # for bad init, after first epoch, the loss is over 1000 + # no more need to continue + print("Parameters are randomly initialized and not good this time because the loss is over 1000 after the first epoch.") + print("Abort this training process and please start again.") + return + if args.ce: print("kpis\ttrain_ppl\t%0.3f" % ppl[0]) save_model_dir = os.path.join(args.save_model_dir, -- GitLab