import argparse def remove_none_params(params_dict): params = {} for key in params_dict: if params_dict[key] is not None: params[key] = params_dict[key] return params if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--task_type", type=str, required=True, help='supported tasks. [generate, translate, evaluate]' ) parser.add_argument( "--input_path", type=str, ) parser.add_argument( "--language_type", type=str, default="python" ) parser.add_argument( "--model_name", type=str, default="chatgpt", help='supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].' ) parser.add_argument( "--output_prefix", type=str ) parser.add_argument( "--problem_folder", type=str, default="eval_set/humaneval-x/", help='for evaluate, problem path.' ) parser.add_argument( "--tmp_dir", type=str, default="output/tmp/", help='for evaluate, temp files path.' ) parser.add_argument( "--n_workers", type=int, default=1, help='for evaluate, number of processes.' ) parser.add_argument( "--timeout", type=float, default=10.0, help='for evaluate, single testing timeout time.' ) args = parser.parse_known_args()[0] task_type = args.task_type if task_type == "generate": from src.generate_humaneval_x import generate params = remove_none_params({ "input_path": args.input_path, "language_type": args.language_type, "model_name": args.model_name, "output_prefix": args.output_prefix }) generate(**params) elif task_type == "translate": print("Not implemented.") elif task_type == "evaluate": from src.evaluate_humaneval_x import evaluate_functional_correctness params = remove_none_params({ "input_folder": args.input_path, "language_type": args.language_type, "model_name": args.model_name, "out_dir": args.output_prefix, "problem_folder": args.problem_folder, "tmp_dir": args.tmp_dir, "n_workers": args.n_workers, "timeout": args.timeout }) evaluate_functional_correctness(**params) else: print(f"task_type: {task_type} not supported.")