未验证 提交 84d76531 编写于 作者: B Bai Yifan 提交者: GitHub

add epochs_no_optarch to replace method in darts/train_search (#215)

上级 7d1ec566
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
``` bash ``` bash
python search.py # DARTS一阶近似搜索方法 python search.py # DARTS一阶近似搜索方法
python search.py --unrolled=True # DARTS的二阶近似搜索方法 python search.py --unrolled=True # DARTS的二阶近似搜索方法
python search.py --method='PC-DARTS' # PC-DARTS搜索方法 python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法
``` ```
模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。 模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。
...@@ -86,4 +86,4 @@ def train_search(batch_size, train_portion, is_shuffle, args): ...@@ -86,4 +86,4 @@ def train_search(batch_size, train_portion, is_shuffle, args):
python visualize.py PC-DARTS python visualize.py PC-DARTS
``` ```
`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中 `PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
\ No newline at end of file
...@@ -50,7 +50,8 @@ add_arg('trainset_num', int, 50000, "images number of trainset. ...@@ -50,7 +50,8 @@ add_arg('trainset_num', int, 50000, "images number of trainset.
add_arg('model_save_dir', str, 'search_cifar', "The path to save model.") add_arg('model_save_dir', str, 'search_cifar', "The path to save model.")
add_arg('grad_clip', float, 5, "Gradient clipping.") add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('arch_learning_rate',float, 3e-4, "Learning rate for arch encoding.") add_arg('arch_learning_rate',float, 3e-4, "Learning rate for arch encoding.")
add_arg('method', str, 'DARTS', "The search method you would like to use") add_arg('method', str, 'DARTS', "The search method you would like to use")
add_arg('epochs_no_archopt', int, 0, "Epochs not optimize the arch params")
add_arg('cutout_length', int, 16, "Cutout length.") add_arg('cutout_length', int, 16, "Cutout length.")
add_arg('cutout', ast.literal_eval, False, "Whether use cutout.") add_arg('cutout', ast.literal_eval, False, "Whether use cutout.")
add_arg('unrolled', ast.literal_eval, False, "Use one-step unrolled validation loss") add_arg('unrolled', ast.literal_eval, False, "Use one-step unrolled validation loss")
...@@ -84,8 +85,9 @@ def main(args): ...@@ -84,8 +85,9 @@ def main(args):
num_imgs=args.trainset_num, num_imgs=args.trainset_num,
arch_learning_rate=args.arch_learning_rate, arch_learning_rate=args.arch_learning_rate,
unrolled=args.unrolled, unrolled=args.unrolled,
method=args.method,
num_epochs=args.epochs, num_epochs=args.epochs,
epochs_no_archopt=args.epochs_no_archopt,
use_gpu=args.use_gpu,
use_data_parallel=args.use_data_parallel, use_data_parallel=args.use_data_parallel,
log_freq=args.log_freq) log_freq=args.log_freq)
searcher.train() searcher.train()
......
...@@ -26,8 +26,6 @@ from ...common import AvgrageMeter, get_logger ...@@ -26,8 +26,6 @@ from ...common import AvgrageMeter, get_logger
from .architect import Architect from .architect import Architect
logger = get_logger(__name__, level=logging.INFO) logger = get_logger(__name__, level=logging.INFO)
SUPPORTED_METHODS = ["PC-DARTS", "DARTS"]
def count_parameters_in_MB(all_params): def count_parameters_in_MB(all_params):
parameters_number = 0 parameters_number = 0
...@@ -47,8 +45,8 @@ class DARTSearch(object): ...@@ -47,8 +45,8 @@ class DARTSearch(object):
num_imgs=50000, num_imgs=50000,
arch_learning_rate=3e-4, arch_learning_rate=3e-4,
unrolled='False', unrolled='False',
method='DARTS',
num_epochs=50, num_epochs=50,
epochs_no_archopt=0,
use_gpu=True, use_gpu=True,
use_data_parallel=False, use_data_parallel=False,
log_freq=50): log_freq=50):
...@@ -60,9 +58,7 @@ class DARTSearch(object): ...@@ -60,9 +58,7 @@ class DARTSearch(object):
self.num_imgs = num_imgs self.num_imgs = num_imgs
self.arch_learning_rate = arch_learning_rate self.arch_learning_rate = arch_learning_rate
self.unrolled = unrolled self.unrolled = unrolled
self.method = method self.epochs_no_archopt = epochs_no_archopt
assert (self.method in SUPPORTED_METHODS
), "Currently only support PC-DARTS, DARTS two methods"
self.num_epochs = num_epochs self.num_epochs = num_epochs
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.use_data_parallel = use_data_parallel self.use_data_parallel = use_data_parallel
...@@ -94,7 +90,7 @@ class DARTSearch(object): ...@@ -94,7 +90,7 @@ class DARTSearch(object):
valid_label.stop_gradient = True valid_label.stop_gradient = True
n = train_image.shape[0] n = train_image.shape[0]
if not (self.method == "PC-DARTS" and epoch < 15): if epoch >= self.epochs_no_archopt:
architect.step(train_image, train_label, valid_image, architect.step(train_image, train_label, valid_image,
valid_label) valid_label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册