diff --git a/demo/nas/darts_cifar10_reader.py b/demo/nas/darts_cifar10_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..7698c176f7a0eb1c10539f1531d6736bda29e344
--- /dev/null
+++ b/demo/nas/darts_cifar10_reader.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from PIL import Image
+from PIL import ImageOps
+import os
+import math
+import random
+import tarfile
+import functools
+import numpy as np
+from PIL import Image, ImageEnhance
+import paddle
+try:
+ import cPickle
+except:
+ import _pickle as cPickle
+
+IMAGE_SIZE = 32
+IMAGE_DEPTH = 3
+CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
+CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
+
+URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
+CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
+CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
+
+paddle.dataset.common.DATA_HOME = "dataset/"
+
+THREAD = 16
+BUF_SIZE = 10240
+
+num_workers = 4
+use_multiprocess = True
+cutout = True
+cutout_length = 16
+
+
+def preprocess(sample, is_training):
+ image_array = sample.reshape(IMAGE_DEPTH, IMAGE_SIZE, IMAGE_SIZE)
+ rgb_array = np.transpose(image_array, (1, 2, 0))
+ img = Image.fromarray(rgb_array, 'RGB')
+
+ if is_training:
+ # pad, ramdom crop, random_flip_left_right
+ img = ImageOps.expand(img, (4, 4, 4, 4), fill=0)
+ left_top = np.random.randint(8, size=2)
+ img = img.crop((left_top[1], left_top[0], left_top[1] + IMAGE_SIZE,
+ left_top[0] + IMAGE_SIZE))
+ if np.random.randint(2):
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ img = np.array(img).astype(np.float32)
+
+ img_float = img / 255.0
+ img = (img_float - CIFAR_MEAN) / CIFAR_STD
+
+ if is_training and cutout:
+ center = np.random.randint(IMAGE_SIZE, size=2)
+ offset_width = max(0, center[0] - cutout_length // 2)
+ offset_height = max(0, center[1] - cutout_length // 2)
+ target_width = min(center[0] + cutout_length // 2, IMAGE_SIZE)
+ target_height = min(center[1] + cutout_length // 2, IMAGE_SIZE)
+
+ for i in range(offset_height, target_height):
+ for j in range(offset_width, target_width):
+ img[i][j][:] = 0.0
+
+ img = np.transpose(img, (2, 0, 1))
+ return img
+
+
+def reader_generator(datasets, batch_size, is_training, is_shuffle):
+ def read_batch(datasets):
+ if is_shuffle:
+ random.shuffle(datasets)
+ for im, label in datasets:
+ im = preprocess(im, is_training)
+ yield im, [int(label)]
+
+ def reader():
+ batch_data = []
+ batch_label = []
+ for data in read_batch(datasets):
+ batch_data.append(data[0])
+ batch_label.append(data[1])
+ if len(batch_data) == batch_size:
+ batch_data = np.array(batch_data, dtype='float32')
+ batch_label = np.array(batch_label, dtype='int64')
+ batch_out = [batch_data, batch_label]
+ yield batch_out
+ batch_data = []
+ batch_label = []
+
+ return reader
+
+
+def cifar10_reader(file_name, data_name, is_shuffle):
+ with tarfile.open(file_name, mode='r') as f:
+ names = [
+ each_item.name for each_item in f if data_name in each_item.name
+ ]
+ names.sort()
+ datasets = []
+ for name in names:
+ print("Reading file " + name)
+ try:
+ batch = cPickle.load(
+ f.extractfile(name), encoding='iso-8859-1')
+ except:
+ batch = cPickle.load(f.extractfile(name))
+ data = batch['data']
+ labels = batch.get('labels', batch.get('fine_labels', None))
+ assert labels is not None
+ dataset = zip(data, labels)
+ datasets.extend(dataset)
+ if is_shuffle:
+ random.shuffle(datasets)
+ return datasets
+
+
+def train_valid(batch_size, is_train, is_shuffle):
+ name = 'data_batch' if is_train else 'test_batch'
+ datasets = cifar10_reader(
+ paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
+ name, is_shuffle)
+ n = int(math.ceil(len(datasets) //
+ num_workers)) if use_multiprocess else len(datasets)
+ datasets_lists = [datasets[i:i + n] for i in range(0, len(datasets), n)]
+ multi_readers = []
+ for pid in range(len(datasets_lists)):
+ multi_readers.append(
+ reader_generator(datasets_lists[pid], batch_size, is_train,
+ is_shuffle))
+ if use_multiprocess:
+ reader = paddle.reader.multiprocess_reader(multi_readers, False)
+ else:
+ reader = multi_readers[0]
+ return reader
diff --git a/demo/nas/darts_nas.py b/demo/nas/darts_nas.py
new file mode 100644
index 0000000000000000000000000000000000000000..43705e8781ab2875e55f7f0b3df12a6123a0f475
--- /dev/null
+++ b/demo/nas/darts_nas.py
@@ -0,0 +1,348 @@
+import os
+import sys
+sys.path.append('..')
+import numpy as np
+import argparse
+import ast
+import time
+import argparse
+import ast
+import logging
+import paddle.fluid as fluid
+from paddleslim.nas import SANAS
+from paddleslim.common import get_logger
+import darts_cifar10_reader as reader
+
+_logger = get_logger(__name__, level=logging.INFO)
+
+auxiliary = True
+auxiliary_weight = 0.4
+trainset_num = 50000
+lr = 0.025
+momentum = 0.9
+weight_decay = 0.0003
+drop_path_probility = 0.2
+
+
+class AvgrageMeter(object):
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.avg = 0
+ self.sum = 0
+ self.cnt = 0
+
+ def update(self, val, n=1):
+ self.sum += val * n
+ self.cnt += n
+ self.avg = self.sum / self.cnt
+
+
+def count_parameters_in_MB(all_params, prefix='model'):
+ parameters_number = 0
+ for param in all_params:
+ if param.name.startswith(
+ prefix) and param.trainable and 'aux' not in param.name:
+ parameters_number += np.prod(param.shape)
+ return parameters_number / 1e6
+
+
+def create_data_loader(image_shape, is_train, args):
+ image = fluid.data(
+ name="image", shape=[None] + image_shape, dtype="float32")
+ label = fluid.data(name="label", shape=[None, 1], dtype="int64")
+ data_loader = fluid.io.DataLoader.from_generator(
+ feed_list=[image, label],
+ capacity=64,
+ use_double_buffer=True,
+ iterable=True)
+ drop_path_prob = ''
+ drop_path_mask = ''
+ if is_train:
+ drop_path_prob = fluid.data(
+ name="drop_path_prob", shape=[args.batch_size, 1], dtype="float32")
+ drop_path_mask = fluid.data(
+ name="drop_path_mask",
+ shape=[args.batch_size, 20, 4, 2],
+ dtype="float32")
+
+ return data_loader, image, label, drop_path_prob, drop_path_mask
+
+
+def build_program(main_program, startup_program, image_shape, archs, args,
+ is_train):
+ with fluid.program_guard(main_program, startup_program):
+ data_loader, data, label, drop_path_prob, drop_path_mask = create_data_loader(
+ image_shape, is_train, args)
+ logits, logits_aux = archs(data, drop_path_prob, drop_path_mask,
+ is_train, 10)
+ top1 = fluid.layers.accuracy(input=logits, label=label, k=1)
+ top5 = fluid.layers.accuracy(input=logits, label=label, k=5)
+ loss = fluid.layers.reduce_mean(
+ fluid.layers.softmax_with_cross_entropy(logits, label))
+
+ if is_train:
+ if auxiliary:
+ loss_aux = fluid.layers.reduce_mean(
+ fluid.layers.softmax_with_cross_entropy(logits_aux, label))
+ loss = loss + auxiliary_weight * loss_aux
+ step_per_epoch = int(trainset_num / args.batch_size)
+ learning_rate = fluid.layers.cosine_decay(lr, step_per_epoch,
+ args.retain_epoch)
+ fluid.clip.set_gradient_clip(
+ clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
+ optimizer = fluid.optimizer.MomentumOptimizer(
+ learning_rate,
+ momentum,
+ regularization=fluid.regularizer.L2DecayRegularizer(
+ weight_decay))
+ optimizer.minimize(loss)
+ outs = [loss, top1, top5, learning_rate]
+ else:
+ outs = [loss, top1, top5]
+ return outs, data_loader
+
+
+def train(main_prog, exe, epoch_id, train_loader, fetch_list, args):
+ loss = AvgrageMeter()
+ top1 = AvgrageMeter()
+ top5 = AvgrageMeter()
+ for step_id, data in enumerate(train_loader()):
+ devices_num = len(data)
+ if drop_path_probility > 0:
+ feed = []
+ for device_id in range(devices_num):
+ image = data[device_id]['image']
+ label = data[device_id]['label']
+ drop_path_prob = np.array(
+ [[drop_path_probility * epoch_id / args.retain_epoch]
+ for i in range(args.batch_size)]).astype(np.float32)
+ drop_path_mask = 1 - np.random.binomial(
+ 1, drop_path_prob[0],
+ size=[args.batch_size, 20, 4, 2]).astype(np.float32)
+ feed.append({
+ "image": image,
+ "label": label,
+ "drop_path_prob": drop_path_prob,
+ "drop_path_mask": drop_path_mask
+ })
+ else:
+ feed = data
+ loss_v, top1_v, top5_v, lr = exe.run(
+ main_prog, feed=feed, fetch_list=[v.name for v in fetch_list])
+ loss.update(loss_v, args.batch_size)
+ top1.update(top1_v, args.batch_size)
+ top5.update(top5_v, args.batch_size)
+ if step_id % 10 == 0:
+ _logger.info(
+ "Train Epoch {}, Step {}, Lr {:.8f}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
+ format(epoch_id, step_id, lr[0], loss.avg[0], top1.avg[0],
+ top5.avg[0]))
+ return top1.avg[0]
+
+
+def valid(main_prog, exe, epoch_id, valid_loader, fetch_list, args):
+ loss = AvgrageMeter()
+ top1 = AvgrageMeter()
+ top5 = AvgrageMeter()
+ for step_id, data in enumerate(valid_loader()):
+ loss_v, top1_v, top5_v = exe.run(
+ main_prog, feed=data, fetch_list=[v.name for v in fetch_list])
+ loss.update(loss_v, args.batch_size)
+ top1.update(top1_v, args.batch_size)
+ top5.update(top5_v, args.batch_size)
+ if step_id % 10 == 0:
+ _logger.info(
+ "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
+ format(epoch_id, step_id, loss.avg[0], top1.avg[0], top5.avg[
+ 0]))
+ return top1.avg[0]
+
+
+def search(config, args, image_size, is_server=True):
+ if is_server:
+ ### start a server and a client
+ sa_nas = SANAS(
+ config,
+ server_addr=(args.server_address, args.port),
+ search_steps=args.search_steps,
+ is_server=True)
+ else:
+ ### start a client
+ sa_nas = SANAS(
+ config,
+ server_addr=(args.server_address, args.port),
+ init_temperature=init_temperature,
+ is_server=False)
+
+ image_shape = [3, image_size, image_size]
+ for step in range(args.search_steps):
+ archs = sa_nas.next_archs()[0]
+
+ train_program = fluid.Program()
+ test_program = fluid.Program()
+ startup_program = fluid.Program()
+ train_fetch_list, train_loader = build_program(
+ train_program,
+ startup_program,
+ image_shape,
+ archs,
+ args,
+ is_train=True)
+
+ current_params = count_parameters_in_MB(
+ train_program.global_block().all_parameters(), 'cifar10')
+ _logger.info('step: {}, current_params: {}M'.format(step,
+ current_params))
+ if current_params > float(3.77):
+ continue
+
+ test_fetch_list, test_loader = build_program(
+ test_program,
+ startup_program,
+ image_shape,
+ archs,
+ args,
+ is_train=False)
+ test_program = test_program.clone(for_test=True)
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ exe.run(startup_program)
+
+ train_reader = reader.train_valid(
+ batch_size=args.batch_size, is_train=True, is_shuffle=True)
+ test_reader = reader.train_valid(
+ batch_size=args.batch_size, is_train=False, is_shuffle=False)
+
+ train_loader.set_batch_generator(train_reader, places=place)
+ test_loader.set_batch_generator(test_reader, places=place)
+
+ build_strategy = fluid.BuildStrategy()
+ train_compiled_program = fluid.CompiledProgram(
+ train_program).with_data_parallel(
+ loss_name=train_fetch_list[0].name,
+ build_strategy=build_strategy)
+
+ valid_top1_list = []
+ for epoch_id in range(args.retain_epoch):
+ train_top1 = train(train_compiled_program, exe, epoch_id,
+ train_loader, train_fetch_list, args)
+ _logger.info("TRAIN: step: {}, Epoch {}, train_acc {:.6f}".format(
+ step, epoch_id, train_top1))
+ valid_top1 = valid(test_program, exe, epoch_id, test_loader,
+ test_fetch_list, args)
+ _logger.info("TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id,
+ valid_top1))
+ valid_top1_list.append(valid_top1)
+ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
+
+
+def final_test(config, args, image_size, token=None):
+ assert token != None, "If you want to start a final experiment, you must input a token."
+ sa_nas = SANAS(
+ config, server_addr=(args.server_address, args.port), is_server=True)
+
+ image_shape = [3, image_size, image_size]
+ archs = sa_nas.tokens2arch(token)[0]
+
+ train_program = fluid.Program()
+ test_program = fluid.Program()
+ startup_program = fluid.Program()
+ train_fetch_list, train_loader = build_program(
+ train_program,
+ startup_program,
+ image_shape,
+ archs,
+ args,
+ is_train=True)
+
+ current_params = count_parameters_in_MB(
+ train_program.global_block().all_parameters(), 'cifar10')
+ _logger.info('current_params: {}M'.format(current_params))
+ test_fetch_list, test_loader = build_program(
+ test_program,
+ startup_program,
+ image_shape,
+ archs,
+ args,
+ is_train=False)
+ test_program = test_program.clone(for_test=True)
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ exe.run(startup_program)
+
+ train_reader = reader.train_valid(
+ batch_size=args.batch_size, is_train=True, is_shuffle=True, args=args)
+ test_reader = reader.train_valid(
+ batch_size=args.batch_size,
+ is_train=False,
+ is_shuffle=False,
+ args=args)
+
+ train_loader.set_batch_generator(train_reader, places=place)
+ test_loader.set_batch_generator(test_reader, places=place)
+
+ build_strategy = fluid.BuildStrategy()
+ train_compiled_program = fluid.CompiledProgram(
+ train_program).with_data_parallel(
+ loss_name=train_fetch_list[0].name, build_strategy=build_strategy)
+
+ valid_top1_list = []
+ for epoch_id in range(args.retain_epoch):
+ train_top1 = train(train_compiled_program, exe, epoch_id, train_loader,
+ train_fetch_list, args)
+ _logger.info("TRAIN: Epoch {}, train_acc {:.6f}".format(epoch_id,
+ train_top1))
+ valid_top1 = valid(test_program, exe, epoch_id, test_loader,
+ test_fetch_list, args)
+ _logger.info("TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id,
+ valid_top1))
+ valid_top1_list.append(valid_top1)
+
+ output_dir = os.path.join('darts_output', str(epoch_id))
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ fluid.io.save_persistables(exe, output_dir, main_program=train_program)
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(
+ description='SA NAS MobileNetV2 cifar10 argparase')
+ parser.add_argument(
+ '--use_gpu',
+ type=ast.literal_eval,
+ default=True,
+ help='Whether to use GPU in train/test model.')
+ parser.add_argument(
+ '--batch_size', type=int, default=96, help='batch size.')
+ parser.add_argument(
+ '--is_server',
+ type=ast.literal_eval,
+ default=True,
+ help='Whether to start a server.')
+ parser.add_argument(
+ '--server_address', type=str, default="", help='server ip.')
+ parser.add_argument('--port', type=int, default=8881, help='server port')
+ parser.add_argument(
+ '--retain_epoch', type=int, default=30, help='epoch for each token.')
+ parser.add_argument('--token', type=int, nargs='+', help='final token.')
+ parser.add_argument(
+ '--search_steps',
+ type=int,
+ default=200,
+ help='controller server number.')
+ args = parser.parse_args()
+ print(args)
+
+ image_size = 32
+
+ config = [('DartsSpace')]
+
+ if args.token == None:
+ search(config, args, image_size, is_server=args.is_server)
+ else:
+ final_test(config, args, image_size, token=args.token)
diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py
index 3adef1356e650623e7a778cbb84afcce593a661a..a641733024b42d65105064672d6950a64e8fd75c 100644
--- a/demo/nas/sa_nas_mobilenetv2.py
+++ b/demo/nas/sa_nas_mobilenetv2.py
@@ -38,19 +38,22 @@ def build_program(main_program,
args,
is_test=False):
with fluid.program_guard(main_program, startup_program):
- data_loader, data, label = create_data_loader(image_shape)
- output = archs(data)
- output = fluid.layers.fc(input=output, size=args.class_dim)
-
- softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
- cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
- avg_cost = fluid.layers.mean(cost)
- acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
- acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
-
- if is_test == False:
- optimizer = create_optimizer(args)
- optimizer.minimize(avg_cost)
+ with fluid.unique_name.guard():
+ data_loader, data, label = create_data_loader(image_shape)
+ output = archs(data)
+ output = fluid.layers.fc(input=output, size=args.class_dim)
+
+ softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
+ cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
+ avg_cost = fluid.layers.mean(cost)
+ acc_top1 = fluid.layers.accuracy(
+ input=softmax_out, label=label, k=1)
+ acc_top5 = fluid.layers.accuracy(
+ input=softmax_out, label=label, k=5)
+
+ if is_test == False:
+ optimizer = create_optimizer(args)
+ optimizer.minimize(avg_cost)
return data_loader, avg_cost, acc_top1, acc_top5
@@ -169,8 +172,6 @@ def test_search_result(tokens, image_size, args, config):
sa_nas = SANAS(
config,
server_addr=(args.server_address, args.port),
- init_temperature=args.init_temperature,
- reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=True)
diff --git a/demo/nas/search_space_doc.md b/demo/nas/search_space_doc.md
deleted file mode 100644
index 682b0eac801bae4ae59b523475e8fa3c66586190..0000000000000000000000000000000000000000
--- a/demo/nas/search_space_doc.md
+++ /dev/null
@@ -1,116 +0,0 @@
-# paddleslim.nas 提供的搜索空间:
-
-1. 根据原本模型结构构造搜索空间:
-
- 1.1 MobileNetV2Space
-
- 1.2 MobileNetV1Space
-
- 1.3 ResNetSpace
-
-
-2. 根据相应模型的block构造搜索空间
-
- 2.1 MobileNetV1BlockSpace
-
- 2.2 MobileNetV2BlockSpace
-
- 2.3 ResNetBlockSpace
-
- 2.4 InceptionABlockSpace
-
- 2.5 InceptionCBlockSpace
-
-
-##搜索空间的配置介绍:
-
-**input_size(int|None)**:`input_size`表示输入feature map的大小。
-**output_size(int|None)**:`output_size`表示输出feature map的大小。
-**block_num(int|None)**:`block_num`表示搜索空间中block的数量。
-**block_mask(list|None)**:`block_mask`表示当前的block是一个reduction block还是一个normal block,是一组由0、1组成的列表,0表示当前block是normal block,1表示当前block是reduction block。如果设置了`block_mask`,则主要以`block_mask`为主要配置,`input_size`,`output_size`和`block_num`三种配置是无效的。
-
-**Note:**
-1. reduction block表示经过这个block之后的feature map大小下降为之前的一半,normal block表示经过这个block之后feature map大小不变。
-2. `input_size`和`output_size`用来计算整个模型结构中reduction block数量。
-
-
-##搜索空间示例:
-
-1. 使用paddleslim中提供用原本的模型结构来构造搜索空间的话,仅需要指定搜索空间名字即可。例如:如果使用原本的MobileNetV2的搜索空间进行搜索的话,传入SANAS中的config直接指定为[('MobileNetV2Space')]。
-2. 使用paddleslim中提供的block搜索空间构造搜索空间:
- 2.1 使用`input_size`, `output_size`和`block_num`来构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'input_size': 224, 'output_size': 32, 'block_num': 10})]。
- 2.2 使用`block_mask`构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]})]。
-
-
-# 自定义搜索空间(search space)
-
-自定义搜索空间类需要继承搜索空间基类并重写以下几部分:
- 1. 初始化的tokens(`init_tokens`函数),可以设置为自己想要的tokens列表, tokens列表中的每个数字指的是当前数字在相应的搜索列表中的索引。例如本示例中若tokens=[0, 3, 5],则代表当前模型结构搜索到的通道数为[8, 40, 128]。
- 2. token中每个数字的搜索列表长度(`range_table`函数),tokens中每个token的索引范围。
- 3. 根据token产生模型结构(`token2arch`函数),根据搜索到的tokens列表产生模型结构。
-
-以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。
-
-```python
-### 引入搜索空间基类函数和search space的注册类函数
-from .search_space_base import SearchSpaceBase
-from .search_space_registry import SEARCHSPACE
-import numpy as np
-
-### 需要调用注册函数把自定义搜索空间注册到space space中
-@SEARCHSPACE.register
-### 定义一个继承SearchSpaceBase基类的搜索空间的类函数
-class ResNetBlockSpace2(SearchSpaceBase):
- def __init__(self, input_size, output_size, block_num, block_mask):
- ### 定义一些实际想要搜索的内容,例如:通道数、每个卷积的重复次数、卷积核大小等等
- ### self.filter_num 代表通道数的搜索列表
- self.filter_num = np.array([8, 16, 32, 40, 64, 128, 256, 512])
-
- ### 定义初始化token,初始化token的长度根据传入的block_num或者block_mask的长度来得到的
- def init_tokens(self):
- return [0] * 3 * len(self.block_mask)
-
- ### 定义
- def range_table(self):
- return [len(self.filter_num)] * 3 * len(self.block_mask)
-
- def token2arch(self, tokens=None):
- if tokens == None:
- tokens = self.init_tokens()
-
- self.bottleneck_params_list = []
- for i in range(len(self.block_mask)):
- self.bottleneck_params_list.append(self.filter_num[tokens[i * 3 + 0]],
- self.filter_num[tokens[i * 3 + 1]],
- self.filter_num[tokens[i * 3 + 2]],
- 2 if self.block_mask[i] == 1 else 1)
-
- def net_arch(input):
- for i, layer_setting in enumerate(self.bottleneck_params_list):
- channel_num, stride = layer_setting[:-1], layer_setting[-1]
- input = self._resnet_block(input, channel_num, stride, name='resnet_layer{}'.format(i+1))
-
- return input
-
- return net_arch
-
- ### 构造具体block的操作
- def _resnet_block(self, input, channel_num, stride, name=None):
- shortcut_conv = self._shortcut(input, channel_num[2], stride, name=name)
- input = self._conv_bn_layer(input=input, num_filters=channel_num[0], filter_size=1, act='relu', name=name + '_conv0')
- input = self._conv_bn_layer(input=input, num_filters=channel_num[1], filter_size=3, stride=stride, act='relu', name=name + '_conv1')
- input = self._conv_bn_layer(input=input, num_filters=channel_num[2], filter_size=1, name=name + '_conv2')
- return fluid.layers.elementwise_add(x=shortcut_conv, y=input, axis=0, name=name+'_elementwise_add')
-
- def _shortcut(self, input, channel_num, stride, name=None):
- channel_in = input.shape[1]
- if channel_in != channel_num or stride != 1:
- return self.conv_bn_layer(input, num_filters=channel_num, filter_size=1, stride=stride, name=name+'_shortcut')
- else:
- return input
-
- def _conv_bn_layer(self, input, num_filters, filter_size, stride=1, padding='SAME', act=None, name=None):
- conv = fluid.layers.conv2d(input, num_filters, filter_size, stride, name=name+'_conv')
- bn = fluid.layers.batch_norm(conv, act=act, name=name+'_bn')
- return bn
-```
diff --git a/docs/en/model_zoo_en.md b/docs/en/model_zoo_en.md
index 7212316d4ef3450e7e0bff568a1076e32607cdca..f81d8af0188c3129a55d328740d3ea6f0fb75d6a 100644
--- a/docs/en/model_zoo_en.md
+++ b/docs/en/model_zoo_en.md
@@ -59,8 +59,19 @@
| Model | Method | Top-1/Top-5 Acc | Volume(MB) | GFLOPs | Download |
|:--:|:---:|:--:|:--:|:--:|:--:|
-| MobileNetV2 | - | 72.15%/90.65% | 15 | 0.59 | [model](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar) |
-| MobileNetV2 | SANAS | 71.518%/90.208% (-0.632%/-0.442%) | 14 | 0.295 | [model](https://paddlemodels.cdn.bcebos.com/PaddleSlim/MobileNetV2_sanas.tar) |
+| MobileNetV2 | - | 72.15%/90.65% | 15 | 0.59 | [model](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar) |
+| MobileNetV2_NAS | SANAS | 71.518%/90.208% (-0.632%/-0.442%) | 14 | 0.295 | [model](https://paddlemodels.cdn.bcebos.com/PaddleSlim/MobileNetV2_sanas.tar) |
+
+Dataset: Cifar10
+| Model | Method | Acc | Params(MB) | Download |
+|:---:|:--:|:--:|:--:|:--:|
+| Darts | - | 97.135% | 3.767 | - |
+| Darts_SA(Based on Darts) | SANAS | 97.276%(+0.141%) | 3.344(-11.2%) | - |
+
+!!! note "Note"
+
+ [1]: The token of MobileNetV2_NAS is [4, 4, 5, 1, 1, 2, 1, 1, 0, 2, 6, 2, 0, 3, 4, 5, 0, 4, 5, 5, 1, 4, 8, 0, 0]. The token of Darts_SA is [5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8].
+
## 2. 目标检测
@@ -135,12 +146,9 @@ Dataset: WIDER-FACE
| :------------: | :---------: | :-------: | :------: | :-----------------------------: | :------------: | :------------: | :----------------------------------------------------------: |
| BlazeFace | - | 8 | 640 | 91.5/89.2/79.7 | 815 | 71.862 | [model](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_original.tar) |
| BlazeFace-NAS | - | 8 | 640 | 83.7/80.7/65.8 | 244 | 21.117 |[model](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas.tar) |
-| BlazeFace-NAS1 | SANAS | 8 | 640 | 87.0/83.7/68.5 | 389 | 22.558 | [model](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas2.tar) |
-
-!!! note "Note"
-
- [1]: latency is based on latency_855.txt, the file is test on 855 by PaddleLite。
+| BlazeFace-NASV2 | SANAS | 8 | 640 | 87.0/83.7/68.5 | 389 | 22.558 | [model](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas2.tar) |
+Note: latency is based on latency_855.txt, the file is test on 855 by PaddleLite。The config of BlazeFace-NASV2 is in [there](https://github.com/PaddlePaddle/PaddleDetection/blob/master/configs/face_detection/blazeface_nas_v2.yml).
## 3. 图像分割
diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md
index 466dd19ef98ced5bd2d457fa5039245a27d7ad37..486fde79c145f29fe1c9ace798c2a49d6d0e1f19 100644
--- a/docs/zh_cn/model_zoo.md
+++ b/docs/zh_cn/model_zoo.md
@@ -91,11 +91,24 @@
### 1.4 搜索
+数据集: ImageNet1000
+
| 模型 | 压缩方法 | Top-1/Top-5 Acc | 模型体积(MB) | GFLOPs | 下载 |
|:--:|:---:|:--:|:--:|:--:|:--:|
| MobileNetV2 | - | 72.15%/90.65% | 15 | 0.59 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar) |
| MobileNetV2 | SANAS | 71.518%/90.208% (-0.632%/-0.442%) | 14 | 0.295 | [下载链接](https://paddlemodels.cdn.bcebos.com/PaddleSlim/MobileNetV2_sanas.tar) |
+数据集: Cifar10
+| 模型 |压缩方法 | Acc | 模型参数(MB) | 下载 |
+|:---:|:--:|:--:|:--:|:--:|
+| Darts | - | 97.135% | 3.767 | - |
+| Darts_SA(基于Darts搜索空间) | SANAS | 97.276%(+0.141%) | 3.344(-11.2%) | - |
+
+
+Note: MobileNetV2_NAS 的token是:[4, 4, 5, 1, 1, 2, 1, 1, 0, 2, 6, 2, 0, 3, 4, 5, 0, 4, 5, 5, 1, 4, 8, 0, 0]. Darts_SA的token是:[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8].
+
+
+
## 2. 目标检测
### 2.1 量化
@@ -171,11 +184,9 @@
| :------------: | :---------: | :-------: | :------: | :-----------------------------: | :------------: | :------------: | :----------------------------------------------------------: |
| BlazeFace | - | 8 | 640 | 91.5/89.2/79.7 | 815 | 71.862 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_original.tar) |
| BlazeFace-NAS | - | 8 | 640 | 83.7/80.7/65.8 | 244 | 21.117 |[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas.tar) |
-| BlazeFace-NAS1 | SANAS | 8 | 640 | 87.0/83.7/68.5 | 389 | 22.558 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas2.tar) |
-
-!!! note "Note"
+| BlazeFace-NASV2 | SANAS | 8 | 640 | 87.0/83.7/68.5 | 389 | 22.558 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/blazeface_nas2.tar) |
- [1]: 硬件延时时间是利用提供的硬件延时表得到的,硬件延时表是在855芯片上基于PaddleLite测试的结果。
+Note: 硬件延时时间是利用提供的硬件延时表得到的,硬件延时表是在855芯片上基于PaddleLite测试的结果。BlazeFace-NASV2的详细配置在[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/master/configs/face_detection/blazeface_nas_v2.yml).
## 3. 图像分割
diff --git a/docs/zh_cn/tutorials/darts_nas_turorial.ipynb b/docs/zh_cn/tutorials/darts_nas_turorial.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8cc43df5ca55543d58d49d04e21313045f4c75ec
--- /dev/null
+++ b/docs/zh_cn/tutorials/darts_nas_turorial.ipynb
@@ -0,0 +1,324 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import paddle\n",
+ "import paddle.fluid as fluid\n",
+ "from paddleslim.nas import SANAS\n",
+ "import numpy as np\n",
+ "\n",
+ "BATCH_SIZE=96\n",
+ "SERVER_ADDRESS = \"\"\n",
+ "PORT = 8377\n",
+ "SEARCH_STEPS = 300\n",
+ "RETAIN_EPOCH=30\n",
+ "MAX_PARAMS=3.77\n",
+ "IMAGE_SHAPE=[3, 32, 32]\n",
+ "AUXILIARY = True\n",
+ "AUXILIARY_WEIGHT= 0.4\n",
+ "TRAINSET_NUM = 50000\n",
+ "LR = 0.025\n",
+ "MOMENTUM = 0.9\n",
+ "WEIGHT_DECAY = 0.0003\n",
+ "DROP_PATH_PROBILITY = 0.2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-02-23 12:28:09,752-INFO: range table: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14])\n",
+ "2020-02-23 12:28:09,754-INFO: ControllerServer - listen on: [127.0.0.1:8377]\n",
+ "2020-02-23 12:28:09,756-INFO: Controller Server run...\n"
+ ]
+ }
+ ],
+ "source": [
+ "config = [('DartsSpace')]\n",
+ "sa_nas = SANAS(config, server_addr=(SERVER_ADDRESS, PORT), search_steps=SEARCH_STEPS, is_server=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def count_parameters_in_MB(all_params, prefix='model'):\n",
+ " parameters_number = 0\n",
+ " for param in all_params:\n",
+ " if param.name.startswith(\n",
+ " prefix) and param.trainable and 'aux' not in param.name:\n",
+ " parameters_number += np.prod(param.shape)\n",
+ " return parameters_number / 1e6"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def create_data_loader(IMAGE_SHAPE, is_train):\n",
+ " image = fluid.data(\n",
+ " name=\"image\", shape=[None] + IMAGE_SHAPE, dtype=\"float32\")\n",
+ " label = fluid.data(name=\"label\", shape=[None, 1], dtype=\"int64\")\n",
+ " data_loader = fluid.io.DataLoader.from_generator(\n",
+ " feed_list=[image, label],\n",
+ " capacity=64,\n",
+ " use_double_buffer=True,\n",
+ " iterable=True)\n",
+ " drop_path_prob = ''\n",
+ " drop_path_mask = ''\n",
+ " if is_train:\n",
+ " drop_path_prob = fluid.data(\n",
+ " name=\"drop_path_prob\", shape=[BATCH_SIZE, 1], dtype=\"float32\")\n",
+ " drop_path_mask = fluid.data(\n",
+ " name=\"drop_path_mask\",\n",
+ " shape=[BATCH_SIZE, 20, 4, 2],\n",
+ " dtype=\"float32\")\n",
+ "\n",
+ " return data_loader, image, label, drop_path_prob, drop_path_mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def build_program(main_program, startup_program, IMAGE_SHAPE, archs, is_train):\n",
+ " with fluid.program_guard(main_program, startup_program):\n",
+ " data_loader, data, label, drop_path_prob, drop_path_mask = create_data_loader(\n",
+ " IMAGE_SHAPE, is_train)\n",
+ " logits, logits_aux = archs(data, drop_path_prob, drop_path_mask,\n",
+ " is_train, 10)\n",
+ " top1 = fluid.layers.accuracy(input=logits, label=label, k=1)\n",
+ " top5 = fluid.layers.accuracy(input=logits, label=label, k=5)\n",
+ " loss = fluid.layers.reduce_mean(\n",
+ " fluid.layers.softmax_with_cross_entropy(logits, label))\n",
+ "\n",
+ " if is_train:\n",
+ " if AUXILIARY:\n",
+ " loss_aux = fluid.layers.reduce_mean(\n",
+ " fluid.layers.softmax_with_cross_entropy(logits_aux, label))\n",
+ " loss = loss + AUXILIARY_WEIGHT * loss_aux\n",
+ " step_per_epoch = int(TRAINSET_NUM / BATCH_SIZE)\n",
+ " learning_rate = fluid.layers.cosine_decay(LR, step_per_epoch, RETAIN_EPOCH)\n",
+ " fluid.clip.set_gradient_clip(\n",
+ " clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))\n",
+ " optimizer = fluid.optimizer.MomentumOptimizer(\n",
+ " learning_rate,\n",
+ " MOMENTUM,\n",
+ " regularization=fluid.regularizer.L2DecayRegularizer(\n",
+ " WEIGHT_DECAY))\n",
+ " optimizer.minimize(loss)\n",
+ " outs = [loss, top1, top5, learning_rate]\n",
+ " else:\n",
+ " outs = [loss, top1, top5]\n",
+ " return outs, data_loader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train(main_prog, exe, epoch_id, train_loader, fetch_list):\n",
+ " loss = []\n",
+ " top1 = []\n",
+ " top5 = []\n",
+ " for step_id, data in enumerate(train_loader()):\n",
+ " devices_num = len(data)\n",
+ " if DROP_PATH_PROBILITY > 0:\n",
+ " feed = []\n",
+ " for device_id in range(devices_num):\n",
+ " image = data[device_id]['image']\n",
+ " label = data[device_id]['label']\n",
+ " drop_path_prob = np.array(\n",
+ " [[DROP_PATH_PROBILITY * epoch_id / RETAIN_EPOCH]\n",
+ " for i in range(BATCH_SIZE)]).astype(np.float32)\n",
+ " drop_path_mask = 1 - np.random.binomial(\n",
+ " 1, drop_path_prob[0],\n",
+ " size=[BATCH_SIZE, 20, 4, 2]).astype(np.float32)\n",
+ " feed.append({\n",
+ " \"image\": image,\n",
+ " \"label\": label,\n",
+ " \"drop_path_prob\": drop_path_prob,\n",
+ " \"drop_path_mask\": drop_path_mask\n",
+ " })\n",
+ " else:\n",
+ " feed = data\n",
+ " loss_v, top1_v, top5_v, lr = exe.run(\n",
+ " main_prog, feed=feed, fetch_list=[v.name for v in fetch_list])\n",
+ " loss.append(loss_v)\n",
+ " top1.append(top1_v)\n",
+ " top5.append(top5_v)\n",
+ " if step_id % 10 == 0:\n",
+ " print(\n",
+ " \"Train Epoch {}, Step {}, Lr {:.8f}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}\".\n",
+ " format(epoch_id, step_id, lr[0], np.mean(loss), np.mean(top1), np.mean(top5)))\n",
+ " return np.mean(top1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def valid(main_prog, exe, epoch_id, valid_loader, fetch_list):\n",
+ " loss = []\n",
+ " top1 = []\n",
+ " top5 = []\n",
+ " for step_id, data in enumerate(valid_loader()):\n",
+ " loss_v, top1_v, top5_v = exe.run(\n",
+ " main_prog, feed=data, fetch_list=[v.name for v in fetch_list])\n",
+ " loss.append(loss_v)\n",
+ " top1.append(top1_v)\n",
+ " top5.append(top5_v)\n",
+ " if step_id % 10 == 0:\n",
+ " print(\n",
+ " \"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}\".\n",
+ " format(epoch_id, step_id, np.mean(loss), np.mean(top1), np.mean(top5)))\n",
+ " return np.mean(top1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-02-23 12:28:57,462-INFO: current tokens: [5, 5, 5, 5, 5, 12, 7, 7, 7, 7, 7, 7, 7, 10, 10, 10, 10, 10, 10, 10]\n"
+ ]
+ }
+ ],
+ "source": [
+ "archs = sa_nas.next_archs()[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_program = fluid.Program()\n",
+ "test_program = fluid.Program()\n",
+ "startup_program = fluid.Program()\n",
+ "train_fetch_list, train_loader = build_program(train_program, startup_program, IMAGE_SHAPE, archs, is_train=True)\n",
+ "test_fetch_list, test_loader = build_program(test_program, startup_program, IMAGE_SHAPE, archs, is_train=False)\n",
+ "test_program = test_program.clone(for_test=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "place = fluid.CPUPlace()\n",
+ "exe = fluid.Executor(place)\n",
+ "exe.run(startup_program)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024), batch_size=BATCH_SIZE, drop_last=True)\n",
+ "test_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False), batch_size=BATCH_SIZE, drop_last=False)\n",
+ "train_loader.set_sample_list_generator(train_reader, places=place)\n",
+ "test_loader.set_sample_list_generator(test_reader, places=place)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train Epoch 0, Step 0, Lr 0.02500000, loss 3.310467, acc_1 0.062500, acc_5 0.468750\n"
+ ]
+ }
+ ],
+ "source": [
+ "for epoch_id in range(RETAIN_EPOCH):\n",
+ " train_top1 = train(train_program, exe, epoch_id, train_loader, train_fetch_list)\n",
+ " print(\"TRAIN: Epoch {}, train_acc {:.6f}\".format(epoch_id, train_top1))\n",
+ " valid_top1 = valid(test_program, exe, epoch_id, test_loader, test_fetch_list)\n",
+ " print(\"TEST: Epoch {}, valid_acc {:.6f}\".format(epoch_id, valid_top1))\n",
+ " valid_top1_list.append(valid_top1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/zh_cn/tutorials/darts_nas_turorial.md b/docs/zh_cn/tutorials/darts_nas_turorial.md
new file mode 100644
index 0000000000000000000000000000000000000000..7f7acd1e24ec93fa0554b16a5b962f1b8cba41ff
--- /dev/null
+++ b/docs/zh_cn/tutorials/darts_nas_turorial.md
@@ -0,0 +1,275 @@
+# SANAS进阶版实验教程
+
+## 收益情况
+利用DARTS搜索出来的最终模型结构(以下简称为DARTS_model)构造相应的搜索空间,根据PaddleSlim提供的SANAS搜索方法进行搜索实验,最终得到的模型结构(以下简称为DARTS_SA)相比DARTS_model的精度提升0.141%,模型大小下降11.2%。
+
+## 搜索教程
+本教程展示了如何在DARTS_model基础上利用SANAS进行搜索实验,并得到DARTS_SA的结果。
+
+本教程包含以下步骤:
+1. 构造搜索空间
+2. 导入依赖包并定义全局变量
+3. 初始化SANAS实例
+4. 定义计算模型参数量的函数
+5. 定义网络输入数据的函数
+6. 定义造program的函数
+7. 定义训练函数
+8. 定义预测函数
+9. 启动搜索
+ 9.1 获取下一个模型结构
+ 9.2 构造相应的训练和预测program
+ 9.3 添加搜索限制
+ 9.4 定义环境
+ 9.5 定义输入数据
+ 9.6 启动训练和评估
+ 9.7 回传当前模型的得分reward
+10. 利用demo下的脚本启动搜索
+11. 利用demo下的脚本启动最终实验
+
+### 1. 构造搜索空间
+进行搜索实验之前,首先需要根据DARTS_model的模型特点构造相应的搜索空间,本次实验仅会对DARTS_model的通道数进行搜索,搜索的目的是得到一个精度更高并且模型参数更少的模型。
+定义如下搜索空间:
+- 通道数`filter_num`: 定义了每个卷积操作的通道数变化区间。取值区间为:`[4, 8, 12, 16, 20, 36, 54, 72, 90, 108, 144, 180, 216, 252]`
+
+按照通道数来区分DARTS_model中block的话,则DARTS_model中共有3个block,第一个block仅包含6个normal cell,之后的两个block每个block都包含和一个reduction cell和6个normal cell,共有20个cell。在构造搜索空间的时候我们定义每个cell中的所有卷积操作都使用相同的通道数,共有20位token。
+
+完整的搜索空间可以参考[基于DARTS_model的搜索空间](../../../paddleslim/nas/search_space/darts_space.py)
+
+### 2. 引入依赖包并定义全局变量
+```python
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+from paddleslim.nas import SANAS
+
+BATCH_SIZE=96
+SERVER_ADDRESS = ""
+PORT = 8377
+SEARCH_STEPS = 300
+RETAIN_EPOCH=30
+MAX_PARAMS=3.77
+IMAGE_SHAPE=[3, 32, 32]
+AUXILIARY = True
+AUXILIARY_WEIGHT= 0.4
+TRAINSET_NUM = 50000
+LR = 0.025
+MOMENTUM = 0.9
+WEIGHT_DECAY = 0.0003
+DROP_PATH_PROBILITY = 0.2
+```
+
+### 3. 初始化SANAS实例
+首先需要初始化SANAS示例。
+```python
+config = [('DartsSpace')]
+sa_nas = SANAS(config, server_addr=(SERVER_ADDRESS, PORT), search_steps=SEARCH_STEPS, is_server=True)
+```
+
+### 4. 定义计算模型参数量的函数
+根据输入的program计算当前模型中的参数量。本教程使用模型参数量作为搜索的限制条件。
+```python
+def count_parameters_in_MB(all_params, prefix='model'):
+ parameters_number = 0
+ for param in all_params:
+ if param.name.startswith(
+ prefix) and param.trainable and 'aux' not in param.name:
+ parameters_number += np.prod(param.shape)
+ return parameters_number / 1e6
+```
+
+### 5. 定义网络输入数据的函数
+根据输入图片的尺寸定义网络中的输入,其中包括图片输入、标签输入和在训练过程中需要随机丢弃单元的比例和掩膜。
+```python
+def create_data_loader(IMAGE_SHAPE, is_train):
+ image = fluid.data(
+ name="image", shape=[None] + IMAGE_SHAPE, dtype="float32")
+ label = fluid.data(name="label", shape=[None, 1], dtype="int64")
+ data_loader = fluid.io.DataLoader.from_generator(
+ feed_list=[image, label],
+ capacity=64,
+ use_double_buffer=True,
+ iterable=True)
+ drop_path_prob = ''
+ drop_path_mask = ''
+ if is_train:
+ drop_path_prob = fluid.data(
+ name="drop_path_prob", shape=[BATCH_SIZE, 1], dtype="float32")
+ drop_path_mask = fluid.data(
+ name="drop_path_mask",
+ shape=[BATCH_SIZE, 20, 4, 2],
+ dtype="float32")
+
+ return data_loader, image, label, drop_path_prob, drop_path_mask
+```
+
+### 6. 定义构造program的函数
+根据输入的模型结构、输入图片尺寸和当前program是否是训练模式构造program。
+```python
+def build_program(main_program, startup_program, IMAGE_SHAPE, archs, is_train):
+ with fluid.program_guard(main_program, startup_program):
+ data_loader, data, label, drop_path_prob, drop_path_mask = create_data_loader(
+ IMAGE_SHAPE, is_train)
+ logits, logits_aux = archs(data, drop_path_prob, drop_path_mask,
+ is_train, 10)
+ top1 = fluid.layers.accuracy(input=logits, label=label, k=1)
+ top5 = fluid.layers.accuracy(input=logits, label=label, k=5)
+ loss = fluid.layers.reduce_mean(
+ fluid.layers.softmax_with_cross_entropy(logits, label))
+
+ if is_train:
+ if AUXILIARY:
+ loss_aux = fluid.layers.reduce_mean(
+ fluid.layers.softmax_with_cross_entropy(logits_aux, label))
+ loss = loss + AUXILIARY_WEIGHT * loss_aux
+ step_per_epoch = int(TRAINSET_NUM / BATCH_SIZE)
+ learning_rate = fluid.layers.cosine_decay(LR, step_per_epoch, RETAIN_EPOCH)
+ fluid.clip.set_gradient_clip(
+ clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
+ optimizer = fluid.optimizer.MomentumOptimizer(
+ learning_rate,
+ MOMENTUM,
+ regularization=fluid.regularizer.L2DecayRegularizer(
+ WEIGHT_DECAY))
+ optimizer.minimize(loss)
+ outs = [loss, top1, top5, learning_rate]
+ else:
+ outs = [loss, top1, top5]
+ return outs, data_loader
+
+```
+
+### 7. 定义训练函数
+```python
+def train(main_prog, exe, epoch_id, train_loader, fetch_list):
+ loss = []
+ top1 = []
+ top5 = []
+ for step_id, data in enumerate(train_loader()):
+ devices_num = len(data)
+ if DROP_PATH_PROBILITY > 0:
+ feed = []
+ for device_id in range(devices_num):
+ image = data[device_id]['image']
+ label = data[device_id]['label']
+ drop_path_prob = np.array(
+ [[DROP_PATH_PROBILITY * epoch_id / RETAIN_EPOCH]
+ for i in range(BATCH_SIZE)]).astype(np.float32)
+ drop_path_mask = 1 - np.random.binomial(
+ 1, drop_path_prob[0],
+ size=[BATCH_SIZE, 20, 4, 2]).astype(np.float32)
+ feed.append({
+ "image": image,
+ "label": label,
+ "drop_path_prob": drop_path_prob,
+ "drop_path_mask": drop_path_mask
+ })
+ else:
+ feed = data
+ loss_v, top1_v, top5_v, lr = exe.run(
+ main_prog, feed=feed, fetch_list=[v.name for v in fetch_list])
+ loss.append(loss_v)
+ top1.append(top1_v)
+ top5.append(top5_v)
+ if step_id % 10 == 0:
+ print(
+ "Train Epoch {}, Step {}, Lr {:.8f}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
+ format(epoch_id, step_id, lr[0], np.mean(loss), np.mean(top1), np.mean(top5)))
+ return np.mean(top1)
+```
+
+### 8. 定义预测函数
+```python
+def valid(main_prog, exe, epoch_id, valid_loader, fetch_list):
+ loss = []
+ top1 = []
+ top5 = []
+ for step_id, data in enumerate(valid_loader()):
+ loss_v, top1_v, top5_v = exe.run(
+ main_prog, feed=data, fetch_list=[v.name for v in fetch_list])
+ loss.append(loss_v)
+ top1.append(top1_v)
+ top5.append(top5_v)
+ if step_id % 10 == 0:
+ print(
+ "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
+ format(epoch_id, step_id, np.mean(loss), np.mean(top1), np.mean(top5)))
+ return np.mean(top1)
+```
+
+### 9. 启动搜索实验
+以下步骤拆解说明了如何获得当前模型结构以及获得当前模型结构之后应该有的步骤。
+
+#### 9.1 获取下一个模型结构
+根据上面的SANAS实例中的函数获取下一个模型结构。
+```python
+archs = sa_nas.next_archs()[0]
+```
+
+#### 9.2 构造训练和预测program
+根据上一步中获得的模型结构分别构造训练program和预测program。
+```python
+train_program = fluid.Program()
+test_program = fluid.Program()
+startup_program = fluid.Program()
+train_fetch_list, train_loader = build_program(train_program, startup_program, IMAGE_SHAPE, archs, is_train=True)
+test_fetch_list, test_loader = build_program(test_program, startup_program, IMAGE_SHAPE, archs, is_train=False)
+test_program = test_program.clone(for_test=True)
+```
+
+#### 9.3 添加搜索限制
+本教程以模型参数量为限制条件。首先计算一下当前program的参数量,如果超出限制条件,则终止本次模型结构的训练,获取下一个模型结构。
+```python
+current_params = count_parameters_in_MB(
+ train_program.global_block().all_parameters(), 'cifar10')
+```
+
+#### 9.4 定义环境
+定义数据和模型的环境并初始化参数。
+```python
+place = fluid.CPUPlace()
+exe = fluid.Executor(place)
+exe.run(startup_program)
+```
+
+#### 9.5 定义输入数据
+由于本示例中对cifar10中的图片进行了一些额外的预处理操作,和[快速开始](../quick_start/nas_tutorial.md)示例中的reader不同,所以需要自定义cifar10的reader,不能直接调用paddle中封装好的`paddle.dataset.cifar10`的reader。自定义cifar10的reader文件位于[demo/nas](../../../demo/nas/darts_cifar10_reader.py)中。
+
+**注意:**本示例为了简化代码直接调用`paddle.dataset.cifar10`定义训练数据和预测数据,实际训练需要使用自定义cifar10的reader。
+```python
+train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024), batch_size=BATCH_SIZE, drop_last=True)
+test_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False), batch_size=BATCH_SIZE, drop_last=False)
+train_loader.set_sample_list_generator(train_reader, places=place)
+test_loader.set_sample_list_generator(test_reader, places=place)
+```
+
+#### 9.6 启动训练和评估
+```python
+for epoch_id in range(RETAIN_EPOCH):
+ train_top1 = train(train_program, exe, epoch_id, train_loader, train_fetch_list)
+ print("TRAIN: Epoch {}, train_acc {:.6f}".format(epoch_id, train_top1))
+ valid_top1 = valid(test_program, exe, epoch_id, test_loader, test_fetch_list)
+ print("TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id, valid_top1))
+ valid_top1_list.append(valid_top1)
+```
+
+#### 9.7 回传当前模型的得分reward
+本教程利用最后两个epoch的准确率均值作为最终的得分回传给SANAS。
+```python
+sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
+```
+
+
+### 10. 利用demo下的脚本启动搜索
+
+搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。
+```python
+cd demo/nas/
+python darts_nas.py
+```
+
+### 11. 利用demo下的脚本启动最终实验
+最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`。
+```python
+cd demo/nas/
+python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600
+```
diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py
index bd8a3d3141dd3390176cc833101003d1cc2a6351..ba72463f020da484fedb75d8f27443347cbd086c 100644
--- a/paddleslim/nas/search_space/__init__.py
+++ b/paddleslim/nas/search_space/__init__.py
@@ -18,11 +18,12 @@ from .resnet import ResNetSpace
from .mobilenet_block import MobileNetV1BlockSpace, MobileNetV2BlockSpace
from .resnet_block import ResNetBlockSpace
from .inception_block import InceptionABlockSpace, InceptionCBlockSpace
+from .darts_space import DartsSpace
from .search_space_registry import SEARCHSPACE
from .search_space_factory import SearchSpaceFactory
from .search_space_base import SearchSpaceBase
__all__ = [
- 'MobileNetV1Space', 'MobileNetV2Space', 'ResNetSpace',
+ 'MobileNetV1Space', 'MobileNetV2Space', 'ResNetSpace', 'DartsSpace',
'MobileNetV1BlockSpace', 'MobileNetV2BlockSpace', 'ResNetBlockSpace',
'InceptionABlockSpace', 'InceptionCBlockSpace', 'SearchSpaceBase',
'SearchSpaceFactory', 'SEARCHSPACE'
diff --git a/paddleslim/nas/search_space/darts_space.py b/paddleslim/nas/search_space/darts_space.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2c5e3ec1fe6dfb94e38f84f74923ef92d70de7
--- /dev/null
+++ b/paddleslim/nas/search_space/darts_space.py
@@ -0,0 +1,626 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle.fluid as fluid
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import UniformInitializer, ConstantInitializer
+from .search_space_base import SearchSpaceBase
+from .base_layer import conv_bn_layer
+from .search_space_registry import SEARCHSPACE
+
+
+@SEARCHSPACE.register
+class DartsSpace(SearchSpaceBase):
+ def __init__(self, input_size, output_size, block_num, block_mask):
+ super(DartsSpace, self).__init__(input_size, output_size, block_num,
+ block_mask)
+ self.filter_num = np.array(
+ [4, 8, 12, 16, 20, 36, 54, 72, 90, 108, 144, 180, 216, 252])
+
+ def init_tokens(self):
+ return [5] * 6 + [7] * 7 + [10] * 7
+
+ def range_table(self):
+ return [len(self.filter_num)] * 20
+
+ def token2arch(self, tokens=None):
+ if tokens == None:
+ tokens = self.init_tokens()
+
+ self.bottleneck_params_list = []
+ reduction_count = 0
+ for i in range(3):
+ for j in range(6):
+ block_idx = i * 6 + j + reduction_count
+ self.bottleneck_params_list.append(
+ (self.filter_num[tokens[block_idx]], 1))
+ if i < 2:
+ reduction_count += 1
+ block_idx = i * 6 + j + reduction_count
+ self.bottleneck_params_list.append(
+ (self.filter_num[tokens[block_idx]], 2))
+
+ def net_arch(input, drop_prob, drop_path_mask, is_train, num_classes):
+ c_in = 36
+ stem_multiplier = 3
+ c_curr = stem_multiplier * c_in
+ x = self._conv_bn(
+ input,
+ c_curr,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ name='cifar10_darts_conv0')
+ s0 = s1 = x
+
+ logits_aux = None
+ reduction_prev = False
+
+ for i, layer_setting in enumerate(self.bottleneck_params_list):
+ filter_num, stride = layer_setting[0], layer_setting[1]
+ if stride == 2:
+ reduction = True
+ else:
+ reduction = False
+
+ if is_train:
+ drop_path_cell = drop_path_mask[:, i, :, :]
+ else:
+ drop_path_cell = drop_path_mask
+
+ s0, s1 = s1, self._cell(
+ s0,
+ s1,
+ filter_num,
+ stride,
+ reduction_prev,
+ drop_prob,
+ drop_path_cell,
+ is_train,
+ name='cifar10_darts_layer{}'.format(i + 1))
+ reduction_prev = reduction
+
+ if i == 2 * 20 // 3:
+ if is_train:
+ logits_aux = self._auxiliary_cifar(
+ s1, num_classes,
+ "cifar10_darts_/l" + str(i) + "/aux")
+
+ logits = self._classifier(s1, num_classes, name='cifar10_darts')
+
+ return logits, logits_aux
+
+ return net_arch
+
+ def _classifier(self, x, num_classes, name):
+ out = fluid.layers.pool2d(x, pool_type='avg', global_pooling=True)
+ out = fluid.layers.squeeze(out, axes=[2, 3])
+ k = (1. / out.shape[1])**0.5
+ out = fluid.layers.fc(out,
+ num_classes,
+ param_attr=fluid.ParamAttr(
+ name=name + "/fc_weights",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=fluid.ParamAttr(
+ name=name + "/fc_bias",
+ initializer=UniformInitializer(
+ low=-k, high=k)))
+ return out
+
+ def _auxiliary_cifar(self, x, num_classes, name):
+ x = fluid.layers.relu(x)
+ pooled = fluid.layers.pool2d(
+ x, pool_size=5, pool_stride=3, pool_padding=0, pool_type='avg')
+ conv1 = self._conv_bn(
+ x=pooled,
+ c_out=128,
+ kernel_size=1,
+ padding=0,
+ stride=1,
+ name=name + '/conv_bn1')
+ conv1 = fluid.layers.relu(conv1)
+ conv2 = self._conv_bn(
+ x=conv1,
+ c_out=768,
+ kernel_size=2,
+ padding=0,
+ stride=1,
+ name=name + '/conv_bn2')
+ conv2 = fluid.layers.relu(conv2)
+ out = self._classifier(conv2, num_classes, name)
+ return out
+
+ def _cell(self,
+ s0,
+ s1,
+ filter_num,
+ stride,
+ reduction_prev,
+ drop_prob,
+ drop_path_cell,
+ is_train,
+ name=None):
+ if reduction_prev:
+ s0 = self._factorized_reduce(s0, filter_num, name=name + '/s-2')
+ else:
+ s0 = self._relu_conv_bn(
+ s0, filter_num, 1, 1, 0, name=name + '/s-2')
+ s1 = self._relu_conv_bn(s1, filter_num, 1, 1, 0, name=name + '/s-1')
+
+ if stride == 1:
+ out = self._normal_cell(
+ s0,
+ s1,
+ filter_num,
+ drop_prob,
+ drop_path_cell,
+ is_train,
+ name=name)
+ else:
+ out = self._reduction_cell(
+ s0,
+ s1,
+ filter_num,
+ drop_prob,
+ drop_path_cell,
+ is_train,
+ name=name)
+ return out
+
+ def _normal_cell(self,
+ s0,
+ s1,
+ filter_num,
+ drop_prob,
+ drop_path_cell,
+ is_train,
+ name=None):
+ hidden0_0 = self._dil_conv(
+ s0,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=2,
+ dilation=2,
+ affine=True,
+ name=name + '_normal_cell_hidden0_0')
+ hidden0_1 = self._sep_conv(
+ s1,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ affine=True,
+ name=name + '_normal_cell_hidden0_1')
+
+ if is_train:
+ hidden0_0 = self._drop_path(
+ hidden0_0,
+ drop_prob,
+ drop_path_cell[:, 0, 0],
+ name=name + '_normal_cell_hidden0_0')
+ hidden0_1 = self._drop_path(
+ hidden0_1,
+ drop_prob,
+ drop_path_cell[:, 0, 1],
+ name=name + '_normal_cell_hidden0_1')
+ n0 = hidden0_0 + hidden0_1
+
+ hidden1_0 = self._sep_conv(
+ s0,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ affine=True,
+ name=name + '_normal_cell_hidden1_0')
+ hidden1_1 = self._sep_conv(
+ s1,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ affine=True,
+ name=name + '_normal_cell_hidden1_1')
+ if is_train:
+ hidden1_0 = self._drop_path(
+ hidden1_0,
+ drop_prob,
+ drop_path_cell[:, 1, 0],
+ name=name + '_normal_cell_hidden1_0')
+ hidden1_1 = self._drop_path(
+ hidden1_1,
+ drop_prob,
+ drop_path_cell[:, 1, 1],
+ name=name + '_normal_cell_hidden1_1')
+ n1 = hidden1_0 + hidden1_1
+
+ hidden2_0 = self._sep_conv(
+ s0,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ affine=True,
+ name=name + '_normal_cell_hidden2_0')
+ hidden2_1 = self._sep_conv(
+ s1,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ affine=True,
+ name=name + '_normal_cell_hidden2_1')
+ if is_train:
+ hidden2_0 = self._drop_path(
+ hidden2_0,
+ drop_prob,
+ drop_path_cell[:, 2, 0],
+ name=name + '_normal_cell_hidden2_0')
+ hidden2_1 = self._drop_path(
+ hidden2_1,
+ drop_prob,
+ drop_path_cell[:, 2, 1],
+ name=name + '_normal_cell_hidden2_1')
+ n2 = hidden2_0 + hidden2_1
+
+ ### skip connect => identity
+ hidden3_0 = s0
+ hidden3_1 = self._sep_conv(
+ s1,
+ c_out=filter_num,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ affine=True,
+ name=name + '_normal_cell_hidden3_1')
+ if is_train:
+ hidden3_1 = self._drop_path(
+ hidden3_1,
+ drop_prob,
+ drop_path_cell[:, 3, 1],
+ name=name + '_normal_cell_hidden3_1')
+ n3 = hidden3_0 + hidden3_1
+
+ out = fluid.layers.concat(
+ input=[n0, n1, n2, n3], axis=1, name=name + '_normal_cell_concat')
+ return out
+
+ def _reduction_cell(self,
+ s0,
+ s1,
+ filter_num,
+ drop_prob,
+ drop_path_cell,
+ is_train,
+ name=None):
+ hidden0_0 = fluid.layers.pool2d(
+ input=s0,
+ pool_size=3,
+ pool_type="max",
+ pool_stride=2,
+ pool_padding=1,
+ name=name + '_reduction_cell_hidden0_0')
+ hidden0_1 = self._factorized_reduce(
+ s1,
+ filter_num,
+ affine=True,
+ name=name + '_reduction_cell_hidden0_1')
+ if is_train:
+ hidden0_0 = self._drop_path(
+ hidden0_0,
+ drop_prob,
+ drop_path_cell[:, 0, 0],
+ name=name + '_reduction_cell_hidden0_0')
+ r0 = hidden0_0 + hidden0_1
+
+ hidden1_0 = fluid.layers.pool2d(
+ input=s1,
+ pool_size=3,
+ pool_type="max",
+ pool_stride=2,
+ pool_padding=1,
+ name=name + '_reduction_cell_hidden1_0')
+ hidden1_1 = r0
+ if is_train:
+ hidden1_0 = self._drop_path(
+ hidden1_0,
+ drop_prob,
+ drop_path_cell[:, 1, 0],
+ name=name + '_reduction_cell_hidden1_0')
+ r1 = hidden1_0 + hidden1_1
+
+ hidden2_0 = r0
+ hidden2_1 = self._dil_conv(
+ r1,
+ c_out=filter_num,
+ kernel_size=5,
+ stride=1,
+ padding=4,
+ dilation=2,
+ affine=True,
+ name=name + '_reduction_cell_hidden2_1')
+ if is_train:
+ hidden2_1 = self._drop_path(
+ hidden2_1,
+ drop_prob,
+ drop_path_cell[:, 2, 0],
+ name=name + '_reduction_cell_hidden2_1')
+ r2 = hidden2_0 + hidden2_1
+
+ hidden3_0 = r0
+ hidden3_1 = fluid.layers.pool2d(
+ input=s1,
+ pool_size=3,
+ pool_type="max",
+ pool_stride=2,
+ pool_padding=1,
+ name=name + '_reduction_cell_hidden3_1')
+ if is_train:
+ hidden3_1 = self._drop_path(
+ hidden3_1,
+ drop_prob,
+ drop_path_cell[:, 3, 0],
+ name=name + '_reduction_cell_hidden3_1')
+ r3 = hidden3_0 + hidden3_1
+
+ out = fluid.layers.concat(
+ input=[r0, r1, r2, r3],
+ axis=1,
+ name=name + '_reduction_cell_concat')
+ return out
+
+ def _conv_bn(self, x, c_out, kernel_size, padding, stride, name):
+ k = (1. / x.shape[1] / kernel_size / kernel_size)**0.5
+ conv1 = fluid.layers.conv2d(
+ x,
+ c_out,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ param_attr=fluid.ParamAttr(
+ name=name + "/conv",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ bn1 = fluid.layers.batch_norm(
+ conv1,
+ param_attr=fluid.ParamAttr(
+ name=name + "/bn_scale",
+ initializer=ConstantInitializer(value=1)),
+ bias_attr=fluid.ParamAttr(
+ name=name + "/bn_offset",
+ initializer=ConstantInitializer(value=0)),
+ moving_mean_name=name + "/bn_mean",
+ moving_variance_name=name + "/bn_variance")
+ return bn1
+
+ def _sep_conv(self,
+ x,
+ c_out,
+ kernel_size,
+ stride,
+ padding,
+ affine=True,
+ name=''):
+ c_in = x.shape[1]
+ x = fluid.layers.relu(x)
+ k = (1. / x.shape[1] / kernel_size / kernel_size)**0.5
+ x = fluid.layers.conv2d(
+ x,
+ c_in,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=c_in,
+ use_cudnn=False,
+ param_attr=fluid.ParamAttr(
+ name=name + "/sep_conv_1_1",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ k = (1. / x.shape[1] / 1 / 1)**0.5
+ x = fluid.layers.conv2d(
+ x,
+ c_in,
+ 1,
+ padding=0,
+ param_attr=fluid.ParamAttr(
+ name=name + "/sep_conv_1_2",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ gama, beta = self._bn_param_config(name, affine, "sep_conv_bn1")
+ x = fluid.layers.batch_norm(
+ x,
+ param_attr=gama,
+ bias_attr=beta,
+ moving_mean_name=name + "/sep_bn1_mean",
+ moving_variance_name=name + "/sep_bn1_variance")
+
+ x = fluid.layers.relu(x)
+ k = (1. / x.shape[1] / kernel_size / kernel_size)**0.5
+ c_in = x.shape[1]
+ x = fluid.layers.conv2d(
+ x,
+ c_in,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=c_in,
+ use_cudnn=False,
+ param_attr=fluid.ParamAttr(
+ name=name + "/sep_conv2_1",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ k = (1. / x.shape[1] / 1 / 1)**0.5
+ x = fluid.layers.conv2d(
+ x,
+ c_out,
+ 1,
+ padding=0,
+ param_attr=fluid.ParamAttr(
+ name=name + "/sep_conv2_2",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ gama, beta = self._bn_param_config(name, affine, "sep_conv_bn2")
+ x = fluid.layers.batch_norm(
+ x,
+ param_attr=gama,
+ bias_attr=beta,
+ moving_mean_name=name + "/sep_bn2_mean",
+ moving_variance_name=name + "/sep_bn2_variance")
+ return x
+
+ def _dil_conv(self,
+ x,
+ c_out,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ affine=True,
+ name=''):
+ c_in = x.shape[1]
+ x = fluid.layers.relu(x)
+ k = (1. / x.shape[1] / kernel_size / kernel_size)**0.5
+ x = fluid.layers.conv2d(
+ x,
+ c_in,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=c_in,
+ use_cudnn=False,
+ param_attr=fluid.ParamAttr(
+ name=name + "/dil_conv1",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ k = (1. / x.shape[1] / 1 / 1)**0.5
+ x = fluid.layers.conv2d(
+ x,
+ c_out,
+ 1,
+ padding=0,
+ param_attr=fluid.ParamAttr(
+ name=name + "/dil_conv2",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ gama, beta = self._bn_param_config(name, affine, "dil_conv_bn")
+ x = fluid.layers.batch_norm(
+ x,
+ param_attr=gama,
+ bias_attr=beta,
+ moving_mean_name=name + "/dil_bn_mean",
+ moving_variance_name=name + "/dil_bn_variance")
+ return x
+
+ def _factorized_reduce(self, x, c_out, affine=True, name=''):
+ assert c_out % 2 == 0
+ x = fluid.layers.relu(x)
+ x_sliced = x[:, :, 1:, 1:]
+ k = (1. / x.shape[1] / 1 / 1)**0.5
+ conv1 = fluid.layers.conv2d(
+ x,
+ c_out // 2,
+ 1,
+ stride=2,
+ param_attr=fluid.ParamAttr(
+ name=name + "/fr_conv1",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ k = (1. / x_sliced.shape[1] / 1 / 1)**0.5
+ conv2 = fluid.layers.conv2d(
+ x_sliced,
+ c_out // 2,
+ 1,
+ stride=2,
+ param_attr=fluid.ParamAttr(
+ name=name + "/fr_conv2",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ x = fluid.layers.concat(input=[conv1, conv2], axis=1)
+ gama, beta = self._bn_param_config(name, affine, "fr_bn")
+ x = fluid.layers.batch_norm(
+ x,
+ param_attr=gama,
+ bias_attr=beta,
+ moving_mean_name=name + "/fr_mean",
+ moving_variance_name=name + "/fr_variance")
+ return x
+
+ def _relu_conv_bn(self,
+ x,
+ c_out,
+ kernel_size,
+ stride,
+ padding,
+ affine=True,
+ name=''):
+ x = fluid.layers.relu(x)
+ k = (1. / x.shape[1] / kernel_size / kernel_size)**0.5
+ x = fluid.layers.conv2d(
+ x,
+ c_out,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ param_attr=fluid.ParamAttr(
+ name=name + "/rcb_conv",
+ initializer=UniformInitializer(
+ low=-k, high=k)),
+ bias_attr=False)
+ gama, beta = self._bn_param_config(name, affine, "rcb_bn")
+ x = fluid.layers.batch_norm(
+ x,
+ param_attr=gama,
+ bias_attr=beta,
+ moving_mean_name=name + "/rcb_mean",
+ moving_variance_name=name + "/rcb_variance")
+ return x
+
+ def _bn_param_config(self, name='', affine=False, op=None):
+ gama_name = name + "/" + str(op) + "/gama"
+ beta_name = name + "/" + str(op) + "/beta"
+ gama = ParamAttr(
+ name=gama_name,
+ initializer=ConstantInitializer(value=1),
+ trainable=affine)
+ beta = ParamAttr(
+ name=beta_name,
+ initializer=ConstantInitializer(value=0),
+ trainable=affine)
+ return gama, beta
+
+ def _drop_path(self, x, drop_prob, mask, name=None):
+ keep_prob = 1 - drop_prob[0]
+ x = fluid.layers.elementwise_mul(
+ x / keep_prob,
+ mask,
+ axis=0,
+ name=name + '_drop_path_elementwise_mul')
+ return x