未验证 提交 9cfcf743 编写于 作者: R Reza Yazdani 提交者: GitHub

Add correct memory-allocation at DeepSpeed-Attention (#2474)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
Co-authored-by: NConnor Holmes <connorholmes@microsoft.com>
上级 a47c3e03
......@@ -849,20 +849,6 @@ at::Tensor ds_linear_layer(at::Tensor& input,
int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
// Reallocate memory if we received a new prompt
if (!workspace) {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
allocate_workspace<T>(input.size(2),
input.size(0),
input.size(1),
num_layers,
num_heads,
1,
external_cache,
0);
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
float alpha = (T)1.0;
......
......@@ -210,6 +210,7 @@ def generic_injection(module, fp16=False):
heads=heads,
fp16=fp16,
triangular_masking=False,
max_out_tokens=4096,
)
attn_module = transformer_inference.DeepSpeedAttention(config)
......
......@@ -233,6 +233,9 @@ class DeepSpeedAttention(nn.Module):
inference_cuda_module.linear_layer_fp32
self.cuda_graph_created = False
self.enable_cuda_graph = False
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
inference_cuda_module.allocate_workspace_fp16
self.iter = 0
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
......@@ -275,6 +278,18 @@ class DeepSpeedAttention(nn.Module):
return outputs
def _forward(self, input, context=None, input_mask=None):
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self.iter == 0:
self.iter += 1
self.allocate_workspace(self.config.hidden_size,
input.size()[0],
input.size()[1],
DeepSpeedAttention.layer_id,
self.config.heads,
self.config.mp_size,
self.config.bigscience_bloom,
0,
self.config.max_out_tokens)
output = DeepSpeedAttentionFunction.apply(input,
context,
input_mask,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册