From d19266ad26d9a6807a8228e76fdaafe6cec51835 Mon Sep 17 00:00:00 2001 From: chenlong Date: Wed, 1 Mar 2023 15:52:06 +0800 Subject: [PATCH] sft --- src/dataset.py | 38 +++++++ train_sft.py | 270 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 train_sft.py diff --git a/src/dataset.py b/src/dataset.py index c66f4cf..aa8a65d 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -4,9 +4,11 @@ import json, math, random, os, sys import numpy as np +import pandas as pd import torch from torch.utils.data import Dataset from pytorch_lightning.utilities import rank_zero_info +from src.utils import TOKENIZER from .binidx import MMapIndexedDataset from .utils import MaybeIsPrime @@ -216,3 +218,39 @@ class MyDataset(Dataset): return x, y, z return x, y + + +class S2SDataset(Dataset): + def __init__(self, args): + self.args = args + self.vocab_size = args.vocab_size + WORD_NAME = [ + "20B_tokenizer.json", + "20B_tokenizer.json", + ] # [vocab, vocab] for Pile model + + self.tokenizer = TOKENIZER(WORD_NAME) + pf = pd.read_csv(args.data_file) + data_list = [] + for index, row in pf.iterrows(): + question = row["question"] + answer = row["answer"] + data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode(answer))) + self.data = data_list + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + ctx_len = self.args.ctx_len + req_len = ctx_len + 1 + question, answer = self.data[index] + text = question + answer + text = text[:req_len] + x = torch.tensor(text, dtype=torch.long) + y = torch.tensor(answer, dtype=torch.long) + + z = [1] * len(question) + [0] * (req_len - len(question)) + z = torch.tensor(z, dtype=torch.long) + + return x, y, z diff --git a/train_sft.py b/train_sft.py new file mode 100644 index 0000000..fa06b1d --- /dev/null +++ b/train_sft.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/3/1 11:54 +# @Author : clong +# @File : train_sft.py + + + +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + + rank_zero_info("########## work in progress ##########") + + ######################################################################################################## + # + # example: train a simple L12-D768 RWKV on dummy data + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "" --data_type "dummy" --vocab_size 0 \ + # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ + # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ + # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: train a simple L6-D512 RWKV from scratch on enwik8 + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \ + # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \ + # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ + # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 + + parser = ArgumentParser() + + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps + parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" + + parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + + parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + + parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_img_version", default=0, type=str) + parser.add_argument("--my_img_size", default=0, type=int) + parser.add_argument("--my_img_bit", default=0, type=int) + parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip_scale", default=1, type=float) + parser.add_argument("--my_img_l1_scale", default=0, type=float) + parser.add_argument("--my_img_encoder", default='x', type=str) + # parser.add_argument("--my_img_noise_scale", default=0, type=float) + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_testing", default='', type=str) + + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + import deepspeed + import pytorch_lightning as pl + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + # os.environ["WDS_SHOW_SEED"] = "1" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = -1 # continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_T_MAX"] = str(args.ctx_len) + os.environ["RWKV_MY_TESTING"] = args.my_testing + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = args.n_embd * 4 + + args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + rank_zero_info( + f""" +############################################################################ +# +# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer +# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + if args.precision == "fp16": + rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + + os.environ["RWKV_JIT_ON"] = "1" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + + from src.trainer import train_callback, generate_init_weight + from src.dataset import S2SDataset + + train_data = S2SDataset(args) + args.vocab_size = train_data.vocab_size + + from src.model import RWKV + model = RWKV(args) + + if len(args.load_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] + model.load_state_dict(load_dict) + + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + + # must set shuffle=False, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + + trainer.fit(model, data_loader) -- GitLab