From cd911f9ab2213edb0c8781bd5fd604c37c020dfb Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 13 Jun 2023 16:51:30 -0700 Subject: [PATCH] Fix output transpose dimension bugs (#3747) --- csrc/transformer/inference/csrc/pt_binding.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index cd3fff9e..97db77bf 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1109,8 +1109,9 @@ at::Tensor ds_linear_layer(at::Tensor& input, int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); + int out_size = transposed_mode ? weight.size(0) : weight.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1313,7 +1314,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input, .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight.size(0) : weight.size(1); + int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1); int bsz = input.size(0) * input.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); -- GitLab