未验证 提交 cd911f9a 编写于 作者: L Logan Adams 提交者: GitHub

Fix output transpose dimension bugs (#3747)

上级 45466afa
...@@ -1109,8 +1109,9 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -1109,8 +1109,9 @@ at::Tensor ds_linear_layer(at::Tensor& input,
int head_size = input_cont.size(2) / num_heads; int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
int out_size = transposed_mode ? weight.size(0) : weight.size(1);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
...@@ -1313,7 +1314,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input, ...@@ -1313,7 +1314,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
.layout(at::kStrided) .layout(at::kStrided)
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1); int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册