From 145c3a75916cc39de9048e6c5b415fac6d634896 Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Thu, 20 Apr 2023 17:38:02 -0700 Subject: [PATCH] Fix missing scale attributes for GPTJ (#3256) Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt --- csrc/transformer/inference/csrc/pt_binding.cpp | 4 ++-- deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 3de59e11..88d4201b 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -462,9 +462,9 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); size_t buf_size = bsz * seq_len * hidden_dim; - auto output = torch::from_blob(workspace + 3 * buf_size, {bsz, seq_len, hidden_dim}, options); + auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); - auto query_cont = workspace + 4 * buf_size; + auto query_cont = workspace + 5 * buf_size; size_t offset = 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) + layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim; diff --git a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py index 89ef0b51..6fcba9a5 100644 --- a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py @@ -23,7 +23,9 @@ class GELUGemmOp(BaseOp): bias: torch.Tensor, weight_out: torch.Tensor, async_op: bool = False): - output = self.fused_gemm_gelu(input, weight, weight.scale, bias, weight_out, weight_out.scale, + output = self.fused_gemm_gelu(input, weight, weight.scale if hasattr(weight, "scale") else torch.empty(1), + bias, weight_out, + weight_out.scale if hasattr(weight_out, "scale") else torch.empty(1), self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op, self.config.transposed_mode) return output -- GitLab