diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index cd3fff9e81d1bae8ea07008d376465cc59962e57..97db77bff8a24f6e4e4b9a728d153866f9a79bcd 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1109,8 +1109,9 @@ at::Tensor ds_linear_layer(at::Tensor& input, int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); + int out_size = transposed_mode ? weight.size(0) : weight.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1313,7 +1314,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input, .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight.size(0) : weight.size(1); + int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1); int bsz = input.size(0) * input.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();