sa_nas_mobilenetv2.py 11.0 KB
Newer Older
I
iamWHTWD 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
import sys
sys.path.append('..')
import numpy as np
import argparse
import ast
import time
import argparse
import ast
import logging
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddleslim.analysis import flops
from paddleslim.nas import SANAS
from paddleslim.common import get_logger
from optimizer import create_optimizer
import imagenet_reader

_logger = get_logger(__name__, level=logging.INFO)


def build_program(main_program,
                  startup_program,
                  image_shape,
                  dataset,
                  archs,
                  args,
                  places,
                  is_test=False):
    with static.program_guard(main_program, startup_program):
        with paddle.utils.unique_name.guard():
            data_shape = [None] + image_shape
            data = static.data(name='data', shape=data_shape, dtype='float32')
            label = static.data(name='label', shape=[None, 1], dtype='int64')
            if args.data == 'cifar10':
                paddle.assign(paddle.reshape(label, [-1, 1]), label)
            if is_test:
                data_loader = paddle.io.DataLoader(
                    dataset,
                    places=places,
                    feed_list=[data, label],
                    drop_last=False,
                    batch_size=args.batch_size,
                    return_list=False,
                    shuffle=False)
            else:
                data_loader = paddle.io.DataLoader(
                    dataset,
                    places=places,
                    feed_list=[data, label],
                    drop_last=True,
                    batch_size=args.batch_size,
                    return_list=False,
                    shuffle=True,
                    use_shared_memory=True,
                    num_workers=4)
            output = archs(data)
            output = static.nn.fc(x=output, size=args.class_dim)

            softmax_out = F.softmax(output)
            cost = F.cross_entropy(softmax_out, label=label)
            avg_cost = paddle.mean(cost)
            acc_top1 = paddle.metric.accuracy(
                input=softmax_out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(
                input=softmax_out, label=label, k=5)

            if is_test == False:
                optimizer = create_optimizer(args)
                optimizer.minimize(avg_cost)
    return data_loader, avg_cost, acc_top1, acc_top5


def search_mobilenetv2(config, args, image_size, is_server=True):
    image_shape = [3, image_size, image_size]
    if args.data == 'cifar10':
        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
        train_dataset = paddle.vision.datasets.Cifar10(
            mode='train', transform=transform, backend='cv2')
        val_dataset = paddle.vision.datasets.Cifar10(
            mode='test', transform=transform, backend='cv2')

    elif args.data == 'imagenet':
        train_dataset = imagenet_reader.ImageNetDataset(mode='train')
        val_dataset = imagenet_reader.ImageNetDataset(mode='val')

    places = static.cuda_places() if args.use_gpu else static.cpu_places()
    place = places[0]
    if is_server:
        ### start a server and a client
        sa_nas = SANAS(
            config,
            server_addr=(args.server_address, args.port),
            search_steps=args.search_steps,
            is_server=True)
    else:
        ### start a client
        sa_nas = SANAS(
            config,
            server_addr=(args.server_address, args.port),
            search_steps=args.search_steps,
            is_server=False)

    for step in range(args.search_steps):
        archs = sa_nas.next_archs()[0]

        train_program = static.Program()
        test_program = static.Program()
        startup_program = static.Program()
        train_loader, avg_cost, acc_top1, acc_top5 = build_program(
            train_program, startup_program, image_shape, train_dataset, archs,
            args, places)

        current_flops = flops(train_program)
        print('step: {}, current_flops: {}'.format(step, current_flops))
        if current_flops > int(321208544):
            continue

        test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
            test_program,
            startup_program,
            image_shape,
            val_dataset,
            archs,
            args,
            place,
            is_test=True)
        test_program = test_program.clone(for_test=True)

        exe = static.Executor(place)
        exe.run(startup_program)

        build_strategy = static.BuildStrategy()
        train_compiled_program = static.CompiledProgram(
Z
zhouzj 已提交
137
            train_program, build_strategy=build_strategy)
I
iamWHTWD 已提交
138 139 140 141
        for epoch_id in range(args.retain_epoch):
            for batch_id, data in enumerate(train_loader()):
                fetches = [avg_cost.name]
                s_time = time.time()
Z
zhouzj 已提交
142
                outs = exe.run(
C
Chang Xu 已提交
143
                    train_compiled_program, feed=data[0], fetch_list=fetches)
I
iamWHTWD 已提交
144 145 146 147
                batch_time = time.time() - s_time
                if batch_id % 10 == 0:
                    _logger.info(
                        'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
C
Chang Xu 已提交
148
                        format(step, epoch_id, batch_id, outs, batch_time))
I
iamWHTWD 已提交
149 150 151 152 153 154

        reward = []
        for batch_id, data in enumerate(test_loader()):
            test_fetches = [
                test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
            ]
Z
zhouzj 已提交
155 156
            batch_reward = exe.run(
                test_program, feed=data, fetch_list=test_fetches)
I
iamWHTWD 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
            reward_avg = np.mean(np.array(batch_reward), axis=1)
            reward.append(reward_avg)

            _logger.info(
                'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
                format(step, batch_id, batch_reward[0], batch_reward[1],
                       batch_reward[2]))

        finally_reward = np.mean(np.array(reward), axis=0)
        _logger.info(
            'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
                finally_reward[0], finally_reward[1], finally_reward[2]))

        sa_nas.reward(float(finally_reward[1]))


def test_search_result(tokens, image_size, args, config):
    places = static.cuda_places() if args.use_gpu else static.cpu_places()
    place = places[0]

    sa_nas = SANAS(
        config,
        server_addr=(args.server_address, args.port),
        search_steps=args.search_steps,
        is_server=True)

    image_shape = [3, image_size, image_size]
    if args.data == 'cifar10':
        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
        train_dataset = paddle.vision.datasets.Cifar10(
            mode='train', transform=transform, backend='cv2')
        val_dataset = paddle.vision.datasets.Cifar10(
            mode='test', transform=transform, backend='cv2')

    elif args.data == 'imagenet':
        train_dataset = imagenet_reader.ImageNetDataset(mode='train')
        val_dataset = imagenet_reader.ImageNetDataset(mode='val')

    archs = sa_nas.tokens2arch(tokens)[0]

    train_program = static.Program()
    test_program = static.Program()
    startup_program = static.Program()
    train_loader, avg_cost, acc_top1, acc_top5 = build_program(
        train_program, startup_program, image_shape, train_dataset, archs, args,
        places)

    current_flops = flops(train_program)
    print('current_flops: {}'.format(current_flops))
    test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
        test_program,
        startup_program,
        image_shape,
        val_dataset,
        archs,
        args,
        place,
        is_test=True)

    test_program = test_program.clone(for_test=True)

    exe = static.Executor(place)
    exe.run(startup_program)

    build_strategy = static.BuildStrategy()
    train_compiled_program = static.CompiledProgram(
Z
zhouzj 已提交
223
        train_program, build_strategy=build_strategy)
I
iamWHTWD 已提交
224 225 226 227
    for epoch_id in range(args.retain_epoch):
        for batch_id, data in enumerate(train_loader()):
            fetches = [avg_cost.name]
            s_time = time.time()
Z
zhouzj 已提交
228 229
            outs = exe.run(
                train_compiled_program, feed=data, fetch_list=fetches)[0]
I
iamWHTWD 已提交
230 231 232 233 234 235 236 237 238 239 240
            batch_time = time.time() - s_time
            if batch_id % 10 == 0:
                _logger.info(
                    'TRAIN: epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
                    format(epoch_id, batch_id, outs[0], batch_time))

        reward = []
        for batch_id, data in enumerate(test_loader()):
            test_fetches = [
                test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
            ]
Z
zhouzj 已提交
241 242
            batch_reward = exe.run(
                test_program, feed=data, fetch_list=test_fetches)
I
iamWHTWD 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
            reward_avg = np.mean(np.array(batch_reward), axis=1)
            reward.append(reward_avg)

            _logger.info(
                'TEST: batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
                format(batch_id, batch_reward[0], batch_reward[1], batch_reward[
                    2]))

        finally_reward = np.mean(np.array(reward), axis=0)
        _logger.info(
            'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
                finally_reward[0], finally_reward[1], finally_reward[2]))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(
        description='SA NAS MobileNetV2 cifar10 argparase')
    parser.add_argument(
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='Whether to use GPU in train/test model.')
    parser.add_argument(
        '--batch_size', type=int, default=256, help='batch size.')
    parser.add_argument(
        '--class_dim', type=int, default=10, help='classify number.')
    parser.add_argument(
        '--data',
        type=str,
        default='cifar10',
        choices=['cifar10', 'imagenet'],
        help='server address.')
    parser.add_argument(
        '--is_server',
        type=ast.literal_eval,
        default=True,
        help='Whether to start a server.')
    parser.add_argument(
        '--search_steps',
        type=int,
        default=100,
        help='controller server number.')
    parser.add_argument(
        '--server_address', type=str, default="", help='server ip.')
    parser.add_argument('--port', type=int, default=8881, help='server port')
    parser.add_argument(
        '--retain_epoch', type=int, default=5, help='epoch for each token.')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate.')
    args = parser.parse_args()
    print(args)

    if args.data == 'cifar10':
        image_size = 32
        block_num = 3
    elif args.data == 'imagenet':
        image_size = 224
        block_num = 6
    else:
        raise NotImplementedError(
            'data must in [cifar10, imagenet], but received: {}'.format(
                args.data))

    config = [('MobileNetV2Space')]
    paddle.enable_static()
    search_mobilenetv2(config, args, image_size, is_server=args.is_server)