提交 941f49eb 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

add codegeex2

上级 a554db77
......@@ -89,11 +89,13 @@ python evaluate_humaneval_x.py --language_type python
受限于模型推理速度,目前只测试了pass@1指标。
| | python | java | cpp | js | go |
|-------------|--------|--------|--------|--------|---------|
| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% |
| chatglm2-7B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
| | python | java | cpp | js | go |
|--------------|--------|--------|--------|--------|---------|
| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% |
| bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% |
| chatglm2-7B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
| codegeex2-7B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% |
......
{
"url": "https://codebbt.ssymmetry.com/test3/code_bbt",
"url": "https://codebbt.ssymmetry.com/chat/code_bbt",
"user_id": "csdnTest_0001",
"stream": 0,
"timeout": 200,
......
{
"model_path": "../llm_set/models/codegeex2-6b",
"device": "cuda:0",
"quantize": false,
"max_length": 1024,
"min_length": 10
}
\ No newline at end of file
python generate_humaneval_x.py --model_name bbt --language_type python
python generate_humaneval_x.py --model_name bbt --language_type java
python generate_humaneval_x.py --model_name bbt --language_type cpp
python generate_humaneval_x.py --model_name bbt --language_type js
python generate_humaneval_x.py --model_name bbt --language_type go
\ No newline at end of file
......@@ -30,7 +30,7 @@ if __name__ == "__main__":
"--model_name",
type=str,
default="chatgpt",
help='supported models in [chatgpt, chatglm2, bbt].'
help='supported models in [chatgpt, chatglm2, bbt, codegeex2].'
)
parser.add_argument(
"--output_prefix",
......@@ -55,6 +55,9 @@ if __name__ == "__main__":
elif args.model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
elif args.model_name == "codegeex2":
from inference.codegeex2_inference import CodeGeex2Inference
model_inference = CodeGeex2Inference()
else:
print(f"{args.model_name} not supported.")
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/7/25 16:19
# @Author : clong
# @File : codegeex2_inference.py.py
import os
import json
import torch
import logging
from transformers import AutoModel, AutoTokenizer
from .inference import Inference
logger = logging.getLogger(__name__)
class CodeGeex2Inference(Inference):
def __init__(self):
super(CodeGeex2Inference, self).__init__()
self.params_url = "../llm_set/params/codegeex2.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.temperature = self.paras_dict.get("temperature")
self.quantize = self.paras_dict.get("quantize")
self.device = self.paras_dict.get("device")
self.model_path = self.paras_dict.get("model_path")
if self.device == "cpu":
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True,
ignore_mismatched_sizes=True).float()
self.DEV = torch.device('cpu')
else:
os.environ["CUDA_VISIBLE_DEVICES"] = self.device.split(":")[1]
self.DEV = torch.device(self.device)
if self.quantize:
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).quantize(8).cuda()
else:
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).cuda()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.eval()
self.max_length = self.paras_dict.get("max_length")
self.min_length = self.paras_dict.get("min_length")
self.top_p = self.paras_dict.get("top_p")
self.top_k = self.paras_dict.get("top_k")
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
def inference(self, message):
input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)
sentence = None
with torch.no_grad():
generation_output = self.model.generate(
input_ids,
do_sample=True,
min_length=self.min_length,
max_length=self.max_length,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature
)
output = self.tokenizer.decode([el.item() for el in generation_output[0]])
sentence = output.strip()
return sentence
......@@ -281,7 +281,6 @@ def cleanup_code(code: str, sample, language_type: str = None):
if method_lines:
method_body = "\n".join(method_lines[1:])
print(method_body)
elif language_type.lower() == "cpp":
method_lines = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册