search_strategy.py 4.3 KB
Newer Older
W
dbg  
weishengyu 已提交
1 2 3 4 5 6 7 8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))

W
weisy11 已提交
9
import subprocess
W
weisy11 已提交
10 11
import numpy as np

W
weisy11 已提交
12 13 14 15 16 17 18
from ppcls.utils import config


def get_result(log_dir):
    log_file = "{}/train.log".format(log_dir)
    with open(log_file, "r") as f:
        raw = f.read()
W
dbg  
weishengyu 已提交
19
    res = float(raw.split("best metric: ")[-1].split("]")[0])
W
weisy11 已提交
20 21 22
    return res


W
weisy11 已提交
23 24
def search_train(search_list, base_program, base_output_dir, search_key,
                 config_replace_value, model_name, search_times=1):
W
weisy11 已提交
25 26 27 28 29 30 31
    best_res = 0.
    best = search_list[0]
    all_result = {}
    for search_i in search_list:
        program = base_program.copy()
        for v in config_replace_value:
            program += ["-o", "{}={}".format(v, search_i)]
W
dbg  
weishengyu 已提交
32 33
            if v == "Arch.name":
                model_name = search_i
W
weisy11 已提交
34 35 36 37 38 39 40 41 42 43 44
        res_list = []
        for j in range(search_times):
            output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, search_i, j).replace(".", "_")
            program += ["-o", "Global.output_dir={}".format(output_dir)]
            process = subprocess.Popen(program)
            process.communicate()
            res = get_result("{}/{}".format(output_dir, model_name))
            res_list.append(res)
        all_result[str(search_i)] = res_list

        if np.mean(res_list) > best_res:
W
weisy11 已提交
45
            best = search_i
W
weisy11 已提交
46
            best_res = np.mean(res_list)
W
weisy11 已提交
47 48 49 50 51 52 53 54
    all_result["best"] = best
    return all_result


def search_strategy():
    args = config.parse_args()
    configs = config.get_config(args.config, overrides=args.override, show=False)
    base_config_file = configs["base_config_file"]
W
dbg  
weishengyu 已提交
55 56
    distill_config_file = configs["distill_config_file"]
    model_name = config.get_config(base_config_file)["Arch"]["name"]
W
weisy11 已提交
57
    gpus = configs["gpus"]
W
dbg  
weishengyu 已提交
58
    gpus = ",".join([str(i) for i in gpus])
W
weisy11 已提交
59 60 61
    base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus),
                    "tools/train.py", "-c", base_config_file]
    base_output_dir = configs["output_dir"]
W
weisy11 已提交
62
    search_times = configs["search_times"]
W
weisy11 已提交
63 64
    search_dict = configs.get("search_dict")
    all_results = {}
W
weisy11 已提交
65 66 67 68
    for search_i in search_dict:
        search_key = search_i["search_key"]
        search_values = search_i["search_values"]
        replace_config = search_i["replace_config"]
W
weisy11 已提交
69 70
        res = search_train(search_values, base_program, base_output_dir,
                           search_key, replace_config, model_name, search_times)
W
weisy11 已提交
71 72 73 74 75 76 77 78 79 80 81 82
        all_results[search_key] = res
        best = res.get("best")
        for v in replace_config:
            base_program += ["-o", "{}={}".format(v, best)]

    teacher_configs = configs.get("teacher", None)
    if teacher_configs is not None:
        teacher_program = base_program.copy()
        # remove incompatible keys
        teacher_rm_keys = teacher_configs["rm_keys"]
        rm_indices = []
        for rm_k in teacher_rm_keys:
W
dbg  
weishengyu 已提交
83 84 85 86
            for ind, ki in enumerate(base_program):
              if rm_k in ki:
                rm_indices.append(ind)
        for rm_index in rm_indices[::-1]:
W
weisy11 已提交
87
            teacher_program.pop(rm_index)
W
dbg  
weishengyu 已提交
88 89
            teacher_program.pop(rm_index-1)
        replace_config = ["Arch.name"]
W
weisy11 已提交
90
        teacher_list = teacher_configs["search_values"]
W
dbg  
weishengyu 已提交
91
        res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name)
W
weisy11 已提交
92 93
        all_results["teacher"] = res
        best = res.get("best")
C
cuicheng01 已提交
94
        t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, "teacher", best, best)
W
weisy11 已提交
95 96 97 98
        base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best),
                         "-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)]
    output_dir = "{}/search_res".format(base_output_dir)
    base_program += ["-o", "Global.output_dir={}".format(output_dir)]
W
dbg  
weishengyu 已提交
99 100 101 102 103 104 105
    final_replace = configs.get('final_replace')
    for i in range(len(base_program)):
      base_program[i] = base_program[i].replace(base_config_file, distill_config_file)
      for k in final_replace:
        v = final_replace[k]
        base_program[i] = base_program[i].replace(k, v)

W
dbg  
weishengyu 已提交
106 107
    process = subprocess.Popen(base_program)
    process.communicate()
W
weisy11 已提交
108
    print(all_results, base_program)
W
weisy11 已提交
109 110 111 112


if __name__ == '__main__':
    search_strategy()