未验证 提交 def286ba 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add support for no dist (#1989)

上级 787f91b6
...@@ -131,7 +131,7 @@ cd path_to_PaddleClas ...@@ -131,7 +131,7 @@ cd path_to_PaddleClas
```shell ```shell
cd dataset cd dataset
wget https://paddleclas.bj.bcebos.com/data/cls_demo/traffic_sign.tar wget https://paddleclas.bj.bcebos.com/data/PULC/traffic_sign.tar
tar -xf traffic_sign.tar tar -xf traffic_sign.tar
cd ../ cd ../
``` ```
......
...@@ -30,6 +30,7 @@ search_dict: ...@@ -30,6 +30,7 @@ search_dict:
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0] - [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
teacher: teacher:
algorithm: "skl-ugi"
rm_keys: rm_keys:
- Arch.lr_mult_list - Arch.lr_mult_list
search_values: search_values:
......
...@@ -25,6 +25,7 @@ search_dict: ...@@ -25,6 +25,7 @@ search_dict:
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0] - [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
teacher: teacher:
algorithm: "skl-ugi"
rm_keys: rm_keys:
- Arch.lr_mult_list - Arch.lr_mult_list
search_values: search_values:
......
...@@ -365,9 +365,6 @@ class RandomCropImage(object): ...@@ -365,9 +365,6 @@ class RandomCropImage(object):
j = random.randint(0, w - tw) j = random.randint(0, w - tw)
img = img[i:i + th, j:j + tw, :] img = img[i:i + th, j:j + tw, :]
if img.shape[0] != 256 or img.shape[1] != 192:
raise ValueError('sample: ', h, w, i, j, th, tw, img.shape)
return img return img
......
...@@ -20,8 +20,13 @@ def get_result(log_dir): ...@@ -20,8 +20,13 @@ def get_result(log_dir):
return res return res
def search_train(search_list, base_program, base_output_dir, search_key, def search_train(search_list,
config_replace_value, model_name, search_times=1): base_program,
base_output_dir,
search_key,
config_replace_value,
model_name,
search_times=1):
best_res = 0. best_res = 0.
best = search_list[0] best = search_list[0]
all_result = {} all_result = {}
...@@ -33,7 +38,8 @@ def search_train(search_list, base_program, base_output_dir, search_key, ...@@ -33,7 +38,8 @@ def search_train(search_list, base_program, base_output_dir, search_key,
model_name = search_i model_name = search_i
res_list = [] res_list = []
for j in range(search_times): for j in range(search_times):
output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, search_i, j).replace(".", "_") output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key,
search_i, j).replace(".", "_")
program += ["-o", "Global.output_dir={}".format(output_dir)] program += ["-o", "Global.output_dir={}".format(output_dir)]
process = subprocess.Popen(program) process = subprocess.Popen(program)
process.communicate() process.communicate()
...@@ -50,14 +56,17 @@ def search_train(search_list, base_program, base_output_dir, search_key, ...@@ -50,14 +56,17 @@ def search_train(search_list, base_program, base_output_dir, search_key,
def search_strategy(): def search_strategy():
args = config.parse_args() args = config.parse_args()
configs = config.get_config(args.config, overrides=args.override, show=False) configs = config.get_config(
args.config, overrides=args.override, show=False)
base_config_file = configs["base_config_file"] base_config_file = configs["base_config_file"]
distill_config_file = configs["distill_config_file"] distill_config_file = configs["distill_config_file"]
model_name = config.get_config(base_config_file)["Arch"]["name"] model_name = config.get_config(base_config_file)["Arch"]["name"]
gpus = configs["gpus"] gpus = configs["gpus"]
gpus = ",".join([str(i) for i in gpus]) gpus = ",".join([str(i) for i in gpus])
base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus), base_program = [
"tools/train.py", "-c", base_config_file] "python3.7", "-m", "paddle.distributed.launch",
"--gpus={}".format(gpus), "tools/train.py", "-c", base_config_file
]
base_output_dir = configs["output_dir"] base_output_dir = configs["output_dir"]
search_times = configs["search_times"] search_times = configs["search_times"]
search_dict = configs.get("search_dict") search_dict = configs.get("search_dict")
...@@ -67,14 +76,22 @@ def search_strategy(): ...@@ -67,14 +76,22 @@ def search_strategy():
search_values = search_i["search_values"] search_values = search_i["search_values"]
replace_config = search_i["replace_config"] replace_config = search_i["replace_config"]
res = search_train(search_values, base_program, base_output_dir, res = search_train(search_values, base_program, base_output_dir,
search_key, replace_config, model_name, search_times) search_key, replace_config, model_name,
search_times)
all_results[search_key] = res all_results[search_key] = res
best = res.get("best") best = res.get("best")
for v in replace_config: for v in replace_config:
base_program += ["-o", "{}={}".format(v, best)] base_program += ["-o", "{}={}".format(v, best)]
teacher_configs = configs.get("teacher", None) teacher_configs = configs.get("teacher", None)
if teacher_configs is not None: if teacher_configs is None:
print(all_results, base_program)
return
algo = teacher_configs.get("algorithm", "skl-ugi")
supported_list = ["skl-ugi", "udml"]
assert algo in supported_list, f"algorithm must be in {supported_list} but got {algo}"
if algo == "skl-ugi":
teacher_program = base_program.copy() teacher_program = base_program.copy()
# remove incompatible keys # remove incompatible keys
teacher_rm_keys = teacher_configs["rm_keys"] teacher_rm_keys = teacher_configs["rm_keys"]
...@@ -85,20 +102,32 @@ def search_strategy(): ...@@ -85,20 +102,32 @@ def search_strategy():
rm_indices.append(ind) rm_indices.append(ind)
for rm_index in rm_indices[::-1]: for rm_index in rm_indices[::-1]:
teacher_program.pop(rm_index) teacher_program.pop(rm_index)
teacher_program.pop(rm_index-1) teacher_program.pop(rm_index - 1)
replace_config = ["Arch.name"] replace_config = ["Arch.name"]
teacher_list = teacher_configs["search_values"] teacher_list = teacher_configs["search_values"]
res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name) res = search_train(teacher_list, teacher_program, base_output_dir,
"teacher", replace_config, model_name)
all_results["teacher"] = res all_results["teacher"] = res
best = res.get("best") best = res.get("best")
t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, "teacher", best, best) t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir,
base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best), "teacher", best, best)
"-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)] base_program += [
"-o", "Arch.models.0.Teacher.name={}".format(best), "-o",
"Arch.models.0.Teacher.pretrained={}".format(t_pretrained)
]
elif algo == "udml":
if "lr_mult_list" in all_results:
base_program += [
"-o", "Arch.models.0.Teacher.lr_mult_list={}".format(
all_results["lr_mult_list"]["best"])
]
output_dir = "{}/search_res".format(base_output_dir) output_dir = "{}/search_res".format(base_output_dir)
base_program += ["-o", "Global.output_dir={}".format(output_dir)] base_program += ["-o", "Global.output_dir={}".format(output_dir)]
final_replace = configs.get('final_replace') final_replace = configs.get('final_replace')
for i in range(len(base_program)): for i in range(len(base_program)):
base_program[i] = base_program[i].replace(base_config_file, distill_config_file) base_program[i] = base_program[i].replace(base_config_file,
distill_config_file)
for k in final_replace: for k in final_replace:
v = final_replace[k] v = final_replace[k]
base_program[i] = base_program[i].replace(k, v) base_program[i] = base_program[i].replace(k, v)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册