提交 60cefda3 编写于 作者: B baiyfbupt

Deployed 09f99c9e with MkDocs version: 1.0.4

上级 cf4e366c
......@@ -187,7 +187,59 @@
</li>
</ul>
<p><strong>示例:</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">from</span> <span class="nn">paddle.fluid.param_attr</span> <span class="kn">import</span> <span class="n">ParamAttr</span>
<span class="kn">from</span> <span class="nn">paddleslim.analysis</span> <span class="kn">import</span> <span class="n">flops</span>
......@@ -241,7 +293,7 @@
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;FLOPS: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">flops</span><span class="p">(</span><span class="n">main_program</span><span class="p">)))</span>
</pre></div>
</td></tr></table>
<h2 id="model_size">model_size<a class="headerlink" href="#model_size" title="Permanent link">#</a></h2>
<blockquote>
......@@ -257,7 +309,51 @@
<li><strong>model_size(int):</strong> 整个网络的参数数量。</li>
</ul>
<p><strong>示例:</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">from</span> <span class="nn">paddle.fluid.param_attr</span> <span class="kn">import</span> <span class="n">ParamAttr</span>
<span class="kn">from</span> <span class="nn">paddleslim.analysis</span> <span class="kn">import</span> <span class="n">model_size</span>
......@@ -303,7 +399,7 @@
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;FLOPS: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">model_size</span><span class="p">(</span><span class="n">main_program</span><span class="p">)))</span>
</pre></div>
</td></tr></table>
<h2 id="tablelatencyevaluator">TableLatencyEvaluator<a class="headerlink" href="#tablelatencyevaluator" title="Permanent link">#</a></h2>
<blockquote>
......
......@@ -179,13 +179,14 @@
- <strong>is_server(bool):</strong> 当前实例是否要启动一个server。默认:True。</p>
<p><strong>返回:</strong>
一个SANAS类的实例</p>
<p><strong>示例代码:</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<p><strong>示例代码:</strong>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<span class="n">sanas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
</pre></div>
</td></tr></table></p>
<hr />
<blockquote>
<p>tokens2arch(tokens)
......@@ -195,16 +196,20 @@
- <strong>tokens(list):</strong> 一组token。</p>
<p><strong>返回</strong>
返回一个模型结构实例。</p>
<p><strong>示例代码:</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<p><strong>示例代码:</strong>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;input&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">token2arch</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<span class="k">for</span> <span class="n">arch</span> <span class="ow">in</span> <span class="n">archs</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">arch</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">output</span>
</pre></div>
</td></tr></table></p>
<hr />
<blockquote>
<p>next_archs():
......@@ -212,16 +217,20 @@
</blockquote>
<p><strong>返回</strong>
返回模型结构实例的列表,形式为list。</p>
<p><strong>示例代码:</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<p><strong>示例代码:</strong>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;input&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">archs</span> <span class="o">=</span> <span class="n">sanas</span><span class="o">.</span><span class="n">next_archs</span><span class="p">()</span>
<span class="k">for</span> <span class="n">arch</span> <span class="ow">in</span> <span class="n">archs</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">arch</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">output</span>
</pre></div>
</td></tr></table></p>
<hr />
<blockquote>
<p>reward(score):
......@@ -231,8 +240,104 @@
<strong>score<float>:</strong> 当前模型的得分,分数越大越好。</p>
<p><strong>返回</strong>
模型结构更新成功或者失败,成功则返回<code>True</code>,失败则返回<code>False</code></p>
<p><strong>代码示例</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<p><strong>代码示例</strong>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">paddle</span>
<span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
......@@ -330,6 +435,7 @@
<span class="c1">### 回传score</span>
<span class="n">sa_nas</span><span class="o">.</span><span class="n">reward</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">finally_reward</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
</pre></div>
</td></tr></table></p>
</div>
</div>
......
......@@ -179,10 +179,11 @@
</ul>
<p><strong>返回:</strong> 一个Pruner类的实例</p>
<p><strong>示例代码:</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.prune</span> <span class="kn">import</span> <span class="n">Pruner</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.prune</span> <span class="kn">import</span> <span class="n">Pruner</span>
<span class="n">pruner</span> <span class="o">=</span> <span class="n">Pruner</span><span class="p">()</span>
</pre></div>
</td></tr></table>
<hr />
<blockquote>
......@@ -198,16 +199,15 @@
<p><strong>scope(paddle.fluid.Scope):</strong> 要裁剪的权重所在的<code>scope</code>,Paddle中用<code>scope</code>实例存放模型参数和运行时变量的值。Scope中的参数值会被<code>inplace</code>的裁剪。更多介绍请参考<a href="">Scope概念介绍</a></p>
</li>
<li>
<p><strong>params(list<str>):</strong> 需要被裁剪的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:</p>
</li>
</ul>
<div class="codehilite"><pre><span></span><span class="k">for</span> <span class="nv">block</span> <span class="nv">in</span> <span class="nv">program</span>.<span class="nv">blocks</span>:
<p><strong>params(list<str>):</strong> 需要被裁剪的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="k">for</span> <span class="nv">block</span> <span class="nv">in</span> <span class="nv">program</span>.<span class="nv">blocks</span>:
<span class="k">for</span> <span class="nv">param</span> <span class="nv">in</span> <span class="nv">block</span>.<span class="nv">all_parameters</span><span class="ss">()</span>:
<span class="nv">print</span><span class="ss">(</span><span class="s2">&quot;</span><span class="s">param: {}; shape: {}</span><span class="s2">&quot;</span>.<span class="nv">format</span><span class="ss">(</span><span class="nv">param</span>.<span class="nv">name</span>, <span class="nv">param</span>.<span class="nv">shape</span><span class="ss">))</span>
</pre></div>
<ul>
</td></tr></table></p>
</li>
<li>
<p><strong>ratios(list<float>):</strong> 用于裁剪<code>params</code>的剪切率,类型为列表。该列表长度必须与<code>params</code>的长度一致。</p>
</li>
......@@ -240,8 +240,78 @@
</li>
</ul>
<p><strong>示例:</strong></p>
<p>点击<a href="https://aistudio.baidu.com/aistudio/projectDetail/200786">AIStudio</a>执行以下示例代码。</p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<p>点击<a href="https://aistudio.baidu.com/aistudio/projectDetail/200786">AIStudio</a>执行以下示例代码。
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">from</span> <span class="nn">paddle.fluid.param_attr</span> <span class="kn">import</span> <span class="n">ParamAttr</span>
<span class="kn">from</span> <span class="nn">paddleslim.prune</span> <span class="kn">import</span> <span class="n">Pruner</span>
......@@ -313,8 +383,7 @@
<span class="k">if</span> <span class="s2">&quot;weights&quot;</span> <span class="ow">in</span> <span class="n">param</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;param name: {}; param shape: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">param</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</pre></div>
</td></tr></table></p>
<hr />
<h2 id="sensitivity">sensitivity<a class="headerlink" href="#sensitivity" title="Permanent link">#</a></h2>
<blockquote>
......@@ -333,11 +402,13 @@
<p><strong>param_names(list<str>):</strong> 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:</p>
</li>
</ul>
<div class="codehilite"><pre><span></span><span class="k">for</span> <span class="nv">block</span> <span class="nv">in</span> <span class="nv">program</span>.<span class="nv">blocks</span>:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="k">for</span> <span class="nv">block</span> <span class="nv">in</span> <span class="nv">program</span>.<span class="nv">blocks</span>:
<span class="k">for</span> <span class="nv">param</span> <span class="nv">in</span> <span class="nv">block</span>.<span class="nv">all_parameters</span><span class="ss">()</span>:
<span class="nv">print</span><span class="ss">(</span><span class="s2">&quot;</span><span class="s">param: {}; shape: {}</span><span class="s2">&quot;</span>.<span class="nv">format</span><span class="ss">(</span><span class="nv">param</span>.<span class="nv">name</span>, <span class="nv">param</span>.<span class="nv">shape</span><span class="ss">))</span>
</pre></div>
</td></tr></table>
<ul>
<li>
......@@ -354,7 +425,15 @@
<ul>
<li><strong>sensitivities(dict):</strong> 存放敏感度信息的dict,其格式为:</li>
</ul>
<div class="codehilite"><pre><span></span><span class="err">{</span><span class="ss">&quot;weight_0&quot;</span><span class="p">:</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8
9</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="err">{</span><span class="ss">&quot;weight_0&quot;</span><span class="p">:</span>
<span class="err">{</span><span class="mi">0</span><span class="p">.</span><span class="mi">1</span><span class="p">:</span> <span class="mi">0</span><span class="p">.</span><span class="mi">22</span><span class="p">,</span>
<span class="mi">0</span><span class="p">.</span><span class="mi">2</span><span class="p">:</span> <span class="mi">0</span><span class="p">.</span><span class="mi">33</span>
<span class="err">}</span><span class="p">,</span>
......@@ -364,12 +443,102 @@
<span class="err">}</span>
<span class="err">}</span>
</pre></div>
</td></tr></table>
<p>其中,<code>weight_0</code>是卷积层参数的名称,sensitivities['weight_0']的<code>value</code>为剪裁比例,<code>value</code>为精度损失的比例。</p>
<p><strong>示例:</strong></p>
<p>点击<a href="https://aistudio.baidu.com/aistudio/projectdetail/201401">AIStudio</a>运行以下示例代码。</p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">from</span> <span class="nn">paddle.fluid.param_attr</span> <span class="kn">import</span> <span class="n">ParamAttr</span>
......@@ -461,7 +630,7 @@
<span class="n">pruned_ratios</span><span class="o">=</span><span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">sensitivities</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h2 id="merge_sensitive">merge_sensitive<a class="headerlink" href="#merge_sensitive" title="Permanent link">#</a></h2>
<blockquote>
......@@ -476,7 +645,15 @@
<ul>
<li><strong>sensitivities(dict):</strong> 合并后的敏感度信息。其格式为:</li>
</ul>
<div class="codehilite"><pre><span></span><span class="err">{</span><span class="ss">&quot;weight_0&quot;</span><span class="p">:</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8
9</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="err">{</span><span class="ss">&quot;weight_0&quot;</span><span class="p">:</span>
<span class="err">{</span><span class="mi">0</span><span class="p">.</span><span class="mi">1</span><span class="p">:</span> <span class="mi">0</span><span class="p">.</span><span class="mi">22</span><span class="p">,</span>
<span class="mi">0</span><span class="p">.</span><span class="mi">2</span><span class="p">:</span> <span class="mi">0</span><span class="p">.</span><span class="mi">33</span>
<span class="err">}</span><span class="p">,</span>
......@@ -486,7 +663,7 @@
<span class="err">}</span>
<span class="err">}</span>
</pre></div>
</td></tr></table>
<p>其中,<code>weight_0</code>是卷积层参数的名称,sensitivities['weight_0']的<code>value</code>为剪裁比例,<code>value</code>为精度损失的比例。</p>
<p>示例:</p>
......@@ -501,7 +678,7 @@
</ul>
<p>返回:</p>
<ul>
<li><strong>sensitivities(dict)</strong>敏感度信息。</li>
<li>**sensitivities(dict)**敏感度信息。</li>
</ul>
<p>示例:</p>
<h2 id="get_ratios_by_losssensitivities-loss">get_ratios_by_loss(sensitivities, loss)<a class="headerlink" href="#get_ratios_by_losssensitivities-loss" title="Permanent link">#</a></h2>
......@@ -520,8 +697,9 @@
<li>ratios(dict): 一组剪切率。<code>key</code>是待剪裁参数的名称。<code>value</code>是对应参数的剪裁率。</li>
</ul>
<p>示例:</p>
<div class="codehilite"><pre><span></span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>
</pre></div>
</td></tr></table>
</div>
</div>
......
......@@ -165,7 +165,26 @@
<h1 id="paddleslimquant-api">paddleslim.quant API文档<a class="headerlink" href="#paddleslimquant-api" title="Permanent link">#</a></h1>
<h2 id="api">量化训练API<a class="headerlink" href="#api" title="Permanent link">#</a></h2>
<h3 id="_1">量化配置<a class="headerlink" href="#_1" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="nv">quant_config_default</span> <span class="o">=</span> {
<p><table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="nv">quant_config_default</span> <span class="o">=</span> {
<span class="s1">&#39;</span><span class="s">weight_quantize_type</span><span class="s1">&#39;</span>: <span class="s1">&#39;</span><span class="s">abs_max</span><span class="s1">&#39;</span>,
<span class="s1">&#39;</span><span class="s">activation_quantize_type</span><span class="s1">&#39;</span>: <span class="s1">&#39;</span><span class="s">abs_max</span><span class="s1">&#39;</span>,
<span class="s1">&#39;</span><span class="s">weight_bits</span><span class="s1">&#39;</span>: <span class="mi">8</span>,
......@@ -186,9 +205,8 @@
<span class="s1">&#39;</span><span class="s">quant_weight_only</span><span class="s1">&#39;</span>: <span class="nv">False</span>
}
</pre></div>
<p>设置量化训练需要的配置。</p>
</td></tr></table>
设置量化训练需要的配置。</p>
<p><strong>参数:</strong></p>
<ul>
<li><strong>weight_quantize_type(str)</strong> - 参数量化方式。可选<code>'abs_max'</code>, <code>'channel_wise_abs_max'</code>, <code>'range_abs_max'</code>, <code>'moving_average_abs_max'</code>。 默认<code>'abs_max'</code></li>
......@@ -242,7 +260,43 @@
<p><strong>注意事项</strong></p>
<p>因为该接口会对<code>op</code><code>Variable</code>做相应的删除和修改,所以此接口只能在训练完成之后调用。如果想转化训练的中间模型,可加载相应的参数之后再使用此接口。</p>
<p><strong>代码示例</strong></p>
<div class="codehilite"><pre><span></span><span class="c1">#encoding=utf8</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1">#encoding=utf8</span>
<span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">import</span> <span class="nn">paddleslim.quant</span> <span class="kn">as</span> <span class="nn">quant</span>
......@@ -280,11 +334,21 @@
<span class="n">inference_prog</span> <span class="o">=</span> <span class="n">quant</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="n">quant_eval_program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<p>更详细的用法请参考 <a href='../../demo/quant/quant_aware/README.md'>量化训练demo</a></p>
<h2 id="api_1">离线量化API<a class="headerlink" href="#api_1" title="Permanent link">#</a></h2>
<div class="codehilite"><pre><span></span><span class="n">paddleslim</span><span class="p">.</span><span class="n">quant</span><span class="p">.</span><span class="n">quant_post</span><span class="p">(</span><span class="n">executor</span><span class="p">,</span>
<p><table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">paddleslim</span><span class="p">.</span><span class="n">quant</span><span class="p">.</span><span class="n">quant_post</span><span class="p">(</span><span class="n">executor</span><span class="p">,</span>
<span class="n">model_dir</span><span class="p">,</span>
<span class="n">quantize_model_path</span><span class="p">,</span>
<span class="n">sample_generator</span><span class="p">,</span>
......@@ -296,9 +360,8 @@
<span class="n">algo</span><span class="o">=</span><span class="s1">&#39;KL&#39;</span><span class="p">,</span>
<span class="n">quantizable_op_type</span><span class="o">=</span><span class="p">[</span><span class="ss">&quot;conv2d&quot;</span><span class="p">,</span> <span class="ss">&quot;depthwise_conv2d&quot;</span><span class="p">,</span> <span class="ss">&quot;mul&quot;</span><span class="p">])</span>
</pre></div>
<p>对保存在<code>${model_dir}</code>下的模型进行量化,使用<code>sample_generator</code>的数据进行参数校正。</p>
</td></tr></table>
对保存在<code>${model_dir}</code>下的模型进行量化,使用<code>sample_generator</code>的数据进行参数校正。</p>
<p><strong>参数:</strong>
- <strong>executor (fluid.Executor)</strong> - 执行模型的executor,可以在cpu或者gpu上执行。
- <strong>model_dir(str)</strong> - 需要量化的模型所在的文件夹。
......@@ -319,7 +382,23 @@
<blockquote>
<p>注: 此示例不能直接运行,因为需要加载<code>${model_dir}</code>下的模型,所以不能直接运行。</p>
</blockquote>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<p><table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">import</span> <span class="nn">paddle.dataset.mnist</span> <span class="kn">as</span> <span class="nn">reader</span>
<span class="kn">from</span> <span class="nn">paddleslim.quant</span> <span class="kn">import</span> <span class="n">quant_post</span>
<span class="n">val_reader</span> <span class="o">=</span> <span class="n">reader</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
......@@ -337,15 +416,13 @@
<span class="n">batch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="n">batch_nums</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
</pre></div>
<p>更详细的用法请参考 <a href='../../demo/quant/quant_post/README.md'>离线量化demo</a></p>
</td></tr></table>
更详细的用法请参考 <a href='../../demo/quant/quant_post/README.md'>离线量化demo</a></p>
<h2 id="embeddingapi">Embedding量化API<a class="headerlink" href="#embeddingapi" title="Permanent link">#</a></h2>
<div class="codehilite"><pre><span></span><span class="n">paddleslim</span><span class="p">.</span><span class="n">quant</span><span class="p">.</span><span class="n">quant_embedding</span><span class="p">(</span><span class="n">program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="k">scope</span><span class="o">=</span><span class="k">None</span><span class="p">)</span>
<p><table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">paddleslim</span><span class="p">.</span><span class="n">quant</span><span class="p">.</span><span class="n">quant_embedding</span><span class="p">(</span><span class="n">program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="k">scope</span><span class="o">=</span><span class="k">None</span><span class="p">)</span>
</pre></div>
<p><code>Embedding</code>参数进行量化。</p>
</td></tr></table>
<code>Embedding</code>参数进行量化。</p>
<p><strong>参数:</strong>
- <strong>program(fluid.Program)</strong> - 需要量化的program
- <strong>scope(fluid.Scope, optional)</strong> - 用来获取和写入<code>Variable</code>, 如果设置为<code>None</code>,则使用<code>fluid.global_scope()</code>.
......@@ -360,8 +437,29 @@
<p>量化之后的program</p>
<p><strong>返回类型</strong></p>
<p><code>fluid.Program</code></p>
<p><strong>代码示例</strong></p>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<p><strong>代码示例</strong>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="kn">import</span> <span class="nn">paddleslim.quant</span> <span class="kn">as</span> <span class="nn">quant</span>
<span class="n">train_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
......@@ -384,8 +482,7 @@
<span class="n">config</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;params_name&#39;</span><span class="p">:</span> <span class="s1">&#39;emb&#39;</span><span class="p">,</span> <span class="s1">&#39;quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span><span class="p">}</span>
<span class="n">quant_program</span> <span class="o">=</span> <span class="n">quant</span><span class="o">.</span><span class="n">quant_embedding</span><span class="p">(</span><span class="n">infer_program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
</pre></div>
</td></tr></table></p>
<p>更详细的用法请参考 <a href='../../demo/quant/quant_embedding/README.md'>Embedding量化demo</a></p>
</div>
......
......@@ -199,7 +199,67 @@
2. token中每个数字的搜索列表长度(<code>range_table</code>函数),tokens中每个token的索引范围。
3. 根据token产生模型结构(<code>token2arch</code>函数),根据搜索到的tokens列表产生模型结构。</p>
<p>以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。</p>
<div class="codehilite"><pre><span></span><span class="c1">### 引入搜索空间基类函数和search space的注册类函数</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1">### 引入搜索空间基类函数和search space的注册类函数</span>
<span class="kn">from</span> <span class="nn">.search_space_base</span> <span class="kn">import</span> <span class="n">SearchSpaceBase</span>
<span class="kn">from</span> <span class="nn">.search_space_registry</span> <span class="kn">import</span> <span class="n">SEARCHSPACE</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
......@@ -261,6 +321,7 @@
<span class="n">bn</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">batch_norm</span><span class="p">(</span><span class="n">conv</span><span class="p">,</span> <span class="n">act</span><span class="o">=</span><span class="n">act</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="o">+</span><span class="s1">&#39;_bn&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">bn</span>
</pre></div>
</td></tr></table>
</div>
</div>
......
此差异已折叠。
......@@ -197,18 +197,20 @@
<ul>
<li>安装develop版本</li>
</ul>
<div class="codehilite"><pre><span></span><span class="n">git</span> <span class="n">clone</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="p">.</span><span class="n">com</span><span class="o">/</span><span class="n">PaddlePaddle</span><span class="o">/</span><span class="n">PaddleSlim</span><span class="p">.</span><span class="n">git</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">git</span> <span class="n">clone</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="p">.</span><span class="n">com</span><span class="o">/</span><span class="n">PaddlePaddle</span><span class="o">/</span><span class="n">PaddleSlim</span><span class="p">.</span><span class="n">git</span>
<span class="n">cd</span> <span class="n">PaddleSlim</span>
<span class="n">python</span> <span class="n">setup</span><span class="p">.</span><span class="n">py</span> <span class="n">install</span>
</pre></div>
</td></tr></table>
<ul>
<li>安装官方发布的最新版本</li>
</ul>
<div class="codehilite"><pre><span></span><span class="n">pip</span> <span class="n">install</span> <span class="n">paddleslim</span> <span class="o">-</span><span class="n">i</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">pypi</span><span class="p">.</span><span class="n">org</span><span class="o">/</span><span class="k">simple</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">pip</span> <span class="n">install</span> <span class="n">paddleslim</span> <span class="o">-</span><span class="n">i</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">pypi</span><span class="p">.</span><span class="n">org</span><span class="o">/</span><span class="k">simple</span>
</pre></div>
</td></tr></table>
<ul>
<li>安装历史版本</li>
......@@ -277,5 +279,5 @@
<!--
MkDocs version : 1.0.4
Build Date UTC : 2019-12-20 02:14:25
Build Date UTC : 2019-12-20 03:48:22
-->
因为 它太大了无法显示 source diff 。你可以改为 查看blob
无法预览此类型文件
......@@ -168,9 +168,9 @@
<p>操作信息字段之间以逗号分割。操作信息与延迟信息之间以制表符分割。</p>
<h3 id="conv2d">conv2d<a class="headerlink" href="#conv2d" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">flag_bias</span><span class="p">,</span><span class="n">flag_relu</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="p">,</span><span class="n">c_out</span><span class="p">,</span><span class="n">groups</span><span class="p">,</span><span class="n">kernel</span><span class="p">,</span><span class="n">padding</span><span class="p">,</span><span class="n">stride</span><span class="p">,</span><span class="n">dilation</span><span class="err">\</span><span class="n">tlatency</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">flag_bias</span><span class="p">,</span><span class="n">flag_relu</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="p">,</span><span class="n">c_out</span><span class="p">,</span><span class="n">groups</span><span class="p">,</span><span class="n">kernel</span><span class="p">,</span><span class="n">padding</span><span class="p">,</span><span class="n">stride</span><span class="p">,</span><span class="n">dilation</span><span class="err">\</span><span class="n">tlatency</span>
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -191,9 +191,9 @@
</ul>
<h3 id="activation">activation<a class="headerlink" href="#activation" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -206,9 +206,9 @@
</ul>
<h3 id="batch_norm">batch_norm<a class="headerlink" href="#batch_norm" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">active_type</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">active_type</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -222,9 +222,9 @@
</ul>
<h3 id="eltwise">eltwise<a class="headerlink" href="#eltwise" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -237,9 +237,9 @@
</ul>
<h3 id="pooling">pooling<a class="headerlink" href="#pooling" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">flag_global_pooling</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="p">,</span><span class="n">kernel</span><span class="p">,</span><span class="n">padding</span><span class="p">,</span><span class="n">stride</span><span class="p">,</span><span class="n">ceil_mode</span><span class="p">,</span><span class="n">pool_type</span><span class="err">\</span><span class="n">tlatency</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">flag_global_pooling</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="p">,</span><span class="n">kernel</span><span class="p">,</span><span class="n">padding</span><span class="p">,</span><span class="n">stride</span><span class="p">,</span><span class="n">ceil_mode</span><span class="p">,</span><span class="n">pool_type</span><span class="err">\</span><span class="n">tlatency</span>
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......@@ -258,9 +258,9 @@
</ul>
<h3 id="softmax">softmax<a class="headerlink" href="#softmax" title="Permanent link">#</a></h3>
<p><strong>格式</strong></p>
<div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">axis</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">op_type</span><span class="p">,</span><span class="n">axis</span><span class="p">,</span><span class="n">n_in</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">h_in</span><span class="p">,</span><span class="n">w_in</span><span class="err">\</span><span class="n">tlatency</span>
</pre></div>
</td></tr></table>
<p><strong>字段解释</strong></p>
<ul>
......
......@@ -163,13 +163,20 @@
<h2 id="_2">接口介绍<a class="headerlink" href="#_2" title="Permanent link">#</a></h2>
<p>请参考。</p>
<h3 id="1">1. 配置搜索空间<a class="headerlink" href="#1" title="Permanent link">#</a></h3>
<p>详细的搜索空间配置可以参考<a href='../../../paddleslim/nas/nas_api.md'>神经网络搜索API文档</a></p>
<div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
<p>详细的搜索空间配置可以参考<a href='../../../paddleslim/nas/nas_api.md'>神经网络搜索API文档</a>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">&#39;MobileNetV2Space&#39;</span><span class="p">)]</span>
</pre></div>
</td></tr></table></p>
<h3 id="2-sanas">2. 利用搜索空间初始化SANAS实例<a class="headerlink" href="#2-sanas" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8
9</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.nas</span> <span class="kn">import</span> <span class="n">SANAS</span>
<span class="n">sa_nas</span> <span class="o">=</span> <span class="n">SANAS</span><span class="p">(</span>
<span class="n">config</span><span class="p">,</span>
......@@ -179,15 +186,34 @@
<span class="n">search_steps</span><span class="o">=</span><span class="mi">300</span><span class="p">,</span>
<span class="n">is_server</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="3-nas">3. 根据实例化的NAS得到当前的网络结构<a class="headerlink" href="#3-nas" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="n">archs</span> <span class="o">=</span> <span class="n">sa_nas</span><span class="p">.</span><span class="n">next_archs</span><span class="p">()</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">archs</span> <span class="o">=</span> <span class="n">sa_nas</span><span class="p">.</span><span class="n">next_archs</span><span class="p">()</span>
</pre></div>
</td></tr></table>
<h3 id="4-program">4. 根据得到的网络结构和输入构造训练和测试program<a class="headerlink" href="#4-program" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">paddle.fluid</span> <span class="kn">as</span> <span class="nn">fluid</span>
<span class="n">train_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">test_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
......@@ -208,19 +234,23 @@
<span class="n">sgd</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
<span class="n">sgd</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">avg_cost</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="5-program">5. 根据构造的训练program添加限制条件<a class="headerlink" href="#5-program" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.analysis</span> <span class="kn">import</span> <span class="n">flops</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="kn">from</span> <span class="nn">paddleslim.analysis</span> <span class="kn">import</span> <span class="n">flops</span>
<span class="k">if</span> <span class="n">flops</span><span class="p">(</span><span class="n">train_program</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">321208544</span><span class="p">:</span>
<span class="k">continue</span>
</pre></div>
</td></tr></table>
<h3 id="6-score">6. 回传score<a class="headerlink" href="#6-score" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="n">sa_nas</span><span class="p">.</span><span class="n">reward</span><span class="p">(</span><span class="n">score</span><span class="p">)</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">sa_nas</span><span class="p">.</span><span class="n">reward</span><span class="p">(</span><span class="n">score</span><span class="p">)</span>
</pre></div>
</td></tr></table>
</div>
</div>
......
......@@ -166,7 +166,18 @@
<p>请参考 <a href='../../../paddleslim/quant/quantization_api_doc.md'>量化API文档</a></p>
<h2 id="_3">分类模型的离线量化流程<a class="headerlink" href="#_3" title="Permanent link">#</a></h2>
<h3 id="1">1. 配置量化参数<a class="headerlink" href="#1" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="n">quant_config</span> <span class="o">=</span> <span class="err">{</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">quant_config</span> <span class="o">=</span> <span class="err">{</span>
<span class="s1">&#39;weight_quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span><span class="p">,</span>
<span class="s1">&#39;activation_quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;moving_average_abs_max&#39;</span><span class="p">,</span>
<span class="s1">&#39;weight_bits&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
......@@ -179,17 +190,26 @@
<span class="s1">&#39;quant_weight_only&#39;</span><span class="p">:</span> <span class="k">False</span>
<span class="err">}</span>
</pre></div>
</td></tr></table>
<h3 id="2-programop">2. 对训练和测试program插入可训练量化op<a class="headerlink" href="#2-programop" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="n">val_program</span> <span class="o">=</span> <span class="n">quant_aware</span><span class="p">(</span><span class="n">val_program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">quant_config</span><span class="p">,</span> <span class="k">scope</span><span class="o">=</span><span class="k">None</span><span class="p">,</span> <span class="n">for_test</span><span class="o">=</span><span class="k">True</span><span class="p">)</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">val_program</span> <span class="o">=</span> <span class="n">quant_aware</span><span class="p">(</span><span class="n">val_program</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">quant_config</span><span class="p">,</span> <span class="k">scope</span><span class="o">=</span><span class="k">None</span><span class="p">,</span> <span class="n">for_test</span><span class="o">=</span><span class="k">True</span><span class="p">)</span>
<span class="n">compiled_train_prog</span> <span class="o">=</span> <span class="n">quant_aware</span><span class="p">(</span><span class="n">train_prog</span><span class="p">,</span> <span class="n">place</span><span class="p">,</span> <span class="n">quant_config</span><span class="p">,</span> <span class="k">scope</span><span class="o">=</span><span class="k">None</span><span class="p">,</span> <span class="n">for_test</span><span class="o">=</span><span class="k">False</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="3build">3.关掉指定build策略<a class="headerlink" href="#3build" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="n">build_strategy</span> <span class="o">=</span> <span class="n">fluid</span><span class="p">.</span><span class="n">BuildStrategy</span><span class="p">()</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">build_strategy</span> <span class="o">=</span> <span class="n">fluid</span><span class="p">.</span><span class="n">BuildStrategy</span><span class="p">()</span>
<span class="n">build_strategy</span><span class="p">.</span><span class="n">fuse_all_reduce_ops</span> <span class="o">=</span> <span class="k">False</span>
<span class="n">build_strategy</span><span class="p">.</span><span class="n">sync_batch_norm</span> <span class="o">=</span> <span class="k">False</span>
<span class="n">exec_strategy</span> <span class="o">=</span> <span class="n">fluid</span><span class="p">.</span><span class="n">ExecutionStrategy</span><span class="p">()</span>
......@@ -198,19 +218,37 @@
<span class="n">build_strategy</span><span class="o">=</span><span class="n">build_strategy</span><span class="p">,</span>
<span class="n">exec_strategy</span><span class="o">=</span><span class="n">exec_strategy</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="4-freeze-program">4. freeze program<a class="headerlink" href="#4-freeze-program" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="n">float_program</span><span class="p">,</span> <span class="n">int8_program</span> <span class="o">=</span> <span class="k">convert</span><span class="p">(</span><span class="n">val_program</span><span class="p">,</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">float_program</span><span class="p">,</span> <span class="n">int8_program</span> <span class="o">=</span> <span class="k">convert</span><span class="p">(</span><span class="n">val_program</span><span class="p">,</span>
<span class="n">place</span><span class="p">,</span>
<span class="n">quant_config</span><span class="p">,</span>
<span class="k">scope</span><span class="o">=</span><span class="k">None</span><span class="p">,</span>
<span class="n">save_int8</span><span class="o">=</span><span class="k">True</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="5">5.保存预测模型<a class="headerlink" href="#5" title="Permanent link">#</a></h3>
<div class="codehilite"><pre><span></span><span class="nv">fluid</span>.<span class="nv">io</span>.<span class="nv">save_inference_model</span><span class="ss">(</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="nv">fluid</span>.<span class="nv">io</span>.<span class="nv">save_inference_model</span><span class="ss">(</span>
<span class="k">dirname</span><span class="o">=</span><span class="nv">float_path</span>,
<span class="nv">feeded_var_names</span><span class="o">=</span>[<span class="nv">image</span>.<span class="nv">name</span>],
<span class="nv">target_vars</span><span class="o">=</span>[<span class="nv">out</span>], <span class="nv">executor</span><span class="o">=</span><span class="nv">exe</span>,
......@@ -226,6 +264,7 @@
<span class="nv">model_filename</span><span class="o">=</span><span class="nv">int8_path</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="s">/model</span><span class="s1">&#39;</span>,
<span class="nv">params_filename</span><span class="o">=</span><span class="nv">int8_path</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="s">/params</span><span class="s1">&#39;</span><span class="ss">)</span>
</pre></div>
</td></tr></table>
</div>
</div>
......
......@@ -179,7 +179,16 @@
<p>以下将以 <code>基于skip-gram的word2vector模型</code> 为例来说明如何使用<code>quant_embedding</code>接口。首先介绍 <code>基于skip-gram的word2vector模型</code> 的正常训练和测试流程。</p>
<h2 id="skip-gramword2vector">基于skip-gram的word2vector模型<a class="headerlink" href="#skip-gramword2vector" title="Permanent link">#</a></h2>
<p>以下是本例的简要目录结构及说明:</p>
<div class="codehilite"><pre><span></span>.
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>.
├── cluster_train.py # 分布式训练函数
├── cluster_train.sh # 本地模拟多机脚本
├── train.py # 训练函数
......@@ -190,41 +199,59 @@
├── train.py # 训练函数
└── utils.py # 通用函数
</pre></div>
</td></tr></table>
<h3 id="_1">介绍<a class="headerlink" href="#_1" title="Permanent link">#</a></h3>
<p>本例实现了skip-gram模式的word2vector模型。</p>
<p>同时推荐用户参考<a href="https://aistudio.baidu.com/aistudio/projectDetail/124377"> IPython Notebook demo</a></p>
<h3 id="_2">数据下载<a class="headerlink" href="#_2" title="Permanent link">#</a></h3>
<p>全量数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark) 的数据集.</p>
<div class="codehilite"><pre><span></span>mkdir data
<p>全量数据集使用的是来自1 Billion Word Language Model Benchmark的(<a href="http://www.statmt.org/lm-benchmark">http://www.statmt.org/lm-benchmark</a>) 的数据集.</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>mkdir data
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output.tar.gz
mv <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
</pre></div>
</td></tr></table>
<p>备用数据地址下载命令如下</p>
<div class="codehilite"><pre><span></span>mkdir data
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>mkdir data
wget https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar xvf <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output.tar
mv <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
</pre></div>
</td></tr></table>
<p>为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下</p>
<div class="codehilite"><pre><span></span>mkdir data
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>mkdir data
wget https://paddlerec.bj.bcebos.com/word2vec/text.tar
tar xvf text.tar
mv text data/
</pre></div>
</td></tr></table>
<h3 id="_3">数据预处理<a class="headerlink" href="#_3" title="Permanent link">#</a></h3>
<p>以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。</p>
<p>词典格式: 词&lt;空格&gt;词频。注意低频词用'UNK'表示</p>
<p>可以按格式自建词典,如果自建词典跳过第一步。</p>
<div class="codehilite"><pre><span></span><span class="n">the</span> <span class="mi">1061396</span>
<p>可以按格式自建词典,如果自建词典跳过第一步。
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">the</span> <span class="mi">1061396</span>
<span class="k">of</span> <span class="mi">593677</span>
<span class="k">and</span> <span class="mi">416629</span>
<span class="n">one</span> <span class="mi">411764</span>
......@@ -235,55 +262,75 @@ mv text data/
<span class="n">zero</span> <span class="mi">264975</span>
<span class="n">nine</span> <span class="mi">250430</span>
</pre></div>
</td></tr></table></p>
<p>第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。</p>
<div class="codehilite"><pre><span></span>python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
</pre></div>
</td></tr></table>
<p>第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"<em>word_to_id</em>"。</p>
<div class="codehilite"><pre><span></span>python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count <span class="m">5</span> --downsample <span class="m">0</span>.001
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count <span class="m">5</span> --downsample <span class="m">0</span>.001
</pre></div>
</td></tr></table>
<h3 id="_4">训练<a class="headerlink" href="#_4" title="Permanent link">#</a></h3>
<p>具体的参数配置可运行</p>
<div class="codehilite"><pre><span></span>python train.py -h
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python train.py -h
</pre></div>
</td></tr></table>
<p>单机多线程训练</p>
<div class="codehilite"><pre><span></span><span class="nv">OPENBLAS_NUM_THREADS</span><span class="o">=</span><span class="m">1</span> <span class="nv">CPU_NUM</span><span class="o">=</span><span class="m">5</span> python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes <span class="m">10</span> --batch_size <span class="m">100</span> --model_output_dir v1_cpu5_b100_lr1dir --base_lr <span class="m">1</span>.0 --print_batch <span class="m">1000</span> --with_speed --is_sparse
<p>单机多线程训练
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="nv">OPENBLAS_NUM_THREADS</span><span class="o">=</span><span class="m">1</span> <span class="nv">CPU_NUM</span><span class="o">=</span><span class="m">5</span> python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes <span class="m">10</span> --batch_size <span class="m">100</span> --model_output_dir v1_cpu5_b100_lr1dir --base_lr <span class="m">1</span>.0 --print_batch <span class="m">1000</span> --with_speed --is_sparse
</pre></div>
</td></tr></table></p>
<p>本地单机模拟多机训练</p>
<div class="codehilite"><pre><span></span>sh cluster_train.sh
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>sh cluster_train.sh
</pre></div>
</td></tr></table>
<p>本示例中按照单机多线程训练的命令进行训练,训练完毕后,可看到在当前文件夹下保存模型的路径为: <code>v1_cpu5_b100_lr1dir</code>, 运行 <code>ls v1_cpu5_b100_lr1dir</code>可看到该文件夹下保存了训练的10个epoch的模型文件。</p>
<div class="codehilite"><pre><span></span><span class="n">pass</span><span class="o">-</span><span class="mi">0</span> <span class="n">pass</span><span class="o">-</span><span class="mi">1</span> <span class="n">pass</span><span class="o">-</span><span class="mi">2</span> <span class="n">pass</span><span class="o">-</span><span class="mi">3</span> <span class="n">pass</span><span class="o">-</span><span class="mi">4</span> <span class="n">pass</span><span class="o">-</span><span class="mi">5</span> <span class="n">pass</span><span class="o">-</span><span class="mi">6</span> <span class="n">pass</span><span class="o">-</span><span class="mi">7</span> <span class="n">pass</span><span class="o">-</span><span class="mi">8</span> <span class="n">pass</span><span class="o">-</span><span class="mi">9</span>
<p>本示例中按照单机多线程训练的命令进行训练,训练完毕后,可看到在当前文件夹下保存模型的路径为: <code>v1_cpu5_b100_lr1dir</code>, 运行 <code>ls v1_cpu5_b100_lr1dir</code>可看到该文件夹下保存了训练的10个epoch的模型文件。
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">pass</span><span class="o">-</span><span class="mi">0</span> <span class="n">pass</span><span class="o">-</span><span class="mi">1</span> <span class="n">pass</span><span class="o">-</span><span class="mi">2</span> <span class="n">pass</span><span class="o">-</span><span class="mi">3</span> <span class="n">pass</span><span class="o">-</span><span class="mi">4</span> <span class="n">pass</span><span class="o">-</span><span class="mi">5</span> <span class="n">pass</span><span class="o">-</span><span class="mi">6</span> <span class="n">pass</span><span class="o">-</span><span class="mi">7</span> <span class="n">pass</span><span class="o">-</span><span class="mi">8</span> <span class="n">pass</span><span class="o">-</span><span class="mi">9</span>
</pre></div>
</td></tr></table></p>
<h3 id="_5">预测<a class="headerlink" href="#_5" title="Permanent link">#</a></h3>
<p>测试集下载命令如下</p>
<div class="codehilite"><pre><span></span><span class="c1">#全量数据集测试集</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1">#全量数据集测试集</span>
wget https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
<span class="c1">#样本数据集测试集</span>
wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
</pre></div>
</td></tr></table>
<p>预测命令,注意词典名称需要加后缀"<em>word_to_id</em>", 此文件是预处理阶段生成的。</p>
<div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span>
<p>预测命令,注意词典名称需要加后缀"<em>word_to_id</em>", 此文件是预处理阶段生成的。
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span>
</pre></div>
<p>运行该预测命令, 可看到如下输出</p>
<div class="codehilite"><pre><span></span><span class="p">(</span><span class="s1">&#39;start index: &#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">&#39; last_index:&#39;</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
</td></tr></table>
运行该预测命令, 可看到如下输出
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="p">(</span><span class="s1">&#39;start index: &#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">&#39; last_index:&#39;</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="p">(</span><span class="s1">&#39;vocab_size:&#39;</span><span class="p">,</span> <span class="mi">63642</span><span class="p">)</span>
<span class="n">step</span><span class="p">:</span><span class="mi">1</span> <span class="mi">249</span>
<span class="n">epoch</span><span class="p">:</span><span class="mi">0</span> <span class="n">acc</span><span class="p">:</span><span class="mi">0</span><span class="p">.</span><span class="mi">014</span>
......@@ -306,24 +353,56 @@ wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
<span class="n">step</span><span class="p">:</span><span class="mi">1</span> <span class="mi">2722</span>
<span class="n">epoch</span><span class="p">:</span><span class="mi">9</span> <span class="n">acc</span><span class="p">:</span><span class="mi">0</span><span class="p">.</span><span class="mi">153</span>
</pre></div>
</td></tr></table></p>
<h2 id="skip-gramword2vector_1">量化<code>基于skip-gram的word2vector模型</code><a class="headerlink" href="#skip-gramword2vector_1" title="Permanent link">#</a></h2>
<p>量化配置为:</p>
<div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="err">{</span>
<p>量化配置为:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="err">{</span>
<span class="s1">&#39;params_name&#39;</span><span class="p">:</span> <span class="s1">&#39;emb&#39;</span><span class="p">,</span>
<span class="s1">&#39;quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span>
<span class="err">}</span>
</pre></div>
</td></tr></table></p>
<p>运行命令为:</p>
<div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span> --emb_quant True
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/ --start_index <span class="m">0</span> --last_index <span class="m">9</span> --emb_quant True
</pre></div>
</td></tr></table>
<p>运行输出为:</p>
<div class="codehilite"><pre><span></span><span class="p">(</span><span class="s1">&#39;start index: &#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">&#39; last_index:&#39;</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="p">(</span><span class="s1">&#39;start index: &#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">&#39; last_index:&#39;</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
<span class="p">(</span><span class="s1">&#39;vocab_size:&#39;</span><span class="p">,</span> <span class="mi">63642</span><span class="p">)</span>
<span class="n">quant_embedding</span> <span class="n">config</span> <span class="err">{</span><span class="s1">&#39;quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span><span class="p">,</span> <span class="s1">&#39;params_name&#39;</span><span class="p">:</span> <span class="s1">&#39;emb&#39;</span><span class="p">,</span> <span class="s1">&#39;quantize_bits&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span> <span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="s1">&#39;int8&#39;</span><span class="err">}</span>
<span class="n">step</span><span class="p">:</span><span class="mi">1</span> <span class="mi">253</span>
......@@ -356,7 +435,7 @@ wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
<span class="n">step</span><span class="p">:</span><span class="mi">1</span> <span class="mi">2719</span>
<span class="n">epoch</span><span class="p">:</span><span class="mi">9</span> <span class="n">acc</span><span class="p">:</span><span class="mi">0</span><span class="p">.</span><span class="mi">153</span>
</pre></div>
</td></tr></table>
<p>量化后的模型保存在<code>./output_quant</code>中,可看到量化后的参数<code>'emb.int8'</code>的大小为3.9M, 在<code>./v1_cpu5_b100_lr1dir</code>中可看到量化前的参数<code>'emb'</code>的大小为16M。</p>
......
......@@ -176,18 +176,16 @@
<p>首先在<a href="https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD">imagenet分类模型</a>中下载训练好的<code>mobilenetv1</code>模型。</p>
<p>在当前文件夹下创建<code>'pretrain'</code>文件夹,将<code>mobilenetv1</code>模型在该文件夹下解压,解压后的目录为<code>pretrain/MobileNetV1_pretrained</code></p>
<h3 id="_6">导出模型<a class="headerlink" href="#_6" title="Permanent link">#</a></h3>
<p>通过运行以下命令可将模型转化为离线量化接口可用的模型:</p>
<div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">export_model</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model &quot;MobileNet&quot; --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet</span>
<p>通过运行以下命令可将模型转化为离线量化接口可用的模型:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">export_model</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model &quot;MobileNet&quot; --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet</span>
</pre></div>
<p>转化之后的模型存储在<code>inference_model/MobileNet/</code>文件夹下,可看到该文件夹下有<code>'model'</code>, <code>'weights'</code>两个文件。</p>
</td></tr></table>
转化之后的模型存储在<code>inference_model/MobileNet/</code>文件夹下,可看到该文件夹下有<code>'model'</code>, <code>'weights'</code>两个文件。</p>
<h3 id="_7">离线量化<a class="headerlink" href="#_7" title="Permanent link">#</a></h3>
<p>接下来对导出的模型文件进行离线量化,离线量化的脚本为<a href="./quant_post.py">quant_post.py</a>,脚本中使用接口<code>paddleslim.quant.quant_post</code>对模型进行离线量化。运行命令为:</p>
<div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">quant_post</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights</span>
<p>接下来对导出的模型文件进行离线量化,离线量化的脚本为<a href="./quant_post.py">quant_post.py</a>,脚本中使用接口<code>paddleslim.quant.quant_post</code>对模型进行离线量化。运行命令为:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">quant_post</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights</span>
</pre></div>
</td></tr></table></p>
<ul>
<li><code>model_path</code>: 需要量化的模型坐在的文件夹</li>
<li><code>save_path</code>: 量化后的模型保存的路径</li>
......@@ -200,27 +198,24 @@
</blockquote>
<h3 id="_8">测试精度<a class="headerlink" href="#_8" title="Permanent link">#</a></h3>
<p>使用<a href="./eval.py">eval.py</a>脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。</p>
<p>首先测试量化前的模型的精度,运行以下命令:</p>
<div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">eval</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model_path ./inference_model/MobileNet --model_name model --params_name weights</span>
<p>首先测试量化前的模型的精度,运行以下命令:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">eval</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model_path ./inference_model/MobileNet --model_name model --params_name weights</span>
</pre></div>
<p>精度输出为:</p>
<div class="codehilite"><pre><span></span><span class="n">top1_acc</span><span class="o">/</span><span class="n">top5_acc</span><span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">.</span><span class="mi">70913923</span> <span class="mi">0</span><span class="p">.</span><span class="mi">89548034</span><span class="p">]</span>
</td></tr></table>
精度输出为:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">top1_acc</span><span class="o">/</span><span class="n">top5_acc</span><span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">.</span><span class="mi">70913923</span> <span class="mi">0</span><span class="p">.</span><span class="mi">89548034</span><span class="p">]</span>
</pre></div>
</td></tr></table></p>
<p>使用以下命令测试离线量化后的模型的精度:</p>
<div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">eval</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model_path ./quant_model_train/MobileNet</span>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">python</span> <span class="n">eval</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model_path ./quant_model_train/MobileNet</span>
</pre></div>
</td></tr></table>
<p>精度输出为</p>
<div class="codehilite"><pre><span></span><span class="n">top1_acc</span><span class="o">/</span><span class="n">top5_acc</span><span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">.</span><span class="mi">70141864</span> <span class="mi">0</span><span class="p">.</span><span class="mi">89086477</span><span class="p">]</span>
<p>精度输出为
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">top1_acc</span><span class="o">/</span><span class="n">top5_acc</span><span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">.</span><span class="mi">70141864</span> <span class="mi">0</span><span class="p">.</span><span class="mi">89086477</span><span class="p">]</span>
</pre></div>
<p>从以上精度对比可以看出,对<code>mobilenet</code><code>imagenet</code>上的分类模型进行离线量化后 <code>top1</code>精度损失为<code>0.77%</code><code>top5</code>精度损失为<code>0.46%</code>. </p>
</td></tr></table>
从以上精度对比可以看出,对<code>mobilenet</code><code>imagenet</code>上的分类模型进行离线量化后 <code>top1</code>精度损失为<code>0.77%</code><code>top5</code>精度损失为<code>0.46%</code>. </p>
</div>
</div>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册