train_search.py 9.9 KB
Newer Older
B
Bai Yifan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
__all__ = ['DARTSearch', 'count_parameters_in_MB']
B
Bai Yifan 已提交
20

21
import os
B
Bai Yifan 已提交
22 23 24 25 26 27
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from ...common import AvgrageMeter, get_logger
from .architect import Architect
28
from .get_genotype import get_genotype
B
Bai Yifan 已提交
29 30 31 32
logger = get_logger(__name__, level=logging.INFO)


def count_parameters_in_MB(all_params):
B
Bai Yifan 已提交
33 34 35 36 37 38 39 40
    """Count the parameters in the target list.
    Args:
        all_params(list): List of Variables.

    Returns:
        float: The total count(MB) of target parameter list.
    """

B
Bai Yifan 已提交
41 42 43 44 45 46 47 48
    parameters_number = 0
    for param in all_params:
        if param.trainable and 'aux' not in param.name:
            parameters_number += np.prod(param.shape)
    return parameters_number / 1e6


class DARTSearch(object):
B
Bai Yifan 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
    """Used for Differentiable ARchiTecture Search(DARTS)

    Args:
        model(Paddle DyGraph model): Super Network for Search.
        train_reader(Python Generator): Generator to provide training data.
        valid_reader(Python Generator): Generator to provide validation  data.
        place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents the executor run on which device.
        learning_rate(float): Model parameter initial learning rate. Default: 0.025.
        batch_size(int): Minibatch size. Default: 64.
        arch_learning_rate(float): Learning rate for arch encoding. Default: 3e-4.
        unrolled(bool): Use one-step unrolled validation loss. Default: False.
        num_epochs(int): Epoch number. Default: 50.
        epochs_no_archopt(int): Epochs skip architecture optimize at begining. Default: 0.
B
Bai Yifan 已提交
62
        use_multiprocess(bool): Whether to use multiprocess in dataloader. Default: False.
B
Bai Yifan 已提交
63 64 65 66 67
        use_data_parallel(bool): Whether to use data parallel mode. Default: False.
        log_freq(int): Log frequency. Default: 50.

    """

B
Bai Yifan 已提交
68 69 70 71
    def __init__(self,
                 model,
                 train_reader,
                 valid_reader,
72
                 place,
B
Bai Yifan 已提交
73 74 75 76
                 learning_rate=0.025,
                 batchsize=64,
                 num_imgs=50000,
                 arch_learning_rate=3e-4,
77
                 unrolled=False,
B
Bai Yifan 已提交
78
                 num_epochs=50,
79
                 epochs_no_archopt=0,
B
Bai Yifan 已提交
80
                 use_multiprocess=False,
B
Bai Yifan 已提交
81
                 use_data_parallel=False,
82
                 save_dir='./',
B
Bai Yifan 已提交
83 84 85 86
                 log_freq=50):
        self.model = model
        self.train_reader = train_reader
        self.valid_reader = valid_reader
87
        self.place = place,
B
Bai Yifan 已提交
88 89 90 91 92
        self.learning_rate = learning_rate
        self.batchsize = batchsize
        self.num_imgs = num_imgs
        self.arch_learning_rate = arch_learning_rate
        self.unrolled = unrolled
93
        self.epochs_no_archopt = epochs_no_archopt
B
Bai Yifan 已提交
94
        self.num_epochs = num_epochs
B
Bai Yifan 已提交
95
        self.use_multiprocess = use_multiprocess
B
Bai Yifan 已提交
96
        self.use_data_parallel = use_data_parallel
97
        self.save_dir = save_dir
B
Bai Yifan 已提交
98 99 100 101 102
        self.log_freq = log_freq

    def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
                        epoch):
        objs = AvgrageMeter()
B
Bai Yifan 已提交
103 104
        top1 = AvgrageMeter()
        top5 = AvgrageMeter()
B
Bai Yifan 已提交
105 106
        self.model.train()

B
Bai Yifan 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119
        for step_id, (
                train_data,
                valid_data) in enumerate(zip(train_loader(), valid_loader())):
            train_image, train_label = train_data
            valid_image, valid_label = valid_data
            train_image = to_variable(train_image)
            train_label = to_variable(train_label)
            train_label.stop_gradient = True
            valid_image = to_variable(valid_image)
            valid_label = to_variable(valid_label)
            valid_label.stop_gradient = True
            n = train_image.shape[0]

120
            if epoch >= self.epochs_no_archopt:
B
Bai Yifan 已提交
121 122
                architect.step(train_image, train_label, valid_image,
                               valid_label)
B
Bai Yifan 已提交
123

B
Bai Yifan 已提交
124 125 126 127 128
            logits = self.model(train_image)
            prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1)
            prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5)
            loss = fluid.layers.reduce_mean(
                fluid.layers.softmax_with_cross_entropy(logits, train_label))
B
Bai Yifan 已提交
129 130 131 132 133 134 135 136

            if self.use_data_parallel:
                loss = self.model.scale_loss(loss)
                loss.backward()
                self.model.apply_collective_grads()
            else:
                loss.backward()

137
            optimizer.minimize(loss)
B
Bai Yifan 已提交
138 139
            self.model.clear_gradients()

B
Bai Yifan 已提交
140 141 142
            objs.update(loss.numpy(), n)
            top1.update(prec1.numpy(), n)
            top5.update(prec5.numpy(), n)
B
Bai Yifan 已提交
143 144

            if step_id % self.log_freq == 0:
B
Bai Yifan 已提交
145 146
                #logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
                #    epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
B
Bai Yifan 已提交
147
                logger.info(
B
Bai Yifan 已提交
148 149 150 151
                    "Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
                    format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
                        0]))
        return top1.avg[0]
B
Bai Yifan 已提交
152 153 154 155 156 157 158

    def valid_one_epoch(self, valid_loader, epoch):
        objs = AvgrageMeter()
        top1 = AvgrageMeter()
        top5 = AvgrageMeter()
        self.model.eval()

B
Bai Yifan 已提交
159
        for step_id, (image, label) in enumerate(valid_loader):
B
Bai Yifan 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
            image = to_variable(image)
            label = to_variable(label)
            n = image.shape[0]
            logits = self.model(image)
            prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
            prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
            loss = fluid.layers.reduce_mean(
                fluid.layers.softmax_with_cross_entropy(logits, label))
            objs.update(loss.numpy(), n)
            top1.update(prec1.numpy(), n)
            top5.update(prec5.numpy(), n)

            if step_id % self.log_freq == 0:
                logger.info(
                    "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
                    format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
                        0]))
        return top1.avg[0]

    def train(self):
B
Bai Yifan 已提交
180 181 182 183
        """Start search process.

        """

B
Bai Yifan 已提交
184 185 186 187 188 189
        model_parameters = [
            p for p in self.model.parameters()
            if p.name not in [a.name for a in self.model.arch_parameters()]
        ]
        logger.info("param size = {:.6f}MB".format(
            count_parameters_in_MB(model_parameters)))
190 191 192 193

        device_num = fluid.dygraph.parallel.Env().nranks
        step_per_epoch = int(self.num_imgs * 0.5 /
                             (self.batchsize * device_num))
B
Bai Yifan 已提交
194 195
        if self.unrolled:
            step_per_epoch *= 2
196

B
Bai Yifan 已提交
197 198
        learning_rate = fluid.dygraph.CosineDecay(
            self.learning_rate, step_per_epoch, self.num_epochs)
199 200

        clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
B
Bai Yifan 已提交
201 202 203 204
        optimizer = fluid.optimizer.MomentumOptimizer(
            learning_rate,
            0.9,
            regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
205 206
            parameter_list=model_parameters,
            grad_clip=clip)
B
Bai Yifan 已提交
207 208 209 210 211 212

        if self.use_data_parallel:
            self.train_reader = fluid.contrib.reader.distributed_batch_reader(
                self.train_reader)

        train_loader = fluid.io.DataLoader.from_generator(
B
Bai Yifan 已提交
213
            capacity=64,
B
Bai Yifan 已提交
214 215
            use_double_buffer=True,
            iterable=True,
216
            return_list=True,
B
Bai Yifan 已提交
217
            use_multiprocess=self.use_multiprocess)
B
Bai Yifan 已提交
218
        valid_loader = fluid.io.DataLoader.from_generator(
B
Bai Yifan 已提交
219
            capacity=64,
B
Bai Yifan 已提交
220 221
            use_double_buffer=True,
            iterable=True,
222
            return_list=True,
B
Bai Yifan 已提交
223
            use_multiprocess=self.use_multiprocess)
B
Bai Yifan 已提交
224 225 226 227

        train_loader.set_batch_generator(self.train_reader, places=self.place)
        valid_loader.set_batch_generator(self.valid_reader, places=self.place)

228 229 230 231 232 233 234 235 236
        base_model = self.model
        architect = Architect(
            model=self.model,
            eta=learning_rate,
            arch_learning_rate=self.arch_learning_rate,
            unrolled=self.unrolled,
            parallel=self.use_data_parallel)

        self.model = architect.get_model()
B
Bai Yifan 已提交
237 238 239 240 241 242 243 244

        save_parameters = (not self.use_data_parallel) or (
            self.use_data_parallel and
            fluid.dygraph.parallel.Env().local_rank == 0)

        for epoch in range(self.num_epochs):
            logger.info('Epoch {}, lr {:.6f}'.format(
                epoch, optimizer.current_step_lr()))
245 246

            genotype = get_genotype(base_model)
B
Bai Yifan 已提交
247 248
            logger.info('genotype = %s', genotype)

B
Bai Yifan 已提交
249 250 251
            train_top1 = self.train_one_epoch(train_loader, valid_loader,
                                              architect, optimizer, epoch)
            logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
B
Bai Yifan 已提交
252 253

            if epoch == self.num_epochs - 1:
B
Bai Yifan 已提交
254 255 256
                valid_top1 = self.valid_one_epoch(valid_loader, epoch)
                logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
                                                                valid_top1))
B
Bai Yifan 已提交
257
            if save_parameters:
258 259 260
                fluid.save_dygraph(
                    self.model.state_dict(),
                    os.path.join(self.save_dir, str(epoch), "params"))