From 941f49eb14e4c8182ef9ec424f19c5e8d1c7d7ee Mon Sep 17 00:00:00 2001 From: chenlong Date: Wed, 26 Jul 2023 09:42:57 +0800 Subject: [PATCH] add codegeex2 --- README.md | 12 +++-- llm_set/params/bbt.json | 2 +- llm_set/params/codegeex2.json | 7 +++ src/generate_batch.sh | 5 ++ src/generate_humaneval_x.py | 5 +- src/inference/codegeex2_inference.py | 72 ++++++++++++++++++++++++++++ src/utils.py | 1 - 7 files changed, 96 insertions(+), 8 deletions(-) create mode 100644 llm_set/params/codegeex2.json create mode 100644 src/generate_batch.sh create mode 100644 src/inference/codegeex2_inference.py diff --git a/README.md b/README.md index 7c158de..4fe910a 100644 --- a/README.md +++ b/README.md @@ -89,11 +89,13 @@ python evaluate_humaneval_x.py --language_type python 受限于模型推理速度,目前只测试了pass@1指标。 -| | python | java | cpp | js | go | -|-------------|--------|--------|--------|--------|---------| -| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% | -| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% | -| chatglm2-7B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% | +| | python | java | cpp | js | go | +|--------------|--------|--------|--------|--------|---------| +| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% | +| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% | +| bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% | +| chatglm2-7B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% | +| codegeex2-7B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% | diff --git a/llm_set/params/bbt.json b/llm_set/params/bbt.json index ab458f3..8a23338 100644 --- a/llm_set/params/bbt.json +++ b/llm_set/params/bbt.json @@ -1,5 +1,5 @@ { - "url": "https://codebbt.ssymmetry.com/test3/code_bbt", + "url": "https://codebbt.ssymmetry.com/chat/code_bbt", "user_id": "csdnTest_0001", "stream": 0, "timeout": 200, diff --git a/llm_set/params/codegeex2.json b/llm_set/params/codegeex2.json new file mode 100644 index 0000000..c6b103e --- /dev/null +++ b/llm_set/params/codegeex2.json @@ -0,0 +1,7 @@ +{ + "model_path": "../llm_set/models/codegeex2-6b", + "device": "cuda:0", + "quantize": false, + "max_length": 1024, + "min_length": 10 +} \ No newline at end of file diff --git a/src/generate_batch.sh b/src/generate_batch.sh new file mode 100644 index 0000000..e5491e8 --- /dev/null +++ b/src/generate_batch.sh @@ -0,0 +1,5 @@ +python generate_humaneval_x.py --model_name bbt --language_type python +python generate_humaneval_x.py --model_name bbt --language_type java +python generate_humaneval_x.py --model_name bbt --language_type cpp +python generate_humaneval_x.py --model_name bbt --language_type js +python generate_humaneval_x.py --model_name bbt --language_type go \ No newline at end of file diff --git a/src/generate_humaneval_x.py b/src/generate_humaneval_x.py index cdffae7..e18598f 100644 --- a/src/generate_humaneval_x.py +++ b/src/generate_humaneval_x.py @@ -30,7 +30,7 @@ if __name__ == "__main__": "--model_name", type=str, default="chatgpt", - help='supported models in [chatgpt, chatglm2, bbt].' + help='supported models in [chatgpt, chatglm2, bbt, codegeex2].' ) parser.add_argument( "--output_prefix", @@ -55,6 +55,9 @@ if __name__ == "__main__": elif args.model_name == "bbt": from inference.bbt_inference import BBTInference model_inference = BBTInference() + elif args.model_name == "codegeex2": + from inference.codegeex2_inference import CodeGeex2Inference + model_inference = CodeGeex2Inference() else: print(f"{args.model_name} not supported.") diff --git a/src/inference/codegeex2_inference.py b/src/inference/codegeex2_inference.py new file mode 100644 index 0000000..1504bc7 --- /dev/null +++ b/src/inference/codegeex2_inference.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/7/25 16:19 +# @Author : clong +# @File : codegeex2_inference.py.py + + + +import os +import json +import torch +import logging +from transformers import AutoModel, AutoTokenizer +from .inference import Inference + +logger = logging.getLogger(__name__) + + +class CodeGeex2Inference(Inference): + + def __init__(self): + super(CodeGeex2Inference, self).__init__() + self.params_url = "../llm_set/params/codegeex2.json" + self.paras_dict = self.get_params() + self.paras_dict.update(self.paras_base_dict) + + self.temperature = self.paras_dict.get("temperature") + self.quantize = self.paras_dict.get("quantize") + self.device = self.paras_dict.get("device") + self.model_path = self.paras_dict.get("model_path") + if self.device == "cpu": + self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True, + ignore_mismatched_sizes=True).float() + self.DEV = torch.device('cpu') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = self.device.split(":")[1] + self.DEV = torch.device(self.device) + if self.quantize: + self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).quantize(8).cuda() + else: + self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).cuda() + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + self.model = self.model.eval() + self.max_length = self.paras_dict.get("max_length") + self.min_length = self.paras_dict.get("min_length") + self.top_p = self.paras_dict.get("top_p") + self.top_k = self.paras_dict.get("top_k") + + def get_params(self): + if not os.path.exists(self.params_url): + logger.error(f"params_url:{self.params_url} is not exists.") + content = open(self.params_url).read() + return json.loads(content) + + def inference(self, message): + input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV) + + sentence = None + with torch.no_grad(): + generation_output = self.model.generate( + input_ids, + do_sample=True, + min_length=self.min_length, + max_length=self.max_length, + top_p=self.top_p, + top_k=self.top_k, + temperature=self.temperature + ) + output = self.tokenizer.decode([el.item() for el in generation_output[0]]) + sentence = output.strip() + return sentence diff --git a/src/utils.py b/src/utils.py index 1847600..2043e70 100644 --- a/src/utils.py +++ b/src/utils.py @@ -281,7 +281,6 @@ def cleanup_code(code: str, sample, language_type: str = None): if method_lines: method_body = "\n".join(method_lines[1:]) - print(method_body) elif language_type.lower() == "cpp": method_lines = [] -- GitLab