main.py 2.5 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95


import argparse


def remove_none_params(params_dict):
    params = {}
    for key in params_dict:
        if params_dict[key] is not None:
            params[key] = params_dict[key]
    return params


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task_type",
        type=str,
        required=True,
        help='supported tasks. [generate, translate, evaluate]'
    )
    parser.add_argument(
        "--input_path",
        type=str,
    )
    parser.add_argument(
        "--language_type",
        type=str,
        default="python"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="chatgpt",
        help='supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].'
    )
    parser.add_argument(
        "--output_prefix",
        type=str
    )
    parser.add_argument(
        "--problem_folder",
        type=str,
        default="eval_set/humaneval-x/",
        help='for evaluate, problem path.'
    )
    parser.add_argument(
        "--tmp_dir",
        type=str,
        default="output/tmp/",
        help='for evaluate, temp files path.'
    )
    parser.add_argument(
        "--n_workers",
        type=int,
        default=1,
        help='for evaluate, number of processes.'
    )
    parser.add_argument(
        "--timeout",
        type=float,
        default=10.0,
        help='for evaluate, single testing timeout time.'
    )
    args = parser.parse_known_args()[0]

    task_type = args.task_type
    if task_type == "generate":
        from src.generate_humaneval_x import generate
        params = remove_none_params({
            "input_path": args.input_path,
            "language_type": args.language_type,
            "model_name": args.model_name,
            "output_prefix": args.output_prefix
        })
        generate(**params)
    elif task_type == "translate":
        print("Not implemented.")

    elif task_type == "evaluate":
        from src.evaluate_humaneval_x import evaluate_functional_correctness
        params = remove_none_params({
            "input_folder": args.input_path,
            "language_type": args.language_type,
            "model_name": args.model_name,
            "out_dir": args.output_prefix,
            "problem_folder": args.problem_folder,
            "tmp_dir": args.tmp_dir,
            "n_workers": args.n_workers,
            "timeout": args.timeout
        })

        evaluate_functional_correctness(**params)
    else:
        print(f"task_type: {task_type} not supported.")