generate_humaneval_x.py 2.7 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
import argparse
import logging
import os
import random
import socket
import time
import json
import torch

from utils import read_dataset, process_extra_prompt
from utils import cleanup_code

logging.getLogger("torch").setLevel(logging.WARNING)


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn")
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_path",
        type=str,
        default="eval_set/humaneval-x"
    )
    parser.add_argument(
        "--language_type",
        type=str,
        default="python"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="chatgpt",
        help='supported model in [chatgpt,chatglm2].'
    )
    parser.add_argument(
        "--output_prefix",
        type=str,
        default="chatgpt"
    )
    args = parser.parse_known_args()[0]

    output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"

    entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval")
    for entry in entries.values():
        entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)

    model_inference = None
    if args.model_name == "chatgpt":
        from inference.chatgpt_inference import ChatGPTInference
        model_inference = ChatGPTInference()
    elif args.model_name == "chatglm2":
        from inference.chatglm2_inference import GLMInference
        model_inference = GLMInference()
    else:
        print(f"{args.model_name} not supported.")

    start_time = time.perf_counter()
    num_finished = 0
    with open(output_file_path, "w") as f:
        for entry in entries.values():
            prompt = entry["prompt"]
            generated_tokens = model_inference.inference(prompt)
            generated_code = cleanup_code(generated_code, language_type=args.language_type)
            f.write(
                json.dumps(
                    {
                        "task_id": entry['task_id'],
                        "prompt": prompt,
                        "generation": generated_code,
                        "scores": 0.0,
                        "finish": 2,
                        "output": generated_tokens,
                    }
                )
                + "\n"
            )
            f.flush()
            num_finished += 1

            time_elapsed = time.perf_counter() - start_time
            time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished
            print(
                    f"finished {num_finished}, "
                    f"elapsed {time_elapsed:.4f}",
                    f"speed {time_per_sampple:.4f}s/sample",
                    f"remaining {len(entries) - num_finished:.4f}",
                    flush=True,
                )