import argparse import logging import os import random import socket import time import json import torch from utils import read_dataset, process_extra_prompt from utils import cleanup_code logging.getLogger("torch").setLevel(logging.WARNING) if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") parser = argparse.ArgumentParser() parser.add_argument( "--input_path", type=str, default="../eval_set/humaneval-x" ) parser.add_argument( "--language_type", type=str, default="python" ) parser.add_argument( "--model_name", type=str, default="chatgpt", help='supported models in [chatgpt, chatglm2, bbt, codegeex2].' ) parser.add_argument( "--output_prefix", type=str, default="../output/humaneval" ) args = parser.parse_known_args()[0] output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl" entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval") for entry in entries.values(): entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) model_inference = None if args.model_name == "chatgpt": from inference.chatgpt_inference import ChatGPTInference model_inference = ChatGPTInference() elif args.model_name == "chatglm2": from inference.chatglm2_inference import GLMInference model_inference = GLMInference() elif args.model_name == "bbt": from inference.bbt_inference import BBTInference model_inference = BBTInference() elif args.model_name == "codegeex2": from inference.codegeex2_inference import CodeGeex2Inference model_inference = CodeGeex2Inference() else: print(f"{args.model_name} not supported.") start_time = time.perf_counter() num_finished = 0 with open(output_file_path, "w") as f: for entry in entries.values(): prompt = entry["prompt"] retry_times = 3 generated_tokens = model_inference.inference(prompt) while generated_tokens is None and retry_times > 0: time.sleep(10) generated_tokens = model_inference.inference(prompt) retry_times -= 1 if generated_tokens is None: print(f"task_id: {entry['task_id']} generate failed!") continue generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type) f.write( json.dumps( { "task_id": entry['task_id'], "prompt": prompt, "generation": generated_code, "scores": 0.0, "finish": 2, "output": generated_tokens, } ) + "\n" ) f.flush() num_finished += 1 time_elapsed = time.perf_counter() - start_time time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished print( f"finished {num_finished}, " f"elapsed {time_elapsed:.4f}", f"speed {time_per_sample:.4f}s/sample", f"remaining {len(entries) - num_finished:.4f}", flush=True, )