未验证 提交 fd2f970b 编写于 作者: R Reza Yazdani 提交者: GitHub

Transformer-kernel - supporting any arbitrary sequence-length (#587)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 6380ee35
Subproject commit fa1d1a71c48623db8a091d9cf636a5fe3b8f43c7
Subproject commit abb270641ca8c33476282bde29916c395a060ae9
......@@ -14,6 +14,8 @@
static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
const int init_seq_length = 128;
// C++ interface
template <typename T>
......@@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id,
int hidden_dim,
int num_heads,
int intermediate_size,
int seq_length,
float attn_dropout_ratio,
float hidden_dropout_ratio,
int seed,
......@@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id,
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads);
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
seq_length,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
pre_or_postLayerNorm,
......@@ -873,6 +874,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
}
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
......
......@@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals,
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
......@@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals,
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
......@@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals,
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
......@@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals,
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
......@@ -339,7 +343,9 @@ void launch_attn_softmax<float>(float* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
......@@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
......
......@@ -18,7 +18,6 @@ stochastic_transformer_cuda_module = None
class TransformerConfig():
def __init__(self,
batch_size,
max_seq_length,
hidden_size,
intermediate_size,
heads,
......@@ -30,7 +29,6 @@ class TransformerConfig():
self.batch_size = batch_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_seq_length = max_seq_length
self.heads = heads
self.attn_dropout_ratio = attn_dropout_ratio
self.hidden_dropout_ratio = hidden_dropout_ratio
......@@ -92,7 +90,6 @@ class DeepSpeedTransformerConfig(TransformerConfig):
"""
def __init__(self,
batch_size=-1,
max_seq_length=-1,
hidden_size=-1,
intermediate_size=-1,
heads=-1,
......@@ -112,7 +109,6 @@ class DeepSpeedTransformerConfig(TransformerConfig):
super(DeepSpeedTransformerConfig,
self).__init__(
batch_size,
max_seq_length,
hidden_size,
(intermediate_size if intermediate_size > 0 else 4 * hidden_size),
heads,
......@@ -142,7 +138,7 @@ class DeepSpeedTransformerConfig(TransformerConfig):
@classmethod
def from_json_file(cls, json_file):
with open(json_file, "r", encoding='utf-8') as reader:
with open(json_file, "r", encoding='utf-16') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
......@@ -177,6 +173,18 @@ class DeepSpeedTransformerFunction(Function):
cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32
inp_size = input.size()
if inp_size[1] % 16 != 0:
input = torch.cat((input,
torch.randn((inp_size[0],
(16 - (inp_size[1] % 16)),
inp_size[2]),
device=input.device,
dtype=input.dtype)),
1)
input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \
(16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3)
(output,
inp_norm,
qkv_tf,
......@@ -303,11 +311,17 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var
if inp_size[1] % 16 != 0:
output = torch.narrow(output, 1, 0, inp_size[1])
return output
@staticmethod
def backward(ctx, grad_output):
bsz = grad_output.shape[0]
grad_output_shape = grad_output.size()
if grad_output_shape[1] % 16 != 0:
grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \
grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1)
if bsz > ctx.config.batch_size:
raise ValueError('grad_output batch size exceeds the limit.')
......@@ -398,6 +412,9 @@ class DeepSpeedTransformerFunction(Function):
norm_w,
norm_b)
if grad_output_shape[1] % 16 != 0:
grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])
return (grad_input,
None,
None,
......@@ -501,7 +518,6 @@ class DeepSpeedTransformerLayer(nn.Module):
self.config.hidden_size,
self.config.heads,
self.config.intermediate_size,
self.config.max_seq_length,
self.config.attn_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.seed,
......
......@@ -150,7 +150,7 @@ def create_models(ds_config):
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=ds_config.initializer_range)
......@@ -210,25 +210,18 @@ def set_seed(seed):
torch.manual_seed(seed)
def run_backward(ds_config, atol=1e-2, verbose=False):
def run_backward(ds_config, seq_len, atol=1e-2, verbose=False):
set_seed(123)
bert_encoder, ds_encoder = create_models(ds_config)
# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(ds_config.batch_size,
ds_config.max_seq_length,
seq_len,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(ds_config.batch_size,
1,
1,
ds_config.max_seq_length,
**kwargs)
Y = torch.randn(ds_config.batch_size,
ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(ds_config.batch_size, 1, 1, seq_len, **kwargs)
Y = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs)
# run baseline
base_results = bert_encoder(hidden_states,
......@@ -257,12 +250,12 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
#test_backward[3-1024-120-16-24-True-True-0.05]
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[
(3,1024,120,16,24,True,False, 0.05),
(3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2),
(3,128,56,2,24,False,False, 0.1),
(3,128,56,2,24,False,True, 0.2),
(3,1024,119,16,24,True,False, 0.05),
(3,1024,115,16,24,True,True, 0.05),
(1024,128,10,2,2,False,False, 0.1),
(3,1024,52,16,24,False,True, 0.2),
(3,128,51,2,24,False,False, 0.1),
(3,128,54,2,24,False,True, 0.2),
]) # yapf: disable
def test_backward(batch_size,
hidden_size,
......@@ -282,7 +275,6 @@ def test_backward(batch_size,
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
......@@ -291,7 +283,7 @@ def test_backward(batch_size,
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_backward(ds_config, atol=atol)
run_backward(ds_config, seq_len, atol=atol)
#@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
......
......@@ -117,7 +117,7 @@ def create_models(ds_config):
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=ds_config.initializer_range,
fp16=ds_config.fp16)
......@@ -186,13 +186,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(bsz,
seq_len, #ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(bsz, 1, 1,
seq_len, #ds_config.max_seq_length,
**kwargs)
hidden_states = torch.randn(bsz, seq_len, ds_config.hidden_size, **kwargs)
input_mask = torch.randn(bsz, 1, 1, seq_len, **kwargs)
# run baseline
base_results = bert_encoder(hidden_states,
......@@ -213,25 +208,25 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(8,256,128,4,3,True,False),
(8,256,128,4,3,True,True),
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
(8,256,53,4,3,True,False),
(8,256,52,4,3,True,True),
(3,1024,51,16,3,True,False),
(3,1024,54,16,3,True,True),
(8,1024,381,16,3,True,False),
(8,1024,384,16,3,True,True),
(8,1024,384,16,3,True,True),
(8,1024,120,16,3,True,False),
(8,1024,119,16,3,True,False),
(8,1024,120,16,3,True,True),
(8,1024,512,16,3,True,False),
(8,1024,509,16,3,True,False),
(8,1024,512,16,3,True,True),
(64,1024,56,16,3,False,False),
(64,1024,56,16,3,False,True),
(64,1024,53,16,3,False,True),
(64,1024,24,16,3,False,False),
(64,1024,24,16,3,False,True),
(64,1024,21,16,3,False,True),
(8,1024,384,16,3,False,False),
(8,1024,384,16,3,False,True),
(8,1024,512,16,3,False,False),
(8,1024,512,16,3,False,True),
(8,1024,511,16,3,False,True),
(8,1536,128,24,3,False,False),
(8,1536,128,24,3,False,True),
(8,2048,128,32,3,False,False),
......@@ -259,7 +254,6 @@ def test_forward(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.max_seq_length = 128 #seq_len
ds_config.intermediate_size = 4 * hidden_size
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
......@@ -297,7 +291,6 @@ def test_forward_with_small_bsz(batch_size,
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
......@@ -332,7 +325,6 @@ def test_forward_stochastic(batch_size,
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册