import logging import time import json import torch from utils import read_dataset, process_extra_prompt from utils import cleanup_code logging.getLogger("torch").setLevel(logging.WARNING) def generate(input_path="eval_set/humaneval-x", language_type="python", model_name="chatgpt", output_prefix="output/humaneval"): torch.multiprocessing.set_start_method("spawn") output_file_path = output_prefix + f"_{model_name}_{language_type}_finished.jsonl" entries = read_dataset(data_folder=input_path, language_type=language_type, dataset_type="humaneval") for entry in entries.values(): entry["prompt"] = process_extra_prompt(entry["prompt"], language_type) model_inference = None if model_name == "chatgpt": from inference.chatgpt_inference import ChatGPTInference model_inference = ChatGPTInference() elif model_name == "chatglm2": from inference.chatglm2_inference import GLMInference model_inference = GLMInference() elif model_name == "bbt": from inference.bbt_inference import BBTInference model_inference = BBTInference() elif model_name == "codegeex2": from inference.codegeex2_inference import CodeGeex2Inference model_inference = CodeGeex2Inference() elif model_name == "llama2": from inference.llama2_inference import LLAMA2Inference model_inference = LLAMA2Inference() elif model_name == "baichuan": from inference.baichuan_inference import BaiChuanInference model_inference = BaiChuanInference() else: print(f"{model_name} not supported.") start_time = time.perf_counter() num_finished = 0 with open(output_file_path, "w") as f: for entry in entries.values(): prompt = entry["prompt"] retry_times = 3 generated_tokens = model_inference.inference(prompt) while generated_tokens is None and retry_times > 0: time.sleep(10) generated_tokens = model_inference.inference(prompt) retry_times -= 1 if generated_tokens is None: print(f"task_id: {entry['task_id']} generate failed!") continue generated_code = cleanup_code(generated_tokens, entry, language_type=language_type) f.write( json.dumps( { "task_id": entry['task_id'], "prompt": prompt, "generation": generated_code, "scores": 0.0, "finish": 2, "output": generated_tokens, } ) + "\n" ) f.flush() num_finished += 1 time_elapsed = time.perf_counter() - start_time time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished print( f"finished {num_finished}, " f"elapsed {time_elapsed:.4f}", f"speed {time_per_sample:.4f}s/sample", f"remaining {len(entries) - num_finished:.4f}", flush=True, )