From def286bac8ff62e79ff9c510291d30acabbb5eca Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 7 Jun 2022 17:25:22 +0800 Subject: [PATCH] add support for no dist (#1989) --- docs/zh_CN/PULC/PULC_traffic_sign.md | 2 +- ppcls/configs/PULC/traffic_sign/search.yaml | 1 + ppcls/configs/PULC/vehicle_attr/search.yaml | 1 + ppcls/data/preprocess/ops/operators.py | 3 - tools/search_strategy.py | 67 +++++++++++++++------ 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/docs/zh_CN/PULC/PULC_traffic_sign.md b/docs/zh_CN/PULC/PULC_traffic_sign.md index 342bd67f..74a7d97a 100644 --- a/docs/zh_CN/PULC/PULC_traffic_sign.md +++ b/docs/zh_CN/PULC/PULC_traffic_sign.md @@ -131,7 +131,7 @@ cd path_to_PaddleClas ```shell cd dataset -wget https://paddleclas.bj.bcebos.com/data/cls_demo/traffic_sign.tar +wget https://paddleclas.bj.bcebos.com/data/PULC/traffic_sign.tar tar -xf traffic_sign.tar cd ../ ``` diff --git a/ppcls/configs/PULC/traffic_sign/search.yaml b/ppcls/configs/PULC/traffic_sign/search.yaml index 755ed201..029d042d 100644 --- a/ppcls/configs/PULC/traffic_sign/search.yaml +++ b/ppcls/configs/PULC/traffic_sign/search.yaml @@ -30,6 +30,7 @@ search_dict: - [0.0, 0.4, 0.4, 0.8, 0.8, 1.0] - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] teacher: + algorithm: "skl-ugi" rm_keys: - Arch.lr_mult_list search_values: diff --git a/ppcls/configs/PULC/vehicle_attr/search.yaml b/ppcls/configs/PULC/vehicle_attr/search.yaml index d5f41a3c..2a16266b 100644 --- a/ppcls/configs/PULC/vehicle_attr/search.yaml +++ b/ppcls/configs/PULC/vehicle_attr/search.yaml @@ -25,6 +25,7 @@ search_dict: - [0.0, 0.4, 0.4, 0.8, 0.8, 1.0] - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] teacher: + algorithm: "skl-ugi" rm_keys: - Arch.lr_mult_list search_values: diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index d31ec4b8..344675fd 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -365,9 +365,6 @@ class RandomCropImage(object): j = random.randint(0, w - tw) img = img[i:i + th, j:j + tw, :] - if img.shape[0] != 256 or img.shape[1] != 192: - raise ValueError('sample: ', h, w, i, j, th, tw, img.shape) - return img diff --git a/tools/search_strategy.py b/tools/search_strategy.py index 15f4aa71..2a391c19 100644 --- a/tools/search_strategy.py +++ b/tools/search_strategy.py @@ -20,8 +20,13 @@ def get_result(log_dir): return res -def search_train(search_list, base_program, base_output_dir, search_key, - config_replace_value, model_name, search_times=1): +def search_train(search_list, + base_program, + base_output_dir, + search_key, + config_replace_value, + model_name, + search_times=1): best_res = 0. best = search_list[0] all_result = {} @@ -33,7 +38,8 @@ def search_train(search_list, base_program, base_output_dir, search_key, model_name = search_i res_list = [] for j in range(search_times): - output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, search_i, j).replace(".", "_") + output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, + search_i, j).replace(".", "_") program += ["-o", "Global.output_dir={}".format(output_dir)] process = subprocess.Popen(program) process.communicate() @@ -50,14 +56,17 @@ def search_train(search_list, base_program, base_output_dir, search_key, def search_strategy(): args = config.parse_args() - configs = config.get_config(args.config, overrides=args.override, show=False) + configs = config.get_config( + args.config, overrides=args.override, show=False) base_config_file = configs["base_config_file"] distill_config_file = configs["distill_config_file"] model_name = config.get_config(base_config_file)["Arch"]["name"] gpus = configs["gpus"] gpus = ",".join([str(i) for i in gpus]) - base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus), - "tools/train.py", "-c", base_config_file] + base_program = [ + "python3.7", "-m", "paddle.distributed.launch", + "--gpus={}".format(gpus), "tools/train.py", "-c", base_config_file + ] base_output_dir = configs["output_dir"] search_times = configs["search_times"] search_dict = configs.get("search_dict") @@ -67,41 +76,61 @@ def search_strategy(): search_values = search_i["search_values"] replace_config = search_i["replace_config"] res = search_train(search_values, base_program, base_output_dir, - search_key, replace_config, model_name, search_times) + search_key, replace_config, model_name, + search_times) all_results[search_key] = res best = res.get("best") for v in replace_config: base_program += ["-o", "{}={}".format(v, best)] teacher_configs = configs.get("teacher", None) - if teacher_configs is not None: + if teacher_configs is None: + print(all_results, base_program) + return + + algo = teacher_configs.get("algorithm", "skl-ugi") + supported_list = ["skl-ugi", "udml"] + assert algo in supported_list, f"algorithm must be in {supported_list} but got {algo}" + if algo == "skl-ugi": teacher_program = base_program.copy() # remove incompatible keys teacher_rm_keys = teacher_configs["rm_keys"] rm_indices = [] for rm_k in teacher_rm_keys: for ind, ki in enumerate(base_program): - if rm_k in ki: - rm_indices.append(ind) + if rm_k in ki: + rm_indices.append(ind) for rm_index in rm_indices[::-1]: teacher_program.pop(rm_index) - teacher_program.pop(rm_index-1) + teacher_program.pop(rm_index - 1) replace_config = ["Arch.name"] teacher_list = teacher_configs["search_values"] - res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name) + res = search_train(teacher_list, teacher_program, base_output_dir, + "teacher", replace_config, model_name) all_results["teacher"] = res best = res.get("best") - t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, "teacher", best, best) - base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best), - "-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)] + t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, + "teacher", best, best) + base_program += [ + "-o", "Arch.models.0.Teacher.name={}".format(best), "-o", + "Arch.models.0.Teacher.pretrained={}".format(t_pretrained) + ] + elif algo == "udml": + if "lr_mult_list" in all_results: + base_program += [ + "-o", "Arch.models.0.Teacher.lr_mult_list={}".format( + all_results["lr_mult_list"]["best"]) + ] + output_dir = "{}/search_res".format(base_output_dir) base_program += ["-o", "Global.output_dir={}".format(output_dir)] final_replace = configs.get('final_replace') for i in range(len(base_program)): - base_program[i] = base_program[i].replace(base_config_file, distill_config_file) - for k in final_replace: - v = final_replace[k] - base_program[i] = base_program[i].replace(k, v) + base_program[i] = base_program[i].replace(base_config_file, + distill_config_file) + for k in final_replace: + v = final_replace[k] + base_program[i] = base_program[i].replace(k, v) process = subprocess.Popen(base_program) process.communicate() -- GitLab