generate_humaneval_x.py 3.2 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1

CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
2 3 4 5 6 7 8 9 10 11 12
import logging
import time
import json
import torch

from utils import read_dataset, process_extra_prompt
from utils import cleanup_code

logging.getLogger("torch").setLevel(logging.WARNING)


CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
13 14 15 16
def generate(input_path="eval_set/humaneval-x",
             language_type="python",
             model_name="chatgpt",
             output_prefix="output/humaneval"):
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
17 18
    torch.multiprocessing.set_start_method("spawn")

CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
19
    output_file_path = output_prefix + f"_{model_name}_{language_type}_finished.jsonl"
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
20

CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
21
    entries = read_dataset(data_folder=input_path, language_type=language_type, dataset_type="humaneval")
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
22
    for entry in entries.values():
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
23
        entry["prompt"] = process_extra_prompt(entry["prompt"], language_type)
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
24 25

    model_inference = None
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
26
    if model_name == "chatgpt":
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
27 28
        from inference.chatgpt_inference import ChatGPTInference
        model_inference = ChatGPTInference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
29
    elif model_name == "chatglm2":
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
30 31
        from inference.chatglm2_inference import GLMInference
        model_inference = GLMInference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
32
    elif model_name == "bbt":
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
33 34
        from inference.bbt_inference import BBTInference
        model_inference = BBTInference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
35
    elif model_name == "codegeex2":
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
36 37
        from inference.codegeex2_inference import CodeGeex2Inference
        model_inference = CodeGeex2Inference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
38
    elif model_name == "llama2":
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
39 40
        from inference.llama2_inference import LLAMA2Inference
        model_inference = LLAMA2Inference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
41 42 43
    elif model_name == "baichuan":
        from inference.baichuan_inference import BaiChuanInference
        model_inference = BaiChuanInference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
44
    else:
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
45
        print(f"{model_name} not supported.")
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
46 47 48 49 50 51

    start_time = time.perf_counter()
    num_finished = 0
    with open(output_file_path, "w") as f:
        for entry in entries.values():
            prompt = entry["prompt"]
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
52
            retry_times = 3
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
53
            generated_tokens = model_inference.inference(prompt)
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
54 55 56 57 58 59 60
            while generated_tokens is None and retry_times > 0:
                time.sleep(10)
                generated_tokens = model_inference.inference(prompt)
                retry_times -= 1
            if generated_tokens is None:
                print(f"task_id: {entry['task_id']} generate failed!")
                continue
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
61
            generated_code = cleanup_code(generated_tokens, entry, language_type=language_type)
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            f.write(
                json.dumps(
                    {
                        "task_id": entry['task_id'],
                        "prompt": prompt,
                        "generation": generated_code,
                        "scores": 0.0,
                        "finish": 2,
                        "output": generated_tokens,
                    }
                )
                + "\n"
            )
            f.flush()
            num_finished += 1

            time_elapsed = time.perf_counter() - start_time
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
79
            time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
80 81 82
            print(
                    f"finished {num_finished}, "
                    f"elapsed {time_elapsed:.4f}",
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
83
                    f"speed {time_per_sample:.4f}s/sample",
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
84 85 86
                    f"remaining {len(entries) - num_finished:.4f}",
                    flush=True,
                )