From 63022e66de304789492cc2e7b0825104334c4143 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 19 Aug 2022 09:58:29 +0530 Subject: [PATCH] fix --- labml_nn/neox/model.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/labml_nn/neox/model.py b/labml_nn/neox/model.py index 66b68b28..1a8fbede 100644 --- a/labml_nn/neox/model.py +++ b/labml_nn/neox/model.py @@ -521,18 +521,18 @@ class LayerGenerator: # with monit.section('Covert to int8'): - layer.attention.output = make_llm_int8_linear(layer.attention.output, - device=self.device, - threshold=self.llm_int8_threshold) - layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin, - device=self.device, - threshold=self.llm_int8_threshold) - layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4, - device=self.device, - threshold=self.llm_int8_threshold) - layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h, - device=self.device, - threshold=self.llm_int8_threshold) + layer.attention._modules['output'] = make_llm_int8_linear(layer.attention.output, + device=self.device, + threshold=self.llm_int8_threshold) + layer.attention._modules['qkv_lin'] = make_llm_int8_linear(layer.attention.qkv_lin, + device=self.device, + threshold=self.llm_int8_threshold) + layer.ffn._modules['dense_h_h4'] = make_llm_int8_linear(layer.ffn.dense_h_h4, + device=self.device, + threshold=self.llm_int8_threshold) + layer.ffn._modules['dense_h4_h'] = make_llm_int8_linear(layer.ffn.dense_h4_h, + device=self.device, + threshold=self.llm_int8_threshold) # return layer -- GitLab