未验证 提交 84d76531 编写于 作者: B Bai Yifan 提交者: GitHub

add epochs_no_optarch to replace method in darts/train_search (#215)

上级 7d1ec566
......@@ -18,7 +18,7 @@
``` bash
python search.py # DARTS一阶近似搜索方法
python search.py --unrolled=True # DARTS的二阶近似搜索方法
python search.py --method='PC-DARTS' # PC-DARTS搜索方法
python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法
```
模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。
......@@ -86,4 +86,4 @@ def train_search(batch_size, train_portion, is_shuffle, args):
python visualize.py PC-DARTS
```
`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
\ No newline at end of file
`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
......@@ -50,7 +50,8 @@ add_arg('trainset_num', int, 50000, "images number of trainset.
add_arg('model_save_dir', str, 'search_cifar', "The path to save model.")
add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('arch_learning_rate',float, 3e-4, "Learning rate for arch encoding.")
add_arg('method', str, 'DARTS', "The search method you would like to use")
add_arg('method', str, 'DARTS', "The search method you would like to use")
add_arg('epochs_no_archopt', int, 0, "Epochs not optimize the arch params")
add_arg('cutout_length', int, 16, "Cutout length.")
add_arg('cutout', ast.literal_eval, False, "Whether use cutout.")
add_arg('unrolled', ast.literal_eval, False, "Use one-step unrolled validation loss")
......@@ -84,8 +85,9 @@ def main(args):
num_imgs=args.trainset_num,
arch_learning_rate=args.arch_learning_rate,
unrolled=args.unrolled,
method=args.method,
num_epochs=args.epochs,
epochs_no_archopt=args.epochs_no_archopt,
use_gpu=args.use_gpu,
use_data_parallel=args.use_data_parallel,
log_freq=args.log_freq)
searcher.train()
......
......@@ -26,8 +26,6 @@ from ...common import AvgrageMeter, get_logger
from .architect import Architect
logger = get_logger(__name__, level=logging.INFO)
SUPPORTED_METHODS = ["PC-DARTS", "DARTS"]
def count_parameters_in_MB(all_params):
parameters_number = 0
......@@ -47,8 +45,8 @@ class DARTSearch(object):
num_imgs=50000,
arch_learning_rate=3e-4,
unrolled='False',
method='DARTS',
num_epochs=50,
epochs_no_archopt=0,
use_gpu=True,
use_data_parallel=False,
log_freq=50):
......@@ -60,9 +58,7 @@ class DARTSearch(object):
self.num_imgs = num_imgs
self.arch_learning_rate = arch_learning_rate
self.unrolled = unrolled
self.method = method
assert (self.method in SUPPORTED_METHODS
), "Currently only support PC-DARTS, DARTS two methods"
self.epochs_no_archopt = epochs_no_archopt
self.num_epochs = num_epochs
self.use_gpu = use_gpu
self.use_data_parallel = use_data_parallel
......@@ -94,7 +90,7 @@ class DARTSearch(object):
valid_label.stop_gradient = True
n = train_image.shape[0]
if not (self.method == "PC-DARTS" and epoch < 15):
if epoch >= self.epochs_no_archopt:
architect.step(train_image, train_label, valid_image,
valid_label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册