gather_output.py 1.8 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
import os
import sys
import fire
import glob


def gather_output(
    output_dir: str = "./output",
    output_prefix: str = None,
    if_remove_rank_files: int = 0,
):
    if output_prefix is None:
        output_list = glob.glob(output_dir + "/*")
    else:
        output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))

    for output_file in output_list:
        if "rank0" in output_file:
            output_prefix_ = output_file.split("_rank0.jsonl")[0]
            rank_files = glob.glob(output_prefix_ + "_rank*")
            with open(output_prefix_ + ".jsonl", "w") as f_out:
                for rank_file in rank_files:
                    with open(rank_file, "r") as f_in:
                        for line in f_in:
                            f_out.write(line)
                        if if_remove_rank_files:
                            os.remove(rank_file)
                            print(f"Removing {rank_file}...")

    if output_prefix is None:
        output_list = glob.glob(output_dir + "/*")
    else:
        output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))

    for output_file in output_list:
        if "rank" in output_file or "_unfinished" in output_file or "all" in output_file or "_result" in output_file:
            continue
        if "_finished" not in output_file:
            continue
        output_prefix_ = output_file.split("_finished.jsonl")[0]
        files = [output_file, output_prefix_ + "_unfinished.jsonl"]
        with open(output_prefix_ + "_all.jsonl", "w") as f_out:
            for f in files:
                with open(f, "r") as f_in:
                    for line in f_in:
                        f_out.write(line)

        print("Gathering finished. Saved in {}".format(output_prefix_ + "_all.jsonl"))


def main():
    fire.Fire(gather_output)


if __name__ == "__main__":
    sys.exit(main())