From 3d097bb8657fb78c18df2b7e9e56093764b7a564 Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Thu, 22 Sep 2022 18:04:36 -0700 Subject: [PATCH] Extend scratch buffer for long prompts (#2212) Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Co-authored-by: Reza Yazdani Co-authored-by: Jeff Rasley --- .../inference/csrc/apply_rotary_pos_emb.cu | 63 +++++-- csrc/transformer/inference/csrc/dequantize.cu | 4 - csrc/transformer/inference/csrc/gelu.cu | 8 +- .../transformer/inference/csrc/pt_binding.cpp | 163 ++++++++++++------ csrc/transformer/inference/csrc/transform.cu | 38 ++-- .../inference/includes/inference_context.h | 62 ++++++- .../includes/inference_cuda_layers.h | 7 +- .../inference/transformer_inference.py | 88 +++++----- 8 files changed, 279 insertions(+), 154 deletions(-) mode change 100755 => 100644 deepspeed/ops/transformer/inference/transformer_inference.py diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index e7279d77..4a91975a 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -18,7 +18,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -31,13 +32,15 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, unsigned offset = head_id * head_size; unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned seq_index = head_id % seq_len; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; if (head_id < total_count) { while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = mixed_query[offset + lane]; - float k = key_layer[offset + lane]; + float k = key_layer[k_offset + lane]; float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -47,7 +50,7 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); mixed_query[offset + lane] = q; - key_layer[offset + lane] = k; + key_layer[k_offset + lane] = k; lane += WARP_SIZE; } @@ -61,7 +64,8 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { #if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); @@ -75,13 +79,15 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, unsigned offset = head_id * head_size; unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned seq_index = head_id % seq_len; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; if (head_id < total_count) { while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[offset + lane]; + float k = (float)key_layer[k_offset + lane]; float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -91,7 +97,7 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); mixed_query[offset + lane] = (__half)q; - key_layer[offset + lane] = (__half)k; + key_layer[k_offset + lane] = (__half)k; lane += WARP_SIZE; } @@ -105,7 +111,8 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -118,13 +125,15 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, unsigned offset = head_id * head_size; unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned seq_index = head_id % seq_len; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; if (head_id < total_count) { while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = mixed_query[offset + lane]; - float k = key_layer[offset + lane]; + float k = key_layer[k_offset + lane]; float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -134,7 +143,7 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); mixed_query[offset + lane] = q; - key_layer[offset + lane] = k; + key_layer[k_offset + lane] = k; lane += WARP_SIZE; } @@ -147,7 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { #if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); @@ -160,7 +170,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; unsigned seq_index = head_id % seq_len; unsigned offset = head_id * head_size; - unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; constexpr unsigned mask[32] = { 0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000, @@ -209,17 +219,32 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned batch, bool rotate_half, bool rotate_every_two, - cudaStream_t stream) + cudaStream_t stream, + int max_out_tokens) { int total_count = batch * num_heads * seq_len; dim3 block_dims(1024); dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); if (rotate_every_two) - apply_rotary_pos_emb<<>>( - mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); + apply_rotary_pos_emb<<>>(mixed_query, + key_layer, + rotary_dim, + seq_len, + offset, + num_heads, + head_size, + total_count, + max_out_tokens); else if (rotate_half) - apply_rotary_pos_emb1<<>>( - mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); + apply_rotary_pos_emb1<<>>(mixed_query, + key_layer, + rotary_dim, + seq_len, + offset, + num_heads, + head_size, + total_count, + max_out_tokens); } template void launch_apply_rotary_pos_emb(float*, @@ -232,7 +257,8 @@ template void launch_apply_rotary_pos_emb(float*, unsigned, bool, bool, - cudaStream_t); + cudaStream_t, + int); template void launch_apply_rotary_pos_emb<__half>(__half*, __half*, unsigned, @@ -243,7 +269,8 @@ template void launch_apply_rotary_pos_emb<__half>(__half*, unsigned, bool, bool, - cudaStream_t); + cudaStream_t, + int); /* __global__ void apply_rotary_pos_emb(float* mixed_query, diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 3018845b..3843c2b6 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -152,10 +152,6 @@ __global__ void dequantize_kernel(__half* output, q_h[1] = __float2half(local_scale * (float)q_int8[1]); q_h[2] = __float2half(local_scale * (float)q_int8[2]); q_h[3] = __float2half(local_scale * (float)q_int8[3]); - // q_h[4] = __float2half(local_scale * (float)q_int8[4]); - // q_h[5] = __float2half(local_scale * (float)q_int8[5]); - // q_h[6] = __float2half(local_scale * (float)q_int8[6]); - // q_h[7] = __float2half(local_scale * (float)q_int8[7]); output_cast[tid] = q_f; tid += blockDim.x; } diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 3f9bf4ca..8bc58769 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -188,7 +188,7 @@ __global__ void fused_bias_residual(float* input, data.z = data.z + out.z + bias_data.z; data.w = data.w + out.w + bias_data.w; } - output_cast[offset] = data; + input_cast[offset] = data; } } @@ -260,7 +260,7 @@ __global__ void fused_bias_residual(__half* input, vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); - output_cast[offset] = vals_vec; + input_cast[offset] = vals_vec; } #endif } @@ -324,7 +324,7 @@ __global__ void gptj_residual_add(float* input, data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale; data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale; - output_cast[offset] = data; + input_cast[offset] = data; } } @@ -390,7 +390,7 @@ __global__ void gptj_residual_add(__half* input, vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); - output_cast[offset] = vals_vec; + input_cast[offset] = vals_vec; } #endif } diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 18f6767f..65549cdc 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -45,7 +45,7 @@ inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int // Bert style models have always a mask stride of 1. return 1; } else if (trnsfrmr_type == TransformerType::UNKNOWN) { - throw std::runtime_error("Unknown transformer type."); + return 0; } // this is just to make the compiler happy. @@ -102,14 +102,14 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, template void allocate_workspace(size_t hidden_dim, - size_t max_seq_len, size_t batch_size, unsigned num_layers, - size_t head_size = 128) + unsigned mp_size = 1, + bool external_cache = false, + unsigned rank = 0) { - size_t _workSpaceSize = 16 * (hidden_dim * batch_size * max_seq_len) + - (num_layers * batch_size * max_seq_len * hidden_dim * 2); // KV-cache - Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T)); + Context::Instance().GenWorkSpace( + num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank); } template @@ -124,10 +124,13 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) float alpha = 1; float gemm_beta = 0.0; - if (!workspace) { - allocate_workspace(W.size(1), MAX_OUT_TOKES, Q.size(0), 1); - workspace = (T*)Context::Instance().GetWorkSpace(); + /* + // Reallocate memory if we received a new prompt + if (!workspace || input.size(1) != 1) { + allocate_workspace(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1, + head_size); workspace = (T*)Context::Instance().GetWorkSpace(); } + */ auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options); unsigned m = W.size(1); @@ -349,7 +352,15 @@ void attention_unfused(T* prev_key_cont, float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float alpha = norm_factor * norm_factor / layer_scale; float gemm_beta = 0.0; - T* workspace = (T*)output + bsz * seq_len * heads * k; + T* workspace; + if (seq_len == 1) { + workspace = (T*)output + bsz * seq_len * heads * k; + } else { + // If we are doing the prompt, switch to the tail workspace + T* scratch = (T*)Context::Instance().GetWorkSpace(); + workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) - + bsz * heads * seq_len * soft_len); + } cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), @@ -363,7 +374,7 @@ void attention_unfused(T* prev_key_cont, workspace, CUBLAS_OP_T, CUBLAS_OP_N, - MAX_OUT_TOKES * k, + Context::Instance().GetMaxTokenLenght() * k, seq_len * k, seq_len * soft_len, bsz * heads, @@ -396,7 +407,7 @@ void attention_unfused(T* prev_key_cont, (T*)output, CUBLAS_OP_N, CUBLAS_OP_N, - MAX_OUT_TOKES * k, + Context::Instance().GetMaxTokenLenght() * k, seq_len * soft_len, seq_len * k, bsz * heads, @@ -444,12 +455,11 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); auto query_cont = workspace + 8 * buf_size; - size_t offset = - 16 * (hidden_dim * bsz * MAX_OUT_TOKES) + layer_id * 2 * bsz * MAX_OUT_TOKES * hidden_dim; - + size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) + + layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; unsigned all_tokens = soft_len; auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); - size_t value_offset = bsz * MAX_OUT_TOKES * hidden_dim; + size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; T* temp_buf = (T*)output.data_ptr() + at::numel(output); launch_bias_add_transform_0213((T*)query_cont, @@ -467,7 +477,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, rotate_half, rotate_every_two, Context::Instance().GetCurrentStream(), - 3); + 3, + Context::Instance().GetMaxTokenLenght()); if (rotary_dim > 0 && rotate_half) launch_apply_rotary_pos_emb(query_cont, kv_cache, @@ -479,7 +490,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream()); + Context::Instance().GetCurrentStream(), + Context::Instance().GetMaxTokenLenght()); attention_unfused(workspace + offset, (T*)query_cont, @@ -614,16 +626,17 @@ void ds_layernorm_internal(T* workspace, } template -void quantized_gemm(at::Tensor& output, +void quantized_gemm(void* output, T* input, at::Tensor& weight, at::Tensor& qscale, int groups, int bsz) { - auto weight16 = at::empty({weight.size(0), weight.size(1)}, output.options()); + T* weight16 = (T*)Context::Instance().GetWorkSpace() + + 12 * Context::Instance().GetMaxTokenLenght() * weight.size(1); - launch_dequantize((T*)weight16.data_ptr(), + launch_dequantize(weight16, (int8_t*)weight.data_ptr(), (float*)qscale.data_ptr(), weight.size(0), @@ -641,9 +654,9 @@ void quantized_gemm(at::Tensor& output, weight.size(1), &alpha, &gemm_beta, - (T*)weight16.data_ptr(), + weight16, (T*)input, - (T*)output.data_ptr(), + (T*)output, #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); #else @@ -667,10 +680,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, T* workspace = (T*)Context::Instance().GetWorkSpace(); workspace += (3 * bsz * input.size(2)); ds_layernorm_internal(workspace, input, gamma, beta, epsilon); - // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); if (q_int8) { - quantized_gemm(output, workspace, weight, q_scale, q_scale.size(0), bsz); + quantized_gemm(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -713,17 +725,20 @@ std::vector ds_qkv_gemm(at::Tensor& input, const float epsilon, bool add_bias, unsigned num_layers, + bool external_cache, + unsigned mp_size, + unsigned rank, bool q_int8) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); int out_size = q_int8 ? weight.size(0) : weight.size(1); - if (!workspace) { + if (!workspace) cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers); - workspace = (T*)Context::Instance().GetWorkSpace(); - } + allocate_workspace(input.size(2), input.size(0), num_layers, mp_size, external_cache, rank); + workspace = (T*)Context::Instance().GetWorkSpace(); + auto options = at::TensorOptions() .dtype(input.options().dtype()) .layout(at::kStrided) @@ -831,10 +846,11 @@ at::Tensor ds_linear_layer(at::Tensor& input, int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); - if (!workspace) { + // Reallocate memory if we received a new prompt + if (!workspace || input.size(1) != 1) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers); + allocate_workspace(input.size(2), input.size(0), num_layers); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); @@ -902,18 +918,20 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& q_scale, bool q_int8) { - auto input_cont = input.contiguous(); auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); int out_size = q_int8 ? weight.size(0) : weight.size(1); - int bsz = input_cont.size(0) * input_cont.size(1); - auto output = at::empty({input_cont.size(0), input_cont.size(1), out_size}, options); + int bsz = input.size(0) * input.size(1); + + T* workspace = (T*)Context::Instance().GetWorkSpace(); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); if (q_int8) { - quantized_gemm(output, (T*)input_cont.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + quantized_gemm( + output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -924,11 +942,11 @@ at::Tensor ds_vector_matmul(at::Tensor& input, CUBLAS_OP_N, weight.size(1), bsz, - input_cont.size(2), + input.size(2), &alpha, &gemm_beta, (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), + (T*)input.data_ptr(), (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); @@ -965,6 +983,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, at::Tensor& residual, at::Tensor& input_bias, at::Tensor& weight, + at::Tensor& weight1, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, @@ -972,13 +991,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, bool preLayerNorm, bool mlp_after_attn, at::Tensor& q_scale, + at::Tensor& q_scale1, bool q_int8, ActivationFuncType act_func_type) { int bsz = input.size(0) * input.size(1); - auto inp_norm = at::empty_like(input); - - launch_residual_layer_norm((T*)inp_norm.data_ptr(), + T* inp_norm = + (T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output); + T* intermediate = inp_norm + torch::numel(input); + launch_residual_layer_norm((T*)inp_norm, (T*)nullptr, (T*)input.data_ptr(), (T*)residual.data_ptr(), @@ -993,7 +1014,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, Context::Instance().GetCurrentStream()); if (q_int8) { - quantized_gemm(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + quantized_gemm(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1008,8 +1029,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, &alpha, &gemm_beta, (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), + inp_norm, + intermediate, #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); #else @@ -1017,20 +1038,45 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, #endif } if (act_func_type == ActivationFuncType::GELU) { - launch_bias_gelu((T*)output.data_ptr(), + launch_bias_gelu(intermediate, (T*)bias.data_ptr(), q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } else if (act_func_type == ActivationFuncType::ReLU) { - launch_bias_relu((T*)output.data_ptr(), + launch_bias_relu(intermediate, (T*)bias.data_ptr(), q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } + if (q_int8) { + quantized_gemm( + output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight1.size(1), + bsz, + weight1.size(0), + &alpha, + &gemm_beta, + (T*)weight1.data_ptr(), + intermediate, + (T*)output.data_ptr(), +#ifdef __HIP_PLATFORM_HCC__ + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } - return inp_norm; + return torch::from_blob(inp_norm, input.sizes(), input.options()); } template @@ -1038,6 +1084,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& residual, at::Tensor& input_bias, at::Tensor& weight, + at::Tensor& weight1, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, @@ -1045,21 +1092,21 @@ std::vector ds_mlp_gemm(at::Tensor& input, bool preLayerNorm, bool mlp_after_attn, at::Tensor& q_scale, + at::Tensor& q_scale1, bool q_int8, int activation_type) { - auto input_cont = input.contiguous(); auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); int out_size = q_int8 ? weight.size(0) : weight.size(1); - auto output = at::from_blob((T*)Context::Instance().GetWorkSpace(), - {input_cont.size(0), input_cont.size(1), out_size}, + auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), + {input.size(0), input.size(1), out_size}, options); - int bsz = input_cont.size(0) * input_cont.size(1); + int bsz = input.size(0) * input.size(1); auto act_func_type = static_cast(activation_type); auto res_add = mlp_unfused_cublas(output, @@ -1067,6 +1114,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, residual, input_bias, weight, + weight1, bias, gamma, beta, @@ -1074,6 +1122,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, preLayerNorm, mlp_after_attn, q_scale, + q_scale1, q_int8, act_func_type); @@ -1184,7 +1233,7 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, template at::Tensor& residual_add_bias(at::Tensor& hidden_state, - const at::Tensor& residual, + at::Tensor& residual, const at::Tensor& attention_output, const at::Tensor& attention_bias, const at::Tensor& final_bias, @@ -1217,7 +1266,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, bsz, mp_size, Context::Instance().GetCurrentStream()); - return hidden_state; + return residual; } std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, @@ -1246,7 +1295,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream()); + Context::Instance().GetCurrentStream(), + Context::Instance().GetMaxTokenLenght()); else launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), (__half*)key_cont.data_ptr(), @@ -1258,7 +1308,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream()); + Context::Instance().GetCurrentStream(), + Context::Instance().GetMaxTokenLenght()); return {query_cont, key_cont}; } diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 92305162..32d2df95 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -22,23 +22,21 @@ __global__ void bias_add_transform_0213(float* output, int rotary_dim, bool rotate_half, bool rotate_every_two, - int head_ext) + int head_ext, + int max_out_tokens) { int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - // int d2_out_stride = d2_stride * seq_length; - int d0 = blockIdx.x; // Batch int d1 = blockIdx.y; // Sequence ID (0-127) int cnt = blockIdx.z / head_ext; // Hidden count int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) int d3 = threadIdx.x; // Values (groups of 4) - int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); const float4* vals_vec = reinterpret_cast(vals); float4* output_vec = @@ -50,7 +48,7 @@ __global__ void bias_add_transform_0213(float* output, vals_vec += (d2 * d2_stride); output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); + output_vec += (d0 * d0_out_stride); output_vec += (d2 * d2_out_stride); unsigned seq_id = d1 + seq_offset; @@ -89,7 +87,8 @@ __global__ void bias_add_transform_0213(__half* output, // q int rotary_dim, bool rotate_half, bool rotate_every_two, - int head_ext) + int head_ext, + int max_out_tokens) { #if __CUDA_ARCH__ >= 700 @@ -104,7 +103,9 @@ __global__ void bias_add_transform_0213(__half* output, // q int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) int d3 = threadIdx.x; // Values (groups of 4) - int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); + float4 vals_arr; float4 output_arr; @@ -121,7 +122,7 @@ __global__ void bias_add_transform_0213(__half* output, // q vals_vec += (d2 * d2_stride); output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); + output_vec += (d0 * d0_out_stride); output_vec += (d2 * d2_out_stride); unsigned seq_id = d1 + seq_offset; @@ -166,7 +167,8 @@ void launch_bias_add_transform_0213(float* output, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count) + int trans_count, + int max_out_tokens) { hidden_dim >>= 2; int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; @@ -186,7 +188,8 @@ void launch_bias_add_transform_0213(float* output, rotary_dim >> 2, rotate_half, rotate_every_two, - head_ext); + head_ext, + max_out_tokens); } template void launch_bias_add_transform_0213(T* outputs, @@ -204,7 +207,8 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count); + int trans_count, + int max_out_tokens); template <> void launch_bias_add_transform_0213<__half>(__half* output, __half* k_cache, @@ -221,12 +225,13 @@ void launch_bias_add_transform_0213<__half>(__half* output, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count) + int trans_count, + int max_out_tokens) { hidden_dim >>= 3; int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(1, seq_length, (trans_count * head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); bias_add_transform_0213<<>>(output, k_cache, v_cache, @@ -240,7 +245,8 @@ void launch_bias_add_transform_0213<__half>(__half* output, rotary_dim >> 3, rotate_half, rotate_every_two, - head_ext); + head_ext, + max_out_tokens); } // Bias add diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index a6f6613f..fe0616a7 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -12,6 +12,10 @@ Copyright 2022 The Microsoft DeepSpeed Team #include "cublas_v2.h" #include "cuda.h" +#define MEGABYTE (1024 * 1024) +#define GIGABYTE (1024 * 1024 * 1024) + +#define MAX_OUT_TOKENS 8192 #define WARP_SIZE 32 #define CUDA_CHECK(callstr) \ @@ -43,7 +47,13 @@ inline int DS_GET_BLOCKS(const int N) class Context { public: - Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0) + Context() + : _workspace(nullptr), + _seed(42), + _curr_offset(0), + _stream(0), + _free_memory_size(0), + _num_tokens(1) { if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { auto message = std::string("Fail to create cublas handle."); @@ -75,19 +85,51 @@ public: return _ctx; } - void GenWorkSpace(size_t size) + void GenWorkSpace(const unsigned& num_layers, + const size_t& batch_size, + const size_t& hidden_dim, + const unsigned& mp_size, + const bool& external_cache, + const size_t& elem_size, + const unsigned& rank) { + size_t total_size; + if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } + + size_t activation_size = 16 * hidden_dim * batch_size; + size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; + _max_seq_len = + (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / + elem_size)) / + (activation_size + cache_size); + size_t workSpaceSize = (external_cache ? activation_size : (activation_size + cache_size)) * + _max_seq_len * elem_size; + _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); + if (rank == 0 && !_workspace) + printf( + "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " + "tokens (input + output) to %lu \n", + _free_memory_size, + total_size, + _max_seq_len); if (!_workspace) { assert(_workspace == nullptr); - cudaMalloc(&_workspace, size); - } else if (_workSpaceSize < size) { + cudaMalloc(&_workspace, workSpaceSize); + } else if (_workSpaceSize < workSpaceSize) { cudaFree(_workspace); - cudaMalloc(&_workspace, size); + cudaMalloc(&_workspace, workSpaceSize); } - if (!_workspace) { throw std::runtime_error("Workspace is null."); } - _workSpaceSize = size; + if (!_workspace) { + printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + workSpaceSize, + _free_memory_size, + total_size); + throw std::runtime_error("Workspace is null."); + } + _workSpaceSize = workSpaceSize; } + inline size_t GetMaxTokenLenght() const { return _max_seq_len; } cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } @@ -100,7 +142,7 @@ public: return _token_length; } - inline void reset_tokens(unsigned initial_tokens = 0) + inline void reset_tokens(unsigned initial_tokens = 1) { _num_tokens = initial_tokens; } //_token_length = 0; } @@ -160,7 +202,11 @@ private: void* _workspace; uint64_t _seed; uint64_t _curr_offset; + size_t _workSpaceSize; + size_t _free_memory_size; + + size_t _max_seq_len; cudaEvent_t _comp1_event; cudaEvent_t _comp2_event; diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index d1e10c51..6302ceb2 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -21,7 +21,6 @@ Copyright 2022 The Microsoft DeepSpeed Team #include #include -#define MAX_OUT_TOKES 128 #define MAX_WARP_NUM 32 #define WARP_SIZE 32 @@ -142,7 +141,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned batch, bool rotate_half, bool rotate_every_two, - cudaStream_t stream); + cudaStream_t stream, + int max_out_tokens); template void launch_moe_res_matmul(T* residual, @@ -178,4 +178,5 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count); + int trans_count, + int max_out_tokens); diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py old mode 100755 new mode 100644 index 4b73f071..ff50d43d --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -380,6 +380,8 @@ class DeepSpeedSelfAttentionFunction(Function): ) else 0 sliced_alibi = alibi[offset:batch_heads + offset, :, :] + +# attn_key_value = score_context_func( qkv_out, ((1 - input_mask).to(qkv_out.dype) * @@ -402,6 +404,7 @@ class DeepSpeedSelfAttentionFunction(Function): def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ inference_cuda_module.vector_matmul_fp32 + if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 @@ -413,18 +416,19 @@ class DeepSpeedSelfAttentionFunction(Function): else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 - qkv_out = qkv_func( - input, - attn_qkvw, - attn_qkvw.scale, - (attn_qkvb if attn_qkvb is not None else norm_b), - norm_w, - norm_b, - config.epsilon, - (attn_qkvb is not None), - 1 if config.bigscience_bloom else - DeepSpeedTransformerInference.layer_id, - config.q_int8) + qkv_out = qkv_func(input, + attn_qkvw, + attn_qkvw.scale, + (attn_qkvb if attn_qkvb is not None else norm_b), + norm_w, + norm_b, + config.epsilon, + (attn_qkvb is not None), + DeepSpeedTransformerInference.layer_id, + config.bigscience_bloom, + config.mp_size, + dist.get_rank() if dist.is_initialized() else 0, + config.q_int8) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, @@ -495,20 +499,18 @@ class DeepSpeedSelfAttention(nn.Module): self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.attn_qkvw = nn.Parameter(torch.empty( - self.config.hidden_size, - (self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type, - device=device), + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), requires_grad=False) - self.attn_qkvb = nn.Parameter(torch.empty( - (self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type_fp, - device=device), + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, + dtype=data_type_fp, + device=device), requires_grad=False) - - self.attn_ow = nn.Parameter(torch.empty(self.config.hidden_size // - self.config.mp_size, + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, self.config.hidden_size, dtype=data_type, device=device), @@ -613,10 +615,11 @@ class DeepSpeedMLPFunction(Function): config.pre_layer_norm, False) else: - intermediate, residual_add = mlp_gemm_func(input, + output, residual_add = mlp_gemm_func(input, residual, bias, inter_w, + output_w, inter_b, attn_nw, attn_nb, @@ -624,17 +627,14 @@ class DeepSpeedMLPFunction(Function): config.pre_layer_norm, config.mlp_after_attn, inter_w.scale, + output_w.scale, config.q_int8, config.mlp_act_func_type) - output = vector_matmul_func(intermediate, - output_w, - False, - output_w.scale, - config.q_int8) - output = residual_add_func( - output, # hidden state - residual if config.pre_layer_norm else residual_add, # residual - input, # attention output + residual = residual if config.pre_layer_norm else residual_add + residual_add_func( + output, # hidden state + residual, # residual + input, # attention output bias if bias is not None else output_b, output_b, config.mp_size, # model parallel size @@ -642,8 +642,8 @@ class DeepSpeedMLPFunction(Function): bias is not None, # whether bias addition is fused config.pre_layer_norm) # whether the layer norm is applied before attention if mp_group is not None and dist.get_world_size(group=mp_group) > 1: - dist.all_reduce(output, group=mp_group) - return output + dist.all_reduce(residual, group=mp_group) + return residual @staticmethod def backward(ctx, grad_output): @@ -673,22 +673,20 @@ class DeepSpeedMLP(nn.Module): dtype=data_type_fp, device=device), requires_grad=False) + intm_size_per_partition = self.config.intermediate_size // self.config.mp_size self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, - self.config.intermediate_size // - self.config.mp_size, + intm_size_per_partition, dtype=data_type, device=device), requires_grad=False) - self.inter_b = nn.Parameter(torch.empty(self.config.intermediate_size // - self.config.mp_size, + self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), requires_grad=False) - self.output_w = nn.Parameter(torch.empty( - (self.config.intermediate_size // self.config.mp_size), - self.config.hidden_size, - dtype=data_type, - device=device), + self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), requires_grad=False) self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, -- GitLab