forward_demo.py 4.5 KB
Newer Older
U
debug  
u010280923 已提交
1 2 3
import os, sys, torch
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
U
u010280923 已提交
4

U
debug  
u010280923 已提交
5 6
# current_path = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(f'{current_path}/rwkv_pip_package/src')
U
u010280923 已提交
7

U
debug  
u010280923 已提交
8 9 10 11 12 13 14
# Tune these below (test True/False for all of them) to find the fastest setting:
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
# torch._C._jit_override_can_fuse_on_cpu(True)
# torch._C._jit_override_can_fuse_on_gpu(True)
# torch._C._jit_set_texpr_fuser_enabled(False)
# torch._C._jit_set_nvfuser_enabled(False)
U
u010280923 已提交
15 16 17

########################################################################################################
#
U
debug  
u010280923 已提交
18
# Use '/' in model path, instead of '\'. Use ctx4096 models if you need long ctx.
U
u010280923 已提交
19
#
U
debug  
u010280923 已提交
20 21 22 23
# fp16 = good for GPU (!!! DOES NOT support CPU !!!)
# fp32 = good for CPU
# bf16 = worse accuracy, supports CPU
# xxxi8 (example: fp16i8) = xxx with int8 quantization to save 50% VRAM/RAM, slower, slightly less accuracy
U
u010280923 已提交
24
#
U
debug  
u010280923 已提交
25
# Read https://pypi.org/project/rwkv/ for Strategy Guide
U
u010280923 已提交
26
#
U
debug  
u010280923 已提交
27 28 29 30 31
########################################################################################################
# set these before import RWKV
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0' #  if '1' then compile CUDA kernel for seq mode (much faster)

U
debug  
u010280923 已提交
32
# from rwkv.model import RWKV # pip install rwkv
U
u010280923 已提交
33
from src.rlhf.rwkv.model import RWKV 
U
debug  
u010280923 已提交
34 35 36
# model = RWKV(model='./model/rwkv-190.pth', strategy='cpu fp32')
model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu fp32')

U
debug  
u010280923 已提交
37 38
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16i8')
U
debug  
u010280923 已提交
39
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32')
U
debug  
u010280923 已提交
40 41 42 43 44 45 46 47 48
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32 *3 -> cuda fp16 *6+')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cpu fp32')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16 *8 -> cpu fp32')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda:0 fp16 -> cuda:1 fp16 -> cpu fp32 *1')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16 *6+')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019', strategy='cuda fp16 *0+ -> cpu fp32 *1')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096', strategy='cuda:0 fp16 *25 -> cuda:1 fp16')

U
u010280923 已提交
49
out, state, token_embed = model.forward([187, 510, 1563, 310, 247], None)
U
debug  
u010280923 已提交
50
print(out.detach().cpu().numpy())                   # get logits
U
debug  
u010280923 已提交
51 52 53 54
# out, state = model.forward([187, 510], None)
# out, state = model.forward([1563], state)           # RNN has state (use deepcopy to clone states)
# out, state = model.forward([310, 247], state)
# print(out.detach().cpu().numpy())                   # same result as above
U
debug  
u010280923 已提交
55

U
debug  
u010280923 已提交
56 57
import ipdb
ipdb.set_trace()
U
debug  
u010280923 已提交
58

U
debug  
u010280923 已提交
59
# print('\n')
U
debug  
u010280923 已提交
60

U
u010280923 已提交
61
# from src.rlhf.rwkv.utils import PIPELINE, PIPELINE_ARGS
U
debug  
u010280923 已提交
62
# pipeline = PIPELINE(model, "20B_tokenizer.json")
U
debug  
u010280923 已提交
63

U
debug  
u010280923 已提交
64 65
# ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
# print(ctx, end='')
U
debug  
u010280923 已提交
66

U
debug  
u010280923 已提交
67 68
# def my_print(s):
#     print(s, end='', flush=True)
U
debug  
u010280923 已提交
69

U
debug  
u010280923 已提交
70 71
# # For alpha_frequency and alpha_presence, see "Frequency and presence penalties":
# # https://platform.openai.com/docs/api-reference/parameter-details
U
u010280923 已提交
72

U
debug  
u010280923 已提交
73 74 75 76 77 78 79 80 81 82
# args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.7,
#                      alpha_frequency = 0.25,
#                      alpha_presence = 0.25,
#                      token_ban = [0], # ban the generation of some tokens
#                      token_stop = []) # stop generation whenever you see any token here

# ########################################################################################################
# # 1. set os.environ["RWKV_CUDA_ON"] = '1' if possible, for faster preprocess of a long ctx.
# # 2. Reuse the state (use deepcopy to clone it) when you are running the same ctx multiple times. 
# pipeline.generate(ctx, token_count=200, args=args, callback=my_print)
U
u010280923 已提交
83

U
debug  
u010280923 已提交
84
# print('\n')