提交 912285c1 编写于 作者: C cuicheng01

Merge branch 'add_person_demo' of http://github.com/cuicheng01/PaddleClas into add_person_demo

...@@ -3,30 +3,32 @@ distill_config_file: ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_dis ...@@ -3,30 +3,32 @@ distill_config_file: ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_dis
gpus: 0,1,2,3 gpus: 0,1,2,3
output_dir: output/search_person output_dir: output/search_person
search_times: 3
search_dict: search_dict:
lrs: - search_key: lrs
replace_config: replace_config:
- Optimizer.lr.learning_rate - Optimizer.lr.learning_rate
search_values: [0.0075, 0.01, 0.0125] search_values: [0.0075, 0.01, 0.0125]
resolutions: - search_key: resolutions
replace_config: replace_config:
- DataLoader.Train.dataset.transform_ops.1.RandCropImage.size - DataLoader.Train.dataset.transform_ops.1.RandCropImage.size
- DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size - DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size
search_values: [176, 192, 224] search_values: [176, 192, 224]
ra_probs: - search_key: ra_probs
replace_config: replace_config:
- DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob - DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob
search_values: [0.0, 0.1, 0.5] search_values: [0.0, 0.1, 0.5]
re_probs: - search_key: re_probs
replace_config: replace_config:
- DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON - DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON
search_values: [0.0, 0.1, 0.5] search_values: [0.0, 0.1, 0.5]
lr_mult_list: - search_key: lr_mult_list
replace_config: replace_config:
- Arch.lr_mult_list - Arch.lr_mult_list
search_values: search_values:
- [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0] - [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
- [1.0. 1.0. 1.0. 1.0. 1.0. 1.0]
teacher: teacher:
rm_keys: rm_keys:
- Arch.lr_mult_list - Arch.lr_mult_list
......
...@@ -7,6 +7,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -7,6 +7,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import subprocess import subprocess
import numpy as np
from ppcls.utils import config from ppcls.utils import config
...@@ -18,7 +20,8 @@ def get_result(log_dir): ...@@ -18,7 +20,8 @@ def get_result(log_dir):
return res return res
def search_train(search_list, base_program, base_output_dir, search_key, config_replace_value, model_name): def search_train(search_list, base_program, base_output_dir, search_key,
config_replace_value, model_name, search_times=1):
best_res = 0. best_res = 0.
best = search_list[0] best = search_list[0]
all_result = {} all_result = {}
...@@ -28,15 +31,19 @@ def search_train(search_list, base_program, base_output_dir, search_key, config_ ...@@ -28,15 +31,19 @@ def search_train(search_list, base_program, base_output_dir, search_key, config_
program += ["-o", "{}={}".format(v, search_i)] program += ["-o", "{}={}".format(v, search_i)]
if v == "Arch.name": if v == "Arch.name":
model_name = search_i model_name = search_i
output_dir = "{}/{}_{}".format(base_output_dir, search_key, search_i).replace(".", "_") res_list = []
program += ["-o", "Global.output_dir={}".format(output_dir)] for j in range(search_times):
process = subprocess.Popen(program) output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, search_i, j).replace(".", "_")
process.communicate() program += ["-o", "Global.output_dir={}".format(output_dir)]
res = get_result("{}/{}".format(output_dir, model_name)) process = subprocess.Popen(program)
all_result[str(search_i)] = res process.communicate()
if res > best_res: res = get_result("{}/{}".format(output_dir, model_name))
res_list.append(res)
all_result[str(search_i)] = res_list
if np.mean(res_list) > best_res:
best = search_i best = search_i
best_res = res best_res = np.mean(res_list)
all_result["best"] = best all_result["best"] = best
return all_result return all_result
...@@ -52,12 +59,15 @@ def search_strategy(): ...@@ -52,12 +59,15 @@ def search_strategy():
base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus), base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus),
"tools/train.py", "-c", base_config_file] "tools/train.py", "-c", base_config_file]
base_output_dir = configs["output_dir"] base_output_dir = configs["output_dir"]
search_times = configs["search_times"]
search_dict = configs.get("search_dict") search_dict = configs.get("search_dict")
all_results = {} all_results = {}
for search_key in search_dict: for search_i in search_dict:
search_values = search_dict[search_key]["search_values"] search_key = search_i["search_key"]
replace_config = search_dict[search_key]["replace_config"] search_values = search_i["search_values"]
res = search_train(search_values, base_program, base_output_dir, search_key, replace_config, model_name) replace_config = search_i["replace_config"]
res = search_train(search_values, base_program, base_output_dir,
search_key, replace_config, model_name, search_times)
all_results[search_key] = res all_results[search_key] = res
best = res.get("best") best = res.get("best")
for v in replace_config: for v in replace_config:
...@@ -73,7 +83,6 @@ def search_strategy(): ...@@ -73,7 +83,6 @@ def search_strategy():
for ind, ki in enumerate(base_program): for ind, ki in enumerate(base_program):
if rm_k in ki: if rm_k in ki:
rm_indices.append(ind) rm_indices.append(ind)
print(rm_indices)
for rm_index in rm_indices[::-1]: for rm_index in rm_indices[::-1]:
teacher_program.pop(rm_index) teacher_program.pop(rm_index)
teacher_program.pop(rm_index-1) teacher_program.pop(rm_index-1)
...@@ -94,9 +103,9 @@ def search_strategy(): ...@@ -94,9 +103,9 @@ def search_strategy():
v = final_replace[k] v = final_replace[k]
base_program[i] = base_program[i].replace(k, v) base_program[i] = base_program[i].replace(k, v)
print(all_results, base_program)
process = subprocess.Popen(base_program) process = subprocess.Popen(base_program)
process.communicate() process.communicate()
print(all_results, base_program)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册