diff --git a/ELMo/train.py b/ELMo/train.py index c8726f84962d4111d6dbbff39d20db0faa5ef0bb..c062df304c08e52736f3526e1da8ee06456c8773 100755 --- a/ELMo/train.py +++ b/ELMo/train.py @@ -555,6 +555,7 @@ def train_loop(args, valid_ppl = eval(vocab, infer_progs, dev_count, logger, args) logger.info("valid ppl {}".format(valid_ppl)) if batch_id > 0 and batch_id % args.save_interval == 0: + epoch_id = int(batch_id / n_batches_per_epoch) model_path = os.path.join(args.para_save_dir, str(batch_id + epoch_id)) if not os.path.isdir(model_path):