train_search.py 9.3 KB
Newer Older
B
Bai Yifan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
__all__ = ['DARTSearch', 'count_parameters_in_MB']
B
Bai Yifan 已提交
20

21
import os
B
Bai Yifan 已提交
22 23 24 25 26 27
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from ...common import AvgrageMeter, get_logger
from .architect import Architect
28
from .get_genotype import get_genotype
B
Bai Yifan 已提交
29 30 31 32
logger = get_logger(__name__, level=logging.INFO)


def count_parameters_in_MB(all_params):
B
Bai Yifan 已提交
33 34 35 36 37 38 39 40
    """Count the parameters in the target list.
    Args:
        all_params(list): List of Variables.

    Returns:
        float: The total count(MB) of target parameter list.
    """

B
Bai Yifan 已提交
41 42 43 44 45 46 47 48
    parameters_number = 0
    for param in all_params:
        if param.trainable and 'aux' not in param.name:
            parameters_number += np.prod(param.shape)
    return parameters_number / 1e6


class DARTSearch(object):
B
Bai Yifan 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
    """Used for Differentiable ARchiTecture Search(DARTS)

    Args:
        model(Paddle DyGraph model): Super Network for Search.
        train_reader(Python Generator): Generator to provide training data.
        valid_reader(Python Generator): Generator to provide validation  data.
        place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents the executor run on which device.
        learning_rate(float): Model parameter initial learning rate. Default: 0.025.
        batch_size(int): Minibatch size. Default: 64.
        arch_learning_rate(float): Learning rate for arch encoding. Default: 3e-4.
        unrolled(bool): Use one-step unrolled validation loss. Default: False.
        num_epochs(int): Epoch number. Default: 50.
        epochs_no_archopt(int): Epochs skip architecture optimize at begining. Default: 0.
B
Bai Yifan 已提交
62
        use_multiprocess(bool): Whether to use multiprocess in dataloader. Default: False.
B
Bai Yifan 已提交
63 64 65 66 67
        use_data_parallel(bool): Whether to use data parallel mode. Default: False.
        log_freq(int): Log frequency. Default: 50.

    """

B
Bai Yifan 已提交
68 69 70 71
    def __init__(self,
                 model,
                 train_reader,
                 valid_reader,
72
                 place,
B
Bai Yifan 已提交
73 74 75 76
                 learning_rate=0.025,
                 batchsize=64,
                 num_imgs=50000,
                 arch_learning_rate=3e-4,
77
                 unrolled=False,
B
Bai Yifan 已提交
78
                 num_epochs=50,
79
                 epochs_no_archopt=0,
B
Bai Yifan 已提交
80
                 use_multiprocess=False,
B
Bai Yifan 已提交
81
                 use_data_parallel=False,
82
                 save_dir='./',
B
Bai Yifan 已提交
83 84 85 86
                 log_freq=50):
        self.model = model
        self.train_reader = train_reader
        self.valid_reader = valid_reader
87
        self.place = place,
B
Bai Yifan 已提交
88 89 90 91 92
        self.learning_rate = learning_rate
        self.batchsize = batchsize
        self.num_imgs = num_imgs
        self.arch_learning_rate = arch_learning_rate
        self.unrolled = unrolled
93
        self.epochs_no_archopt = epochs_no_archopt
B
Bai Yifan 已提交
94
        self.num_epochs = num_epochs
B
Bai Yifan 已提交
95
        self.use_multiprocess = use_multiprocess
B
Bai Yifan 已提交
96
        self.use_data_parallel = use_data_parallel
97
        self.save_dir = save_dir
B
Bai Yifan 已提交
98 99 100 101 102
        self.log_freq = log_freq

    def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
                        epoch):
        objs = AvgrageMeter()
B
Bai Yifan 已提交
103 104 105
        ce_losses = AvgrageMeter()
        kd_losses = AvgrageMeter()
        e_losses = AvgrageMeter()
B
Bai Yifan 已提交
106 107
        self.model.train()

B
Bai Yifan 已提交
108
        step_id = 0
B
Bai Yifan 已提交
109
        for train_data, valid_data in zip(train_loader(), valid_loader()):
110
            if epoch >= self.epochs_no_archopt:
B
Bai Yifan 已提交
111
                architect.step(train_data, valid_data)
B
Bai Yifan 已提交
112

B
Bai Yifan 已提交
113
            loss, ce_loss, kd_loss, e_loss = self.model.loss(train_data)
B
Bai Yifan 已提交
114 115 116 117 118 119 120 121

            if self.use_data_parallel:
                loss = self.model.scale_loss(loss)
                loss.backward()
                self.model.apply_collective_grads()
            else:
                loss.backward()

122
            optimizer.minimize(loss)
B
Bai Yifan 已提交
123 124
            self.model.clear_gradients()

B
Bai Yifan 已提交
125 126 127 128 129
            batch_size = train_data[0].shape[0]
            objs.update(loss.numpy(), batch_size)
            ce_losses.update(ce_loss.numpy(), batch_size)
            kd_losses.update(kd_loss.numpy(), batch_size)
            e_losses.update(e_loss.numpy(), batch_size)
B
Bai Yifan 已提交
130 131

            if step_id % self.log_freq == 0:
B
Bai Yifan 已提交
132 133
                #logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
                #    epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
B
Bai Yifan 已提交
134
                logger.info(
B
Bai Yifan 已提交
135 136 137 138 139 140
                    "Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}".
                    format(epoch, step_id,
                           loss.numpy(),
                           ce_loss.numpy(), kd_loss.numpy(), e_loss.numpy()))
            step_id += 1
        return objs.avg[0]
B
Bai Yifan 已提交
141 142 143 144 145 146 147

    def valid_one_epoch(self, valid_loader, epoch):
        objs = AvgrageMeter()
        top1 = AvgrageMeter()
        top5 = AvgrageMeter()
        self.model.eval()

B
Bai Yifan 已提交
148
        for step_id, valid_data in enumerate(valid_loader):
B
Bai Yifan 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
            image = to_variable(image)
            label = to_variable(label)
            n = image.shape[0]
            logits = self.model(image)
            prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
            prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
            loss = fluid.layers.reduce_mean(
                fluid.layers.softmax_with_cross_entropy(logits, label))
            objs.update(loss.numpy(), n)
            top1.update(prec1.numpy(), n)
            top5.update(prec5.numpy(), n)

            if step_id % self.log_freq == 0:
                logger.info(
                    "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
                    format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
                        0]))
        return top1.avg[0]

    def train(self):
B
Bai Yifan 已提交
169 170 171 172
        """Start search process.

        """

B
Bai Yifan 已提交
173 174 175 176 177 178
        model_parameters = [
            p for p in self.model.parameters()
            if p.name not in [a.name for a in self.model.arch_parameters()]
        ]
        logger.info("param size = {:.6f}MB".format(
            count_parameters_in_MB(model_parameters)))
179 180 181 182

        device_num = fluid.dygraph.parallel.Env().nranks
        step_per_epoch = int(self.num_imgs * 0.5 /
                             (self.batchsize * device_num))
B
Bai Yifan 已提交
183 184
        if self.unrolled:
            step_per_epoch *= 2
185

B
Bai Yifan 已提交
186 187
        learning_rate = fluid.dygraph.CosineDecay(
            self.learning_rate, step_per_epoch, self.num_epochs)
188 189

        clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
B
Bai Yifan 已提交
190 191 192 193
        optimizer = fluid.optimizer.MomentumOptimizer(
            learning_rate,
            0.9,
            regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
194 195
            parameter_list=model_parameters,
            grad_clip=clip)
B
Bai Yifan 已提交
196 197 198 199 200 201

        if self.use_data_parallel:
            self.train_reader = fluid.contrib.reader.distributed_batch_reader(
                self.train_reader)

        train_loader = fluid.io.DataLoader.from_generator(
B
Bai Yifan 已提交
202
            capacity=64,
B
Bai Yifan 已提交
203 204
            use_double_buffer=True,
            iterable=True,
205
            return_list=True,
B
Bai Yifan 已提交
206
            use_multiprocess=self.use_multiprocess)
B
Bai Yifan 已提交
207
        valid_loader = fluid.io.DataLoader.from_generator(
B
Bai Yifan 已提交
208
            capacity=64,
B
Bai Yifan 已提交
209 210
            use_double_buffer=True,
            iterable=True,
211
            return_list=True,
B
Bai Yifan 已提交
212
            use_multiprocess=self.use_multiprocess)
B
Bai Yifan 已提交
213 214 215 216

        train_loader.set_batch_generator(self.train_reader, places=self.place)
        valid_loader.set_batch_generator(self.valid_reader, places=self.place)

217 218 219 220 221 222 223 224 225
        base_model = self.model
        architect = Architect(
            model=self.model,
            eta=learning_rate,
            arch_learning_rate=self.arch_learning_rate,
            unrolled=self.unrolled,
            parallel=self.use_data_parallel)

        self.model = architect.get_model()
B
Bai Yifan 已提交
226 227 228 229 230 231 232 233

        save_parameters = (not self.use_data_parallel) or (
            self.use_data_parallel and
            fluid.dygraph.parallel.Env().local_rank == 0)

        for epoch in range(self.num_epochs):
            logger.info('Epoch {}, lr {:.6f}'.format(
                epoch, optimizer.current_step_lr()))
234 235

            genotype = get_genotype(base_model)
B
Bai Yifan 已提交
236 237
            logger.info('genotype = %s', genotype)

B
Bai Yifan 已提交
238 239
            self.train_one_epoch(train_loader, valid_loader, architect,
                                 optimizer, epoch)
B
Bai Yifan 已提交
240 241

            if epoch == self.num_epochs - 1:
B
Bai Yifan 已提交
242 243
                #                valid_top1 = self.valid_one_epoch(valid_loader, epoch)
                logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, 1))
B
Bai Yifan 已提交
244
            if save_parameters:
245 246 247
                fluid.save_dygraph(
                    self.model.state_dict(),
                    os.path.join(self.save_dir, str(epoch), "params"))