未验证 提交 afdc7287 编写于 作者: R Reza Yazdani 提交者: GitHub

Ds-inference Int8 support through ZeroQuant technology (#2217)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 088212a7
......@@ -108,3 +108,90 @@ template void launch_dequantize<__half>(__half*,
unsigned,
unsigned,
cudaStream_t);
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int hidden_dim,
unsigned merge_hidden,
int cnt)
{
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt)
{
unsigned bid = blockIdx.x * gridDim.y + blockIdx.y;
unsigned tid = threadIdx.x;
float local_scale = qscale[blockIdx.x];
const float* input_cast = reinterpret_cast<const float*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
input_cast += bid * merge_hidden;
output_cast += bid * merge_hidden;
for (int c = 0; c < cnt; c++) {
if (tid < merge_hidden) {
float q = input_cast[tid];
int8_t* q_int8 = (int8_t*)&q;
float2 q_f;
__half* q_h = (__half*)&q_f;
q_h[0] = __float2half(local_scale * (float)q_int8[0]);
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
// q_h[4] = __float2half(local_scale * (float)q_int8[4]);
// q_h[5] = __float2half(local_scale * (float)q_int8[5]);
// q_h[6] = __float2half(local_scale * (float)q_int8[6]);
// q_h[7] = __float2half(local_scale * (float)q_int8[7]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
}
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
cudaStream_t stream)
{
unsigned threads = 1024;
hidden_dim /= 4;
unsigned hid_cnt = threads / hidden_dim;
unsigned thd_cnt = (hidden_dim - 1) / threads + 1;
hid_cnt = hid_cnt > 0 ? hid_cnt : 1;
unsigned blocks = output_size / hid_cnt / groups;
dim3 block_dims(threads);
dim3 grid_dims(groups, blocks);
dequantize_kernel<<<grid_dims, block_dims, 0, stream>>>(
output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
cudaStream_t);
#include "custom_cuda_layers.h"
namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
......
......@@ -558,15 +558,55 @@ void ds_layernorm_internal(T* workspace,
Context::Instance().GetCurrentStream());
}
template <typename T>
void quantized_gemm(at::Tensor& output,
T* input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int bsz)
{
auto weight16 = at::empty({weight.size(0), weight.size(1)}, output.options());
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(0),
weight.size(1),
groups,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
weight.size(0),
bsz,
weight.size(1),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
(T*)input,
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
template <typename T>
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
bool add_bias,
bool q_int8)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
......@@ -574,48 +614,55 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
ds_layernorm_internal<T>(workspace, input, gamma, beta, epsilon);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
if (q_int8) {
quantized_gemm<T>(output, workspace, weight, q_scale, q_scale.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
workspace,
(T*)output.data_ptr(),
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
workspace,
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return torch::from_blob(workspace, input.sizes(), input.options());
}
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias,
unsigned num_layers)
unsigned num_layers,
bool q_int8)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1);
if (!workspace) {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
......@@ -628,9 +675,9 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
auto inp_norm =
qkv_unfused_cublas<T>(output, input, weight, bias, gamma, beta, epsilon, add_bias);
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
auto inp_norm = qkv_unfused_cublas<T>(
output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8);
return {output, inp_norm};
}
......@@ -654,20 +701,18 @@ void quantized_gemm(at::Tensor& output,
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(1),
weight.size(0),
weight.size(1),
groups,
merge_count,
Context::Instance().GetCurrentStream());
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
weight.size(0),
bsz,
input.size(2),
&alpha,
......@@ -796,7 +841,11 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
}
template <typename T>
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op)
at::Tensor ds_vector_matmul(at::Tensor& input,
at::Tensor& weight,
bool async_op,
at::Tensor& q_scale,
bool q_int8)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
......@@ -805,28 +854,33 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
auto output = at::empty({input_cont.size(0), input_cont.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output, (T*)input_cont.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
return output;
}
......@@ -862,6 +916,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn,
at::Tensor& q_scale,
bool q_int8,
ActivationFuncType act_func_type)
{
int bsz = input.size(0) * input.size(1);
......@@ -881,36 +937,40 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
mlp_after_attn,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
if (q_int8) {
quantized_gemm<T>(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
......@@ -929,6 +989,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn,
at::Tensor& q_scale,
bool q_int8,
int activation_type)
{
auto input_cont = input.contiguous();
......@@ -938,7 +1000,10 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace(),
{input_cont.size(0), input_cont.size(1), out_size},
options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
......@@ -953,6 +1018,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
epsilon,
preLayerNorm,
mlp_after_attn,
q_scale,
q_int8,
act_func_type);
return {output, res_add};
......@@ -984,20 +1051,6 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
auto inp_norm = at::empty_like(input_cont);
auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
......
......@@ -109,6 +109,14 @@ void launch_dequantize(T* output,
cudaStream_t stream);
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
cudaStream_t stream);
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
......
......@@ -242,7 +242,8 @@ def init_inference(model,
moe_type='standard',
args=None,
enable_cuda_graph=False,
save_mp_checkpoint_path=None):
save_mp_checkpoint_path=None,
base_dir=""):
"""Initialize the DeepSpeed InferenceEngine.
Arguments:
......@@ -278,7 +279,19 @@ def init_inference(model,
of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping
for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for
all the network except the MLP part that we use 8 extra grouping).
replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine
replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
return_tuple: Specify whether or not the transformer layers need to return a tuple or a Tensor. It is set to True by default (returning a tuple).
ep_size: The expert-parallelism size which is used for partitioning the experts across the GPUs in the expert-parallel group.
moe: Specify if the type of Transformer is MoE. It is set to False by default.
moe_experts: The global number of experts used in an MoE layer.
moe_type: Specify the type of MoE layer. We have two types of MoE layer: 'Standard' and 'Residual'. It is set to 'Standard' type by default.
args: All the arguments used for launching the inference api that can be useful at the inference-engine for injecting the optimizations.
enable_cuda_graph: use this flag for capturing the CUDA-Graph of the inference ops, so that it can run faster using the graph replay method,
this is set to False by default
save_mp_checkpoint_path: The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the
parallelism degree to help alleviate the model loading overhead. It does not save any new checkpoint if no path is passed.
base_dir: This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too.
Returns:
A deepspeed.InferenceEngine wrapped model.
......@@ -309,6 +322,7 @@ def init_inference(model,
moe_type,
args,
enable_cuda_graph,
save_mp_checkpoint_path)
save_mp_checkpoint_path,
base_dir)
return engine
......@@ -51,7 +51,8 @@ class InferenceEngine(Module):
moe_type='standard',
config=None,
enable_cuda_graph=False,
save_mp_checkpoint_path=None):
save_mp_checkpoint_path=None,
base_dir=""):
"""
Args:
model: torch.nn.Module
......@@ -100,6 +101,9 @@ class InferenceEngine(Module):
self.checkpoint_engine = TorchCheckpointEngine()
self._init_quantization_setting(quantization_setting)
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self.remove_mask_prepare_for_bloom()
if enable_cuda_graph:
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"
......@@ -135,7 +139,8 @@ class InferenceEngine(Module):
moe_type,
training_mp_size,
self.checkpoint if replace_with_kernel_inject else None,
save_mp_checkpoint_path=save_mp_checkpoint_path)
save_mp_checkpoint_path=save_mp_checkpoint_path,
base_dir=base_dir)
elif replace_method == 'auto':
self._apply_injection_policy(
return_tuple=return_tuple,
......@@ -145,7 +150,8 @@ class InferenceEngine(Module):
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None,
save_mp_checkpoint_path=save_mp_checkpoint_path)
save_mp_checkpoint_path=save_mp_checkpoint_path,
base_dir=base_dir)
device = torch.cuda.current_device()
self.module.to(device)
......@@ -165,6 +171,11 @@ class InferenceEngine(Module):
self.config = getattr(self.module, 'config', None) if config is None else config
self.generate = getattr(self.module, 'generate', None)
def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, '_prepare_attn_mask'):
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
def _create_model_parallel_group(self):
# Call the init process
if InferenceEngine.inference_mp_group is None:
......@@ -326,36 +337,37 @@ class InferenceEngine(Module):
moe_type='standard',
training_mp_size=1,
checkpoint_dir=None,
save_mp_checkpoint_path=False):
save_mp_checkpoint_path=False,
base_dir=""):
checkpoint = SDLoaderFactory.get_sd_loader_json(
checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
replace_transformer_layer(
client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dict=checkpoint,
save_mp_checkpoint_path=save_mp_checkpoint_path,
)
replace_transformer_layer(client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half)
or (self.dtype == torch.int8),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dict=checkpoint,
save_mp_checkpoint_path=save_mp_checkpoint_path,
base_dir=base_dir)
def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
......@@ -450,7 +462,7 @@ class InferenceEngine(Module):
return 'model'
def _convert_to_dtype(self):
if self.dtype is torch.int8 and self.quantization_scales is None:
if False: #self.dtype is torch.int8 and self.quantization_scales is None:
quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
model, self.quantization_scales = quantizer.model_quantize(self.module,
self.injection_dict,
......
......@@ -3,9 +3,15 @@ import deepspeed.ops.transformer as transformer_inference
from ..runtime.zero import GatheredParameters
from .layers import LinearLayer, Normalize, EmbeddingLayer
import torch
import gc
def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
def load_model_with_checkpoint(r_module,
sd,
mp_replace,
ckpt_type,
weight_quantizer=None,
rank=0):
error_msgs = []
def transpose(data):
......@@ -15,7 +21,7 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
return data.reshape(data.shape[-1], data.shape[-2])
def load(module, prefix):
args = (sd, prefix, {}, True, [], [], error_msgs)
args = (sd[0], prefix, {}, True, [], [], error_msgs)
if len(list(module.parameters())) > 0 and list(
module.parameters())[0].numel() == 0:
......@@ -25,81 +31,142 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
else:
if hasattr(module, 'weight'):
module.weight = mp_replace.copy(module.weight.data,
sd[prefix + 'weight'])
if prefix + 'bias' in sd.keys():
module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias'])
sd[0][prefix + 'weight'])
if prefix + 'bias' in sd[0].keys():
module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
args = None
gc.collect()
def load_transformer_layer(module, prefix):
if ckpt_type == "tp":
def load_parameters(module, prefix):
for n, p in module.named_parameters():
if len(n.split('.')) == 1:
src_shape = sd[prefix + n].shape
if prefix + n in sd[0] and len(n.split('.')) == 1:
if type(sd[0][prefix + n]) is list:
tmp_data, scale = sd[0][prefix + n]
tmp_data = tmp_data
scale = scale.to(torch.cuda.current_device())
else:
tmp_data = sd[0][prefix + n].to(torch.cuda.current_device())
scale = None
src_shape = tmp_data.shape
dst_shape = p.shape
inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[0] == dst_shape[0] and src_shape[
1] == dst_shape[1]:
p.data.copy_(sd[prefix + n])
if (src_shape[inner_dim] == dst_shape[0]
and src_shape[outer_dim] == dst_shape[1]):
if tmp_data.dtype != torch.int8:
p = weight_quantizer.quantize(
transpose(tmp_data) if weight_quantizer.
q_int8 else tmp_data)
else:
p = torch.nn.parameter.Parameter(tmp_data,
requires_grad=False)
p.scale = scale
setattr(module, n, p)
else:
if src_shape[0] != dst_shape[0]:
weight_split = torch.split(
sd[prefix + n],
dst_shape[0],
dim=0)[rank].to(
torch.cuda.current_device()).contiguous()
dim = inner_dim if src_shape[inner_dim] != dst_shape[
0] else outer_dim
dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
if src_shape[dim] > dst_shape[dim1]:
weight_partition = torch.split(
tmp_data,
dst_shape[dim1],
dim=dim)[rank].to(torch.cuda.current_device())
assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
'''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
scale = scale.view(
-1)[weight_quantizer.num_groups *
(rank + 1):].reshape(
weight_quantizer.num_groups,
-1).contiguous()
else:
weight_split = torch.split(
sd[prefix + n],
dst_shape[1],
dim=1)[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(weight_split.contiguous())
assert tmp_data.dtype != torch.int8, \
'''Merging of the checkpoints are not supported when using INT8 checkpoint! \
Please use a as many GPUs as TP-size for the checkpoint'''
all_data = [
sd[j][prefix +
n] if type(sd[j][prefix + n]) is list else
sd[j][prefix + n].to(torch.cuda.current_device())
for j in range(len(sd))
]
weight_partition = torch.cat([
ad[0].to(torch.cuda.current_device())
if type(ad) is list else ad for ad in all_data
],
dim=dim)
if tmp_data.dtype == torch.int8:
scale = torch.cat([
ad[1].to(torch.cuda.current_device())
for ad in all_data
],
dim=dim)
if tmp_data.dtype != torch.int8:
weight_partition = weight_quantizer.quantize(
transpose(weight_partition), \
parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(weight_partition)
else:
weight_partition = torch.nn.parameter.Parameter(
weight_partition,
requires_grad=False)
weight_partition.scale = scale
setattr(module, n, weight_partition)
else:
if src_shape[0] == dst_shape[0]:
p.data.copy_(sd[prefix + n])
p.data.copy_(tmp_data)
else:
bias_split = torch.split(
sd[prefix + n],
dst_shape[-1])[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(bias_split)
if src_shape[0] > dst_shape[0]:
bias_split = torch.split(
tmp_data,
dst_shape[-1])[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(bias_split)
else:
p.data.copy_(
torch.cat(
[sd[j][prefix + n] for j in range(len(sd))],
dim=0).to(torch.cuda.current_device()).
contiguous())
load_parameters(module, prefix)
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(
module.attention.attn_qkvw.data,
transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight']))
module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight'])))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(
module.attention.attn_ow.data,
transpose(sd[prefix + 'self_attention.dense.' + 'weight']))
sd[0][prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight'])))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(
module.mlp.inter_w.data,
transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight']))
sd[0][prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight'])))
module.mlp.inter_b = mp_replace.copy(
module.mlp.inter_b.data,
sd[prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(
module.mlp.output_w.data,
transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight']))
sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(module.mlp.output_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight'])))
module.mlp.output_b = mp_replace.copy(
module.mlp.output_b.data,
sd[prefix + 'mlp.dense_4h_to_h.' + 'bias'])
sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias'])
layer_policies = {
nn.Linear: load,
......@@ -117,7 +184,7 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
for name, child in module.named_children():
if child.__class__ in layer_policies:
checking_key = prefix + name + '.'
if not any(checking_key in item for item in sd.keys()):
if not any(checking_key in item for item in sd[0].keys()):
if hasattr(child, 'weight') and \
(hasattr(child.weight, 'ds_id') and \
child.weight.ds_id in all_ds_ids):
......@@ -168,6 +235,7 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
embedding_weight = p
assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist"
r_module.lm_head.weight = embedding_weight
del sd
for sd_ in sd:
del sd_
sd = None
gc.collect()
......@@ -5,7 +5,7 @@ import deepspeed
import deepspeed.ops.transformer as transformer_inference
from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy
from .replace_policy import replace_policies
from ..runtime.weight_quantizer import WeightQuantization
#from ..runtime.weight_quantizer import WeightQuantization
from deepspeed import comm as dist
from torch import nn
......@@ -115,8 +115,10 @@ class ReplaceWithTensorSlicing:
dst_shape[-1])[self.gpu_index].to(
torch.cuda.current_device()).contiguous()
dst.data.copy_(bias_split)
return torch.nn.parameter.Parameter(dst, requires_grad=False)
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
def get_transformer_name(replaced_module):
......@@ -134,6 +136,57 @@ def get_transformer_name(replaced_module):
return transformer_name
class GroupQuantizer:
def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8):
self.num_groups = num_groups
self.group_size = group_size
self.num_bits = num_bits
self.q_int8 = q_int8
def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
if not self.q_int8 or not qkv:
inputs = torch.nn.Parameter(inputs, requires_grad=False)
inputs.scale = torch.empty(1)
return inputs
q_range = 2**self.num_bits
inputs = inputs.to(torch.cuda.current_device())
input_flat = inputs.reshape(self.num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
out = torch.nn.Parameter(inputs_q, requires_grad=False)
#print(inputs.shape)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [
inputs_split[i].reshape(self.num_groups,
-1).contiguous() for i in range(2)
]
input_min = [
torch.min(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
input_max = [
torch.max(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
scale1 = [
(torch.max(input_min[i].abs(),
input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
for i in range(2)
]
out.scale = torch.cat([scale.squeeze().unsqueeze(0),
scale1[0],
scale1[1]],
dim=0).reshape(self.num_groups,
-1).contiguous()
return out
def replace_transformer_layer(orig_layer_impl,
model,
policy=None,
......@@ -161,7 +214,8 @@ def replace_transformer_layer(orig_layer_impl,
moe_experts=1,
moe_type='standard',
checkpoint_dict=None,
save_mp_checkpoint_path=None):
save_mp_checkpoint_path=None,
base_dir=""):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
......@@ -225,7 +279,7 @@ def replace_transformer_layer(orig_layer_impl,
_res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)
attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
if quantize:
if False:
if policy_cls is not HFBertLayerPolicy:
qkvw = qkvw.to(torch.int8)
dense_w = dense_w.to(torch.int8)
......@@ -257,6 +311,7 @@ def replace_transformer_layer(orig_layer_impl,
#expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)
quantizer = GroupQuantizer(q_int8=quantize)
if inference:
if moe:
ep_world_size = dist.get_world_size()
......@@ -329,21 +384,21 @@ def replace_transformer_layer(orig_layer_impl,
new_module = transformer_inference.DeepSpeedTransformerInference(
transformer_config,
mp_group=mp_group,
quantize_scales=quantization_scales[layer_id],
#quantize_scales=quantization_scales[layer_id],
quantize_groups=quantize_groups,
merge_count=merge_count,
mlp_extra_grouping=mlp_extra_grouping,
qkv_merging=(policy_cls is HFBertLayerPolicy))
if quantize and qkvw.dtype != torch.int8:
quantize_bits = 8
quantizer = WeightQuantization()
if policy_cls is HFBertLayerPolicy:
data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
else:
data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
qkvw.data.copy_(data_quantized)
qkvw.data = qkvw.data.to(torch.int8)
#if quantize and qkvw.dtype != torch.int8:
# quantize_bits = 8
# quantizer = WeightQuantization()
# if policy_cls is HFBertLayerPolicy:
# data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
# else:
# data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
# qkvw.data.copy_(data_quantized)
# qkvw.data = qkvw.data.to(torch.int8)
else:
if moe:
......@@ -478,18 +533,17 @@ def replace_transformer_layer(orig_layer_impl,
attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
else:
if bigscience_bloom:
attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw)
attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb)
else:
attn_block.attn_qkvw = mp_replace.qkv_copy(
attn_block.attn_qkvw,
qkvw)
attn_block.attn_qkvb = mp_replace.qkv_copy(
attn_block.attn_qkvb,
qkvb)
attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
attn_block.attn_qkvw = quantizer.quantize(
mp_replace.copy(attn_block.attn_qkvw, qkvw) if bigscience_bloom else \
mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw))
attn_block.attn_qkvb = \
mp_replace.copy(attn_block.attn_qkvb, qkvb) if bigscience_bloom else \
mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb)
attn_block.attn_ow = quantizer.quantize(
mp_replace.copy(attn_block.attn_ow,
dense_w))
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
if moe:
......@@ -545,9 +599,13 @@ def replace_transformer_layer(orig_layer_impl,
mpl_block.output_b,
_4hh_b)
else:
mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w)
mpl_block.inter_w = quantizer.quantize(
mp_replace.copy(mpl_block.inter_w,
_h4h_w))
mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b)
mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w)
mpl_block.output_w = quantizer.quantize(
mp_replace.copy(mpl_block.output_w,
_4hh_w))
mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b)
if attn_nw is None:
......@@ -782,50 +840,92 @@ def replace_transformer_layer(orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=policy)
quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
if checkpoint_dict is not None:
start_time = time.time()
checkpoint = checkpoint_dict['checkpoints']
ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
ckpt_type = checkpoint_dict.get('parallelization', 'pp')
ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size)
base_dir = checkpoint_dict.get('base_dir', '')
ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
base_dir1 = checkpoint_dict.get('base_dir', base_dir)
if ckpt_type == 'pp':
if ckpt_type == 'pp' and type(checkpoint) is list:
pbar = tqdm.tqdm(total=len(checkpoint),
desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
pbar.update(1)
sd = torch.load(checkpoint[i], map_location='cpu')
load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type)
sd = [
torch.load(os.path.join(base_dir1,
checkpoint[i]),
map_location='cpu')
]
load_model_with_checkpoint(
replaced_module,
sd,
mp_replace,
ckpt_type,
quantizer,
)
else:
num_checkpoints = len(checkpoint) // ckpt_mp_size
assert world_size >= ckpt_mp_size,\
"Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!"
checkpoint_stride = world_size // ckpt_mp_size
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
pbar = tqdm.tqdm(total=num_checkpoints,
desc=f"Loading {num_checkpoints} checkpoint shards")
import gc
num_checkpoints = len(ckpt_list) // ckpt_mp_size
tp_split_size = (world_size / ckpt_mp_size)
sd_offset = int(rank / tp_split_size)
sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
pbar = tqdm.tqdm(total=num_checkpoints,
desc=f"Loading {num_checkpoints} checkpoint shards")
for i in range(num_checkpoints):
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
pbar.update(1)
ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride)
ckpt_file = os.path.join(
base_dir,
checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index]
sd = torch.load(ckpt_file, map_location='cpu')
pbar.update(1)
ckpt_index = i * ckpt_mp_size + sd_offset
ckpt_files = [
os.path.join(base_dir1,
ckpt_list[ckpt_index +
j]) if base_dir1 else ckpt_list[ckpt_index +
j]
for j in range(sd_count)
]
sds = [
torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files
]
load_model_with_checkpoint(replaced_module,
sd,
sds,
mp_replace,
ckpt_type,
rank % (world_size // ckpt_mp_size))
quantizer,
int(rank % tp_split_size))
sds = [None for _ in sds]
gc.collect()
if "non_tp" in checkpoint:
pbar = tqdm.tqdm(
total=len(checkpoint["non_tp"]),
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
for i in range(len(checkpoint["non_tp"])):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size))
sds = [None for _ in sds]
gc.collect()
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
if save_mp_checkpoint_path is not None:
from collections import OrderedDict
import json
num_partitions = 8
if checkpoint_dict is None:
ckpt_name = "ds_model"
......@@ -840,8 +940,8 @@ def replace_transformer_layer(orig_layer_impl,
if dist.is_initialized():
dist.barrier()
transformer_name = get_transformer_name(replaced_module)
non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt'
ckpt_files = [non_tp_ckpt_name] * world_size
non_tp_ckpt_name = f'non-tp.pt'
ckpt_files = [non_tp_ckpt_name]
os.makedirs(save_mp_checkpoint_path, exist_ok=True)
if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints")
......@@ -853,25 +953,47 @@ def replace_transformer_layer(orig_layer_impl,
if transformer_name not in k
}),
f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}')
ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)]
config = json.dumps({
'type': ckpt_name,
'base_dir': f'{save_mp_checkpoint_path}',
'checkpoints': ckpt_files,
'version': 1.0,
'parallelization': 'tp',
'mp_size': world_size
'type':
ckpt_name,
'base_dir':
f'{save_mp_checkpoint_path}',
'checkpoints': {
"non_tp":
ckpt_files,
"tp": [
f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions)
for r in range(world_size)
]
},
'version':
1.0,
'parallelization':
'tp',
'tp_size':
world_size,
'dtype':
'int8' if quantize else ('float16' if fp16 else 'float32')
})
with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json",
"w") as cfg:
with open(f"{save_mp_checkpoint_path}/ds-inference_config.json", "w") as cfg:
cfg.write(config)
torch.save(
OrderedDict({
k: v
for k,
v in dict(replaced_module.state_dict()).items() if transformer_name in k
}),
f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt')
rep_sd = replaced_module.state_dict()
for n, p in replaced_module.named_parameters():
if hasattr(p, 'scale'):
rep_sd[n] = [p, p.scale]
keys = list(rep_sd.keys())
partition_size = (len(keys) // num_partitions + 1)
for m in range(num_partitions):
torch.save(
OrderedDict({
k: [rep_sd[k],
rep_sd[k].scale] if hasattr(rep_sd[k],
'scale') else rep_sd[k]
for k in keys[m * partition_size:(m + 1) * partition_size]
if transformer_name in k
}),
f'{save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
return replaced_module
......
......@@ -206,16 +206,6 @@ class DeepSpeedSelfAttentionFunction(Function):
value_layer) = split_tensor_along_last_dim(mixed_x_layer,
3)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer),
dim=1)
presents = (key_layer, value_layer)
# [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0),
query_layer.size(2),
......@@ -223,24 +213,37 @@ class DeepSpeedSelfAttentionFunction(Function):
key_layer.size(1))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1,
0).reshape(
output_size[2],
2).reshape(
output_size[0] * output_size[1],
output_size[2],
-1)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1,
0).reshape(output_size[3],
output_size[0] * output_size[1],
-1)
2).reshape(output_size[0] * output_size[1],
output_size[3],
-1).transpose(-1,
-2)
value_layer = value_layer.transpose(1,
2).reshape(
output_size[0] * output_size[1],
output_size[3],
-1)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer),
dim=-2)
presents = (key_layer, value_layer)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result = torch.matmul(query_layer.transpose(1,
0),
key_layer.transpose(1,
0).transpose(1,
2))
matmul_result = torch.matmul(query_layer, key_layer)
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)
attention_scores = matmul_result.view(output_size[0],
output_size[1],
output_size[2],
-1)
offset = dist.get_rank(
) * num_attention_heads_per_partition if dist.is_initialized() else 0
......@@ -261,12 +264,7 @@ class DeepSpeedSelfAttentionFunction(Function):
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(
attention_probs_reshaped,
value_layer.transpose(1,
2).reshape(-1,
value_layer.size(1),
value_layer.size(3)))
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(
......@@ -418,15 +416,21 @@ class DeepSpeedSelfAttentionFunction(Function):
qkv_out = qkv_func(
input,
attn_qkvw,
attn_qkvw.scale,
(attn_qkvb if attn_qkvb is not None else norm_b),
norm_w,
norm_b,
config.epsilon,
(attn_qkvb is not None),
1 if config.bigscience_bloom else
DeepSpeedTransformerInference.layer_id)
DeepSpeedTransformerInference.layer_id,
config.q_int8)
context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask)
output = vector_matmul_func(context_layer, attn_ow, False)
output = vector_matmul_func(context_layer,
attn_ow,
False,
attn_ow.scale,
config.q_int8)
return output, key_layer, value_layer, context_layer, qkv_out[-1]
......@@ -458,7 +462,7 @@ class DeepSpeedSelfAttentionFunction(Function):
(merge_count))
return output, key_layer, value_layer, context_layer
if config.q_int8:
if False: #config.q_int8:
output, key_layer, value_layer, context_layer = selfAttention_int8()
else:
output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
......@@ -486,30 +490,34 @@ class DeepSpeedSelfAttention(nn.Module):
qkv_merging=False):
super(DeepSpeedSelfAttention, self).__init__()
self.config = config
data_type = torch.half if config.fp16 else torch.float
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float
self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
self.attn_qkvw = nn.Parameter(
torch.empty(self.config.hidden_size,
(self.config.hidden_size // self.config.mp_size) * 3,
dtype=data_type,
device=device))
self.attn_qkvb = nn.Parameter(
torch.empty((self.config.hidden_size // self.config.mp_size) * 3,
dtype=data_type,
device=device))
self.attn_ow = nn.Parameter(
torch.empty(self.config.hidden_size // self.config.mp_size,
self.config.hidden_size,
dtype=data_type,
device=device))
self.attn_ob = nn.Parameter(
torch.empty(self.config.hidden_size,
dtype=data_type,
device=device))
self.attn_qkvw = nn.Parameter(torch.empty(
self.config.hidden_size,
(self.config.hidden_size // self.config.mp_size) * 3,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(
(self.config.hidden_size // self.config.mp_size) * 3,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.attn_ow = nn.Parameter(torch.empty(self.config.hidden_size //
self.config.mp_size,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
......@@ -595,36 +603,16 @@ class DeepSpeedMLPFunction(Function):
bias_residual_func,
activation_func_type=ActivationFuncType.GELU):
if config.q_int8:
(intermediate,
residual_add) = inference_cuda_module.mlp_gemm_int8(
input,
residual,
bias,
inter_w,
inter_b,
attn_nw,
attn_nb,
config.epsilon,
q_scales[2],
(q_groups * (2**merge_count)),
config.pre_layer_norm)
output = inference_cuda_module.vector_matmul_int8(intermediate,
output_w,
q_scales[3],
q_groups,
(merge_count))
if attn_nw is None:
output = fused_gemm_gelu(residual_norm,
inter_w,
inter_b,
output_w,
config.epsilon,
config.pre_layer_norm,
False)
else:
if attn_nw is None:
output = fused_gemm_gelu(residual_norm,
inter_w,
inter_b,
output_w,
config.epsilon,
config.pre_layer_norm,
False)
else:
intermediate, residual_add = mlp_gemm_func(input,
intermediate, residual_add = mlp_gemm_func(input,
residual,
bias,
inter_w,
......@@ -634,9 +622,14 @@ class DeepSpeedMLPFunction(Function):
config.epsilon,
config.pre_layer_norm,
config.mlp_after_attn,
inter_w.scale,
config.q_int8,
config.mlp_act_func_type)
output = vector_matmul_func(intermediate, output_w, False)
output = vector_matmul_func(intermediate,
output_w,
False,
output_w.scale,
config.q_int8)
inference_cuda_module.residual_add(
output,
residual if config.pre_layer_norm else residual_add,
......@@ -668,34 +661,38 @@ class DeepSpeedMLP(nn.Module):
super(DeepSpeedMLP, self).__init__()
self.config = config
data_type = torch.half if config.fp16 else torch.float
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
self.attn_nw = nn.Parameter(
torch.empty(self.config.hidden_size,
dtype=data_type,
device=device))
self.attn_nb = nn.Parameter(
torch.empty(self.config.hidden_size,
dtype=data_type,
device=device))
self.inter_w = nn.Parameter(
torch.empty(self.config.hidden_size,
self.config.intermediate_size // self.config.mp_size,
dtype=data_type,
device=device))
self.inter_b = nn.Parameter(
torch.empty(self.config.intermediate_size // self.config.mp_size,
dtype=data_type,
device=device))
self.output_w = nn.Parameter(
torch.empty((self.config.intermediate_size // self.config.mp_size),
self.config.hidden_size,
dtype=data_type,
device=device))
self.output_b = nn.Parameter(
torch.empty(self.config.hidden_size,
dtype=data_type,
device=device))
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
self.config.intermediate_size //
self.config.mp_size,
dtype=data_type,
device=device),
requires_grad=False)
self.inter_b = nn.Parameter(torch.empty(self.config.intermediate_size //
self.config.mp_size,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.output_w = nn.Parameter(torch.empty(
(self.config.intermediate_size // self.config.mp_size),
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
device=device),
requires_grad=False)
# used for quantization
self.q_scales = q_scales
......@@ -790,14 +787,14 @@ class DeepSpeedTransformerInference(nn.Module):
mlp_extra_grouping)
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(
torch.empty(self.config.hidden_size,
dtype=data_type,
device=device))
self.norm_b = nn.Parameter(
torch.empty(self.config.hidden_size,
dtype=data_type,
device=device))
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.layer_past = None
def forward(
......@@ -826,7 +823,6 @@ class DeepSpeedTransformerInference(nn.Module):
# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
layer_past = layer_past if layer_past is not None else self.layer_past
head_mask = layer_head_mask if layer_head_mask is not None else head_mask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册