未验证 提交 e8e48533 编写于 作者: B Bai Yifan 提交者: GitHub

Add multi-cards search&final support in DARTS (#226) (#257)

* support multi-cards in darts
上级 cb5edebc
...@@ -2,9 +2,31 @@ ...@@ -2,9 +2,31 @@
本示例介绍如何使用PaddlePaddle进行可微分架构搜索,可以直接使用[DARTS](https://arxiv.org/abs/1806.09055)[PC-DARTS](https://arxiv.org/abs/1907.05737)两种方法,也支持自定义修改后使用其他可微分架构搜索算法。 本示例介绍如何使用PaddlePaddle进行可微分架构搜索,可以直接使用[DARTS](https://arxiv.org/abs/1806.09055)[PC-DARTS](https://arxiv.org/abs/1907.05737)两种方法,也支持自定义修改后使用其他可微分架构搜索算法。
本示例目录结构如下:
```
├── genotypes.py 搜索过程得到的模型结构Genotypes
├── model.py 对搜索得到的子网络组网
├── model_search.py 对搜索前的超网络组网
├── operations.py 用于搜索的多种运算符组合
├── reader.py 数据读取与增广部分
├── search.py 模型结构搜索入口
├── train.py CIFAR10数据集评估训练入口
├── train_imagenet.py ImageNet数据集评估训练入口
├── visualize.py 模型结构可视化入口
```
## 依赖项 ## 依赖项
> PaddlePaddle >= 1.7.0, graphviz >= 0.11.1 PaddlePaddle >= 1.8.0, PaddleSlim >= 1.1.0, graphviz >= 0.11.1
## 数据集 ## 数据集
...@@ -21,6 +43,14 @@ python search.py --unrolled=True # DARTS的二阶近似搜索方法 ...@@ -21,6 +43,14 @@ python search.py --unrolled=True # DARTS的二阶近似搜索方法
python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法 python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法
``` ```
也可以使用多卡进行模型结构搜索,以4卡为例(GPU id: 0-3), 启动命令如下:
```bash
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog search.py --use_data_parallel 1
```
因为使用多卡训练总的BatchSize会扩大n倍,n代表卡数,为了获得与单卡相当的准确率效果,请相应的将初始学习率扩大n倍。
模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。 模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。
![networks](images/networks.gif) ![networks](images/networks.gif)
...@@ -40,6 +70,15 @@ python train.py --arch='PC_DARTS' # 在CIFAR10数据集上对搜索 ...@@ -40,6 +70,15 @@ python train.py --arch='PC_DARTS' # 在CIFAR10数据集上对搜索
python train_imagenet.py --arch='PC_DARTS' # 在ImageNet数据集上对搜索得到的结构评估训练 python train_imagenet.py --arch='PC_DARTS' # 在ImageNet数据集上对搜索得到的结构评估训练
``` ```
同样,也支持用多卡进行评估训练, 以4卡为例(GPU id: 0-3), 启动命令如下:
```bash
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py --use_data_parallel 1 --arch='DARTS_V2'
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_imagenet.py --use_data_parallel 1 --arch='DARTS_V2'
```
同理,使用多卡训练总的BatchSize会扩大n倍,n代表卡数,为了获得与单卡相当的准确率效果,请相应的将初始学习率扩大n倍。
对搜索到的`DARTS_V1``DARTS_V2``PC-DARTS`做评估训练的结果如下: 对搜索到的`DARTS_V1``DARTS_V2``PC-DARTS`做评估训练的结果如下:
| 模型结构 | 数据集 | 准确率 | | 模型结构 | 数据集 | 准确率 |
......
...@@ -80,6 +80,7 @@ def main(args): ...@@ -80,6 +80,7 @@ def main(args):
model, model,
train_reader, train_reader,
valid_reader, valid_reader,
place,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
batchsize=args.batch_size, batchsize=args.batch_size,
num_imgs=args.trainset_num, num_imgs=args.trainset_num,
...@@ -87,8 +88,8 @@ def main(args): ...@@ -87,8 +88,8 @@ def main(args):
unrolled=args.unrolled, unrolled=args.unrolled,
num_epochs=args.epochs, num_epochs=args.epochs,
epochs_no_archopt=args.epochs_no_archopt, epochs_no_archopt=args.epochs_no_archopt,
use_gpu=args.use_gpu,
use_data_parallel=args.use_data_parallel, use_data_parallel=args.use_data_parallel,
save_dir=args.model_save_dir,
log_freq=args.log_freq) log_freq=args.log_freq)
searcher.train() searcher.train()
......
...@@ -19,13 +19,14 @@ from __future__ import print_function ...@@ -19,13 +19,14 @@ from __future__ import print_function
import os import os
import sys import sys
import ast import ast
import logging
import argparse import argparse
import functools import functools
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddleslim.common import AvgrageMeter, get_logger from paddleslim.common import AvgrageMeter, get_logger
from paddleslim.nas.darts import count_parameters_in_MB
import genotypes import genotypes
import reader import reader
...@@ -140,9 +141,6 @@ def main(args): ...@@ -140,9 +141,6 @@ def main(args):
if args.use_data_parallel else fluid.CUDAPlace(0) if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
genotype = eval("genotypes.%s" % args.arch) genotype = eval("genotypes.%s" % args.arch)
model = Network( model = Network(
C=args.init_channels, C=args.init_channels,
...@@ -151,7 +149,12 @@ def main(args): ...@@ -151,7 +149,12 @@ def main(args):
auxiliary=args.auxiliary, auxiliary=args.auxiliary,
genotype=genotype) genotype=genotype)
step_per_epoch = int(args.trainset_num / args.batch_size) logger.info("param size = {:.6f}MB".format(
count_parameters_in_MB(model.parameters())))
device_num = fluid.dygraph.parallel.Env().nranks
step_per_epoch = int(args.trainset_num /
(args.batch_size * device_num))
learning_rate = fluid.dygraph.CosineDecay(args.learning_rate, learning_rate = fluid.dygraph.CosineDecay(args.learning_rate,
step_per_epoch, args.epochs) step_per_epoch, args.epochs)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip) clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip)
...@@ -163,18 +166,21 @@ def main(args): ...@@ -163,18 +166,21 @@ def main(args):
grad_clip=clip) grad_clip=clip)
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy) model = fluid.dygraph.parallel.DataParallel(model, strategy)
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
capacity=64, capacity=1024,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=True) return_list=True,
use_multiprocess=True)
valid_loader = fluid.io.DataLoader.from_generator( valid_loader = fluid.io.DataLoader.from_generator(
capacity=64, capacity=1024,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=True) return_list=True,
use_multiprocess=True)
train_reader = reader.train_valid( train_reader = reader.train_valid(
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -186,13 +192,13 @@ def main(args): ...@@ -186,13 +192,13 @@ def main(args):
is_train=False, is_train=False,
is_shuffle=False, is_shuffle=False,
args=args) args=args)
train_loader.set_batch_generator(train_reader, places=place)
valid_loader.set_batch_generator(valid_reader, places=place)
if args.use_data_parallel: if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader( train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader) train_reader)
train_loader.set_batch_generator(train_reader, places=place)
valid_loader.set_batch_generator(valid_reader, places=place)
save_parameters = (not args.use_data_parallel) or ( save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0) fluid.dygraph.parallel.Env().local_rank == 0)
......
...@@ -19,13 +19,15 @@ from __future__ import print_function ...@@ -19,13 +19,15 @@ from __future__ import print_function
import os import os
import sys import sys
import ast import ast
import logging
import argparse import argparse
import functools import functools
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddleslim.common import AvgrageMeter, get_logger from paddleslim.common import AvgrageMeter, get_logger
from paddleslim.nas.darts import count_parameters_in_MB
import genotypes import genotypes
import reader import reader
from model import NetworkImageNet as Network from model import NetworkImageNet as Network
...@@ -152,9 +154,6 @@ def main(args): ...@@ -152,9 +154,6 @@ def main(args):
if args.use_data_parallel else fluid.CUDAPlace(0) if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
genotype = eval("genotypes.%s" % args.arch) genotype = eval("genotypes.%s" % args.arch)
model = Network( model = Network(
C=args.init_channels, C=args.init_channels,
...@@ -163,7 +162,12 @@ def main(args): ...@@ -163,7 +162,12 @@ def main(args):
auxiliary=args.auxiliary, auxiliary=args.auxiliary,
genotype=genotype) genotype=genotype)
step_per_epoch = int(args.trainset_num / args.batch_size) logger.info("param size = {:.6f}MB".format(
count_parameters_in_MB(model.parameters())))
device_num = fluid.dygraph.parallel.Env().nranks
step_per_epoch = int(args.trainset_num /
(args.batch_size * device_num))
learning_rate = fluid.dygraph.ExponentialDecay( learning_rate = fluid.dygraph.ExponentialDecay(
args.learning_rate, args.learning_rate,
step_per_epoch, step_per_epoch,
...@@ -179,6 +183,7 @@ def main(args): ...@@ -179,6 +183,7 @@ def main(args):
grad_clip=clip) grad_clip=clip)
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy) model = fluid.dygraph.parallel.DataParallel(model, strategy)
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
...@@ -199,20 +204,19 @@ def main(args): ...@@ -199,20 +204,19 @@ def main(args):
valid_reader = fluid.io.batch( valid_reader = fluid.io.batch(
reader.imagenet_reader(args.data_dir, 'val'), reader.imagenet_reader(args.data_dir, 'val'),
batch_size=args.batch_size) batch_size=args.batch_size)
train_loader.set_sample_list_generator(train_reader, places=place)
valid_loader.set_sample_list_generator(valid_reader, places=place)
if args.use_data_parallel: if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader( train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader) train_reader)
train_loader.set_sample_list_generator(train_reader, places=place)
valid_loader.set_sample_list_generator(valid_reader, places=place)
save_parameters = (not args.use_data_parallel) or ( save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0) fluid.dygraph.parallel.Env().local_rank == 0)
best_top1 = 0 best_top1 = 0
for epoch in range(args.epochs): for epoch in range(args.epochs):
logging.info('Epoch {}, lr {:.6f}'.format( logger.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr())) epoch, optimizer.current_step_lr()))
train_top1, train_top5 = train(model, train_loader, optimizer, train_top1, train_top5 = train(model, train_loader, optimizer,
epoch, args) epoch, args)
......
...@@ -97,7 +97,7 @@ DARTSearch ...@@ -97,7 +97,7 @@ DARTSearch
model = SuperNet() model = SuperNet()
train_reader = batch_generator_creator() train_reader = batch_generator_creator()
valid_reader = batch_generator_creator() valid_reader = batch_generator_creator()
searcher = DARTSearch(model, train_reader, valid_reader, unrolled=False) searcher = DARTSearch(model, train_reader, valid_reader, place)
searcher.train() searcher.train()
.. ..
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +21,8 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -21,7 +21,8 @@ from paddle.fluid.dygraph.base import to_variable
class Architect(object): class Architect(object):
def __init__(self, model, eta, arch_learning_rate, place, unrolled): def __init__(self, model, eta, arch_learning_rate, place, unrolled,
parallel):
self.network_momentum = 0.9 self.network_momentum = 0.9
self.network_weight_decay = 3e-4 self.network_weight_decay = 3e-4
self.eta = eta self.eta = eta
...@@ -34,6 +35,7 @@ class Architect(object): ...@@ -34,6 +35,7 @@ class Architect(object):
parameter_list=self.model.arch_parameters()) parameter_list=self.model.arch_parameters())
self.place = place self.place = place
self.unrolled = unrolled self.unrolled = unrolled
self.parallel = parallel
if self.unrolled: if self.unrolled:
self.unrolled_model = self.model.new() self.unrolled_model = self.model.new()
self.unrolled_model_params = [ self.unrolled_model_params = [
...@@ -49,6 +51,17 @@ class Architect(object): ...@@ -49,6 +51,17 @@ class Architect(object):
self.network_weight_decay), self.network_weight_decay),
parameter_list=self.unrolled_model_params) parameter_list=self.unrolled_model_params)
if self.parallel:
strategy = fluid.dygraph.parallel.prepare_context()
self.parallel_model = fluid.dygraph.parallel.DataParallel(
self.model, strategy)
if self.unrolled:
self.parallel_unrolled_model = fluid.dygraph.parallel.DataParallel(
self.unrolled_model, strategy)
def get_model(self):
return self.parallel_model if self.parallel else self.model
def step(self, input_train, target_train, input_valid, target_valid): def step(self, input_train, target_train, input_valid, target_valid):
if self.unrolled: if self.unrolled:
params_grads = self._backward_step_unrolled( params_grads = self._backward_step_unrolled(
...@@ -61,7 +74,12 @@ class Architect(object): ...@@ -61,7 +74,12 @@ class Architect(object):
def _backward_step(self, input_valid, target_valid): def _backward_step(self, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid) loss = self.model._loss(input_valid, target_valid)
loss.backward() if self.parallel:
loss = self.parallel_model.scale_loss(loss)
loss.backward()
self.parallel_model.apply_collective_grads()
else:
loss.backward()
return loss return loss
def _backward_step_unrolled(self, input_train, target_train, input_valid, def _backward_step_unrolled(self, input_train, target_train, input_valid,
...@@ -69,7 +87,14 @@ class Architect(object): ...@@ -69,7 +87,14 @@ class Architect(object):
self._compute_unrolled_model(input_train, target_train) self._compute_unrolled_model(input_train, target_train)
unrolled_loss = self.unrolled_model._loss(input_valid, target_valid) unrolled_loss = self.unrolled_model._loss(input_valid, target_valid)
unrolled_loss.backward() if self.parallel:
unrolled_loss = self.parallel_unrolled_model.scale_loss(
unrolled_loss)
unrolled_loss.backward()
self.parallel_unrolled_model.apply_collective_grads()
else:
unrolled_loss.backward()
vector = [ vector = [
to_variable(param._grad_ivar().numpy()) to_variable(param._grad_ivar().numpy())
for param in self.unrolled_model_params for param in self.unrolled_model_params
...@@ -93,7 +118,13 @@ class Architect(object): ...@@ -93,7 +118,13 @@ class Architect(object):
self.model.parameters()): self.model.parameters()):
x.value().get_tensor().set(y.numpy(), self.place) x.value().get_tensor().set(y.numpy(), self.place)
loss = self.unrolled_model._loss(input, target) loss = self.unrolled_model._loss(input, target)
loss.backward() if self.parallel:
loss = self.parallel_unrolled_model.scale_loss(loss)
loss.backward()
self.parallel_unrolled_model.apply_collective_grads()
else:
loss.backward()
self.unrolled_optimizer.minimize(loss) self.unrolled_optimizer.minimize(loss)
self.unrolled_model.clear_gradients() self.unrolled_model.clear_gradients()
...@@ -112,7 +143,13 @@ class Architect(object): ...@@ -112,7 +143,13 @@ class Architect(object):
param_p = param + grad * R param_p = param + grad * R
param.value().get_tensor().set(param_p.numpy(), self.place) param.value().get_tensor().set(param_p.numpy(), self.place)
loss = self.model._loss(input, target) loss = self.model._loss(input, target)
loss.backward() if self.parallel:
loss = self.parallel_model.scale_loss(loss)
loss.backward()
self.parallel_model.apply_collective_grads()
else:
loss.backward()
grads_p = [ grads_p = [
to_variable(param._grad_ivar().numpy()) to_variable(param._grad_ivar().numpy())
for param in self.model.arch_parameters() for param in self.model.arch_parameters()
...@@ -124,7 +161,13 @@ class Architect(object): ...@@ -124,7 +161,13 @@ class Architect(object):
self.model.clear_gradients() self.model.clear_gradients()
loss = self.model._loss(input, target) loss = self.model._loss(input, target)
loss.backward() if self.parallel:
loss = self.parallel_model.scale_loss(loss)
loss.backward()
self.parallel_model.apply_collective_grads()
else:
loss.backward()
grads_n = [ grads_n = [
to_variable(param._grad_ivar().numpy()) to_variable(param._grad_ivar().numpy())
for param in self.model.arch_parameters() for param in self.model.arch_parameters()
......
...@@ -16,8 +16,9 @@ from __future__ import absolute_import ...@@ -16,8 +16,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
__all__ = ['DARTSearch'] __all__ = ['DARTSearch', 'count_parameters_in_MB']
import os
import logging import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -67,19 +68,21 @@ class DARTSearch(object): ...@@ -67,19 +68,21 @@ class DARTSearch(object):
model, model,
train_reader, train_reader,
valid_reader, valid_reader,
place,
learning_rate=0.025, learning_rate=0.025,
batchsize=64, batchsize=64,
num_imgs=50000, num_imgs=50000,
arch_learning_rate=3e-4, arch_learning_rate=3e-4,
unrolled='False', unrolled=False,
num_epochs=50, num_epochs=50,
epochs_no_archopt=0, epochs_no_archopt=0,
use_gpu=True,
use_data_parallel=False, use_data_parallel=False,
save_dir='./',
log_freq=50): log_freq=50):
self.model = model self.model = model
self.train_reader = train_reader self.train_reader = train_reader
self.valid_reader = valid_reader self.valid_reader = valid_reader
self.place = place,
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.batchsize = batchsize self.batchsize = batchsize
self.num_imgs = num_imgs self.num_imgs = num_imgs
...@@ -87,14 +90,8 @@ class DARTSearch(object): ...@@ -87,14 +90,8 @@ class DARTSearch(object):
self.unrolled = unrolled self.unrolled = unrolled
self.epochs_no_archopt = epochs_no_archopt self.epochs_no_archopt = epochs_no_archopt
self.num_epochs = num_epochs self.num_epochs = num_epochs
self.use_gpu = use_gpu
self.use_data_parallel = use_data_parallel self.use_data_parallel = use_data_parallel
if not self.use_gpu: self.save_dir = save_dir
self.place = fluid.CPUPlace()
elif not self.use_data_parallel:
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
self.log_freq = log_freq self.log_freq = log_freq
def train_one_epoch(self, train_loader, valid_loader, architect, optimizer, def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
...@@ -187,9 +184,13 @@ class DARTSearch(object): ...@@ -187,9 +184,13 @@ class DARTSearch(object):
] ]
logger.info("param size = {:.6f}MB".format( logger.info("param size = {:.6f}MB".format(
count_parameters_in_MB(model_parameters))) count_parameters_in_MB(model_parameters)))
step_per_epoch = int(self.num_imgs * 0.5 / self.batchsize)
device_num = fluid.dygraph.parallel.Env().nranks
step_per_epoch = int(self.num_imgs * 0.5 /
(self.batchsize * device_num))
if self.unrolled: if self.unrolled:
step_per_epoch *= 2 step_per_epoch *= 2
learning_rate = fluid.dygraph.CosineDecay( learning_rate = fluid.dygraph.CosineDecay(
self.learning_rate, step_per_epoch, self.num_epochs) self.learning_rate, step_per_epoch, self.num_epochs)
...@@ -202,30 +203,37 @@ class DARTSearch(object): ...@@ -202,30 +203,37 @@ class DARTSearch(object):
grad_clip=clip) grad_clip=clip)
if self.use_data_parallel: if self.use_data_parallel:
self.model = fluid.dygraph.parallel.DataParallel(self.model,
strategy)
self.train_reader = fluid.contrib.reader.distributed_batch_reader( self.train_reader = fluid.contrib.reader.distributed_batch_reader(
self.train_reader) self.train_reader)
self.valid_reader = fluid.contrib.reader.distributed_batch_reader( self.valid_reader = fluid.contrib.reader.distributed_batch_reader(
self.valid_reader) self.valid_reader)
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
capacity=64, capacity=1024,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=True) return_list=True,
use_multiprocess=True)
valid_loader = fluid.io.DataLoader.from_generator( valid_loader = fluid.io.DataLoader.from_generator(
capacity=64, capacity=1024,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=True) return_list=True,
use_multiprocess=True)
train_loader.set_batch_generator(self.train_reader, places=self.place) train_loader.set_batch_generator(self.train_reader, places=self.place)
valid_loader.set_batch_generator(self.valid_reader, places=self.place) valid_loader.set_batch_generator(self.valid_reader, places=self.place)
architect = Architect(self.model, learning_rate, base_model = self.model
self.arch_learning_rate, self.place, architect = Architect(
self.unrolled) model=self.model,
eta=learning_rate,
arch_learning_rate=self.arch_learning_rate,
place=self.place,
unrolled=self.unrolled,
parallel=self.use_data_parallel)
self.model = architect.get_model()
save_parameters = (not self.use_data_parallel) or ( save_parameters = (not self.use_data_parallel) or (
self.use_data_parallel and self.use_data_parallel and
...@@ -234,7 +242,8 @@ class DARTSearch(object): ...@@ -234,7 +242,8 @@ class DARTSearch(object):
for epoch in range(self.num_epochs): for epoch in range(self.num_epochs):
logger.info('Epoch {}, lr {:.6f}'.format( logger.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr())) epoch, optimizer.current_step_lr()))
genotype = get_genotype(self.model)
genotype = get_genotype(base_model)
logger.info('genotype = %s', genotype) logger.info('genotype = %s', genotype)
train_top1 = self.train_one_epoch(train_loader, valid_loader, train_top1 = self.train_one_epoch(train_loader, valid_loader,
...@@ -246,4 +255,6 @@ class DARTSearch(object): ...@@ -246,4 +255,6 @@ class DARTSearch(object):
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
valid_top1)) valid_top1))
if save_parameters: if save_parameters:
fluid.save_dygraph(self.model.state_dict(), "./weights") fluid.save_dygraph(
self.model.state_dict(),
os.path.join(self.save_dir, str(epoch), "params"))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册