未验证 提交 0a2ae2ef 编写于 作者: A Arash Bakhtiari 提交者: GitHub

Fix the MLP output tensor's shape (#2380)

上级 ff427438
......@@ -1083,8 +1083,8 @@ template <typename T>
std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& weight1,
at::Tensor& weight_interm,
at::Tensor& weight_out,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
......@@ -1102,7 +1102,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
......@@ -1113,8 +1113,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
weight1,
weight_interm,
weight_out,
bias,
gamma,
beta,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册