diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 65549cdcd71a1633983aa538b7de6ea15819c91b..ba9ecd72777bfa7de8426fd4a3b4f87b4c1214d3 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1083,8 +1083,8 @@ template std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& residual, at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& weight1, + at::Tensor& weight_interm, + at::Tensor& weight_out, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, @@ -1102,7 +1102,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight.size(0) : weight.size(1); + int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), {input.size(0), input.size(1), out_size}, options); @@ -1113,8 +1113,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, mlp_after_attn ? input : residual, residual, input_bias, - weight, - weight1, + weight_interm, + weight_out, bias, gamma, beta,