未验证 提交 3d097bb8 编写于 作者: C Connor Holmes 提交者: GitHub

Extend scratch buffer for long prompts (#2212)

Co-authored-by: NReza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: NReza Yazdani <reyazda@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 b76e0f4f
......@@ -18,7 +18,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
......@@ -31,13 +32,15 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
......@@ -47,7 +50,7 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
key_layer[k_offset + lane] = k;
lane += WARP_SIZE;
}
......@@ -61,7 +64,8 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
......@@ -75,13 +79,15 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
......@@ -91,7 +97,7 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
......@@ -105,7 +111,8 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
......@@ -118,13 +125,15 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
......@@ -134,7 +143,7 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
key_layer[k_offset + lane] = k;
lane += WARP_SIZE;
}
......@@ -147,7 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
......@@ -160,7 +170,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned seq_index = head_id % seq_len;
unsigned offset = head_id * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
......@@ -209,17 +219,32 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream)
cudaStream_t stream,
int max_out_tokens)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
else if (rotate_half)
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
}
template void launch_apply_rotary_pos_emb<float>(float*,
......@@ -232,7 +257,8 @@ template void launch_apply_rotary_pos_emb<float>(float*,
unsigned,
bool,
bool,
cudaStream_t);
cudaStream_t,
int);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
......@@ -243,7 +269,8 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
unsigned,
bool,
bool,
cudaStream_t);
cudaStream_t,
int);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
......
......@@ -152,10 +152,6 @@ __global__ void dequantize_kernel(__half* output,
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
// q_h[4] = __float2half(local_scale * (float)q_int8[4]);
// q_h[5] = __float2half(local_scale * (float)q_int8[5]);
// q_h[6] = __float2half(local_scale * (float)q_int8[6]);
// q_h[7] = __float2half(local_scale * (float)q_int8[7]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
......
......@@ -188,7 +188,7 @@ __global__ void fused_bias_residual(float* input,
data.z = data.z + out.z + bias_data.z;
data.w = data.w + out.w + bias_data.w;
}
output_cast[offset] = data;
input_cast[offset] = data;
}
}
......@@ -260,7 +260,7 @@ __global__ void fused_bias_residual(__half* input,
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
input_cast[offset] = vals_vec;
}
#endif
}
......@@ -324,7 +324,7 @@ __global__ void gptj_residual_add(float* input,
data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale;
data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale;
output_cast[offset] = data;
input_cast[offset] = data;
}
}
......@@ -390,7 +390,7 @@ __global__ void gptj_residual_add(__half* input,
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
input_cast[offset] = vals_vec;
}
#endif
}
......
......@@ -45,7 +45,7 @@ inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int
// Bert style models have always a mask stride of 1.
return 1;
} else if (trnsfrmr_type == TransformerType::UNKNOWN) {
throw std::runtime_error("Unknown transformer type.");
return 0;
}
// this is just to make the compiler happy.
......@@ -102,14 +102,14 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t max_seq_len,
size_t batch_size,
unsigned num_layers,
size_t head_size = 128)
unsigned mp_size = 1,
bool external_cache = false,
unsigned rank = 0)
{
size_t _workSpaceSize = 16 * (hidden_dim * batch_size * max_seq_len) +
(num_layers * batch_size * max_seq_len * hidden_dim * 2); // KV-cache
Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T));
Context::Instance().GenWorkSpace(
num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank);
}
template <typename T>
......@@ -124,10 +124,13 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
float alpha = 1;
float gemm_beta = 0.0;
if (!workspace) {
allocate_workspace<T>(W.size(1), MAX_OUT_TOKES, Q.size(0), 1);
workspace = (T*)Context::Instance().GetWorkSpace();
/*
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1,
head_size); workspace = (T*)Context::Instance().GetWorkSpace();
}
*/
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
unsigned m = W.size(1);
......@@ -349,7 +352,15 @@ void attention_unfused(T* prev_key_cont,
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0;
T* workspace = (T*)output + bsz * seq_len * heads * k;
T* workspace;
if (seq_len == 1) {
workspace = (T*)output + bsz * seq_len * heads * k;
} else {
// If we are doing the prompt, switch to the tail workspace
T* scratch = (T*)Context::Instance().GetWorkSpace();
workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) -
bsz * heads * seq_len * soft_len);
}
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
......@@ -363,7 +374,7 @@ void attention_unfused(T* prev_key_cont,
workspace,
CUBLAS_OP_T,
CUBLAS_OP_N,
MAX_OUT_TOKES * k,
Context::Instance().GetMaxTokenLenght() * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
......@@ -396,7 +407,7 @@ void attention_unfused(T* prev_key_cont,
(T*)output,
CUBLAS_OP_N,
CUBLAS_OP_N,
MAX_OUT_TOKES * k,
Context::Instance().GetMaxTokenLenght() * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
......@@ -444,12 +455,11 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options);
auto query_cont = workspace + 8 * buf_size;
size_t offset =
16 * (hidden_dim * bsz * MAX_OUT_TOKES) + layer_id * 2 * bsz * MAX_OUT_TOKES * hidden_dim;
size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
unsigned all_tokens = soft_len;
auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1);
size_t value_offset = bsz * MAX_OUT_TOKES * hidden_dim;
size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
T* temp_buf = (T*)output.data_ptr() + at::numel(output);
launch_bias_add_transform_0213<T>((T*)query_cont,
......@@ -467,7 +477,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
3);
3,
Context::Instance().GetMaxTokenLenght());
if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont,
kv_cache,
......@@ -479,7 +490,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
attention_unfused<T>(workspace + offset,
(T*)query_cont,
......@@ -614,16 +626,17 @@ void ds_layernorm_internal(T* workspace,
}
template <typename T>
void quantized_gemm(at::Tensor& output,
void quantized_gemm(void* output,
T* input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int bsz)
{
auto weight16 = at::empty({weight.size(0), weight.size(1)}, output.options());
T* weight16 = (T*)Context::Instance().GetWorkSpace() +
12 * Context::Instance().GetMaxTokenLenght() * weight.size(1);
launch_dequantize((T*)weight16.data_ptr(),
launch_dequantize(weight16,
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(0),
......@@ -641,9 +654,9 @@ void quantized_gemm(at::Tensor& output,
weight.size(1),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
weight16,
(T*)input,
(T*)output.data_ptr(),
(T*)output,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
......@@ -667,10 +680,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
T* workspace = (T*)Context::Instance().GetWorkSpace();
workspace += (3 * bsz * input.size(2));
ds_layernorm_internal<T>(workspace, input, gamma, beta, epsilon);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
if (q_int8) {
quantized_gemm<T>(output, workspace, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -713,17 +725,20 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
const float epsilon,
bool add_bias,
unsigned num_layers,
bool external_cache,
unsigned mp_size,
unsigned rank,
bool q_int8)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1);
if (!workspace) {
if (!workspace)
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
allocate_workspace<T>(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers);
workspace = (T*)Context::Instance().GetWorkSpace();
}
allocate_workspace<T>(input.size(2), input.size(0), num_layers, mp_size, external_cache, rank);
workspace = (T*)Context::Instance().GetWorkSpace();
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
......@@ -831,10 +846,11 @@ at::Tensor ds_linear_layer(at::Tensor& input,
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
if (!workspace) {
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
allocate_workspace<T>(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers);
allocate_workspace<T>(input.size(2), input.size(0), num_layers);
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
......@@ -902,18 +918,20 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
at::Tensor& q_scale,
bool q_int8)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input_cont.size(0) * input_cont.size(1);
auto output = at::empty({input_cont.size(0), input_cont.size(1), out_size}, options);
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output, (T*)input_cont.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -924,11 +942,11 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
......@@ -965,6 +983,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& weight1,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
......@@ -972,13 +991,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
bool preLayerNorm,
bool mlp_after_attn,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
ActivationFuncType act_func_type)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
launch_residual_layer_norm((T*)inp_norm.data_ptr(),
T* inp_norm =
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output);
T* intermediate = inp_norm + torch::numel(input);
launch_residual_layer_norm((T*)inp_norm,
(T*)nullptr,
(T*)input.data_ptr(),
(T*)residual.data_ptr(),
......@@ -993,7 +1014,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
Context::Instance().GetCurrentStream());
if (q_int8) {
quantized_gemm<T>(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -1008,8 +1029,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
inp_norm,
intermediate,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
......@@ -1017,20 +1038,45 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
#endif
}
if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu((T*)output.data_ptr(),
launch_bias_gelu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu((T*)output.data_ptr(),
launch_bias_relu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight1.size(1),
bsz,
weight1.size(0),
&alpha,
&gemm_beta,
(T*)weight1.data_ptr(),
intermediate,
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
return inp_norm;
return torch::from_blob(inp_norm, input.sizes(), input.options());
}
template <typename T>
......@@ -1038,6 +1084,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& weight1,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
......@@ -1045,21 +1092,21 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
bool preLayerNorm,
bool mlp_after_attn,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
int activation_type)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace(),
{input_cont.size(0), input_cont.size(1), out_size},
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
int bsz = input_cont.size(0) * input_cont.size(1);
int bsz = input.size(0) * input.size(1);
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
auto res_add = mlp_unfused_cublas<T>(output,
......@@ -1067,6 +1114,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
residual,
input_bias,
weight,
weight1,
bias,
gamma,
beta,
......@@ -1074,6 +1122,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
preLayerNorm,
mlp_after_attn,
q_scale,
q_scale1,
q_int8,
act_func_type);
......@@ -1184,7 +1233,7 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
template <typename T>
at::Tensor& residual_add_bias(at::Tensor& hidden_state,
const at::Tensor& residual,
at::Tensor& residual,
const at::Tensor& attention_output,
const at::Tensor& attention_bias,
const at::Tensor& final_bias,
......@@ -1217,7 +1266,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
return hidden_state;
return residual;
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
......@@ -1246,7 +1295,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
......@@ -1258,7 +1308,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
return {query_cont, key_cont};
}
......
......@@ -22,23 +22,21 @@ __global__ void bias_add_transform_0213(float* output,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int head_ext)
int head_ext,
int max_out_tokens)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
// int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES);
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens);
int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec =
......@@ -50,7 +48,7 @@ __global__ void bias_add_transform_0213(float* output,
vals_vec += (d2 * d2_stride);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
unsigned seq_id = d1 + seq_offset;
......@@ -89,7 +87,8 @@ __global__ void bias_add_transform_0213(__half* output, // q
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int head_ext)
int head_ext,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
......@@ -104,7 +103,9 @@ __global__ void bias_add_transform_0213(__half* output, // q
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES);
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens);
int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens);
float4 vals_arr;
float4 output_arr;
......@@ -121,7 +122,7 @@ __global__ void bias_add_transform_0213(__half* output, // q
vals_vec += (d2 * d2_stride);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
unsigned seq_id = d1 + seq_offset;
......@@ -166,7 +167,8 @@ void launch_bias_add_transform_0213<float>(float* output,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream,
int trans_count)
int trans_count,
int max_out_tokens)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
......@@ -186,7 +188,8 @@ void launch_bias_add_transform_0213<float>(float* output,
rotary_dim >> 2,
rotate_half,
rotate_every_two,
head_ext);
head_ext,
max_out_tokens);
}
template <typename T>
void launch_bias_add_transform_0213(T* outputs,
......@@ -204,7 +207,8 @@ void launch_bias_add_transform_0213(T* outputs,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream,
int trans_count);
int trans_count,
int max_out_tokens);
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
__half* k_cache,
......@@ -221,12 +225,13 @@ void launch_bias_add_transform_0213<__half>(__half* output,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream,
int trans_count)
int trans_count,
int max_out_tokens)
{
hidden_dim >>= 3;
int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(1, seq_length, (trans_count * head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
bias_add_transform_0213<<<grid_dim, block_dim, 0, stream>>>(output,
k_cache,
v_cache,
......@@ -240,7 +245,8 @@ void launch_bias_add_transform_0213<__half>(__half* output,
rotary_dim >> 3,
rotate_half,
rotate_every_two,
head_ext);
head_ext,
max_out_tokens);
}
// Bias add
......
......@@ -12,6 +12,10 @@ Copyright 2022 The Microsoft DeepSpeed Team
#include "cublas_v2.h"
#include "cuda.h"
#define MEGABYTE (1024 * 1024)
#define GIGABYTE (1024 * 1024 * 1024)
#define MAX_OUT_TOKENS 8192
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
......@@ -43,7 +47,13 @@ inline int DS_GET_BLOCKS(const int N)
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
Context()
: _workspace(nullptr),
_seed(42),
_curr_offset(0),
_stream(0),
_free_memory_size(0),
_num_tokens(1)
{
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
......@@ -75,19 +85,51 @@ public:
return _ctx;
}
void GenWorkSpace(size_t size)
void GenWorkSpace(const unsigned& num_layers,
const size_t& batch_size,
const size_t& hidden_dim,
const unsigned& mp_size,
const bool& external_cache,
const size_t& elem_size,
const unsigned& rank)
{
size_t total_size;
if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }
size_t activation_size = 16 * hidden_dim * batch_size;
size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2;
_max_seq_len =
(((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) /
elem_size)) /
(activation_size + cache_size);
size_t workSpaceSize = (external_cache ? activation_size : (activation_size + cache_size)) *
_max_seq_len * elem_size;
_max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len);
if (rank == 0 && !_workspace)
printf(
"Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total "
"tokens (input + output) to %lu \n",
_free_memory_size,
total_size,
_max_seq_len);
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaMalloc(&_workspace, workSpaceSize);
} else if (_workSpaceSize < workSpaceSize) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
cudaMalloc(&_workspace, workSpaceSize);
}
if (!_workspace) { throw std::runtime_error("Workspace is null."); }
_workSpaceSize = size;
if (!_workspace) {
printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
workSpaceSize,
_free_memory_size,
total_size);
throw std::runtime_error("Workspace is null.");
}
_workSpaceSize = workSpaceSize;
}
inline size_t GetMaxTokenLenght() const { return _max_seq_len; }
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
......@@ -100,7 +142,7 @@ public:
return _token_length;
}
inline void reset_tokens(unsigned initial_tokens = 0)
inline void reset_tokens(unsigned initial_tokens = 1)
{
_num_tokens = initial_tokens;
} //_token_length = 0; }
......@@ -160,7 +202,11 @@ private:
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
size_t _free_memory_size;
size_t _max_seq_len;
cudaEvent_t _comp1_event;
cudaEvent_t _comp2_event;
......
......@@ -21,7 +21,6 @@ Copyright 2022 The Microsoft DeepSpeed Team
#include <cassert>
#include <iostream>
#define MAX_OUT_TOKES 128
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
......@@ -142,7 +141,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream);
cudaStream_t stream,
int max_out_tokens);
template <typename T>
void launch_moe_res_matmul(T* residual,
......@@ -178,4 +178,5 @@ void launch_bias_add_transform_0213(T* outputs,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream,
int trans_count);
int trans_count,
int max_out_tokens);
......@@ -380,6 +380,8 @@ class DeepSpeedSelfAttentionFunction(Function):
) else 0
sliced_alibi = alibi[offset:batch_heads + offset, :, :]
#
attn_key_value = score_context_func(
qkv_out,
((1 - input_mask).to(qkv_out.dype) *
......@@ -402,6 +404,7 @@ class DeepSpeedSelfAttentionFunction(Function):
def selfAttention_fp():
vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
inference_cuda_module.vector_matmul_fp32
if not config.pre_layer_norm:
linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \
inference_cuda_module.linear_layer_fp32
......@@ -413,18 +416,19 @@ class DeepSpeedSelfAttentionFunction(Function):
else:
qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \
inference_cuda_module.qkv_gemm_fp32
qkv_out = qkv_func(
input,
attn_qkvw,
attn_qkvw.scale,
(attn_qkvb if attn_qkvb is not None else norm_b),
norm_w,
norm_b,
config.epsilon,
(attn_qkvb is not None),
1 if config.bigscience_bloom else
DeepSpeedTransformerInference.layer_id,
config.q_int8)
qkv_out = qkv_func(input,
attn_qkvw,
attn_qkvw.scale,
(attn_qkvb if attn_qkvb is not None else norm_b),
norm_w,
norm_b,
config.epsilon,
(attn_qkvb is not None),
DeepSpeedTransformerInference.layer_id,
config.bigscience_bloom,
config.mp_size,
dist.get_rank() if dist.is_initialized() else 0,
config.q_int8)
context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask)
output = vector_matmul_func(context_layer,
attn_ow,
......@@ -495,20 +499,18 @@ class DeepSpeedSelfAttention(nn.Module):
self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
self.attn_qkvw = nn.Parameter(torch.empty(
self.config.hidden_size,
(self.config.hidden_size // self.config.mp_size) * 3,
dtype=data_type,
device=device),
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(
(self.config.hidden_size // self.config.mp_size) * 3,
dtype=data_type_fp,
device=device),
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.attn_ow = nn.Parameter(torch.empty(self.config.hidden_size //
self.config.mp_size,
out_size_per_partition = self.config.hidden_size // self.config.mp_size
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
......@@ -613,10 +615,11 @@ class DeepSpeedMLPFunction(Function):
config.pre_layer_norm,
False)
else:
intermediate, residual_add = mlp_gemm_func(input,
output, residual_add = mlp_gemm_func(input,
residual,
bias,
inter_w,
output_w,
inter_b,
attn_nw,
attn_nb,
......@@ -624,17 +627,14 @@ class DeepSpeedMLPFunction(Function):
config.pre_layer_norm,
config.mlp_after_attn,
inter_w.scale,
output_w.scale,
config.q_int8,
config.mlp_act_func_type)
output = vector_matmul_func(intermediate,
output_w,
False,
output_w.scale,
config.q_int8)
output = residual_add_func(
output, # hidden state
residual if config.pre_layer_norm else residual_add, # residual
input, # attention output
residual = residual if config.pre_layer_norm else residual_add
residual_add_func(
output, # hidden state
residual, # residual
input, # attention output
bias if bias is not None else output_b,
output_b,
config.mp_size, # model parallel size
......@@ -642,8 +642,8 @@ class DeepSpeedMLPFunction(Function):
bias is not None, # whether bias addition is fused
config.pre_layer_norm) # whether the layer norm is applied before attention
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
dist.all_reduce(output, group=mp_group)
return output
dist.all_reduce(residual, group=mp_group)
return residual
@staticmethod
def backward(ctx, grad_output):
......@@ -673,22 +673,20 @@ class DeepSpeedMLP(nn.Module):
dtype=data_type_fp,
device=device),
requires_grad=False)
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
self.config.intermediate_size //
self.config.mp_size,
intm_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.inter_b = nn.Parameter(torch.empty(self.config.intermediate_size //
self.config.mp_size,
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.output_w = nn.Parameter(torch.empty(
(self.config.intermediate_size // self.config.mp_size),
self.config.hidden_size,
dtype=data_type,
device=device),
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册