未验证 提交 0a2ae2ef 编写于 作者: A Arash Bakhtiari 提交者: GitHub

Fix the MLP output tensor's shape (#2380)

上级 ff427438
...@@ -1083,8 +1083,8 @@ template <typename T> ...@@ -1083,8 +1083,8 @@ template <typename T>
std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual, at::Tensor& residual,
at::Tensor& input_bias, at::Tensor& input_bias,
at::Tensor& weight, at::Tensor& weight_interm,
at::Tensor& weight1, at::Tensor& weight_out,
at::Tensor& bias, at::Tensor& bias,
at::Tensor& gamma, at::Tensor& gamma,
at::Tensor& beta, at::Tensor& beta,
...@@ -1102,7 +1102,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, ...@@ -1102,7 +1102,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1); int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size}, {input.size(0), input.size(1), out_size},
options); options);
...@@ -1113,8 +1113,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, ...@@ -1113,8 +1113,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
mlp_after_attn ? input : residual, mlp_after_attn ? input : residual,
residual, residual,
input_bias, input_bias,
weight, weight_interm,
weight1, weight_out,
bias, bias,
gamma, gamma,
beta, beta,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册