未验证 提交 9a9cf241 编写于 作者: B Bai Yifan 提交者: GitHub

Add darts api doc (#250)

* add darts en api doc

* add darts ch api doc

* fix example code
上级 83d5d128
...@@ -15,6 +15,7 @@ API Documents ...@@ -15,6 +15,7 @@ API Documents
paddleslim.quant.rst paddleslim.quant.rst
paddleslim.nas.rst paddleslim.nas.rst
paddleslim.nas.one_shot.rst paddleslim.nas.one_shot.rst
paddleslim.nas.darts.rst
paddleslim.pantheon.rst paddleslim.pantheon.rst
search_space_en.md search_space_en.md
table_latency_en.md table_latency_en.md
paddleslim\.nas\.darts package
==============================
.. automodule:: paddleslim.nas.darts
:members:
:undoc-members:
:show-inheritance:
...@@ -39,6 +39,7 @@ PaddleSlim also provides auxiliary and primitive API for developer and researche ...@@ -39,6 +39,7 @@ PaddleSlim also provides auxiliary and primitive API for developer and researche
- Neural architecture search based on evolution strategy. - Neural architecture search based on evolution strategy.
- Support distributed search. - Support distributed search.
- One-Shot neural architecture search. - One-Shot neural architecture search.
- Differentiable Architecture Search.
- Support FLOPs and latency constrained search. - Support FLOPs and latency constrained search.
- Support the latency estimation on different hardware and platforms. - Support the latency estimation on different hardware and platforms.
......
可微分模型架构搜索DARTS
=========
DARTSearch
---------
.. py:class:: paddleslim.nas.DARTSearch(model, train_reader, valid_reader, place, learning_rate=0.025, batchsize=64, num_imgs=50000, arch_learning_rate=3e-4, unrolled=False, num_epochs=50, epochs_no_archopt=0, use_data_parallel=False, save_dir='./', log_freq=50)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/nas/darts/train_search.py>`_
定义一个DARTS搜索示例,用于在特定数据集和搜索空间上启动模型架构搜索。
**参数:**
- **model** (Paddle Dygraph model)-用于搜索的超网络模型,需要以PaddlePaddle动态图的形式定义。
- **train_reader** (Python Generator)-输入train数据的 `batch generator <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/io_cn/DataLoader_cn.html>`_
- **valid_reader** (Python Generator)-输入valid数据的 `batch generator <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/io_cn/DataLoader_cn.html>`_
- **place** (fluid.CPUPlace()|fluid.CUDAPlace(N))-该参数表示程序运行在何种设备上,这里的NGPU对应的ID
- **learning_rate** (float)-模型参数的初始学习率。默认值:0.025
- **batchsize** (int)-搜索过程数据的批大小。默认值:64
- **arch_learning_rate** (float)-架构参数的学习率。默认值:3e-4
- **unrolled** (bool)-是否使用二阶搜索算法。默认值:False
- **num_epochs** (int)-搜索训练的轮数。默认值:50
- **epochs_no_archopt** (int)-跳过前若干轮的模型架构参数优化。默认值:0
- **use_data_parallel** (bool)-是否使用数据并行的多卡训练。默认值:False
- **log_freq** (int)-每多少步输出一条log。默认值:50
.. py:method:: paddleslim.nas.DARTSearch.train()
对以上定义好的目标网络和数据进行DARTS搜索
**使用示例:**
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
from paddleslim.nas.darts import DARTSearch
class SuperNet(fluid.dygraph.Layer):
def __init__(self):
super(SuperNet, self).__init__()
self._method = 'DARTS'
self._steps = 1
self.stem=fluid.dygraph.nn.Conv2D(
num_channels=1,
num_filters=3,
filter_size=3,
padding=1)
self.classifier = fluid.dygraph.nn.Linear(
input_dim=3072,
output_dim=10)
self._multiplier = 4
self._primitives = ['none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5']
self._initialize_alphas()
def _initialize_alphas(self):
self.alphas_normal = fluid.layers.create_parameter(
shape=[14, 8],
dtype="float32")
self.alphas_reduce = fluid.layers.create_parameter(
shape=[14, 8],
dtype="float32")
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
def arch_parameters(self):
return self._arch_parameters
def forward(self, input):
out = self.stem(input) * self.alphas_normal[0][0] * self.alphas_reduce[0][0]
out = fluid.layers.reshape(out, [0, -1])
logits = self.classifier(out)
return logits
def _loss(self, input, label):
logits = self.forward(input)
return fluid.layers.reduce_mean(fluid.layers.softmax_with_cross_entropy(logits, label))
def batch_generator_creator():
def __reader__():
for _ in range(1024):
batch_image = np.random.random(size=[64, 1, 32, 32]).astype('float32')
batch_label = np.random.random(size=[64, 1]).astype('int64')
yield batch_image, batch_label
return __reader__
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
model = SuperNet()
train_reader = batch_generator_creator()
valid_reader = batch_generator_creator()
searcher = DARTSearch(model, train_reader, valid_reader, unrolled=False)
searcher.train()
..
...@@ -12,6 +12,7 @@ API文档 ...@@ -12,6 +12,7 @@ API文档
analysis_api.rst analysis_api.rst
nas_api.rst nas_api.rst
one_shot_api.rst one_shot_api.rst
darts.rst
pantheon_api.md pantheon_api.md
prune_api.rst prune_api.rst
quantization_api.rst quantization_api.rst
......
...@@ -28,6 +28,7 @@ PaddleSlim会从底层能力、技术咨询合作和业务场景等角度支持 ...@@ -28,6 +28,7 @@ PaddleSlim会从底层能力、技术咨询合作和业务场景等角度支持
- 神经网络结构自动搜索(NAS) - 神经网络结构自动搜索(NAS)
- 支持基于进化算法的轻量神经网络结构自动搜索 - 支持基于进化算法的轻量神经网络结构自动搜索
- 支持One-Shot网络结构自动搜索 - 支持One-Shot网络结构自动搜索
- 支持基于梯度的DARTS网络结构自动搜索
- 支持 FLOPS / 硬件延时约束 - 支持 FLOPS / 硬件延时约束
- 支持多平台模型延时评估 - 支持多平台模型延时评估
- 支持用户自定义搜索算法和搜索空间 - 支持用户自定义搜索算法和搜索空间
......
...@@ -29,6 +29,14 @@ logger = get_logger(__name__, level=logging.INFO) ...@@ -29,6 +29,14 @@ logger = get_logger(__name__, level=logging.INFO)
def count_parameters_in_MB(all_params): def count_parameters_in_MB(all_params):
"""Count the parameters in the target list.
Args:
all_params(list): List of Variables.
Returns:
float: The total count(MB) of target parameter list.
"""
parameters_number = 0 parameters_number = 0
for param in all_params: for param in all_params:
if param.trainable and 'aux' not in param.name: if param.trainable and 'aux' not in param.name:
...@@ -37,6 +45,24 @@ def count_parameters_in_MB(all_params): ...@@ -37,6 +45,24 @@ def count_parameters_in_MB(all_params):
class DARTSearch(object): class DARTSearch(object):
"""Used for Differentiable ARchiTecture Search(DARTS)
Args:
model(Paddle DyGraph model): Super Network for Search.
train_reader(Python Generator): Generator to provide training data.
valid_reader(Python Generator): Generator to provide validation data.
place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents the executor run on which device.
learning_rate(float): Model parameter initial learning rate. Default: 0.025.
batch_size(int): Minibatch size. Default: 64.
arch_learning_rate(float): Learning rate for arch encoding. Default: 3e-4.
unrolled(bool): Use one-step unrolled validation loss. Default: False.
num_epochs(int): Epoch number. Default: 50.
epochs_no_archopt(int): Epochs skip architecture optimize at begining. Default: 0.
use_data_parallel(bool): Whether to use data parallel mode. Default: False.
log_freq(int): Log frequency. Default: 50.
"""
def __init__(self, def __init__(self,
model, model,
train_reader, train_reader,
...@@ -149,6 +175,10 @@ class DARTSearch(object): ...@@ -149,6 +175,10 @@ class DARTSearch(object):
return top1.avg[0] return top1.avg[0]
def train(self): def train(self):
"""Start search process.
"""
if self.use_data_parallel: if self.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context() strategy = fluid.dygraph.parallel.prepare_context()
model_parameters = [ model_parameters = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册