generate_humaneval_x.py 3.5 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import argparse
import logging
import os
import random
import socket
import time
import json
import torch

from utils import read_dataset, process_extra_prompt
from utils import cleanup_code

logging.getLogger("torch").setLevel(logging.WARNING)


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn")
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_path",
        type=str,
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
22
        default="../eval_set/humaneval-x"
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
23 24 25 26 27 28 29 30 31 32
    )
    parser.add_argument(
        "--language_type",
        type=str,
        default="python"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="chatgpt",
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
33
        help='supported models in [chatgpt, chatglm2, bbt, codegeex2].'
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
34 35 36 37
    )
    parser.add_argument(
        "--output_prefix",
        type=str,
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
38
        default="../output/humaneval"
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
39 40 41
    )
    args = parser.parse_known_args()[0]

CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
42
    output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
43

CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
44
    entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
45 46 47 48 49 50 51 52 53 54
    for entry in entries.values():
        entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)

    model_inference = None
    if args.model_name == "chatgpt":
        from inference.chatgpt_inference import ChatGPTInference
        model_inference = ChatGPTInference()
    elif args.model_name == "chatglm2":
        from inference.chatglm2_inference import GLMInference
        model_inference = GLMInference()
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
55 56 57
    elif args.model_name == "bbt":
        from inference.bbt_inference import BBTInference
        model_inference = BBTInference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
58 59 60
    elif args.model_name == "codegeex2":
        from inference.codegeex2_inference import CodeGeex2Inference
        model_inference = CodeGeex2Inference()
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
61 62 63 64 65 66 67 68
    else:
        print(f"{args.model_name} not supported.")

    start_time = time.perf_counter()
    num_finished = 0
    with open(output_file_path, "w") as f:
        for entry in entries.values():
            prompt = entry["prompt"]
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
69
            retry_times = 3
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
70
            generated_tokens = model_inference.inference(prompt)
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
71 72 73 74 75 76 77 78
            while generated_tokens is None and retry_times > 0:
                time.sleep(10)
                generated_tokens = model_inference.inference(prompt)
                retry_times -= 1
            if generated_tokens is None:
                print(f"task_id: {entry['task_id']} generate failed!")
                continue
            generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
            f.write(
                json.dumps(
                    {
                        "task_id": entry['task_id'],
                        "prompt": prompt,
                        "generation": generated_code,
                        "scores": 0.0,
                        "finish": 2,
                        "output": generated_tokens,
                    }
                )
                + "\n"
            )
            f.flush()
            num_finished += 1

            time_elapsed = time.perf_counter() - start_time
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
96
            time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
97 98 99
            print(
                    f"finished {num_finished}, "
                    f"elapsed {time_elapsed:.4f}",
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
100
                    f"speed {time_per_sample:.4f}s/sample",
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
101 102 103
                    f"remaining {len(entries) - num_finished:.4f}",
                    flush=True,
                )