index.html 21.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  
  <link rel="shortcut icon" href="../../img/favicon.ico">
  <title>SA搜索 - PaddleSlim Docs</title>
  <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>

  <link rel="stylesheet" href="../../css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
  <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
17
  <link href="../../extra.css" rel="stylesheet">
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
  
  <script>
    // Current page data
    var mkdocs_page_name = "SA\u641c\u7d22";
    var mkdocs_page_input_path = "api/nas_api.md";
    var mkdocs_page_url = null;
  </script>
  
  <script src="../../js/jquery-2.1.1.min.js" defer></script>
  <script src="../../js/modernizr-2.8.3.min.js" defer></script>
  <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
  <script>hljs.initHighlightingOnLoad();</script> 
  
</head>

<body class="wy-body-for-nav" role="document">

  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
      <div class="wy-side-nav-search">
        <a href="../.." class="icon icon-home"> PaddleSlim Docs</a>
        <div role="search">
  <form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" title="Type search term here" />
  </form>
</div>
      </div>

      <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
	<ul class="current">
	  
          
            <li class="toctree-l1">
		
    <a class="" href="../..">Home</a>
	    </li>
          
            <li class="toctree-l1">
		
59 60 61 62 63
    <a class="" href="../../model_zoo/">模型库</a>
	    </li>
          
            <li class="toctree-l1">
		
64 65 66 67
    <span class="caption-text">教程</span>
    <ul class="subnav">
                <li class="">
                    
68 69 70 71
    <a class="" href="../../tutorials/pruning_tutorial/">图像分类模型通道剪裁-快速开始</a>
                </li>
                <li class="">
                    
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    <a class="" href="../../tutorials/quant_post_demo/">离线量化</a>
                </li>
                <li class="">
                    
    <a class="" href="../../tutorials/quant_aware_demo/">量化训练</a>
                </li>
                <li class="">
                    
    <a class="" href="../../tutorials/quant_embedding_demo/">Embedding量化</a>
                </li>
                <li class="">
                    
    <a class="" href="../../tutorials/nas_demo/">SA搜索</a>
                </li>
                <li class="">
                    
88
    <a class="" href="../../search_space/">搜索空间</a>
89 90 91
                </li>
                <li class="">
                    
92 93 94 95 96 97 98
    <a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
99 100 101 102 103 104 105 106
    <span class="caption-text">API</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../quantization_api/">量化</a>
                </li>
                <li class="">
                    
107
    <a class="" href="../prune_api/">剪枝与敏感度</a>
108 109 110
                </li>
                <li class="">
                    
111
    <a class="" href="../analysis_api/">模型分析</a>
112 113 114
                </li>
                <li class="">
                    
115
    <a class="" href="../single_distiller_api/">知识蒸馏</a>
116 117 118 119 120 121
                </li>
                <li class=" current">
                    
    <a class="current" href="./">SA搜索</a>
    <ul class="subnav">
            
122
    <li class="toctree-l3"><a href="#_1">搜索空间参数的配置</a></li>
123
    
124 125

    <li class="toctree-l3"><a href="#sanas">SANAS</a></li>
126 127 128 129 130 131
    

    </ul>
                </li>
                <li class="">
                    
132 133
    <a class="" href="../../table_latency/">硬件延时评估表</a>
                </li>
134 135 136
    </ul>
	    </li>
          
137 138 139 140 141
            <li class="toctree-l1">
		
    <a class="" href="../../algo/algo/">算法原理</a>
	    </li>
          
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        </ul>
      </div>
      &nbsp;
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../..">PaddleSlim Docs</a>
      </nav>

      
      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
    <li><a href="../..">Docs</a> &raquo;</li>
    
      
        
          <li>API &raquo;</li>
        
      
    
    <li>SA搜索</li>
    <li class="wy-breadcrumbs-aside">
      
171
        <a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/api/nas_api.md"
172 173 174 175 176 177 178 179 180
          class="icon icon-github"> Edit on GitHub</a>
      
    </li>
  </ul>
  <hr/>
</div>
          <div role="main">
            <div class="section">
              
181 182 183 184
                <h2 id="_1">搜索空间参数的配置<a class="headerlink" href="#_1" title="Permanent link">#</a></h2>
<p>通过参数配置搜索空间。更多搜索空间的使用可以参考<a href="../../search_space/">search_space</a></p>
<p><strong>参数:</strong></p>
<ul>
185 186
<li><strong>input_size(int|None)</strong>:- <code>input_size</code>表示输入feature map的大小。<code>input_size</code><code>output_size</code>用来计算整个模型结构中下采样次数。</li>
<li><strong>output_size(int|None)</strong>:- <code>output_size</code>表示输出feature map的大小。<code>input_size</code><code>output_size</code>用来计算整个模型结构中下采样次数。</li>
187
<li><strong>block_num(int|None)</strong>:- <code>block_num</code>表示搜索空间中block的数量。</li>
188
<li><strong>block_mask(list|None)</strong>:- <code>block_mask</code>是一组由0、1组成的列表,0表示当前block是normal block,1表示当前block是reduction block。reduction block表示经过这个block之后的feature map大小下降为之前的一半,normal block表示经过这个block之后feature map大小不变。如果设置了<code>block_mask</code>,则主要以<code>block_mask</code>为主要配置,<code>input_size</code><code>output_size</code><code>block_num</code>三种配置是无效的。</li>
189 190 191
</ul>
<h2 id="sanas">SANAS<a class="headerlink" href="#sanas" title="Permanent link">#</a></h2>
<dl>
192
<dt>paddleslim.nas.SANAS(configs, server_addr=("", 8881), init_temperature=None, reduce_rate=0.85, init_tokens=None, search_steps=300, save_checkpoint='./nas_checkpoint', load_checkpoint=None, is_server=True)<a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/nas/sa_nas.py#L36">源代码</a></dt>
193 194 195 196
<dd>SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火算法进行模型结构搜索的算法,一般用于离散搜索任务。</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
197
<li><strong>configs(list<tuple>)</strong> - 搜索空间配置列表,格式是<code>[(key, {input_size, output_size, block_num, block_mask})]</code>或者<code>[(key)]</code>(MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定<code>key</code>即可), <code>input_size</code><code>output_size</code>表示输入和输出的特征图的大小,<code>block_num</code>是指搜索网络中的block数量,<code>block_mask</code>是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考<a href="../../search_space/">Search Space</a></li>
198
<li><strong>server_addr(tuple)</strong> - SANAS的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。</li>
199 200 201
<li><strong>init_temperature(float)</strong> - 基于模拟退火进行搜索的初始温度。如果init_template为None而且init_tokens为None,则默认初始温度为10.0,如果init_template为None且init_tokens不为None,则默认初始温度为1.0。详细的温度设置可以参考下面的Note。默认:None。</li>
<li><strong>reduce_rate(float)</strong> - 基于模拟退火进行搜索的衰减率。详细的退火率设置可以参考下面的Note。默认:0.85。</li>
<li><strong>init_tokens(list|None)</strong> - 初始化token,若init_tokens为空,则SA算法随机生成初始化tokens。默认:None。</li>
202 203 204 205 206 207
<li><strong>search_steps(int)</strong> - 搜索过程迭代的次数。默认:300。</li>
<li><strong>save_checkpoint(str|None)</strong> - 保存checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:<code>./nas_checkpoint</code></li>
<li><strong>load_checkpoint(str|None)</strong> - 加载checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。</li>
<li><strong>is_server(bool)</strong> - 当前实例是否要启动一个server。默认:True。</li>
</ul>
<p><strong>返回:</strong>
208
一个SANAS类的实例</p>
209
<p><strong>示例代码:</strong>
210
<div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
211
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
212
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">configs</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
213
</pre></div></p>
214 215 216 217 218 219
<div class="admonition note">
<p class="admonition-title">Note</p>
</div>
<ul>
<li>
<p>初始化温度和退火率的意义: <br></p>
220
<ul>
221 222 223 224 225 226 227 228 229 230 231 232 233 234
<li>SA算法内部会保存一个基础token(初始化token可以自己传入也可以随机生成)和基础score(初始化score为-1),下一个token会在当前SA算法保存的token的基础上产生。在SA的搜索过程中,如果本轮的token训练得到的score大于SA算法中保存的score,则本轮的token一定会被SA算法接收保存为下一轮token产生的基础token。<br></li>
<li>初始温度越高表示SA算法当前处的阶段越不稳定,本轮的token训练得到的score小于SA算法中保存的score的话,本轮的token和score被SA算法接收的可能性越大。<br></li>
<li>初始温度越低表示SA算法当前处的阶段越稳定,本轮的token训练得到的score小于SA算法中保存的score的话,本轮的token和score被SA算法接收的可能性越小。<br></li>
<li>退火率越大,表示SA算法收敛的越慢,即SA算法越慢到稳定阶段。<br></li>
<li>退火率越低,表示SA算法收敛的越快,即SA算法越快到稳定阶段。<br></li>
</ul>
</li>
<li>
<p>初始化温度和退火率的设置: <br></p>
<ul>
<li>如果原本就有一个较好的初始化token,想要基于这个较好的token来进行搜索的话,SA算法可以处于一个较为稳定的状态进行搜索r这种情况下初始温度可以设置的低一些,例如设置为1.0,退火率设置的大一些,例如设置为0.85。如果想要基于这个较好的token利用贪心算法进行搜索,即只有当本轮token训练得到的score大于SA算法中保存的score,SA算法才接收本轮token,则退火率可设置为一个极小的数字,例如设置为0.85 ** 10。<br></li>
<li>初始化token如果是随机生成的话,代表初始化token是一个比较差的token,SA算法可以处于一种不稳定的阶段进行搜索,尽可能的随机探索所有可能得token,从而找到一个较好的token。初始温度可以设置的高一些,例如设置为1000,退火率相对设置的小一些。</li>
</ul>
</li>
235 236 237 238 239 240
</ul>
<dl>
<dt>paddleslim.nas.SANAS.next_archs()</dt>
<dd>获取下一组模型结构。</dd>
</dl>
<p><strong>返回:</strong>
241
返回模型结构实例的列表,形式为list。</p>
242
<p><strong>示例代码:</strong>
243
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
244 245 246
<span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">configs</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
247
<span class="nb">input</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;input&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
248 249 250 251
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">next_archs</span><span class="p">()</span>
<span class="k">for</span> <span class="n">arch</span> <span class="ow">in</span> <span class="n">archs</span><span class="p">:</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">arch</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
    <span class="nb">input</span> <span class="o">=</span> <span class="n">output</span>
252
<span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
253
</pre></div></p>
254 255 256 257 258 259 260 261 262
<dl>
<dt>paddleslim.nas.SANAS.reward(score)</dt>
<dd>把当前模型结构的得分情况回传。</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li><strong>score<float>:</strong> - 当前模型的得分,分数越大越好。</li>
</ul>
<p><strong>返回:</strong>
263
模型结构更新成功或者失败,成功则返回<code>True</code>,失败则返回<code>False</code></p>
264
<p><strong>示例代码:</strong>
265
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
<span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">configs</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">next_archs</span><span class="p">()</span>

<span class="c1">### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。</span>
<span class="n">score</span><span class="o">=</span><span class="nb">float</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">sanas</span><span class="o">.</span><span class="n">reward</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">score</span><span class="p">))</span>
</pre></div></p>
<dl>
<dt>paddlesim.nas.SANAS.tokens2arch(tokens)</dt>
<dd>通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li><strong>tokens(list):</strong> - 一组tokens。tokens的长度和范取决于搜索空间。</li>
</ul>
<p><strong>返回:</strong>
根据传入的token得到一个模型结构实例。</p>
<p><strong>示例代码:</strong>
286
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
287 288 289
<span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">configs</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
290
<span class="nb">input</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;input&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
291 292
<span class="n">tokens</span> <span class="o">=</span> <span class="p">([</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mi">25</span><span class="p">)</span>
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">tokens2arch</span><span class="p">(</span><span class="n">tokens</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
293
<span class="nb">print</span><span class="p">(</span><span class="n">archs</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
294
</pre></div></p>
295 296 297 298 299 300
<dl>
<dt>paddleslim.nas.SANAS.current_info()</dt>
<dd>返回当前token和搜索过程中最好的token和reward。</dd>
</dl>
<p><strong>返回:</strong>
搜索过程中最好的token,reward和当前训练的token,形式为dict。</p>
301
<p><strong>示例代码:</strong>
302
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
303 304 305
<span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">configs</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
306
<span class="nb">print</span><span class="p">(</span><span class="n">sanas</span><span class="o">.</span><span class="n">current_info</span><span class="p">())</span>
307
</pre></div></p>
308 309 310 311 312 313 314
              
            </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
315
        <a href="../../table_latency/" class="btn btn-neutral float-right" title="硬件延时评估表">Next <span class="icon icon-circle-arrow-right"></span></a>
316 317
      
      
318
        <a href="../single_distiller_api/" class="btn btn-neutral" title="知识蒸馏"><span class="icon icon-circle-arrow-left"></span> Previous</a>
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <!-- Copyright etc -->
    
  </div>

  Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
      
        </div>
      </div>

    </section>

  </div>

  <div class="rst-versions" role="note" style="cursor: pointer">
    <span class="rst-current-version" data-toggle="rst-current-version">
      
          <a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
      
      
        <span><a href="../single_distiller_api/" style="color: #fcfcfc;">&laquo; Previous</a></span>
      
      
349
        <span style="margin-left: 15px"><a href="../../table_latency/" style="color: #fcfcfc">Next &raquo;</a></span>
350 351 352 353 354 355
      
    </span>
</div>
    <script>var base_url = '../..';</script>
    <script src="../../js/theme.js" defer></script>
      <script src="../../mathjax-config.js" defer></script>
356
      <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
357 358 359 360
      <script src="../../search/main.js" defer></script>

</body>
</html>