diff --git a/src/trainer.py b/src/trainer.py index b43becd86c3f1e46212eb05cc11062953bf88da1..257d2a5ae2648e4262fe7146df6adadabb1ecd09 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -248,7 +248,7 @@ class rm_train_callback(pl.Callback): to_save_dict = pl_module.state_dict() my_save( to_save_dict, - f"{args.proj_dir}/rwkv-final.pth", + f"{args.proj_dir}/rm-final.pth", ) @@ -276,7 +276,7 @@ class rm_train_callback(pl.Callback): try: my_save( to_save_dict, - f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + f"{args.proj_dir}/rm-{args.epoch_begin + trainer.current_epoch}.pth", ) except Exception as e: print('Error\n\n', e, '\n\n')