提交 59792d34 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

init repo

上级
data
.DS_Store
.idea
\ No newline at end of file
此差异已折叠。
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print('Loading...')
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
os.environ["RWKV_JIT_ON"] = '1'
CHAT_LANG = 'Chinese' # English Chinese
from src.model_run import RWKV_RNN
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115'
args.n_layer = 24
args.n_embd = 2048
args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.n_layer = 32
# args.n_embd = 4096
# args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024
if CHAT_LANG == 'English':
user = "User"
bot = "Bot"
interface = ":"
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
init_prompt = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
{user}{interface} french revolution what year
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
{user}{interface} 3+5=?
{bot}{interface} The answer is 8.
{user}{interface} guess i marry who ?
{bot}{interface} Only if you tell me more about yourself - what are your interests?
{user}{interface} solve for a: 9-a=2
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
{user}{interface} wat is lhc
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
'''
HELP_MSG = '''Commands:
say something --> chat with bot. use \\n for new line.
+alt --> alternate chat reply
+reset --> reset chat
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
+more --> continue last free generation (only for +gen / +qa)
+retry --> retry last free generation (only for +gen / +qa)
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == 'Chinese':
args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115'
args.n_layer = 24
args.n_embd = 2048
args.ctx_len = 1024
user = "Q"
bot = "A"
interface = ":"
init_prompt = '''
Q: 企鹅会飞吗?
A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
Q: 西瓜是什么
A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。
'''
HELP_MSG = '''指令:
直接输入内容 --> 和机器人聊天,用\\n代表换行
+alt --> 让机器人换个回答
+reset --> 重置对话
+gen 某某内容 --> 续写任何中英文内容,用\\n代表换行
+qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行
+more --> 继续 +gen / +qa 的回答
+retry --> 换个 +gen / +qa 的回答
现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。
'''
# Load Model
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args)
model_tokens = []
current_state = None
########################################################################################################
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only=True)
# print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')
out[0] = -999999999 # disable <|endoftext|>
out[187] += newline_adj
# if newline_adj > 0:
# out[15] += newline_adj / 2 # '.'
return out
all_state = {}
def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}'
all_state[n] = {}
all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(current_state)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name):
global model_tokens, current_state
n = f'{name}_{srv}'
current_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out']
########################################################################################################
# Run inference
print(f'\nRun prompt...')
out = run_rnn(tokenizer.tokenizer.encode(init_prompt))
gc.collect()
torch.cuda.empty_cache()
save_all_stat('', 'chat_init', out)
srv_list = ['dummy_server']
for s in srv_list:
save_all_stat(s, 'chat', out)
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg):
print(f'{bot}{interface} {msg}\n')
def on_message(message):
global model_tokens, current_state
srv = 'dummy_server'
msg = message.replace('\\n','\n').strip()
if len(msg) > 1000:
reply_msg('your message is too long (max 1000 tokens)')
return
x_temp = 1.0
x_top_p = 0.85
if ("-temp=" in msg):
x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{x_temp:g}', "")
# print(f"temp: {x_temp}")
if ("-top_p=" in msg):
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}")
if x_temp <= 0.2:
x_temp = 0.2
if x_temp >= 5:
x_temp = 5
if x_top_p <= 0:
x_top_p = 0
if msg == '+reset':
out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.")
return
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry':
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]')
# current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+more':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '+retry':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
for i in range(150):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
if msg[:4].lower() == '+qa ':
out = run_rnn([token], newline_adj=-1)
else:
out = run_rnn([token])
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
print('\n')
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
else:
if msg.lower() == '+alt':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= 30:
newline_adj = (i - 30) / 10
elif i <= 130:
newline_adj = 0
else:
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
# print(f'{model_tokens}')
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'chat', out)
print(HELP_MSG)
while True:
msg = input(f'{user}{interface} ')
if len(msg.strip()) > 0:
on_message(msg)
else:
print('Erorr: please say something')
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/2/21 16:09
# @Author : clong
# @File : clean_data.py.py
import base64
import html
import json
import re
import json
def decode_base64(context):
if not isinstance(context, str):
return ""
return base64.b64decode(context).decode(encoding="utf-8")
def clean_text(text):
"""清理内容"""
if text is None:
return ""
pattern = re.compile(r'<[^>]+>|&#.*?;', re.S)
result = pattern.sub('', text)
result = html.unescape(result)
return result
def clean_blog_data():
# 数据从odps读取
raw_data_path = "./data/raw_data.txt"
data_path = "./data/data.json"
with open(raw_data_path) as file_r:
with open(data_path, "w") as file_w:
index = 1
for line in file_r:
try:
articleid, content, title, tags, username, createtime = line.split("\t")
content = decode_base64(content)
content = clean_text(content)
meta = {"ID": index}
ss = json.dumps({"meta":meta,"text":content}, check_circular=False)
file_w.write(ss + "\n")
index += 1
except Exception as e:
print(str(e))
continue
def clean_ask_data():
# 数据从odps读取
import pandas as pd
raw_data_path = "./data/ask.csv"
data_path = "./data/ask.jsonl"
df = pd.read_csv(raw_data_path)
with open(data_path, "w") as file_w:
index = 1
for row in df.itertuples():
title = row[2]
question_body = row[3]
answer_body = row[4]
content = str(title) + "\n" + str(question_body) + "\n" + str(answer_body)
content = clean_text(content)
meta = {"ID": index}
ss = json.dumps({"meta": meta, "text": content}, check_circular=False)
file_w.write(ss + "\n")
index +=1
clean_ask_data()
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
F p = 0, q = 0, o = MIN_VALUE;
// p and q are running sums divided by exp(o) (to avoid overflows)
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;
F y[Tmax], z[Tmax], zexp[Tmax];
F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);
F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);
y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
F gp = 0, gq = 0;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;
F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch, types, os
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.nn import functional as F
import torchvision as vision
import torchvision.transforms as transforms
np.set_printoptions(precision=4, suppress=True, linewidth=200)
print(f'loading...')
########################################################################################################
model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201'
input_img = 'test/img_ae_test/test0.png'
########################################################################################################
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone() # pass-through
class R_ENCODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
class R_DECODER(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################
print(f'building model...')
args = types.SimpleNamespace()
args.my_img_bit = 13
encoder = R_ENCODER(args).eval().cuda()
decoder = R_DECODER(args).eval().cuda()
zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long()
encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth'))
decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth'))
########################################################################################################
print(f'test image...')
img_transform = transforms.Compose([
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Resize((224, 224))
])
with torch.no_grad():
img = img_transform(Image.open(input_img)).unsqueeze(0).cuda()
z = encoder(img)
z = ToBinary.apply(z)
zz = torch.sum(z.squeeze().long() * zpow, dim=0)
print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
out = decoder(z)
vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg")
git+https://github.com/EleutherAI/DeeperSpeed.git@eb7f5cff36678625d23db8a8fe78b4a93e5d2c75#egg=deepspeed
torch==1.13.1
tokenizers>=0.13.2
lm_dataformat==0.0.20
ftfy==6.1.1
tensorboardX==2.6
shortuuid==1.0.11
wandb==0.10.28
tiktoken==0.1.2
\ No newline at end of file
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import math, os, sys, types, time, gc
import torch
from src.utils import TOKENIZER
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
########################################################################################################
# Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible)
########################################################################################################
args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast)
args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU)
# if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
TOKEN_MODE = "pile"
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
vocab_size = 50277
# Download Pile models: https://huggingface.co/BlinkDL
# or, set MODEL_NAME to your fine-tuned model
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
# n_layer = 12
# n_embd = 768
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
# n_layer = 24
# n_embd = 1024
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# n_layer = 32
# n_embd = 2560
# ctx_len = 1024
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
n_layer = 32
n_embd = 4096
ctx_len = 1024
args.MODEL_NAME = MODEL_NAME
args.n_layer = n_layer
args.n_embd = n_embd
args.ctx_len = ctx_len
args.vocab_size = vocab_size
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
########################################################################################################
# Step 2: set prompt & sampling stuffs
########################################################################################################
# context = 'A'
# context = "\nIn the"
# context = '\nSugar:'
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
# context = "\n深圳是" # test Chinese
# context = "\n東京は" # test Japanese
# ###### A good prompt for Q&A ######
# context = '''
# Questions & Helpful Answers
# Ask Research Experts
# Question:
# Can penguins fly?
# Full Answer:
# '''
# ###### A good prompt for chatbot ######
# context = '''
# The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
# User: who is president of usa?
# Bot: It’s Joe Biden; he was sworn in earlier this year.
# User: french revolution what year
# Bot: It started in 1789, but it lasted 10 years until 1799.
# User: guess i marry who ?
# Bot: Only if you tell me more about yourself - what are your interests?
# User: wat is lhc
# Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
# User:''' # type your question here
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
TEMPERATURE = 1.0
top_p = 0.8
top_p_newline = 0.9 # only used in TOKEN_MODE = char
DEBUG_DEBUG = False # True False --> show softmax output
########################################################################################################
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(args)
print(f'\nOptimizing speed...')
out, _ = model.forward([187], None)
# print(out)
gc.collect()
torch.cuda.empty_cache()
# input(0)
print(f'\nLoading tokenizer {WORD_NAME}...')
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == "pile":
assert tokenizer.tokenizer.decode([187]) == '\n'
########################################################################################################
if tokenizer.charMode:
context = tokenizer.refine_context(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
src_len = len(ctx)
src_ctx = ctx.copy()
print("\nYour prompt has " + str(src_len) + " tokens.")
print(
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
)
time_slot = {}
time_ref = time.time_ns()
def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt
init_state = None
init_out = None
state = None
out = None
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
print(("-" * 50) + '\n' + context, end="")
time_ref = time.time_ns()
ctx = src_ctx.copy()
if TRIAL == 0:
for i in range(src_len):
x = ctx[: i + 1]
if i == src_len - 1:
init_out, init_state = model.forward(x, init_state)
else:
init_state = model.forward(x, init_state, preprocess_only=True)
gc.collect()
torch.cuda.empty_cache()
record_time('preprocess')
out_last = src_len
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
x = ctx[: i + 1]
x = x[-ctx_len:]
if i == src_len:
out = init_out.clone()
state = init_state.clone()
else:
out, state = model.forward(x, state)
if DEBUG_DEBUG:
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
if TOKEN_MODE == "pile":
out[0] = -999999999 # disable <|endoftext|>
ttt = tokenizer.sample_logits(
out,
x,
ctx_len,
temperature=TEMPERATURE,
top_p_usual=top_p,
top_p_newline=top_p_newline,
)
ctx += [ttt]
if tokenizer.charMode:
char = tokenizer.itos[ttt]
print(char, end="", flush=True)
else:
char = tokenizer.tokenizer.decode(ctx[out_last:])
if '\ufffd' not in char: # is valid utf8 string?
print(char, end="", flush=True)
out_last = i+1
record_time('total')
# print(f'\n\n{time_slot}\n\n')
print(
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
)
print(("-" * 50) + '\n')
python train.py --load_model "RWKV-4-Pile-1B5-EngChn-test4-20230115.pth" --wandb "" --proj_dir "out" \
--data_file "data/ask_text_document" --data_type "binidx" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 10 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1
python preprocess_data.py \
--input ../data/data.txt \
--output-prefix ../data/blog \
--vocab ../20B_tokenizer.json \
--dataset-impl mmap \
--tokenizer-type HFTokenizer \
--append-eod
python preprocess_data.py \
--input ../data/ask.json \
--output-prefix ../data/ask \
--vocab ../20B_tokenizer.json \
--dataset-impl mmap \
--tokenizer-type HFTokenizer \
--append-eod
\ No newline at end of file
from lib2to3.pgen2 import token
import os
import torch
import numpy as np
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
pass
# """If distributed is initialized print only on rank 0."""
# if torch.distributed.is_initialized():
# if torch.distributed.get_rank() == 0:
# print(*message, flush=True)
# else:
# print(*message, flush=True)
def _warmup_mmap_file(path):
pass
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError(
"Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json, math, random, os, sys
import numpy as np
import torch
from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
class MyDataset(Dataset):
def __init__(self, args):
self.args = args
if args.data_type == "binidx":
self.vocab_size = args.vocab_size
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
if args.data_file.endswith('/'):
d_all = []
for p in os.listdir(args.data_file):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
rank_zero_info(d_all)
exit(0)
else:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2
rank_zero_info(f"Data has {self.data_size} tokens.")
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2
if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0]
rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
self.vocab_size = -1
self.data_size = -1
self.data = None
self.error_count = 0
else:
if args.data_type == "dummy":
rank_zero_info("Building dummy data...")
self.data = ""
for i in range(100000):
aa = (i) % 10000
bb = (i * i) % 10000
cc = aa + bb
self.data += f".{aa}+{bb}={cc}."
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
rank_zero_info("Building token list...")
unique = sorted(list(set(self.data)))
self.vocab_size = len(unique)
# rank_zero_info()
# for u in unique:
# print(u, end=' ')
# rank_zero_info('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx):
args = self.args
rank = self.global_rank
epoch = self.real_epoch
world_size = self.world_size
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
if args.data_type == "wds_img":
def init_wds(self, bias=0):
def identity(x):
return x
import webdataset as wds
import torchvision.transforms as transforms
# img_transform = transforms.Compose(
# [transforms.CenterCrop(256)]
# )
img_transform = transforms.Compose([
transforms.CenterCrop(512),
transforms.Resize((args.my_img_size))
])
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
for pp in self.data_raw.pipeline:
if 'Resampled' in str(pp):
pp.deterministic = True
def worker_seed():
return rank*100000+epoch+bias*1e9
pp.worker_seed = worker_seed
self.data = iter(self.data_raw)
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
if self.data == None:
init_wds(self)
trial = 0
while trial < 10:
try:
dd = next(self.data) # jpg, json, txt
break
except:
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
self.error_count += 1
init_wds(self, self.error_count)
trial += 1
pass
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
return dd[0], dd[2]
else:
if args.data_type == "uint16":
i = np.random.randint(0, self.data_size-1)
dix = self.data[i]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
else:
ctx_len = args.ctx_len
req_len = ctx_len + 1
magic_prime = args.magic_prime
data = self.data
if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime
magic_prime = 324331313
data = self.data_pile
else:
ii = ii // 2
factor = (math.sqrt(5) - 1) / 2
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
if (args.my_qa_mask == 0) or (data == self.data_pile):
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else:
# cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len)
if args.data_type == "binidx":
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
elif args.data_type == "numpy":
dix = data[i : i + req_len]
else:
dix = [self.stoi[s] for s in data[i : i + req_len]]
if args.my_qa_mask == 1:
if data == self.data_pile:
z = [1] * ctx_len
else:
z = [0] * ctx_len
z_sum = 0
isGood = False
for i in range(3, ctx_len):
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
isGood = True
if dix[i] == 0:
isGood = False
if isGood:
z[i] = 1
z_sum += 1
if z_sum == 0:
z = [1] * ctx_len
i = np.random.randint(0, self.data_pile_size - req_len)
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
z = torch.tensor(z, dtype=torch.bfloat16)
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
# if ii_orig < 50:
# # if rank == 1:
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
# else:
# exit(0)
if args.my_qa_mask == 1:
return x, y, z
return x, y
此差异已折叠。
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
import os, math, gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from pytorch_msssim import MS_SSIM
def __nop(ob):
return ob
MyModule = torch.jit.ScriptModule
# MyFunction = __nop
MyFunction = torch.jit.script_method
import clip
from transformers import CLIPModel
class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
super(L2pooling, self).__init__()
self.padding = (filter_size - 2) // 2
self.stride = stride
self.channels = channels
a = np.hanning(filter_size)[1:-1]
g = torch.Tensor(a[:, None] * a[None, :])
g = g / torch.sum(g)
self.register_buffer(
"filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
def forward(self, input):
input = input**2
out = F.conv2d(
input,
self.filter,
stride=self.stride,
padding=self.padding,
groups=input.shape[1],
)
return (out + 1e-12).sqrt()
class DISTS(torch.nn.Module):
def __init__(self, load_weights=True):
super(DISTS, self).__init__()
vgg_pretrained_features = vision.models.vgg16(
weights="VGG16_Weights.IMAGENET1K_V1"
).features
self.stage1 = torch.nn.Sequential()
self.stage2 = torch.nn.Sequential()
self.stage3 = torch.nn.Sequential()
self.stage4 = torch.nn.Sequential()
self.stage5 = torch.nn.Sequential()
for x in range(0, 4):
self.stage1.add_module(str(x), vgg_pretrained_features[x])
self.stage2.add_module(str(4), L2pooling(channels=64))
for x in range(5, 9):
self.stage2.add_module(str(x), vgg_pretrained_features[x])
self.stage3.add_module(str(9), L2pooling(channels=128))
for x in range(10, 16):
self.stage3.add_module(str(x), vgg_pretrained_features[x])
self.stage4.add_module(str(16), L2pooling(channels=256))
for x in range(17, 23):
self.stage4.add_module(str(x), vgg_pretrained_features[x])
self.stage5.add_module(str(23), L2pooling(channels=512))
for x in range(24, 30):
self.stage5.add_module(str(x), vgg_pretrained_features[x])
self.register_buffer(
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
)
self.register_buffer(
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
)
self.chns = [3, 64, 128, 256, 512, 512]
self.register_buffer(
"alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))
)
self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
self.alpha.data.normal_(0.1, 0.01)
self.beta.data.normal_(0.1, 0.01)
weights = torch.load("test/DISTS_weights.pt")
self.alpha.data = weights["alpha"]
self.beta.data = weights["beta"]
for param in self.parameters():
param.requires_grad = False
def forward_once(self, x):
h = (x - self.mean) / self.std
h = self.stage1(h)
h_relu1_2 = h
h = self.stage2(h)
h_relu2_2 = h
h = self.stage3(h)
h_relu3_3 = h
h = self.stage4(h)
h_relu4_3 = h
h = self.stage5(h)
h_relu5_3 = h
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
def forward(self, x, y, require_grad=False, batch_average=False):
if require_grad:
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
else:
with torch.no_grad():
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
dist1 = 0
dist2 = 0
c1 = 1e-6
c2 = 1e-6
w_sum = self.alpha.sum() + self.beta.sum()
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
beta = torch.split(self.beta / w_sum, self.chns, dim=1)
for k in range(len(self.chns)):
x_mean = feats0[k].mean([2, 3], keepdim=True)
y_mean = feats1[k].mean([2, 3], keepdim=True)
S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
xy_cov = (feats0[k] * feats1[k]).mean(
[2, 3], keepdim=True
) - x_mean * y_mean
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
score = 1 - (dist1 + dist2).squeeze()
if batch_average:
return score.mean()
else:
return score
class ToBinary(torch.autograd.Function):
@staticmethod
def forward(ctx, x):#, noise_scale):
# if noise_scale > 0:
# noise_min = 0.5 - noise_scale / 2
# noise_max = 0.5 + noise_scale / 2
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
# else:
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()#, None
########################################################################################################
class R_ENCODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.Bxx = nn.BatchNorm2d(dd*64)
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*4)
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*64)
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B21 = nn.BatchNorm2d(dd*64)
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
@MyFunction
def forward(self, img):
ACT = F.mish
x = self.CIN(img)
xx = self.Bxx(F.pixel_unshuffle(x, 8))
x = x + self.Cx1(ACT(self.Cx0(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_unshuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
# x = x + self.C27(ACT(self.C26(x)))
x = self.COUT(x + xx)
return torch.sigmoid(x)
########################################################################################################
class R_DECODER(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
dd = 8
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
self.B00 = nn.BatchNorm2d(dd*64)
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.B01 = nn.BatchNorm2d(dd*64)
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
self.B10 = nn.BatchNorm2d(dd*16)
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
self.B20 = nn.BatchNorm2d(dd*4)
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
@MyFunction
def forward(self, code):
ACT = F.mish
x = self.CIN(code)
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
x = x + self.C03(ACT(self.C02(x)))
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
# x = x + self.C07(ACT(self.C06(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
x = x + self.C13(ACT(self.C12(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
x = x + self.C23(ACT(self.C22(x)))
x = F.pixel_shuffle(x, 2)
x = x + self.Cx1(ACT(self.Cx0(x)))
x = self.COUT(x)
return torch.sigmoid(x)
########################################################################################################`
def cosine_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return 1 - torch.einsum('ij,ij->i',[x,y])
class RWKV_IMG(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.encoder = R_ENCODER(args)
self.decoder = R_DECODER(args)
self.clip_model = None
clip_name = args.my_img_clip
if clip_name == 'B32':
clip_name = 'ViT-B/32'
elif clip_name == 'B16':
clip_name = 'ViT-B/16'
elif clip_name == 'L14':
clip_name = 'ViT-L/14'
elif clip_name == 'OB32':
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
self.clip_model = CLIPModel.from_pretrained(clip_name)
self.clip_model.encode_image = self.clip_model.get_image_features
if self.clip_model == None:
self.clip_model, _ = clip.load(clip_name, jit = True)
self.register_buffer(
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
)
self.register_buffer(
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
)
for n, p in self.named_parameters():
if 'clip_model' in n:
p.requires_grad = False
self.loss_dists = DISTS()
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
def configure_optimizers(self):
args = self.args
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adamw_mode=False,
weight_decay=0,
amsgrad=False,
)
return FusedAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adam_w_mode=False,
weight_decay=0,
amsgrad=False,
)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
config = strategy.config["zero_optimization"]
return config.get("offload_optimizer") or config.get("offload_param")
return False
def forward(self, img):
z = self.encoder(img)
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
out = self.decoder(z)
return out
def training_step(self, batch, batch_idx):
args = self.args
img, txt = batch
out = self(img)
if self.trainer.is_global_zero:
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
img_dir = f"test/image_model/{args.run_name}"
if not os.path.exists(img_dir):
os.makedirs(img_dir)
vision.utils.save_image(
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
)
vision.utils.save_image(
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
)
# loss_ssim = 1 - self.loss_ssim(out, img)
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
loss_clip = torch.mean(cosine_loss(iii, ooo))
if args.my_img_l1_scale > 0:
loss_l1 = F.l1_loss(out, img)
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
else:
return loss_dists + loss_clip * args.my_img_clip_scale
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
scale = 1
p = self.state_dict()[n]
shape = p.shape
ss = n.split('.')
# if ss[0] in ['encoder', 'decoder']:
# if ss[2] == 'bias':
# scale = 0
# # elif n == 'encoder.CIN.weight':
# # nn.init.dirac_(p)
# else:
# try:
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
# scale = 0
# except:
# pass
# m[n] = p * scale
m[n] = p
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
gc.collect()
torch.cuda.empty_cache()
return m
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import torch
import math, os, gc
from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
MyModule = nn.Module
def __nop(ob):
return ob
MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster for fp32, slower for fp16 (why?)
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.FLOAT_MODE = args.FLOAT_MODE
self.RUN_DEVICE = args.RUN_DEVICE
with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device
keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False
for x in keys:
block_id = 0
if 'blocks.' in x:
block_id = int(x.split('.')[1])
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if '.time_' in x:
w[x] = w[x].squeeze()
if DEBUG_TIME:
print(x, w[x].numpy())
if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
elif '.time_first' in x:
w[x] = w[x].float()
else:
if self.FLOAT_MODE == "fp32":
w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x].requires_grad = False
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
if print_need_newline:
print('\n', end = '')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
else:
print_need_newline = True
print('.', end = '', flush = True)
# store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace()
for x in keys:
xx = x.split('.')
here = self.w
for i in range(len(xx)):
if xx[i].isdigit():
ii = int(xx[i])
if ii not in here:
here[ii] = types.SimpleNamespace()
here = here[ii]
else:
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
self.eval()
gc.collect()
torch.cuda.empty_cache()
def LN(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk))
kv = vw @ k
return r * kv
@MyFunction
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
if '16' in self.FLOAT_MODE:
kk = k.float()
vv = v.float()
else:
kk = k
vv = v
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + time_decay
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half()
else:
wkv = a / b
return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False):
with torch.no_grad():
w = self.w
args = self.args
x = w.emb.weight[ctx[-1]]
if self.RUN_DEVICE == 'cuda':
x = x.cuda()
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
if state == None:
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(args.n_layer):
state[5*i+4] -= 1e30
for i in range(args.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if (i+1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
if preprocess_only:
return state
x = self.LN(x, w.ln_out)
x = w.head.weight @ x
return x.float(), state
import os, math, time, datetime, subprocess
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True)
class train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
# LR schedule
w_step = args.warmup_steps
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
if trainer.global_step < w_step:
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{real_step} {lr}")
if trainer.global_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
wandb.init(
project=args.wandb,
name=args.run_name + " " + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
kt_s = 0
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
kt_s = token_per_step / t_cost / 1000
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
for k in raw_dict:
if k.startswith('encoder.') or k.startswith('decoder.'):
to_save_dict[k] = raw_dict[k]
else:
to_save_dict = pl_module.state_dict()
try:
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
except Exception as e:
print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
if model.args.my_pile_stage == 1:
if len(model.args.load_model) > 0:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
src = load_dict[k]
try:
mm[k] = src.reshape(mm[k].shape)
except:
tmp = mm[k].squeeze().clone()
print(k, src.shape, '-->', mm[k].shape)
ss = src.shape[0]
dd = tmp.shape[0]
for i in range(dd):
pos = i / dd * ss
if pos >= ss - 1:
tmp[i] = src[ss-1]
else:
p0 = int(math.floor(pos))
ii = pos - p0
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
mm[k] = tmp.reshape(mm[k].shape)
sss = src.squeeze().float().cpu().numpy()
print(sss[:10], '...', sss[-10:])
mmm = mm[k].squeeze().float().cpu().numpy()
print(mmm[:10], '...', mmm[-10:])
print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)
if model.args.my_pile_stage == 1:
print("Done. Now go for stage 2.")
exit(0)
import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
time_slot = {}
time_ref = time.time_ns()
def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
if WORD_NAME[0] == WORD_NAME[1]:
from transformers import PreTrainedTokenizerFast
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
else:
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16le") as result_file:
self.word_table = json.load(result_file)
self.vocab_size = len(self.word_table)
self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(out, dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
probs = probs.numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
else:
sorted_probs = torch.sort(probs, descending=True)[0]
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return out
def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
return True
else:
return False
def FermatPrimalityTest(number):
if number > 1:
for time in range(3):
randomNumber = random.randint(2, number) - 1
if pow(randomNumber, number - 1, number) != 1:
return False
return True
else:
return False
def MillerRabinPrimalityTest(number):
if number == 2:
return True
elif number == 1 or number % 2 == 0:
return False
oddPartOfNumber = number - 1
timesTwoDividNumber = 0
while oddPartOfNumber % 2 == 0:
oddPartOfNumber = oddPartOfNumber // 2
timesTwoDividNumber = timesTwoDividNumber + 1
for time in range(3):
while True:
randomNumber = random.randint(2, number) - 1
if randomNumber != 0 and randomNumber != 1:
break
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
iterationNumber = 1
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
iterationNumber = iterationNumber + 1
if randomNumberWithPower != (number - 1):
return False
return True
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import sys
import types
from . import ops
from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist
from .utils.distributed import init_distributed
from .runtime import zero
from .pipe import PipelineModule
from .git_version_info import version, git_hash, git_branch
def _parse_version(version_str):
'''Parse a version string and extract the major, minor, and patch versions.'''
import re
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str)
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3))
# Export version information
__version__ = version
__version_major__, __version_minor__, __version_patch__ = _parse_version(__version__)
__git_hash__ = git_hash
__git_branch__ = git_branch
# Provide backwards compatability with old deepspeed.pt module structure, should hopefully not be used
pt = types.ModuleType('pt', 'dummy pt module for backwards compatability')
deepspeed = sys.modules[__name__]
setattr(deepspeed, 'pt', pt)
setattr(deepspeed.pt, 'deepspeed_utils', deepspeed.runtime.utils)
sys.modules['deepspeed.pt'] = deepspeed.pt
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config
setattr(deepspeed.pt, 'loss_scaler', deepspeed.runtime.fp16.loss_scaler)
sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler
def initialize(args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None,
config_params=None):
"""Initialize the DeepSpeed Engine.
Arguments:
args: an object containing local_rank and deepspeed_config fields. This is optional if `config_params` is passed.
model: Required: nn.module class before apply any wrappers
optimizer: Optional: a user defined optimizer, this is typically used instead of defining
an optimizer in the DeepSpeed json config.
model_parameters: Optional: An iterable of torch.Tensors or dicts.
Specifies what Tensors should be optimized.
training_data: Optional: Dataset of type torch.utils.data.Dataset
lr_scheduler: Optional: Learning Rate Scheduler Object. It should define a get_lr(),
step(), state_dict(), and load_state_dict() methods
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
dist_init_required: Optional: None will auto-initialize torch.distributed if needed,
otherwise the user can force it to be initialized or not via boolean.
collate_fn: Optional: Merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
as a dictionary instead.
Returns:
A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``
* ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.
* ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if
optimizer is specified in json config else ``None``.
* ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,
otherwise ``None``.
* ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
"""
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
__version__,
__git_hash__,
__git_branch__),
ranks=[0])
assert model is not None, "deepspeed.initialize requires a model"
if not isinstance(model, PipelineModule):
engine = DeepSpeedEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config_params=config_params)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
engine = PipelineEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=model.mpu(),
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config_params=config_params)
return_items = [
engine,
engine.optimizer,
engine.training_dataloader,
engine.lr_scheduler
]
return tuple(return_items)
def _add_core_arguments(parser):
r"""Helper (internal) function to update an argument parser with an argument group of the core DeepSpeed arguments.
The core set of DeepSpeed arguments include the following:
1) --deepspeed: boolean flag to enable DeepSpeed
2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
This is a helper function to the public add_config_arguments()
Arguments:
parser: argument parser
Return:
parser: Updated Parser
"""
group = parser.add_argument_group('DeepSpeed', 'DeepSpeed configurations')
group.add_argument(
'--deepspeed',
default=False,
action='store_true',
help=
'Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
group.add_argument('--deepspeed_config',
default=None,
type=str,
help='DeepSpeed json configuration file.')
group.add_argument(
'--deepscale',
default=False,
action='store_true',
help=
'Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)'
)
group.add_argument('--deepscale_config',
default=None,
type=str,
help='Deprecated DeepSpeed json configuration file.')
group.add_argument(
'--deepspeed_mpi',
default=False,
action='store_true',
help=
"Run via MPI, this will attempt to discover the necessary variables to initialize torch "
"distributed from the MPI environment")
return parser
def add_config_arguments(parser):
r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments.
The set of DeepSpeed arguments include the following:
1) --deepspeed: boolean flag to enable DeepSpeed
2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
Arguments:
parser: argument parser
Return:
parser: Updated Parser
"""
parser = _add_core_arguments(parser)
return parser
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
from datetime import timedelta
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
# Default process group wide timeout, if applicable.
# This only applies to the gloo and nccl backends
# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1).
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import json
from .constants import *
class ElasticityError(Exception):
"""
Base exception for all elasticity related errors
"""
pass
class ElasticityConfigError(ElasticityError):
"""
Elasticity configuration error
"""
pass
class ElasticityIncompatibleWorldSize(ElasticityError):
"""
Attempting to run a world size that is incompatible with a given elastic config
"""
pass
class ElasticityConfig:
"""
Elastic config object, constructed from a param dictionary that only contains elastic
config parameters, example below:
If elasticity is enabled, user must specify (at least) max_train_batch_size
and micro_batch_sizes.
{
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20
"ignore_non_elastic_batch_info": false
"version": 0.1
}
"""
def __init__(self, param_dict):
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
if self.enabled:
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
else:
raise ElasticityConfigError(
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
if MICRO_BATCHES in param_dict:
self.micro_batches = param_dict[MICRO_BATCHES]
else:
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
else:
self.max_acceptable_batch_size = param_dict.get(
MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
if not isinstance(self.micro_batches, list):
raise ElasticityConfigError(
f"Elasticity expected value of {MICRO_BATCHES} to be a "
f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}"
)
if not all(map(lambda m: isinstance(m, int), self.micro_batches)):
raise ElasticityConfigError(
f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
f"instead contains: f{self.micro_batches}")
if not all(map(lambda m: m > 0, self.micro_batches)):
raise ElasticityConfigError(
f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
f"instead contains: f{self.micro_batches}")
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError(
"Elasticity min/max gpus must be > 0, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
if self.max_gpus < self.min_gpus:
raise ElasticityConfigError(
"Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0:
raise ElasticityConfigError(
f"Elasticity min time needs to be >= 0: given {self.min_time}")
self.version = param_dict.get(VERSION, VERSION_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
def repr(self):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
#########################################
# Elasticity
#########################################
''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible
with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that
can support a large number of GPUs based on the user specified parameters
'''
FORMAT = '''
Elasticity should be enabled as:
"elasticity": {
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20,
"prefer_larger_batch": true,
"ignore_non_elastic_batch_info": false,
"version": 0.1
}
'''
ELASTICITY = 'elasticity'
# Current elasticity version
LATEST_ELASTICITY_VERSION = 0.1
ENABLED = 'enabled'
ENABLED_DEFAULT = False
# Max acceptable train_batch_size
MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size'
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000
# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu
MICRO_BATCHES = 'micro_batch_sizes'
MICRO_BATCHES_DEFAULT = [2, 4, 6]
# Min/max of GPUs to search over
MIN_GPUS = 'min_gpus'
MIN_GPUS_DEFAULT = 1
MAX_GPUS = 'max_gpus'
MAX_GPUS_DEFAULT = 10000
# Minimum running time (minutes) before the scheduler will scale us, 0 implies it's unknown
MIN_TIME = "min_time"
MIN_TIME_DEFAULT = 0
# When finding a suitable batch size, attempt to find one that is closest
# to the max train batch size given.
PREFER_LARGER_BATCH = 'prefer_larger_batch'
PREFER_LARGER_BATCH_DEFAULT = True
# In order to reduce confusion, if elastic mode is enabled we
# require (via assert) that no batch info is set outside of the
# elastic config. You can turn off this assert via this config
# but keep in mind that all batch info defined outside the
# elastic mode *will be ignored*.
IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info'
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False
# Version of elastic logic to use
VERSION = "version"
VERSION_DEFAULT = LATEST_ELASTICITY_VERSION
# Minimum deepspeed version to use elasticity
MINIMUM_DEEPSPEED_VERSION = "0.3.8"
# Environment variable storing elastic config from resource scheduler
DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG"
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import os
import re
import json
import numpy as np
from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \
ElasticityIncompatibleWorldSize
from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \
MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG
from ..git_version_info import version as __version__
from ..utils import logger
# Thirty eight smallest highly composite numbers. The list should
# be enough to support up to 720K batch size.
HCN_LIST = [
1,
2,
4,
6,
12,
24,
36,
48,
60,
120,
180,
240,
360,
720,
840,
1260,
1680,
2520,
5040,
7560,
10080,
15120,
20160,
25200,
27720,
45360,
50400,
55440,
83160,
110880,
166320,
221760,
277200,
332640,
498960,
554400,
665280,
720720
]
def get_candidate_batch_sizes(base_list, max_acceptable_batch_size):
candidate_batch_size = []
#brute force is fine here. We are working with very small lists
for base in base_list:
batch_size = base
for hcn in HCN_LIST:
new_batch_size = base * hcn
if new_batch_size > max_acceptable_batch_size:
break
batch_size = new_batch_size
candidate_batch_size.append(batch_size)
return list(set(candidate_batch_size))
def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
valid_gpus = []
for micro_batch in micro_batches:
if batch_size % micro_batch == 0:
max_gpus = batch_size // micro_batch
if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus:
valid_gpus.append(max_gpus)
for i in range(1, max_gpus // 2 + 1):
if max_gpus % i == 0:
if i >= min_valid_gpus and i <= max_valid_gpus:
valid_gpus.append(i)
valid_gpus = set(valid_gpus)
valid_gpus = sorted(list(valid_gpus))
return valid_gpus
def get_best_candidates(candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger):
max_valid_gpus = 0
valid_gpus = None
final_batch_size = int(min(micro_batches))
for batch_size in candidate_batch_sizes:
current_valid_gpus = get_valid_gpus(batch_size,
micro_batches,
min_gpus,
max_gpus)
if (len(current_valid_gpus) > max_valid_gpus
or (len(current_valid_gpus) == max_valid_gpus and
((prefer_larger and batch_size > final_batch_size) or
(not prefer_larger and batch_size < final_batch_size)))):
max_valid_gpus = len(current_valid_gpus)
valid_gpus = current_valid_gpus
final_batch_size = batch_size
return final_batch_size, valid_gpus
def _get_compatible_gpus_v01(micro_batches,
max_acceptable_batch_size,
min_gpus=None,
max_gpus=None,
prefer_larger=True):
'''We use two heuristics to compute the batch size
1. We use the Lowest Common Multiple of the micro-batches
as the base batch size and scale it by a HCN such that the result is
the largest batch size less than the max_acceptable batch size
2. We use each of the micro batches as a base and scale it
by a HCN such that the result is the largest batch size less than the
max_acceptable batch size.
We then use brute force to count the number of compatible GPU count for
each of the aforementioned cases, and return the batch size with the most number of
compatible GPU counts in the min-max GPU range if provided, other wise
we return the batch size with the most number of total compatible GPU counts.
Returns:
final_batch_size
valid_gpus
'''
if min_gpus is None:
min_gpus = int(1)
if max_gpus is None:
max_gpus = int(max_acceptable_batch_size / min(micro_batches))
assert all(mb <= max_acceptable_batch_size for mb in micro_batches ), \
f"All micro batches must be less than \
or equal to max_acceptable_batch_size: {max_acceptable_batch_size}"
lcm = np.lcm.reduce(micro_batches)
base_list = []
base_list.extend(micro_batches)
base_list.append(lcm)
candidate_batch_sizes = get_candidate_batch_sizes(base_list,
max_acceptable_batch_size)
final_batch_size, valid_gpus = get_best_candidates(
candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger)
return final_batch_size, valid_gpus
def _parse_version(version_str):
'''Parse a version string and extract the major and minor versions (and possibly patch version).'''
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str)
if matched:
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3))
else:
matched = re.search('^(\d+)\.(\d+)', version_str)
assert matched != None, "Unable to parse version number, expecting" \
f"major.minor[.patch] format but received {version_str}"
return int(matched.group(1)), int(matched.group(2)), 0
def _compatible_ds_version_check(target_deepspeed_version: str):
min_major, min_minor, min_patch = _parse_version(MINIMUM_DEEPSPEED_VERSION)
trg_major, trg_minor, trg_patch = _parse_version(target_deepspeed_version)
err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \
f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity."
if trg_major < min_major:
raise ElasticityError(err_str)
if trg_minor < min_minor:
raise ElasticityError(err_str)
if trg_patch < min_patch:
raise ElasticityError(err_str)
return True
def elasticity_enabled(ds_config: dict):
if ELASTICITY not in ds_config:
return False
return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT)
def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
"""
Ensure the resource scheduler saw the same elastic config we are using at runtime
"""
if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
scheduler_elastic_config_dict = json.loads(
os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
raise ElasticityConfigError(
err_str.format('max_acceptable_batch_size',
scheduler_elastic_config.max_acceptable_batch_size,
'max_acceptable_batch_size',
runtime_elastic_config.max_acceptable_batch_size))
if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
raise ElasticityConfigError(
err_str.format('micro_batches',
scheduler_elastic_config.micro_batches,
'micro_batches',
runtime_elastic_config.micro_batches))
if runtime_elastic_config.version != scheduler_elastic_config.version:
raise ElasticityConfigError(
err_str.format('version',
scheduler_elastic_config.version,
'version',
runtime_elastic_config.version))
else:
logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
"guarantee resource scheduler will scale this job using compatible GPU counts.")
def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0):
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below)
DeepSpeed will compute a total train batch size corresponding valid GPU count list that
provides a high level of elasticity. Elasticity in this case means we are safe to scale
the training job up/down across the GPU count list *without* any negative impacts on
training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation
feature which allows us to decompose a global training batch size into:
micro-batch-size * gradient-accumulation-steps * world-size.
"elasticity": {
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20
"version": 0.1
}
Intended to be called both by scheduling infrastructure and deepspeed runtime.
For the same `ds_config` we should return deterministic results.
Args:
ds_config (dict): DeepSpeed config dictionary/json
target_deepspeed_version (str): When called from scheduling
infrastructure we want to ensure that the target deepspeed version is
compatible with the elasticity version used in the backend.
world_size (int, optional): Intended/current world size, will do some sanity
checks to ensure world size is actually valid with the config.
Raises:
ElasticityConfigError: Missing required elasticity config or elasticity disabled
ElasticityError: If target deepspeed version is not compatible with current version
Returns:
final_batch_size (int): total batch size used for training
valid_gpus (list(int)): list of valid GPU counts with this config
micro_batch_size (int, optional): if world_size is provided will return
specific micro batch size
"""
if not isinstance(ds_config, dict):
raise ValueError("Expected ds_config to be a dictionary but received " \
f"a {type(ds_config)}, containing: {ds_config}")
if ELASTICITY not in ds_config:
raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \
" please add it if running an elastic training job.")
elastic_config_dict = ds_config[ELASTICITY]
if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT):
raise ElasticityConfigError("Elasticity is disabled, please enable it " \
"('enabled':true) if running an elastic training job.")
elastic_config = ElasticityConfig(elastic_config_dict)
if float(elastic_config.version) > LATEST_ELASTICITY_VERSION:
raise ElasticityConfigError("Attempting to run elasticity version " \
f"{elastic_config.version} but runtime only supports up " \
f"to {LATEST_ELASTICITY_VERSION}")
# Ensure target deepspeed version works with intended elasticity version
if not _compatible_ds_version_check(target_deepspeed_version):
raise ElasticityError("Unable to run elasticity on target deepspeed version of" \
f" {target_deepspeed_version}, currently {__version__}")
if float(elastic_config.version) == 0.1:
final_batch_size, valid_gpus = _get_compatible_gpus_v01(
micro_batches=elastic_config.micro_batches,
max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
min_gpus=elastic_config.min_gpus,
max_gpus=elastic_config.max_gpus,
prefer_larger=elastic_config.prefer_larger_batch_size)
# ensure batch size is int dtype
final_batch_size = int(final_batch_size)
else:
raise NotImplementedError(
f"Unable to find elastic logic for version: {elastic_config.version}")
if world_size > 0:
if world_size not in valid_gpus:
raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \
f"with the current list of valid GPU counts: {valid_gpus}")
# Pick largest valid micro batch size
micro_batch_size = None
for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
if final_batch_size // world_size % mbsz == 0:
micro_batch_size = mbsz
break
assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
f" micro_batches={elastic_config.micro_batches}."
return final_batch_size, valid_gpus, micro_batch_size
return final_batch_size, valid_gpus
import torch
import deepspeed
import subprocess
from .ops.op_builder import ALL_OPS
from .git_version_info import installed_ops, torch_info
from .ops import __compatible_ops__ as compatible_ops
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
END = '\033[0m'
SUCCESS = f"{GREEN} [SUCCESS] {END}"
OKAY = f"{GREEN}[OKAY]{END}"
WARNING = f"{YELLOW}[WARNING]{END}"
FAIL = f'{RED}[FAIL]{END}'
INFO = '[INFO]'
color_len = len(GREEN) + len(END)
okay = f"{GREEN}[OKAY]{END}"
warning = f"{YELLOW}[WARNING]{END}"
def op_report():
max_dots = 23
max_dots2 = 11
h = ["op name", "installed", "compatible"]
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print("DeepSpeed C++/CUDA extension op report")
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print("NOTE: Ops not installed will be just-in-time (JIT) compiled at\n"
" runtime if needed. Op compatibility means that your system\n"
" meet the required dependencies to JIT install the op.")
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print("JIT compiled ops requires ninja")
ninja_status = OKAY if ninja_installed() else FAIL
print('ninja', "." * (max_dots - 5), ninja_status)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print(h[0], "." * (max_dots - len(h[0])), h[1], "." * (max_dots2 - len(h[1])), h[2])
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
installed = f"{GREEN}[YES]{END}"
no = f"{YELLOW}[NO]{END}"
for op_name, builder in ALL_OPS.items():
dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible() else no
is_installed = installed if installed_ops[op_name] else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) -
(len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
def ninja_installed():
try:
import ninja
except ImportError:
return False
return True
def nvcc_version():
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
if cuda_home is None:
return f"{RED} [FAIL] cannot find CUDA_HOME via torch.utils.cpp_extension.CUDA_HOME={torch.utils.cpp_extension.CUDA_HOME} {END}"
try:
output = subprocess.check_output([cuda_home + "/bin/nvcc",
"-V"],
universal_newlines=True)
except FileNotFoundError:
return f"{RED} [FAIL] nvcc missing {END}"
output_split = output.split()
release_idx = output_split.index("release")
release = output_split[release_idx + 1].replace(',', '').split(".")
return ".".join(release)
def debug_report():
max_dots = 33
report = [
("torch install path",
torch.__path__),
("torch version",
torch.__version__),
("torch cuda version",
torch.version.cuda),
("nvcc version",
nvcc_version()),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, cuda {torch_info['cuda_version']}"),
]
print("DeepSpeed general environment info:")
for name, value in report:
print(name, "." * (max_dots - len(name)), value)
def main():
op_report()
debug_report()
if __name__ == "__main__":
main()
try:
# This is populated by setup.py
from .git_version_info_installed import *
except ModuleNotFoundError:
import os
if os.path.isfile('version.txt'):
# Will be missing from checkouts that haven't been installed (e.g., readthedocs)
version = open('version.txt', 'r').read().strip()
else:
version = "0.0.0"
git_hash = '[none]'
git_branch = '[none]'
from .ops.op_builder import ALL_OPS
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
torch_info = {'version': "0.0", "cuda_version": "0.0"}
# Copyright 2020 The Microsoft DeepSpeed Team
PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024
OPENMPI_LAUNCHER = 'openmpi'
SLURM_LAUNCHER = 'slurm'
MOSAICML_LAUNCHER = 'mosaicml'
MVAPICH_LAUNCHER = 'mvapich'
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'
from subprocess import check_output
import re
CONNECTION_TYPES = ["X", "SYS", "NODE", "PHB", "PXB", "PIX", "NV[\d]+"]
def get_topology_str():
return check_output(["nvidia-smi", "topo", "-m"]).decode()
def contains_nvlinks(topology):
return any([is_nvlink(item) for sublist in topology for item in sublist])
def is_nvlink(connection_type):
return re.search(CONNECTION_TYPES[-1], connection_type)
def get_nvlink_pairs(topology):
"""
takes a topology matrix and outputs a list of pairs bridged by nvlink
"""
out = set()
for device_idx1, item1 in enumerate(topology):
for device_idx2, item2 in enumerate(item1):
if is_nvlink(item2):
if (device_idx2, device_idx1) not in out:
out.add((device_idx1, device_idx2))
return out
def get_cuda_visible_device_mapping(nvlink_pairs, local_gpu_ids=None):
nvlink_pairs = [item for sublist in sorted(nvlink_pairs) for item in sublist]
if local_gpu_ids is not None:
nvlink_pairs = [item for item in nvlink_pairs if item in local_gpu_ids]
# deduplicate incase there's > pair per gpu
deduped = []
for item in nvlink_pairs:
if item not in deduped:
deduped.append(item)
return_string = ",".join(map(str, deduped))
return return_string
def topology_from_string(string):
output_per_gpu = string.strip().split('Legend:')[0].strip().split('\n')
headers = output_per_gpu.pop(0)
headers = headers.strip().split()
headers = [i for i in headers if re.search('GPU[\d]+', i)]
num_gpus = len(headers)
topology = []
for output in output_per_gpu:
output = output.strip().split()
gpu_id = output.pop(0)
output = output[:num_gpus]
if 'GPU' in gpu_id:
links = []
for idx, i in enumerate(output):
if idx >= num_gpus:
break
links.append(i.strip())
topology.append(links)
# check for consistency
assert all([len(i) == len(topology) for i in topology])
return topology
def detect_nvlink_pairs_and_map_visible_devices(rank, local_gpu_ids):
string = get_topology_str()
topology = topology_from_string(string)
if contains_nvlinks(topology):
pairs = get_nvlink_pairs(topology)
remapping = get_cuda_visible_device_mapping(pairs, local_gpu_ids)
return remapping
else:
print(f'No NVLINK detected on rank {rank}')
return None
if __name__ == "__main__":
detect_nvlink_pairs_and_map_visible_devices()
# Copyright 2020 The Microsoft DeepSpeed Team
"""
DeepSpeed launcher, this is similar to torch.distributed.launch but supports
additional features such as abitrary gpu exclusion.
deepspeed.launcher.launch is intended to be run on a single worker node and
will spawn several worker sub-processes depending on how many devices/ranks
are on the worker.
"""
import sys
import subprocess
import os
import json
import base64
import time
import signal
from collections import defaultdict
from argparse import ArgumentParser, REMAINDER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger
from deepspeed.launcher.gpu_topology import detect_nvlink_pairs_and_map_visible_devices
def parse_args():
parser = ArgumentParser(description="DeepSpeed distributed training launch"
" utility that creates multiple distributed"
" processes on a single node")
# Optional arguments for the launch helper
parser.add_argument("--node_rank",
type=int,
default=0,
help="The rank of the node for multi-node distributed "
"training")
parser.add_argument("--master_addr",
default="127.0.0.1",
type=str,
help="Master node (rank 0)'s address, should be either"
" the IP address or the hostname of node 0, for"
" single node multi-proc training, the"
" --master_addr can simply be 127.0.0.1")
parser.add_argument("--master_port",
default=TORCH_DISTRIBUTED_DEFAULT_PORT,
type=int,
help="Master node (rank 0)'s free port that needs to "
"be used for communication during distributed "
"training")
parser.add_argument("--world_info",
default="None",
type=str,
help="world info base64 encoded dictionary")
parser.add_argument("--detect_nvlink_pairs", action="store_true",
help="autodetects nvlink pairs and remaps CUDA_VISIBLE_DEVICES along the fastest connections")
# positional
parser.add_argument("training_script",
type=str,
help="The full path to the single GPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# rest from the training program
parser.add_argument('training_script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
current_env = os.environ.copy()
for k in current_env.keys():
if "NCCL" in k:
logger.info("%s %s %s", args.node_rank, k, current_env[k])
world_info = None
assert args.world_info != "None", "must provide world info dict"
world_info = base64.urlsafe_b64decode(args.world_info)
world_info = json.loads(world_info)
logger.info("WORLD INFO DICT: {}".format(world_info))
node_list = list(world_info.keys())
args.nnodes = len(node_list)
local_node = node_list[args.node_rank]
local_gpu_ids = world_info[local_node]
num_local_procs = len(local_gpu_ids)
logger.info(
"nnodes={}, num_local_procs={}, node_rank={}".format(args.nnodes,
num_local_procs,
args.node_rank),
)
global_rank_mapping = defaultdict(list)
curr_global_rank = 0
dist_world_size = 0
for node_id in node_list:
gids = world_info[node_id]
dist_world_size += len(gids)
for gid in gids:
global_rank_mapping[node_id].append(curr_global_rank)
curr_global_rank += 1
logger.info("global_rank_mapping={}".format(global_rank_mapping))
logger.info("dist_world_size={}".format(dist_world_size))
CUDA_VISIBLE_DEVICES = None
if args.detect_nvlink_pairs:
logger.info("Autodetecting nvlink pairs...")
CUDA_VISIBLE_DEVICES = detect_nvlink_pairs_and_map_visible_devices(args.node_rank, local_gpu_ids)
if CUDA_VISIBLE_DEVICES is None:
CUDA_VISIBLE_DEVICES = ",".join(map(str, local_gpu_ids))
current_env["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
logger.info("Setting CUDA_VISIBLE_DEVICES={}".format(
current_env["CUDA_VISIBLE_DEVICES"]))
exclusion_counts_per_node = None
# set PyTorch distributed related environmental variables
current_env["MASTER_ADDR"] = args.master_addr
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)
processes = []
for local_rank in range(0, num_local_procs):
# each process's rank
dist_rank = global_rank_mapping[local_node][local_rank]
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)
# spawn the processes
cmd = [
sys.executable,
"-u",
args.training_script,
"--local_rank={}".format(local_rank)
] + args.training_script_args
sig_names = {2: "SIGINT", 15: "SIGTERM"}
last_return_code = None
def sigkill_handler(signum, frame):
for process in processes:
print(f"Killing subprocess {process.pid}")
try:
process.kill()
except Exception as e:
pass
if last_return_code is not None:
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
if signum in sig_names:
print(f"Main process received {sig_names[signum]}, exiting")
sys.exit(1)
# pass SIGINT/SIGTERM to children if the parent is being terminated
signal.signal(signal.SIGINT, sigkill_handler)
signal.signal(signal.SIGTERM, sigkill_handler)
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
alive_processes = set(processes)
while len(alive_processes):
finished_processes = []
for process in alive_processes:
if process.poll() is None:
# the process is still running
continue
else:
if process.returncode != 0:
last_return_code = process.returncode # for sigkill_handler
sigkill_handler(signal.SIGTERM, None) # not coming back
else:
# exited cleanly
finished_processes.append(process)
alive_processes = set(alive_processes) - set(finished_processes)
time.sleep(1)
if __name__ == "__main__":
main()
import base64
import json
import os
import sys
import shutil
import subprocess
import warnings
from abc import ABC, abstractmethod
from ..utils import logger
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
class MultiNodeRunner(ABC):
def __init__(self, args, world_info_base64):
self.args = args
self.user_arguments = self.parse_user_args()
self.user_script = args.user_script
self.world_info_base64 = world_info_base64
self.exports = {}
@abstractmethod
def backend_exists(self):
pass
@abstractmethod
def get_cmd(self, environment, active_resources):
pass
def add_export(self, key, var):
self.exports[key.strip()] = var.strip()
def parse_user_args(self):
return self.args.user_args
class PDSHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64):
super().__init__(args, world_info_base64)
def backend_exists(self):
return shutil.which('pdsh')
def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else "'{}'".format(x),
self.args.user_args))
def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'
active_workers = ",".join(active_resources.keys())
logger.info("Running on the following workers: %s" % active_workers)
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]
exports = ""
for key, val in self.exports.items():
exports += "export {}={}; ".format(key, val)
deepspeed_launch = [
exports,
"cd {};".format(os.path.abspath('.')),
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
'--world_info={}'.format(self.world_info_base64),
"--node_rank=%n",
"--master_addr={}".format(self.args.master_addr),
"--master_port={}".format(self.args.master_port)
]
if self.args.detect_nvlink_pairs:
deepspeed_launch += ["--detect_nvlink_pairs"]
return pdsh_cmd_args + deepspeed_launch + [self.user_script
] + self.user_arguments
class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
self.add_export('UCX_TLS', 'tcp')
def backend_exists(self):
#TODO: if IB is available we should suggestion mvapich
return shutil.which('ompi_info')
def get_cmd(self, environment, active_resources):
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'openmpi backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'openmpi backend does not support limiting num nodes/gpus'
assert not self.args.detect_nvlink_pairs, "openmpi backend does not support remapping visible devices"
total_process_count = sum(self.resource_pool.values())
allow_run_as_root = os.environ.get('RUN_MPI_AS_ROOT', False)
mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-hostfile',
f'{self.args.hostfile}',
'--mca',
'btl',
'^openib',
'--mca',
'btl_tcp_if_include',
'eth0',
]
if allow_run_as_root:
mpirun_cmd.insert(1, '--allow-run-as-root')
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={v}']
python_exec = [sys.executable, "-u"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
class SlurmRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
def backend_exists(self):
return shutil.which('sinfo')
def parse_user_args(self):
user_args = []
for arg in self.args.user_args:
if arg.startswith('{') and arg.endswith('}'):
try:
arg_dict = json.loads(arg)
if 'config_files' in arg_dict:
config_files = {}
for k, v in arg_dict.get('config_files', {}).items():
config_files[k] = json.loads(v)
arg_dict['config_files'] = config_files
except json.JSONDecodeError as jde:
raise ValueError('SLURM is picky and needs you to use plain json for your configs. Check for comments and lowercase trues') from jde
arg = json.dumps(arg_dict, separators=(',', ':'))
user_args.append(arg)
return user_args
def get_cmd(self, environment, active_resources):
assert not getattr(self.args, 'detect_nvlink_pairs', False), "slurm backend does not support remapping visible devices"
total_process_count = sum(self.resource_pool.values())
srun_cmd = [
'srun',
'-n',
f'{total_process_count}',
]
if self.args.comment != '':
srun_cmd += ['--comment', self.args.comment]
if self.args.include != "":
srun_cmd.append('--include')
srun_cmd.append(f'{self.args.include}')
if self.args.exclude != "":
srun_cmd.append('--exclude')
srun_cmd.append(f'{self.args.exclude}')
if self.args.num_nodes > 0:
srun_cmd.append('--nodes')
srun_cmd.append(f'{self.args.num_nodes}')
if self.args.num_gpus > 0:
srun_cmd.append('--gpus')
srun_cmd.append(f'{self.args.num_gpus}')
exports = '--export=ALL'
for key, val in self.exports.items():
exports += f",{key}={val}"
python_exec = [sys.executable, "-u"]
command = srun_cmd + [exports] + python_exec + [self.user_script] + self.user_arguments
return command
class MVAPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
# Disable the CMA kernel module, not available on Ubuntu systems
self.add_export('MV2_SMP_USE_CMA', '0')
# If we fail this will output more verbose logging
self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1')
# Enabled cuda-aware communication
self.add_export('MV2_USE_CUDA', '1')
# Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/
self.add_export('MV2_SUPPORT_DL', '1')
# Support MPI_THREAD_MULTIPLE
self.add_export('MV2_ENABLE_AFFINITY', '0')
# Performance tuning flags for allgather
self.add_export('MV2_INTER_ALLGATHER_TUNING', '5')
self.add_export('MV2_CUDA_USE_NAIVE', '0')
def backend_exists(self):
#TODO: if IB is available we should suggestion mvapich
mpiname_exists = shutil.which('mpiname')
exists = False
if not mpiname_exists:
warnings.warn("mpiname does not exist, mvapich is not installed properly")
else:
results = subprocess.check_output('mpiname', shell=True)
mpiname_results = results.decode('utf-8').strip()
if "MVAPICH2-GDR" in mpiname_results:
exists = True
else:
warnings.warn(
f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}"
)
return exists
def get_cmd(self, environment, active_resources):
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'mvapich backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'mvapich backend does not support limiting num nodes/gpus'
assert not self.args.detect_nvlink_pairs, "openmpi backend does not support remapping visible devices"
devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0]
assert all([n == process_per_node for n in devices_per_node]), "mvapich requires same number of devices per node"
with open(MVAPICH_TMP_HOSTFILE, 'w') as fd:
for host in self.resource_pool.keys():
fd.write(f'{host}\n')
mpirun_cmd = [
'mpirun',
'-np',
f'{total_process_count}',
'-ppn',
f'{process_per_node}',
'--hostfile',
f'{MVAPICH_TMP_HOSTFILE}',
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-env', f'{k}={v}']
python_exec = [sys.executable, "-u"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
class MosaicMLRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64):
super().__init__(args, world_info_base64)
def backend_exists(self):
return True
def parse_user_args(self):
user_args = []
for arg in self.args.user_args:
if arg.startswith('{') and arg.endswith('}'):
try:
arg_dict = json.loads(arg)
if 'config_files' in arg_dict:
config_files = {}
for k, v in arg_dict.get('config_files', {}).items():
config_files[k] = json.loads(v)
arg_dict['config_files'] = config_files
except json.JSONDecodeError as jde:
raise ValueError('Please use plain json for your configs. Check for comments and lowercase trues') from jde
arg = json.dumps(arg_dict, separators=(',', ':'))
user_args.append(arg)
return user_args
def get_cmd(self, environment, active_resources):
deepspeed_launch = [
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
'--world_info={}'.format(self.world_info_base64),
"--node_rank={}".format(os.environ['NODE_RANK']),
"--master_addr={}".format(os.environ['MASTER_ADDR']),
"--master_port={}".format(os.environ['MASTER_PORT']),
]
return deepspeed_launch + [self.user_script] + self.user_arguments
# Copyright 2020 The Microsoft DeepSpeed Team
"""
DeepSpeed runner is the main front-end to launching multi-worker
training jobs with DeepSpeed. By default this uses pdsh to parallel
ssh into multiple worker nodes and launch all the neccisary processes
per rank for training.
"""
import os
import sys
import json
import shutil
import base64
import argparse
import subprocess
import collections
from copy import deepcopy
import torch.cuda
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MosaicMLRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MOSAICML_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger
DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX']
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
PDSH_MAX_FAN_OUT = 1024
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="DeepSpeed runner to help launch distributed "
"multi-node/multi-gpu training jobs.")
parser.add_argument("-H",
"--hostfile",
type=str,
default=DLTS_HOSTFILE,
help="Hostfile path (in MPI style) that defines the "
"resource pool available to the job (e.g., "
"worker-0 slots=4)")
parser.add_argument("-i",
"--include",
type=str,
default="",
help='''Specify hardware resources to use during execution.
String format is
NODE_SPEC[@NODE_SPEC ...],
where
NODE_SPEC=NAME[:SLOT[,SLOT ...]].
If :SLOT is omitted, include all slots on that host.
Example: -i "worker-0@worker-1:0,2" will use all slots
on worker-0 and slots [0, 2] on worker-1.
''')
parser.add_argument("-e",
"--exclude",
type=str,
default="",
help='''Specify hardware resources to NOT use during execution.
Mutually exclusive with --include. Resource formatting
is the same as --include.
Example: -e "worker-1:0" will use all available
resources except slot 0 on worker-1.
''')
parser.add_argument("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to run on, this will use "
"the top N hosts from the given hostfile.")
parser.add_argument("--num_gpus",
type=int,
default=-1,
help="Max number of GPUs to use on each node, will use "
"[0:N) GPU ids on each node.")
parser.add_argument("--master_port",
default=TORCH_DISTRIBUTED_DEFAULT_PORT,
type=int,
help="(optional) Port used by PyTorch distributed for "
"communication during training.")
parser.add_argument("--master_addr",
default="",
type=str,
help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.")
parser.add_argument("--launcher",
default=PDSH_LAUNCHER,
type=str,
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, MVAPICH.")
parser.add_argument("--launcher_args",
default="",
type=str,
help="(optional) pass launcher specific arguments as a "
"single quoted argument.")
parser.add_argument("--force_multi",
action="store_true",
help="Force multi-node launcher mode, helps in cases where user "
"wants to launch on single remote node.")
parser.add_argument("--comment",
default="",
type=str,
help="A comment for the run that can provide metadata. Is passed to the SlurmLauncher, if using")
parser.add_argument("--detect_nvlink_pairs", action="store_true",
help="(optional) autodetects nvlink pairs and remaps CUDA_VISIBLE_DEVICES along the "
"fastest connections")
parser.add_argument("user_script",
type=str,
help="User script to launch, followed by any required "
"arguments.")
parser.add_argument('user_args', nargs=argparse.REMAINDER)
return parser.parse_args(args=args)
def fetch_hostfile(hostfile_path):
if not os.path.isfile(hostfile_path):
logger.warning("Unable to find hostfile, will proceed with training "
"with local resources only.")
return None
# e.g., worker-0 slots=16
with open(hostfile_path, 'r') as fd:
resource_pool = collections.OrderedDict()
for line in fd.readlines():
line = line.strip()
if line == '':
# skip empty lines
continue
try:
hostname, slots = line.split()
_, slot_count = slots.split("=")
slot_count = int(slot_count)
except ValueError as err:
logger.error("Hostfile is not formatted correctly, unable to "
"proceed with training.")
raise err
if hostname in resource_pool:
logger.error("Hostfile contains duplicate hosts, unable to "
"proceed with training.")
raise ValueError("host {} is already defined".format(hostname))
resource_pool[hostname] = slot_count
return resource_pool
def parse_resource_filter(host_info, include_str="", exclude_str=""):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
String format is NODE_SPEC[@NODE_SPEC ...], where
NODE_SPEC = NAME[:SLOT[,SLOT ...]].
If :SLOT is omitted, include/exclude all slots on that host.
Examples:
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
slots [0, 2] on worker-1.
exclude_str="worker-1:0" will use all available resources except
slot 0 on worker-1.
'''
# Constants that define our syntax
NODE_SEP = '@'
SLOT_LIST_START = ':'
SLOT_SEP = ','
# Ensure include/exclude are mutually exclusive
if (include_str != "") and (exclude_str != ""):
raise ValueError('include_str and exclude_str are mutually exclusive.')
# no-op
if (include_str == "") and (exclude_str == ""):
return host_info
# Either build from scratch or remove items
filtered_hosts = dict()
if include_str:
parse_str = include_str
if exclude_str != "":
filtered_hosts = deepcopy(host_info)
parse_str = exclude_str
# foreach node in the list
for node_config in parse_str.split(NODE_SEP):
# Node can either be alone or node:slot,slot,slot
if SLOT_LIST_START in node_config:
hostname, slots = node_config.split(SLOT_LIST_START)
slots = [int(x) for x in slots.split(SLOT_SEP)]
# sanity checks
if hostname not in host_info:
raise ValueError("Hostname '{}' not found in hostfile".format(hostname))
for s in slots:
if s not in host_info[hostname]:
raise ValueError("No slot '{}' specified on host '{}'".format(
s,
hostname))
# If include string, build the list from here
if include_str:
filtered_hosts[hostname] = slots
elif exclude_str:
for s in slots:
logger.info('removing {} from {}'.format(s, hostname))
filtered_hosts[hostname].remove(s)
# User just specified the whole node
else:
hostname = node_config
# sanity check hostname
if hostname not in host_info:
raise ValueError("Hostname '{}' not found in hostfile".format(hostname))
if include_str:
filtered_hosts[hostname] = host_info[hostname]
elif exclude_str:
filtered_hosts[hostname] = []
# Post-processing to remove duplicates and empty nodes
del_keys = []
for hostname in filtered_hosts:
# Remove duplicates
filtered_hosts[hostname] = list(set(filtered_hosts[hostname]))
# Remove empty hosts
if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname)
for name in del_keys:
del filtered_hosts[name]
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
# we map ranks to nodes correctly by maintaining host_info ordering.
ordered_hosts = collections.OrderedDict()
for host in host_info:
if host in filtered_hosts:
ordered_hosts[host] = filtered_hosts[host]
return ordered_hosts
def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
active_resources = collections.OrderedDict()
for hostname, slots in resource_pool.items():
active_resources[hostname] = list(range(slots))
return parse_resource_filter(active_resources,
include_str=inclusion,
exclude_str=exclusion)
def encode_world_info(world_info):
world_info_json = json.dumps(world_info).encode('utf-8')
world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8')
return world_info_base64
def main(args=None):
args = parse_args(args)
if args.num_nodes >= 0 or args.num_gpus >= 0:
if args.include != "" or args.exclude != "":
raise ValueError("Cannot specify num_nodes/gpus with include/exclude")
multi_node_exec = True
resource_pool = fetch_hostfile(args.hostfile)
if not resource_pool:
resource_pool = {}
device_count = torch.cuda.device_count()
if device_count == 0:
raise RuntimeError("Unable to proceed, no GPU resources available")
resource_pool['localhost'] = device_count
args.master_addr = "127.0.0.1"
multi_node_exec = False
if not multi_node_exec and args.num_nodes > 1:
raise ValueError("Num nodes is >1 but no extra nodes available via hostfile")
active_resources = parse_inclusion_exclusion(resource_pool,
args.include,
args.exclude)
env = os.environ.copy()
if not args.master_addr:
first_host = list(active_resources.keys())[0]
hostname_cmd = ["ssh {} hostname -I".format(first_host)]
result = subprocess.check_output(hostname_cmd, shell=True)
args.master_addr = result.decode('utf-8').split()[0]
logger.info("Using IP address of {} for node {}".format(
args.master_addr,
first_host))
if args.num_nodes > 0:
updated_active_resources = collections.OrderedDict()
for count, hostname in enumerate(active_resources.keys()):
if args.num_nodes == count:
break
updated_active_resources[hostname] = active_resources[hostname]
active_resources = updated_active_resources
if args.num_gpus > 0:
updated_active_resources = collections.OrderedDict()
for hostname in active_resources.keys():
updated_active_resources[hostname] = list(range(args.num_gpus))
active_resources = updated_active_resources
# encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources)
multi_node_exec = args.force_multi or len(active_resources) > 1
if not multi_node_exec:
deepspeed_launch = [
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
"--world_info={}".format(world_info_base64),
"--master_addr={}".format(args.master_addr),
"--master_port={}".format(args.master_port)
]
if args.detect_nvlink_pairs:
deepspeed_launch += ["--detect_nvlink_pairs"]
cmd = deepspeed_launch + [args.user_script] + args.user_args
else:
args.launcher = args.launcher.lower()
if args.launcher == PDSH_LAUNCHER:
runner = PDSHRunner(args, world_info_base64)
elif args.launcher == OPENMPI_LAUNCHER:
runner = OpenMPIRunner(args, world_info_base64, resource_pool)
elif args.launcher == MVAPICH_LAUNCHER:
runner = MVAPICHRunner(args, world_info_base64, resource_pool)
elif args.launcher == SLURM_LAUNCHER:
runner = SlurmRunner(args, world_info_base64, resource_pool)
elif args.launcher == MOSAICML_LAUNCHER:
runner = MosaicMLRunner(args, world_info_base64)
else:
raise NotImplementedError(f"Unknown launcher {args.launcher}")
if not runner.backend_exists():
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
curr_path = os.path.abspath('.')
if 'PYTHONPATH' in env:
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
else:
env['PYTHONPATH'] = curr_path
exports = ""
for var in env.keys():
if any([var.startswith(name) for name in EXPORT_ENVS]):
runner.add_export(var, env[var])
for environ_path in DEEPSPEED_ENVIRONMENT_PATHS:
environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME)
if os.path.isfile(environ_file):
with open(environ_file, 'r') as fd:
for var in fd.readlines():
key, val = var.split('=')
runner.add_export(key, val)
cmd = runner.get_cmd(env, active_resources)
logger.info("cmd = {}".format(' '.join(cmd)))
result = subprocess.Popen(cmd, env=dict(env, **runner.exports))
result.wait()
# In case of failure must propagate the error-condition back to the caller (usually shell). The
# actual error and traceback should have been printed in the subprocess, so in order to avoid
# unnecessary noise we just quietly exit here with the same code as the subprocess
if result.returncode > 0:
sys.exit(result.returncode)
if __name__ == "__main__":
main()
from .replace_module import replace_transformer_layer
此差异已折叠。
此差异已折叠。
from . import adam
from . import lamb
from . import sparse_attention
from . import transformer
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_module
from ..git_version_info import compatible_ops as __compatible_ops__
from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam
此差异已折叠。
此差异已折叠。
'''
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex, commit a109f85
'''
import torch
class MultiTensorApply(object):
def __init__(self, chunk_size):
self.chunk_size = chunk_size
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
'''
Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
'''
from ..op_builder import AsyncIOBuilder
../../csrc
\ No newline at end of file
from .fused_lamb import FusedLamb
此差异已折叠。
此差异已折叠。
../../op_builder
\ No newline at end of file
from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig, LocalSlidingWindowSparsityConfig
from .softmax import Softmax
from .matmul import MatMul
from .sparse_self_attention import SparseSelfAttention
from .bert_sparse_self_attention import BertSparseSelfAttention
from .sparse_attention_utils import SparseAttentionUtils
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
此差异已折叠。
============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-6.0.1, py-1.9.0, pluggy-0.13.1
rootdir: /home/chengli1/projects/DeepSpeed
plugins: forked-1.3.0, hypothesis-5.41.3, xdist-2.1.0, cov-2.10.1
collected 0 items
============================ no tests ran in 0.01s =============================
from ..runtime.pipe import PipelineModule, LayerSpec, TiedLayerSpec
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册