diff --git a/labml_nn/neox/model.py b/labml_nn/neox/model.py index 66b68b28eb8abbbd5a56aea2aeabf4a66d50aec5..1a8fbede768078433a4120ff162b91bd5e03d169 100644 --- a/labml_nn/neox/model.py +++ b/labml_nn/neox/model.py @@ -521,18 +521,18 @@ class LayerGenerator: # with monit.section('Covert to int8'): - layer.attention.output = make_llm_int8_linear(layer.attention.output, - device=self.device, - threshold=self.llm_int8_threshold) - layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin, - device=self.device, - threshold=self.llm_int8_threshold) - layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4, - device=self.device, - threshold=self.llm_int8_threshold) - layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h, - device=self.device, - threshold=self.llm_int8_threshold) + layer.attention._modules['output'] = make_llm_int8_linear(layer.attention.output, + device=self.device, + threshold=self.llm_int8_threshold) + layer.attention._modules['qkv_lin'] = make_llm_int8_linear(layer.attention.qkv_lin, + device=self.device, + threshold=self.llm_int8_threshold) + layer.ffn._modules['dense_h_h4'] = make_llm_int8_linear(layer.ffn.dense_h_h4, + device=self.device, + threshold=self.llm_int8_threshold) + layer.ffn._modules['dense_h4_h'] = make_llm_int8_linear(layer.ffn.dense_h4_h, + device=self.device, + threshold=self.llm_int8_threshold) # return layer