提交 f5fac6bc 编写于 作者: B baiyfbupt

Deployed 405a4686 with MkDocs version: 1.0.4

上级 c8e905e6
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="/css/theme.css" type="text/css" />
<link rel="stylesheet" href="/css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="/extra.css" rel="stylesheet">
<script src="/js/jquery-2.1.1.min.js" defer></script>
<script src="/js/modernizr-2.8.3.min.js" defer></script>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......
此差异已折叠。
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -204,10 +205,13 @@
<p><strong>返回:</strong>
一个SANAS类的实例</p>
<p><strong>示例代码:</strong>
<div class="highlight"><pre><span></span>from paddleslim.nas import SANAS
config = [(&#39;MobileNetV2Space&#39;)]
sanas = SANAS(config=config)
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
</pre></div>
</td></tr></table></p>
<dl>
<dt>paddlesim.nas.SANAS.tokens2arch(tokens)</dt>
<dd>通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。</dd>
......@@ -221,13 +225,19 @@ tokens是一个列表,token映射到搜索空间转换成相应的网络结构
<p><strong>返回:</strong>
根据传入的token得到一个模型结构实例。</p>
<p><strong>示例代码:</strong>
<div class="highlight"><pre><span></span>import paddle.fluid as fluid
input = fluid.data(name=&#39;input&#39;, shape=[None, 3, 32, 32], dtype=&#39;float32&#39;)
archs = sanas.token2arch(tokens)
for arch in archs:
output = arch(input)
input = output
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;input&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">token2arch</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<span class="k">for</span> <span class="n">arch</span> <span class="ow">in</span> <span class="n">archs</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">arch</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">output</span>
</pre></div>
</td></tr></table></p>
<dl>
<dt>paddleslim.nas.SANAS.next_archs()</dt>
<dd>获取下一组模型结构。</dd>
......@@ -235,13 +245,19 @@ for arch in archs:
<p><strong>返回:</strong>
返回模型结构实例的列表,形式为list。</p>
<p><strong>示例代码:</strong>
<div class="highlight"><pre><span></span>import paddle.fluid as fluid
input = fluid.data(name=&#39;input&#39;, shape=[None, 3, 32, 32], dtype=&#39;float32&#39;)
archs = sanas.next_archs()
for arch in archs:
output = arch(input)
input = output
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;input&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">next_archs</span><span class="p">()</span>
<span class="k">for</span> <span class="n">arch</span> <span class="ow">in</span> <span class="n">archs</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">arch</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">output</span>
</pre></div>
</td></tr></table></p>
<dl>
<dt>paddleslim.nas.SANAS.reward(score)</dt>
<dd>把当前模型结构的得分情况回传。</dd>
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
.codehilite {
width: 590px;
overflow: auto;
}
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="css/theme.css" type="text/css" />
<link rel="stylesheet" href="css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -211,16 +212,20 @@
<ul>
<li>安装develop版本</li>
</ul>
<div class="highlight"><pre><span></span>git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddleSlim
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>git clone https://github.com/PaddlePaddle/PaddleSlim.git
<span class="nb">cd</span> PaddleSlim
python setup.py install
</pre></div>
</td></tr></table>
<ul>
<li>安装官方发布的最新版本</li>
</ul>
<div class="highlight"><pre><span></span>pip install paddleslim -i https://pypi.org/simple
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>pip install paddleslim -i https://pypi.org/simple
</pre></div>
</td></tr></table>
<ul>
<li>安装历史版本</li>
......@@ -289,5 +294,5 @@ python setup.py install
<!--
MkDocs version : 1.0.4
Build Date UTC : 2020-01-16 06:38:06
Build Date UTC : 2020-01-17 02:21:20
-->
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../extra.css" rel="stylesheet">
<script>
// Current page data
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="./css/theme.css" type="text/css" />
<link rel="stylesheet" href="./css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="./extra.css" rel="stylesheet">
<script src="./js/jquery-2.1.1.min.js" defer></script>
<script src="./js/modernizr-2.8.3.min.js" defer></script>
......
此差异已折叠。
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -243,13 +244,74 @@
&emsp; 2. token中每个数字的搜索列表长度(<code>range_table</code>函数),tokens中每个token的索引范围。<br>
&emsp; 3. 根据token产生模型结构(<code>token2arch</code>函数),根据搜索到的tokens列表产生模型结构。 <br></p>
<p>以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。</p>
<div class="highlight"><pre><span></span><span class="c1">### 引入搜索空间基类函数和search space的注册类函数</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1">### 引入搜索空间基类函数和search space的注册类函数</span>
<span class="kn">from</span> <span class="nn">.search_space_base</span> <span class="kn">import</span> <span class="n">SearchSpaceBase</span>
<span class="kn">from</span> <span class="nn">.search_space_registry</span> <span class="kn">import</span> <span class="n">SEARCHSPACE</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="c1">### 需要调用注册函数把自定义搜索空间注册到space space中</span>
<span class="nd">@SEARCHSPACE.register</span>
<span class="nd">@SEARCHSPACE</span><span class="o">.</span><span class="n">register</span>
<span class="c1">### 定义一个继承SearchSpaceBase基类的搜索空间的类函数</span>
<span class="k">class</span> <span class="nc">ResNetBlockSpace2</span><span class="p">(</span><span class="n">SearchSpaceBase</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">block_num</span><span class="p">,</span> <span class="n">block_mask</span><span class="p">):</span>
......@@ -266,8 +328,8 @@
<span class="k">return</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filter_num</span><span class="p">)]</span> <span class="o">*</span> <span class="mi">3</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">block_mask</span><span class="p">)</span>
<span class="c1">### 把token转换成模型结构</span>
<span class="k">def</span> <span class="nf">token2arch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">tokens</span> <span class="o">==</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">token2arch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">tokens</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_tokens</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bottleneck_params_list</span> <span class="o">=</span> <span class="p">[]</span>
......@@ -280,32 +342,33 @@
<span class="k">def</span> <span class="nf">net_arch</span><span class="p">(</span><span class="nb">input</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer_setting</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bottleneck_params_list</span><span class="p">):</span>
<span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">layer_setting</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_setting</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_resnet_block</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;resnet_layer{}&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">))</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_resnet_block</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;resnet_layer</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="nb">input</span>
<span class="k">return</span> <span class="n">net_arch</span>
<span class="c1">### 构造具体block的操作</span>
<span class="k">def</span> <span class="nf">_resnet_block</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_resnet_block</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">shortcut_conv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shortcut</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_bn_layer</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">channel_num</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">filter_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">act</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span> <span class="o">+</span> <span class="s1">&#39;_conv0&#39;</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_bn_layer</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">channel_num</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">filter_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span> <span class="n">act</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span> <span class="o">+</span> <span class="s1">&#39;_conv1&#39;</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_bn_layer</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">channel_num</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">filter_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span> <span class="o">+</span> <span class="s1">&#39;_conv2&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">elementwise_add</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">shortcut_conv</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="nb">input</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="o">+</span><span class="s1">&#39;_elementwise_add&#39;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_shortcut</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_shortcut</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">channel_num</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">channel_in</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">channel_in</span> <span class="o">!=</span> <span class="n">channel_num</span> <span class="ow">or</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_bn_layer</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">channel_num</span><span class="p">,</span> <span class="n">filter_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="o">+</span><span class="s1">&#39;_shortcut&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">input</span>
<span class="k">def</span> <span class="nf">_conv_bn_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="p">,</span> <span class="n">filter_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;SAME&#39;</span><span class="p">,</span> <span class="n">act</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_conv_bn_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="p">,</span> <span class="n">filter_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;SAME&#39;</span><span class="p">,</span> <span class="n">act</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">conv</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">num_filters</span><span class="p">,</span> <span class="n">filter_size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="o">+</span><span class="s1">&#39;_conv&#39;</span><span class="p">)</span>
<span class="n">bn</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">batch_norm</span><span class="p">(</span><span class="n">conv</span><span class="p">,</span> <span class="n">act</span><span class="o">=</span><span class="n">act</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="o">+</span><span class="s1">&#39;_bn&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">bn</span>
</pre></div>
</td></tr></table>
</div>
</div>
......
......@@ -2,77 +2,77 @@
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2020-01-16</lastmod>
<lastmod>2020-01-17</lastmod>
<changefreq>daily</changefreq>
</url>
</urlset>
\ No newline at end of file
无法预览此类型文件
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -208,8 +209,9 @@
<p>操作信息字段之间以逗号分割。操作信息与延迟信息之间以制表符分割。</p>
<h3 id="conv2d">conv2d<a class="headerlink" href="#conv2d" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="highlight"><pre><span></span>op_type,flag_bias,flag_relu,n_in,c_in,h_in,w_in,c_out,groups,kernel,padding,stride,dilation\tlatency
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>op_type,flag_bias,flag_relu,n_in,c_in,h_in,w_in,c_out,groups,kernel,padding,stride,dilation\tlatency
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -230,8 +232,9 @@
</ul>
<h3 id="activation">activation<a class="headerlink" href="#activation" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="highlight"><pre><span></span>op_type,n_in,c_in,h_in,w_in\tlatency
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>op_type,n_in,c_in,h_in,w_in\tlatency
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -244,8 +247,9 @@
</ul>
<h3 id="batch_norm">batch_norm<a class="headerlink" href="#batch_norm" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="highlight"><pre><span></span>op_type,active_type,n_in,c_in,h_in,w_in\tlatency
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>op_type,active_type,n_in,c_in,h_in,w_in\tlatency
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -259,8 +263,9 @@
</ul>
<h3 id="eltwise">eltwise<a class="headerlink" href="#eltwise" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="highlight"><pre><span></span>op_type,n_in,c_in,h_in,w_in\tlatency
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>op_type,n_in,c_in,h_in,w_in\tlatency
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -273,8 +278,9 @@
</ul>
<h3 id="pooling">pooling<a class="headerlink" href="#pooling" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="highlight"><pre><span></span>op_type,flag_global_pooling,n_in,c_in,h_in,w_in,kernel,padding,stride,ceil_mode,pool_type\tlatency
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>op_type,flag_global_pooling,n_in,c_in,h_in,w_in,kernel,padding,stride,ceil_mode,pool_type\tlatency
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -293,8 +299,9 @@
</ul>
<h3 id="softmax">softmax<a class="headerlink" href="#softmax" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="highlight"><pre><span></span>op_type,axis,n_in,c_in,h_in,w_in\tlatency
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>op_type,axis,n_in,c_in,h_in,w_in\tlatency
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -194,28 +195,56 @@
<p>一般情况下,模型参数量越多,结构越复杂,其性能越好,但运算量和资源消耗也越大。<strong>知识蒸馏</strong> 就是一种将大模型学习到的有用信息(Dark Knowledge)压缩进更小更快的模型,而获得可以匹敌大模型结果的方法。</p>
<p>在本示例中精度较高的大模型被称为teacher,精度稍逊但速度更快的小模型被称为student。</p>
<h3 id="1-student_program">1. 定义student_program<a class="headerlink" href="#1-student_program" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span><span class="n">student_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">student_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">student_startup</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">student_program</span><span class="p">,</span> <span class="n">student_startup</span><span class="p">):</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s1">&#39;image&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;label&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;int64&#39;</span><span class="p">)</span>
<span class="n">name</span><span class="o">=</span><span class="s1">&#39;image&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;label&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;int64&#39;</span><span class="p">)</span>
<span class="c1"># student model definition</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MobileNet</span><span class="p">()</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">image</span><span class="p">,</span> <span class="n">class_dim</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
<span class="n">cost</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">out</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">)</span>
<span class="n">avg_cost</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">cost</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="2-teacher_program">2. 定义teacher_program<a class="headerlink" href="#2-teacher_program" title="Permanent link">#</a></h3>
<p>在定义好<code>teacher_program</code>后,可以一并加载训练好的pretrained_model。</p>
<p><code>teacher_program</code>内需要加上<code>with fluid.unique_name.guard():</code>,保证teacher的变量命名不被<code>student_program</code>影响,从而能够正确地加载预训练参数。</p>
<div class="highlight"><pre><span></span><span class="n">teacher_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">teacher_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">teacher_startup</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">teacher_program</span><span class="p">,</span> <span class="n">teacher_startup</span><span class="p">):</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">unique_name</span><span class="o">.</span><span class="n">guard</span><span class="p">():</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s1">&#39;data&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">name</span><span class="o">=</span><span class="s1">&#39;data&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="c1"># teacher model definition</span>
<span class="n">teacher_model</span> <span class="o">=</span> <span class="n">ResNet</span><span class="p">()</span>
<span class="n">predict</span> <span class="o">=</span> <span class="n">teacher_model</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">class_dim</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
......@@ -229,18 +258,36 @@
<span class="n">main_program</span><span class="o">=</span><span class="n">teacher_program</span><span class="p">,</span>
<span class="n">predicate</span><span class="o">=</span><span class="n">if_exist</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="3">3.选择特征图<a class="headerlink" href="#3" title="Permanent link">#</a></h3>
<p>定义好<code>student_program</code><code>teacher_program</code>后,我们需要从中两两对应地挑选出若干个特征图,留待后续为其添加知识蒸馏损失函数。</p>
<div class="highlight"><pre><span></span><span class="c1"># get all student variables</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1"># get all student variables</span>
<span class="n">student_vars</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">student_program</span><span class="o">.</span><span class="n">list_vars</span><span class="p">():</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">student_vars</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">except</span><span class="p">:</span>
<span class="k">pass</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="o">+</span><span class="s2">&quot;student_model_vars&quot;</span><span class="o">+</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">student_vars</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="o">+</span><span class="s2">&quot;student_model_vars&quot;</span><span class="o">+</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">student_vars</span><span class="p">)</span>
<span class="c1"># get all teacher variables</span>
<span class="n">teacher_vars</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">teacher_program</span><span class="o">.</span><span class="n">list_vars</span><span class="p">():</span>
......@@ -248,21 +295,32 @@
<span class="n">teacher_vars</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">except</span><span class="p">:</span>
<span class="k">pass</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="o">+</span><span class="s2">&quot;teacher_model_vars&quot;</span><span class="o">+</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">teacher_vars</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="o">+</span><span class="s2">&quot;teacher_model_vars&quot;</span><span class="o">+</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">teacher_vars</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="4-programmerge">4. 合并Program(merge)<a class="headerlink" href="#4-programmerge" title="Permanent link">#</a></h3>
<p>PaddlePaddle使用Program来描述计算图,为了同时计算student和teacher两个Program,这里需要将其两者合并(merge)为一个Program。</p>
<p>merge过程操作较多,具体细节请参考<a href="https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge">merge API文档</a></p>
<div class="highlight"><pre><span></span><span class="n">data_name_map</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;data&#39;</span><span class="p">:</span> <span class="s1">&#39;image&#39;</span><span class="p">}</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">data_name_map</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;data&#39;</span><span class="p">:</span> <span class="s1">&#39;image&#39;</span><span class="p">}</span>
<span class="n">merge</span><span class="p">(</span><span class="n">teacher_program</span><span class="p">,</span> <span class="n">student_program</span><span class="p">,</span> <span class="n">data_name_map</span><span class="p">,</span> <span class="n">place</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="5loss">5.添加蒸馏loss<a class="headerlink" href="#5loss" title="Permanent link">#</a></h3>
<p>在添加蒸馏loss的过程中,可能还会引入部分变量(Variable),为了避免命名重复这里可以使用<code>with fluid.name_scope("distill"):</code>为新引入的变量加一个命名作用域。</p>
<p>另外需要注意的是,merge过程为<code>teacher_program</code>的变量统一加了名称前缀,默认是<code>"teacher_"</code>, 这里在添加<code>l2_loss</code>时也要为teacher的变量加上这个前缀。</p>
<div class="highlight"><pre><span></span><span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">student_program</span><span class="p">,</span> <span class="n">student_startup</span><span class="p">):</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8
9</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">student_program</span><span class="p">,</span> <span class="n">student_startup</span><span class="p">):</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">name_scope</span><span class="p">(</span><span class="s2">&quot;distill&quot;</span><span class="p">):</span>
<span class="n">distill_loss</span> <span class="o">=</span> <span class="n">l2_loss</span><span class="p">(</span><span class="s1">&#39;teacher_bn5c_branch2b.output.1.tmp_3&#39;</span><span class="p">,</span>
<span class="s1">&#39;depthwise_conv2d_11.tmp_0&#39;</span><span class="p">,</span> <span class="n">student_program</span><span class="p">)</span>
......@@ -272,6 +330,7 @@
<span class="n">opt</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">student_startup</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<p>至此,我们就得到了用于蒸馏训练的<code>student_program</code>,后面就可以使用一个普通program一样对其开始训练和评估。</p>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -182,57 +183,93 @@
<p>请参考。</p>
<h3 id="1">1. 配置搜索空间<a class="headerlink" href="#1" title="Permanent link">#</a></h3>
<p>详细的搜索空间配置可以参考<a href='../../../paddleslim/nas/nas_api.md'>神经网络搜索API文档</a>
<div class="highlight"><pre><span></span>config = [(&#39;MobileNetV2Space&#39;)]
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
</pre></div>
</td></tr></table></p>
<h3 id="2-sanas">2. 利用搜索空间初始化SANAS实例<a class="headerlink" href="#2-sanas" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>from paddleslim.nas import SANAS
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8
9</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
sa_nas = SANAS(
config,
server_addr=(&quot;&quot;, 8881),
init_temperature=10.24,
reduce_rate=0.85,
search_steps=300,
is_server=True)
<span class="n">sa_nas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span>
<span class="n">config</span><span class="p">,</span>
<span class="n">server_addr</span><span class="o">=</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="mi">8881</span><span class="p">),</span>
<span class="n">init_temperature</span><span class="o">=</span><span class="mf">10.24</span><span class="p">,</span>
<span class="n">reduce_rate</span><span class="o">=</span><span class="mf">0.85</span><span class="p">,</span>
<span class="n">search_steps</span><span class="o">=</span><span class="mi">300</span><span class="p">,</span>
<span class="n">is_server</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="3-nas">3. 根据实例化的NAS得到当前的网络结构<a class="headerlink" href="#3-nas" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>archs = sa_nas.next_archs()
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">archs</span> <span class="o">=</span> <span class="n">sa_nas</span><span class="o">.</span><span class="n">next_archs</span><span class="p">()</span>
</pre></div>
</td></tr></table>
<h3 id="4-program">4. 根据得到的网络结构和输入构造训练和测试program<a class="headerlink" href="#4-program" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>import paddle.fluid as fluid
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="k">as</span> <span class="nn">fluid</span>
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
<span class="n">train_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">test_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">startup_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
with fluid.program_guard(train_program, startup_program):
data = fluid.data(name=&#39;data&#39;, shape=[None, 3, 32, 32], dtype=&#39;float32&#39;)
label = fluid.data(name=&#39;label&#39;, shape=[None, 1], dtype=&#39;int64&#39;)
for arch in archs:
data = arch(data)
output = fluid.layers.fc(data, 10)
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">train_program</span><span class="p">,</span> <span class="n">startup_program</span><span class="p">):</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;data&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;label&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;int64&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">arch</span> <span class="ow">in</span> <span class="n">archs</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">arch</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">softmax_out</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">output</span><span class="p">,</span> <span class="n">use_cudnn</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">cost</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">softmax_out</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">)</span>
<span class="n">avg_cost</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>
<span class="n">acc_top1</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">softmax_out</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
test_program = train_program.clone(for_test=True)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(avg_cost)
<span class="n">test_program</span> <span class="o">=</span> <span class="n">train_program</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">for_test</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">sgd</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
<span class="n">sgd</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">avg_cost</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="5-program">5. 根据构造的训练program添加限制条件<a class="headerlink" href="#5-program" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>from paddleslim.analysis import flops
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.analysis</span> <span class="kn">import</span> <span class="n">flops</span>
if flops(train_program) &gt; 321208544:
continue
<span class="k">if</span> <span class="n">flops</span><span class="p">(</span><span class="n">train_program</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">321208544</span><span class="p">:</span>
<span class="k">continue</span>
</pre></div>
</td></tr></table>
<h3 id="6-score">6. 回传score<a class="headerlink" href="#6-score" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>sa_nas.reward(score)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">sa_nas</span><span class="o">.</span><span class="n">reward</span><span class="p">(</span><span class="n">score</span><span class="p">)</span>
</pre></div>
</td></tr></table>
</div>
</div>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -173,16 +174,20 @@
<p>该示例使用了<code>paddleslim.Pruner</code>工具类,用户接口使用介绍请参考:<a href="https://paddlepaddle.github.io/PaddleSlim/api/prune_api/">API文档</a></p>
<h2 id="_3">确定待裁参数<a class="headerlink" href="#_3" title="Permanent link">#</a></h2>
<p>不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名:</p>
<div class="highlight"><pre><span></span>for param in program.global_block().all_parameters():
print(&quot;param name: {}; shape: {}&quot;.format(param.name, param.shape))
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">program</span><span class="o">.</span><span class="n">global_block</span><span class="p">()</span><span class="o">.</span><span class="n">all_parameters</span><span class="p">():</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;param name: </span><span class="si">{}</span><span class="s2">; shape: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">param</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</pre></div>
</td></tr></table>
<p><code>train.py</code>脚本中,提供了<code>get_pruned_params</code>方法,根据用户设置的选项<code>--model</code>确定要裁剪的参数。</p>
<h2 id="_4">启动裁剪任务<a class="headerlink" href="#_4" title="Permanent link">#</a></h2>
<p>通过以下命令启动裁剪任务:</p>
<div class="highlight"><pre><span></span>export CUDA_VISIBLE_DEVICES=0
python train.py
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">export</span> <span class="n">CUDA_VISIBLE_DEVICES</span><span class="o">=</span><span class="mi">0</span>
<span class="n">python</span> <span class="n">train</span><span class="o">.</span><span class="n">py</span>
</pre></div>
</td></tr></table>
<p>执行<code>python train.py --help</code>查看更多选项。</p>
<h2 id="_5">注意<a class="headerlink" href="#_5" title="Permanent link">#</a></h2>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -184,62 +185,105 @@
<p>请参考 <a href='../../../paddleslim/quant/quantization_api_doc.md'>量化API文档</a></p>
<h2 id="_3">分类模型的离线量化流程<a class="headerlink" href="#_3" title="Permanent link">#</a></h2>
<h3 id="1">1. 配置量化参数<a class="headerlink" href="#1" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>quant_config = {
&#39;weight_quantize_type&#39;: &#39;abs_max&#39;,
&#39;activation_quantize_type&#39;: &#39;moving_average_abs_max&#39;,
&#39;weight_bits&#39;: 8,
&#39;activation_bits&#39;: 8,
&#39;not_quant_pattern&#39;: [&#39;skip_quant&#39;],
&#39;quantize_op_types&#39;: [&#39;conv2d&#39;, &#39;depthwise_conv2d&#39;, &#39;mul&#39;],
&#39;dtype&#39;: &#39;int8&#39;,
&#39;window_size&#39;: 10000,
&#39;moving_rate&#39;: 0.9,
&#39;quant_weight_only&#39;: False
}
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">quant_config</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;weight_quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span><span class="p">,</span>
<span class="s1">&#39;activation_quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;moving_average_abs_max&#39;</span><span class="p">,</span>
<span class="s1">&#39;weight_bits&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="s1">&#39;activation_bits&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="s1">&#39;not_quant_pattern&#39;</span><span class="p">:</span> <span class="p">[</span><span class="s1">&#39;skip_quant&#39;</span><span class="p">],</span>
<span class="s1">&#39;quantize_op_types&#39;</span><span class="p">:</span> <span class="p">[</span><span class="s1">&#39;conv2d&#39;</span><span class="p">,</span> <span class="s1">&#39;depthwise_conv2d&#39;</span><span class="p">,</span> <span class="s1">&#39;mul&#39;</span><span class="p">],</span>
<span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="s1">&#39;int8&#39;</span><span class="p">,</span>
<span class="s1">&#39;window_size&#39;</span><span class="p">:</span> <span class="mi">10000</span><span class="p">,</span>
<span class="s1">&#39;moving_rate&#39;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">,</span>
<span class="s1">&#39;quant_weight_only&#39;</span><span class="p">:</span> <span class="kc">False</span>
<span class="p">}</span>
</pre></div>
</td></tr></table>
<h3 id="2-programop">2. 对训练和测试program插入可训练量化op<a class="headerlink" href="#2-programop" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">val_program</span> <span class="o">=</span> <span class="n">quant_aware</span><span class="p">(</span><span class="n">val_program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">quant_config</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">for_test</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
<span class="n">compiled_train_prog</span> <span class="o">=</span> <span class="n">quant_aware</span><span class="p">(</span><span class="n">train_prog</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">quant_config</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">for_test</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="3build">3.关掉指定build策略<a class="headerlink" href="#3build" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">build_strategy</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">BuildStrategy</span><span class="p">()</span>
<span class="n">build_strategy</span><span class="o">.</span><span class="n">fuse_all_reduce_ops</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">build_strategy</span><span class="o">.</span><span class="n">sync_batch_norm</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">exec_strategy</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">ExecutionStrategy</span><span class="p">()</span>
<span class="n">compiled_train_prog</span> <span class="o">=</span> <span class="n">compiled_train_prog</span><span class="o">.</span><span class="n">with_data_parallel</span><span class="p">(</span>
<span class="n">loss_name</span><span class="o">=</span><span class="n">avg_cost</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
<span class="n">build_strategy</span><span class="o">=</span><span class="n">build_strategy</span><span class="p">,</span>
<span class="n">exec_strategy</span><span class="o">=</span><span class="n">exec_strategy</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="4-freeze-program">4. freeze program<a class="headerlink" href="#4-freeze-program" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>float_program, int8_program = convert(val_program,
place,
quant_config,
scope=None,
save_int8=True)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">float_program</span><span class="p">,</span> <span class="n">int8_program</span> <span class="o">=</span> <span class="n">convert</span><span class="p">(</span><span class="n">val_program</span><span class="p">,</span>
<span class="n">place</span><span class="p">,</span>
<span class="n">quant_config</span><span class="p">,</span>
<span class="n">scope</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">save_int8</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="5">5.保存预测模型<a class="headerlink" href="#5" title="Permanent link">#</a></h3>
<div class="highlight"><pre><span></span>fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
main_program=float_program,
model_filename=float_path + &#39;/model&#39;,
params_filename=float_path + &#39;/params&#39;)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">fluid</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">save_inference_model</span><span class="p">(</span>
<span class="n">dirname</span><span class="o">=</span><span class="n">float_path</span><span class="p">,</span>
<span class="n">feeded_var_names</span><span class="o">=</span><span class="p">[</span><span class="n">image</span><span class="o">.</span><span class="n">name</span><span class="p">],</span>
<span class="n">target_vars</span><span class="o">=</span><span class="p">[</span><span class="n">out</span><span class="p">],</span> <span class="n">executor</span><span class="o">=</span><span class="n">exe</span><span class="p">,</span>
<span class="n">main_program</span><span class="o">=</span><span class="n">float_program</span><span class="p">,</span>
<span class="n">model_filename</span><span class="o">=</span><span class="n">float_path</span> <span class="o">+</span> <span class="s1">&#39;/model&#39;</span><span class="p">,</span>
<span class="n">params_filename</span><span class="o">=</span><span class="n">float_path</span> <span class="o">+</span> <span class="s1">&#39;/params&#39;</span><span class="p">)</span>
fluid.io.save_inference_model(
dirname=int8_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
main_program=int8_program,
model_filename=int8_path + &#39;/model&#39;,
params_filename=int8_path + &#39;/params&#39;)
<span class="n">fluid</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">save_inference_model</span><span class="p">(</span>
<span class="n">dirname</span><span class="o">=</span><span class="n">int8_path</span><span class="p">,</span>
<span class="n">feeded_var_names</span><span class="o">=</span><span class="p">[</span><span class="n">image</span><span class="o">.</span><span class="n">name</span><span class="p">],</span>
<span class="n">target_vars</span><span class="o">=</span><span class="p">[</span><span class="n">out</span><span class="p">],</span> <span class="n">executor</span><span class="o">=</span><span class="n">exe</span><span class="p">,</span>
<span class="n">main_program</span><span class="o">=</span><span class="n">int8_program</span><span class="p">,</span>
<span class="n">model_filename</span><span class="o">=</span><span class="n">int8_path</span> <span class="o">+</span> <span class="s1">&#39;/model&#39;</span><span class="p">,</span>
<span class="n">params_filename</span><span class="o">=</span><span class="n">int8_path</span> <span class="o">+</span> <span class="s1">&#39;/params&#39;</span><span class="p">)</span>
</pre></div>
</td></tr></table>
</div>
</div>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -197,7 +198,16 @@
<p>以下将以 <code>基于skip-gram的word2vector模型</code> 为例来说明如何使用<code>quant_embedding</code>接口。首先介绍 <code>基于skip-gram的word2vector模型</code> 的正常训练和测试流程。</p>
<h2 id="skip-gramword2vector">基于skip-gram的word2vector模型<a class="headerlink" href="#skip-gramword2vector" title="Permanent link">#</a></h2>
<p>以下是本例的简要目录结构及说明:</p>
<div class="highlight"><pre><span></span>.
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>.
├── cluster_train.py # 分布式训练函数
├── cluster_train.sh # 本地模拟多机脚本
├── train.py # 训练函数
......@@ -208,37 +218,59 @@
├── train.py # 训练函数
└── utils.py # 通用函数
</pre></div>
</td></tr></table>
<h3 id="_1">介绍<a class="headerlink" href="#_1" title="Permanent link">#</a></h3>
<p>本例实现了skip-gram模式的word2vector模型。</p>
<p>同时推荐用户参考<a href="https://aistudio.baidu.com/aistudio/projectDetail/124377"> IPython Notebook demo</a></p>
<h3 id="_2">数据下载<a class="headerlink" href="#_2" title="Permanent link">#</a></h3>
<p>全量数据集使用的是来自1 Billion Word Language Model Benchmark的(<a href="http://www.statmt.org/lm-benchmark">http://www.statmt.org/lm-benchmark</a>) 的数据集.</p>
<div class="highlight"><pre><span></span>mkdir data
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>mkdir data
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output.tar.gz
mv <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
</pre></div>
</td></tr></table>
<p>备用数据地址下载命令如下</p>
<div class="highlight"><pre><span></span>mkdir data
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>mkdir data
wget https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar xvf <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output.tar
mv <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
</pre></div>
</td></tr></table>
<p>为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下</p>
<div class="highlight"><pre><span></span>mkdir data
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>mkdir data
wget https://paddlerec.bj.bcebos.com/word2vec/text.tar
tar xvf text.tar
mv text data/
</pre></div>
</td></tr></table>
<h3 id="_3">数据预处理<a class="headerlink" href="#_3" title="Permanent link">#</a></h3>
<p>以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。</p>
<p>词典格式: 词&lt;空格&gt;词频。注意低频词用'UNK'表示</p>
<p>可以按格式自建词典,如果自建词典跳过第一步。
<div class="highlight"><pre><span></span>the 1061396
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>the 1061396
of 593677
and 416629
one 411764
......@@ -248,43 +280,76 @@ a 325873
to 316376
zero 264975
nine 250430
</pre></div></p>
</pre></div>
</td></tr></table></p>
<p>第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。</p>
<div class="highlight"><pre><span></span>python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
</pre></div>
</td></tr></table>
<p>第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"<em>word_to_id</em>"。</p>
<div class="highlight"><pre><span></span>python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count <span class="m">5</span> --downsample <span class="m">0</span>.001
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count <span class="m">5</span> --downsample <span class="m">0</span>.001
</pre></div>
</td></tr></table>
<h3 id="_4">训练<a class="headerlink" href="#_4" title="Permanent link">#</a></h3>
<p>具体的参数配置可运行</p>
<div class="highlight"><pre><span></span>python train.py -h
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python train.py -h
</pre></div>
</td></tr></table>
<p>单机多线程训练
<div class="highlight"><pre><span></span><span class="nv">OPENBLAS_NUM_THREADS</span><span class="o">=</span><span class="m">1</span> <span class="nv">CPU_NUM</span><span class="o">=</span><span class="m">5</span> python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes <span class="m">10</span> --batch_size <span class="m">100</span> --model_output_dir v1_cpu5_b100_lr1dir --base_lr <span class="m">1</span>.0 --print_batch <span class="m">1000</span> --with_speed --is_sparse
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="nv">OPENBLAS_NUM_THREADS</span><span class="o">=</span><span class="m">1</span> <span class="nv">CPU_NUM</span><span class="o">=</span><span class="m">5</span> python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes <span class="m">10</span> --batch_size <span class="m">100</span> --model_output_dir v1_cpu5_b100_lr1dir --base_lr <span class="m">1</span>.0 --print_batch <span class="m">1000</span> --with_speed --is_sparse
</pre></div>
</td></tr></table></p>
<p>本地单机模拟多机训练</p>
<div class="highlight"><pre><span></span>sh cluster_train.sh
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>sh cluster_train.sh
</pre></div>
</td></tr></table>
<p>本示例中按照单机多线程训练的命令进行训练,训练完毕后,可看到在当前文件夹下保存模型的路径为: <code>v1_cpu5_b100_lr1dir</code>, 运行 <code>ls v1_cpu5_b100_lr1dir</code>可看到该文件夹下保存了训练的10个epoch的模型文件。
<div class="highlight"><pre><span></span>pass-0 pass-1 pass-2 pass-3 pass-4 pass-5 pass-6 pass-7 pass-8 pass-9
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>pass-0 pass-1 pass-2 pass-3 pass-4 pass-5 pass-6 pass-7 pass-8 pass-9
</pre></div>
</td></tr></table></p>
<h3 id="_5">预测<a class="headerlink" href="#_5" title="Permanent link">#</a></h3>
<p>测试集下载命令如下</p>
<div class="highlight"><pre><span></span><span class="c1">#全量数据集测试集</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1">#全量数据集测试集</span>
wget https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
<span class="c1">#样本数据集测试集</span>
wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
</pre></div>
</td></tr></table>
<p>预测命令,注意词典名称需要加后缀"<em>word_to_id</em>", 此文件是预处理阶段生成的。
<div class="highlight"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span>
</pre></div>
</td></tr></table>
运行该预测命令, 可看到如下输出
<div class="highlight"><pre><span></span>(&#39;start index: &#39;, 0, &#39; last_index:&#39;, 9)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>(&#39;start index: &#39;, 0, &#39; last_index:&#39;, 9)
(&#39;vocab_size:&#39;, 63642)
step:1 249
epoch:0 acc:0.014
......@@ -306,20 +371,57 @@ step:1 2606
epoch:8 acc:0.146
step:1 2722
epoch:9 acc:0.153
</pre></div></p>
</pre></div>
</td></tr></table></p>
<h2 id="skip-gramword2vector_1">量化<code>基于skip-gram的word2vector模型</code><a class="headerlink" href="#skip-gramword2vector_1" title="Permanent link">#</a></h2>
<p>量化配置为:
<div class="highlight"><pre><span></span>config = {
&#39;params_name&#39;: &#39;emb&#39;,
&#39;quantize_type&#39;: &#39;abs_max&#39;
}
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;params_name&#39;</span><span class="p">:</span> <span class="s1">&#39;emb&#39;</span><span class="p">,</span>
<span class="s1">&#39;quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span>
<span class="p">}</span>
</pre></div>
</td></tr></table></p>
<p>运行命令为:</p>
<div class="highlight"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span> --emb_quant True
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span> --emb_quant True
</pre></div>
</td></tr></table>
<p>运行输出为:</p>
<div class="highlight"><pre><span></span>(&#39;start index: &#39;, 0, &#39; last_index:&#39;, 9)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>(&#39;start index: &#39;, 0, &#39; last_index:&#39;, 9)
(&#39;vocab_size:&#39;, 63642)
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 253
......@@ -352,6 +454,7 @@ quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_
step:1 2719
epoch:9 acc:0.153
</pre></div>
</td></tr></table>
<p>量化后的模型保存在<code>./output_quant</code>中,可看到量化后的参数<code>'emb.int8'</code>的大小为3.9M, 在<code>./v1_cpu5_b100_lr1dir</code>中可看到量化前的参数<code>'emb'</code>的大小为16M。</p>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -195,13 +196,15 @@
<p>在当前文件夹下创建<code>'pretrain'</code>文件夹,将<code>mobilenetv1</code>模型在该文件夹下解压,解压后的目录为<code>pretrain/MobileNetV1_pretrained</code></p>
<h3 id="_6">导出模型<a class="headerlink" href="#_6" title="Permanent link">#</a></h3>
<p>通过运行以下命令可将模型转化为离线量化接口可用的模型:
<div class="highlight"><pre><span></span>python export_model.py --model &quot;MobileNet&quot; --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python export_model.py --model <span class="s2">&quot;MobileNet&quot;</span> --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
</pre></div>
</td></tr></table>
转化之后的模型存储在<code>inference_model/MobileNet/</code>文件夹下,可看到该文件夹下有<code>'model'</code>, <code>'weights'</code>两个文件。</p>
<h3 id="_7">离线量化<a class="headerlink" href="#_7" title="Permanent link">#</a></h3>
<p>接下来对导出的模型文件进行离线量化,离线量化的脚本为<a href="./quant_post.py">quant_post.py</a>,脚本中使用接口<code>paddleslim.quant.quant_post</code>对模型进行离线量化。运行命令为:
<div class="highlight"><pre><span></span>python quant_post.py --model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python quant_post.py --model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights
</pre></div>
</td></tr></table></p>
<ul>
<li><code>model_path</code>: 需要量化的模型坐在的文件夹</li>
<li><code>save_path</code>: 量化后的模型保存的路径</li>
......@@ -215,18 +218,22 @@
<h3 id="_8">测试精度<a class="headerlink" href="#_8" title="Permanent link">#</a></h3>
<p>使用<a href="./eval.py">eval.py</a>脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。</p>
<p>首先测试量化前的模型的精度,运行以下命令:
<div class="highlight"><pre><span></span>python eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
</pre></div>
</td></tr></table>
精度输出为:
<div class="highlight"><pre><span></span>top1_acc/top5_acc= [0.70913923 0.89548034]
</pre></div></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>top1_acc/top5_acc= [0.70913923 0.89548034]
</pre></div>
</td></tr></table></p>
<p>使用以下命令测试离线量化后的模型的精度:</p>
<div class="highlight"><pre><span></span>python eval.py --model_path ./quant_model_train/MobileNet
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python eval.py --model_path ./quant_model_train/MobileNet
</pre></div>
</td></tr></table>
<p>精度输出为
<div class="highlight"><pre><span></span>top1_acc/top5_acc= [0.70141864 0.89086477]
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>top1_acc/top5_acc= [0.70141864 0.89086477]
</pre></div>
</td></tr></table>
从以上精度对比可以看出,对<code>mobilenet</code><code>imagenet</code>上的分类模型进行离线量化后 <code>top1</code>精度损失为<code>0.77%</code><code>top5</code>精度损失为<code>0.46%</code>. </p>
</div>
......
......@@ -14,6 +14,7 @@
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="../../extra.css" rel="stylesheet">
<script>
// Current page data
......@@ -176,9 +177,11 @@
</ul>
<h2 id="2">2. 运行示例<a class="headerlink" href="#2" title="Permanent link">#</a></h2>
<p>在路径<code>PaddleSlim/demo/sensitive</code>下执行以下代码运行示例:</p>
<div class="highlight"><pre><span></span>export CUDA_VISIBLE_DEVICES=0
python train.py --model &quot;MobileNetV1&quot;
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="nb">export</span> <span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span><span class="m">0</span>
python train.py --model <span class="s2">&quot;MobileNetV1&quot;</span>
</pre></div>
</td></tr></table>
<p>通过<code>python train.py --help</code>查看更多选项。</p>
<h2 id="3">3. 重要步骤说明<a class="headerlink" href="#3" title="Permanent link">#</a></h2>
......@@ -187,35 +190,53 @@ python train.py --model &quot;MobileNetV1&quot;
<p>调用<code>paddleslim.prune.sensitivity</code>接口计算敏感度。敏感度信息会追加到<code>sensitivities_file</code>选项所指定的文件中,如果需要重新计算敏感度,需要先删除<code>sensitivities_file</code>文件。</p>
<p>如果模型评估速度较慢,可以通过多进程的方式加速敏感度计算过程。比如在进程1中设置<code>pruned_ratios=[0.1, 0.2, 0.3, 0.4]</code>,并将敏感度信息存放在文件<code>sensitivities_0.data</code>中,然后在进程2中设置<code>pruned_ratios=[0.5, 0.6, 0.7]</code>,并将敏感度信息存储在文件<code>sensitivities_1.data</code>中。这样每个进程只会计算指定剪切率下的敏感度信息。多进程可以运行在单机多卡,或多机多卡。</p>
<p>代码如下:</p>
<div class="highlight"><pre><span></span># 进程1
sensitivity(
val_program,
place,
params,
test,
sensitivities_file=&quot;sensitivities_0.data&quot;,
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1"># 进程1</span>
<span class="n">sensitivity</span><span class="p">(</span>
<span class="n">val_program</span><span class="p">,</span>
<span class="n">place</span><span class="p">,</span>
<span class="n">params</span><span class="p">,</span>
<span class="n">test</span><span class="p">,</span>
<span class="n">sensitivities_file</span><span class="o">=</span><span class="s2">&quot;sensitivities_0.data&quot;</span><span class="p">,</span>
<span class="n">pruned_ratios</span><span class="o">=</span><span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">])</span>
</pre></div>
</td></tr></table>
<div class="highlight"><pre><span></span># 进程2
sensitivity(
val_program,
place,
params,
test,
sensitivities_file=&quot;sensitivities_1.data&quot;,
pruned_ratios=[0.5, 0.6, 0.7])
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1"># 进程2</span>
<span class="n">sensitivity</span><span class="p">(</span>
<span class="n">val_program</span><span class="p">,</span>
<span class="n">place</span><span class="p">,</span>
<span class="n">params</span><span class="p">,</span>
<span class="n">test</span><span class="p">,</span>
<span class="n">sensitivities_file</span><span class="o">=</span><span class="s2">&quot;sensitivities_1.data&quot;</span><span class="p">,</span>
<span class="n">pruned_ratios</span><span class="o">=</span><span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">])</span>
</pre></div>
</td></tr></table>
<h3 id="32">3.2 合并敏感度<a class="headerlink" href="#32" title="Permanent link">#</a></h3>
<p>如果用户通过上一节多进程的方式生成了多个存储敏感度信息的文件,可以通过<code>paddleslim.prune.merge_sensitive</code>将其合并,合并后的敏感度信息存储在一个<code>dict</code>中。代码如下:</p>
<div class="highlight"><pre><span></span>sens = merge_sensitive([&quot;./sensitivities_0.data&quot;, &quot;./sensitivities_1.data&quot;])
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">sens</span> <span class="o">=</span> <span class="n">merge_sensitive</span><span class="p">([</span><span class="s2">&quot;./sensitivities_0.data&quot;</span><span class="p">,</span> <span class="s2">&quot;./sensitivities_1.data&quot;</span><span class="p">])</span>
</pre></div>
</td></tr></table>
<h3 id="33">3.3 计算剪裁率<a class="headerlink" href="#33" title="Permanent link">#</a></h3>
<p>调用<code>paddleslim.prune.get_ratios_by_loss</code>接口计算一组剪裁率。</p>
<div class="highlight"><pre><span></span>ratios = get_ratios_by_loss(sens, 0.01)
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">ratios</span> <span class="o">=</span> <span class="n">get_ratios_by_loss</span><span class="p">(</span><span class="n">sens</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<p>其中,<code>0.01</code>为一个阈值,对于任意卷积层,其剪裁率为使精度损失低于阈值<code>0.01</code>的最大剪裁率。</p>
<p>用户在计算出一组剪裁率之后可以通过接口<code>paddleslim.prune.Pruner</code>剪裁网络,并用接口<code>paddleslim.analysis.flops</code>计算<code>FLOPs</code>。如果<code>FLOPs</code>不满足要求,调整阈值重新计算出一组剪裁率。</p>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册