index.html 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  
  <link rel="shortcut icon" href="../../img/favicon.ico">
  <title>Embedding量化 - PaddleSlim Docs</title>
  <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>

  <link rel="stylesheet" href="../../css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
  <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
17
  <link href="../../extra.css" rel="stylesheet">
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
  
  <script>
    // Current page data
    var mkdocs_page_name = "Embedding\u91cf\u5316";
    var mkdocs_page_input_path = "tutorials/quant_embedding_demo.md";
    var mkdocs_page_url = null;
  </script>
  
  <script src="../../js/jquery-2.1.1.min.js" defer></script>
  <script src="../../js/modernizr-2.8.3.min.js" defer></script>
  <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
  <script>hljs.initHighlightingOnLoad();</script> 
  
</head>

<body class="wy-body-for-nav" role="document">

  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
      <div class="wy-side-nav-search">
        <a href="../.." class="icon icon-home"> PaddleSlim Docs</a>
        <div role="search">
  <form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" title="Type search term here" />
  </form>
</div>
      </div>

      <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
	<ul class="current">
	  
          
            <li class="toctree-l1">
		
    <a class="" href="../..">Home</a>
	    </li>
          
            <li class="toctree-l1">
		
59 60 61 62 63
    <a class="" href="../../model_zoo/">模型库</a>
	    </li>
          
            <li class="toctree-l1">
		
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    <span class="caption-text">教程</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../quant_post_demo/">离线量化</a>
                </li>
                <li class="">
                    
    <a class="" href="../quant_aware_demo/">量化训练</a>
                </li>
                <li class=" current">
                    
    <a class="current" href="./">Embedding量化</a>
    <ul class="subnav">
            
    <li class="toctree-l3"><a href="#embedding">Embedding量化示例</a></li>
    
        <ul>
        
            <li><a class="toctree-l4" href="#skip-gramword2vector">基于skip-gram的word2vector模型</a></li>
        
            <li><a class="toctree-l4" href="#skip-gramword2vector_1">量化基于skip-gram的word2vector模型</a></li>
        
        </ul>
    

    </ul>
                </li>
                <li class="">
                    
    <a class="" href="../nas_demo/">SA搜索</a>
                </li>
96 97
                <li class="">
                    
98
    <a class="" href="../../search_space/">搜索空间</a>
99 100 101
                </li>
                <li class="">
                    
102 103
    <a class="" href="../distillation_demo/">知识蒸馏</a>
                </li>
104 105
    </ul>
	    </li>
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
          
            <li class="toctree-l1">
		
    <span class="caption-text">API</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../api/quantization_api/">量化</a>
                </li>
                <li class="">
                    
    <a class="" href="../../api/prune_api/">剪枝与敏感度</a>
                </li>
                <li class="">
                    
    <a class="" href="../../api/analysis_api/">模型分析</a>
                </li>
                <li class="">
                    
    <a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
                </li>
                <li class="">
                    
    <a class="" href="../../api/nas_api/">SA搜索</a>
                </li>
                <li class="">
                    
    <a class="" href="../../table_latency/">硬件延时评估表</a>
                </li>
    </ul>
	    </li>
137 138 139
          
            <li class="toctree-l1">
		
140
    <a class="" href="../../algo/algo/">算法原理</a>
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
	    </li>
          
        </ul>
      </div>
      &nbsp;
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../..">PaddleSlim Docs</a>
      </nav>

      
      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
    <li><a href="../..">Docs</a> &raquo;</li>
    
      
        
          <li>教程 &raquo;</li>
        
      
    
    <li>Embedding量化</li>
    <li class="wy-breadcrumbs-aside">
      
172
        <a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/tutorials/quant_embedding_demo.md"
173 174 175 176 177 178 179 180 181 182
          class="icon icon-github"> Edit on GitHub</a>
      
    </li>
  </ul>
  <hr/>
</div>
          <div role="main">
            <div class="section">
              
                <h1 id="embedding">Embedding量化示例<a class="headerlink" href="#embedding" title="Permanent link">#</a></h1>
183 184
<p>本示例介绍如何使用Embedding量化的接口 <a href="">paddleslim.quant.quant_embedding</a><code>quant_embedding</code>接口将网络中的Embedding参数从<code>float32</code>类型量化到 <code>8-bit</code>整数类型,在几乎不损失模型精度的情况下减少模型的存储空间和显存占用。</p>
<p>接口介绍请参考 <a href='../../../paddleslim/quant/quantization_api_doc.md'>量化API文档</a></p>
185 186 187
<p>该接口对program的修改:</p>
<p>量化前:</p>
<p align="center">
188
<img src="./image/before.png" height=200 width=100 hspace='10'/> <br />
189 190 191 192 193
<strong>图1:量化前的模型结构</strong>
</p>

<p>量化后:</p>
<p align="center">
194
<img src="./image/after.png" height=300 width=300 hspace='10'/> <br />
195 196 197 198 199 200
<strong>图2: 量化后的模型结构</strong>
</p>

<p>以下将以 <code>基于skip-gram的word2vector模型</code> 为例来说明如何使用<code>quant_embedding</code>接口。首先介绍 <code>基于skip-gram的word2vector模型</code> 的正常训练和测试流程。</p>
<h2 id="skip-gramword2vector">基于skip-gram的word2vector模型<a class="headerlink" href="#skip-gramword2vector" title="Permanent link">#</a></h2>
<p>以下是本例的简要目录结构及说明:</p>
201
<div class="codehilite"><pre><span></span>.
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
├── cluster_train.py    # 分布式训练函数
├── cluster_train.sh    # 本地模拟多机脚本
├── train.py            # 训练函数
├── infer.py            # 预测脚本
├── net.py              # 网络结构
├── preprocess.py       # 预处理脚本,包括构建词典和预处理文本
├── reader.py           # 训练阶段的文本读写
├── train.py            # 训练函数
└── utils.py            # 通用函数
</pre></div>

<h3 id="_1">介绍<a class="headerlink" href="#_1" title="Permanent link">#</a></h3>
<p>本例实现了skip-gram模式的word2vector模型。</p>
<p>同时推荐用户参考<a href="https://aistudio.baidu.com/aistudio/projectDetail/124377"> IPython Notebook demo</a></p>
<h3 id="_2">数据下载<a class="headerlink" href="#_2" title="Permanent link">#</a></h3>
217
<p>全量数据集使用的是来自1 Billion Word Language Model Benchmark的(<a href="http://www.statmt.org/lm-benchmark">http://www.statmt.org/lm-benchmark</a>) 的数据集.</p>
218
<div class="codehilite"><pre><span></span>mkdir data
219 220 221 222 223 224
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output.tar.gz
mv <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
</pre></div>

<p>备用数据地址下载命令如下</p>
225
<div class="codehilite"><pre><span></span>mkdir data
226 227 228 229 230 231
wget https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar xvf <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output.tar
mv <span class="m">1</span>-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
</pre></div>

<p>为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下</p>
232
<div class="codehilite"><pre><span></span>mkdir data
233 234 235 236 237 238 239 240
wget https://paddlerec.bj.bcebos.com/word2vec/text.tar
tar xvf text.tar
mv text data/
</pre></div>

<h3 id="_3">数据预处理<a class="headerlink" href="#_3" title="Permanent link">#</a></h3>
<p>以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。</p>
<p>词典格式: 词&lt;空格&gt;词频。注意低频词用'UNK'表示</p>
241
<p>可以按格式自建词典,如果自建词典跳过第一步。
242
<div class="codehilite"><pre><span></span>the 1061396
243 244 245 246 247 248 249 250 251
of 593677
and 416629
one 411764
in 372201
a 325873
&lt;UNK&gt; 324608
to 316376
zero 264975
nine 250430
252
</pre></div></p>
253
<p>第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。</p>
254
<div class="codehilite"><pre><span></span>python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
255 256 257
</pre></div>

<p>第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"<em>word_to_id</em>"。</p>
258
<div class="codehilite"><pre><span></span>python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count <span class="m">5</span> --downsample <span class="m">0</span>.001
259 260 261 262
</pre></div>

<h3 id="_4">训练<a class="headerlink" href="#_4" title="Permanent link">#</a></h3>
<p>具体的参数配置可运行</p>
263
<div class="codehilite"><pre><span></span>python train.py -h
264 265
</pre></div>

266
<p>单机多线程训练
267 268
<div class="codehilite"><pre><span></span><span class="nv">OPENBLAS_NUM_THREADS</span><span class="o">=</span><span class="m">1</span> <span class="nv">CPU_NUM</span><span class="o">=</span><span class="m">5</span> python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes <span class="m">10</span> --batch_size <span class="m">100</span> --model_output_dir v1_cpu5_b100_lr1dir --base_lr <span class="m">1</span>.0 --print_batch <span class="m">1000</span> --with_speed --is_sparse
</pre></div></p>
269
<p>本地单机模拟多机训练</p>
270
<div class="codehilite"><pre><span></span>sh cluster_train.sh
271 272
</pre></div>

273
<p>本示例中按照单机多线程训练的命令进行训练,训练完毕后,可看到在当前文件夹下保存模型的路径为:     <code>v1_cpu5_b100_lr1dir</code>, 运行 <code>ls v1_cpu5_b100_lr1dir</code>可看到该文件夹下保存了训练的10个epoch的模型文件。
274 275
<div class="codehilite"><pre><span></span>pass-0  pass-1  pass-2  pass-3  pass-4  pass-5  pass-6  pass-7  pass-8  pass-9
</pre></div></p>
276 277
<h3 id="_5">预测<a class="headerlink" href="#_5" title="Permanent link">#</a></h3>
<p>测试集下载命令如下</p>
278
<div class="codehilite"><pre><span></span><span class="c1">#全量数据集测试集</span>
279 280 281 282 283
wget https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
<span class="c1">#样本数据集测试集</span>
wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
</pre></div>

284
<p>预测命令,注意词典名称需要加后缀"<em>word_to_id</em>", 此文件是预处理阶段生成的。
285
<div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/  --start_index <span class="m">0</span> --last_index <span class="m">9</span>
286
</pre></div>
287
运行该预测命令, 可看到如下输出
288
<div class="codehilite"><pre><span></span>(&#39;start index: &#39;, 0, &#39; last_index:&#39;, 9)
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
(&#39;vocab_size:&#39;, 63642)
step:1 249
epoch:0          acc:0.014
step:1 590
epoch:1          acc:0.033
step:1 982
epoch:2          acc:0.055
step:1 1338
epoch:3          acc:0.075
step:1 1653
epoch:4          acc:0.093
step:1 1914
epoch:5          acc:0.107
step:1 2204
epoch:6          acc:0.124
step:1 2416
epoch:7          acc:0.136
step:1 2606
epoch:8          acc:0.146
step:1 2722
epoch:9          acc:0.153
310
</pre></div></p>
311
<h2 id="skip-gramword2vector_1">量化<code>基于skip-gram的word2vector模型</code><a class="headerlink" href="#skip-gramword2vector_1" title="Permanent link">#</a></h2>
312
<p>量化配置为:
313
<div class="codehilite"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="p">{</span>
314 315 316
        <span class="s1">&#39;params_name&#39;</span><span class="p">:</span> <span class="s1">&#39;emb&#39;</span><span class="p">,</span>
        <span class="s1">&#39;quantize_type&#39;</span><span class="p">:</span> <span class="s1">&#39;abs_max&#39;</span>
        <span class="p">}</span>
317
</pre></div></p>
318
<p>运行命令为:</p>
319
<div class="codehilite"><pre><span></span>python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size <span class="m">20000</span> --model_dir v1_cpu5_b100_lr1dir/  --start_index <span class="m">0</span> --last_index <span class="m">9</span> --emb_quant True
320 321 322
</pre></div>

<p>运行输出为:</p>
323
<div class="codehilite"><pre><span></span>(&#39;start index: &#39;, 0, &#39; last_index:&#39;, 9)
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
(&#39;vocab_size:&#39;, 63642)
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 253
epoch:0          acc:0.014
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 586
epoch:1          acc:0.033
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 970
epoch:2          acc:0.054
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 1364
epoch:3          acc:0.077
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 1642
epoch:4          acc:0.092
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 1936
epoch:5          acc:0.109
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 2216
epoch:6          acc:0.124
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 2419
epoch:7          acc:0.136
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 2603
epoch:8          acc:0.146
quant_embedding config {&#39;quantize_type&#39;: &#39;abs_max&#39;, &#39;params_name&#39;: &#39;emb&#39;, &#39;quantize_bits&#39;: 8, &#39;dtype&#39;: &#39;int8&#39;}
step:1 2719
epoch:9          acc:0.153
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
</pre></div>

<p>量化后的模型保存在<code>./output_quant</code>中,可看到量化后的参数<code>'emb.int8'</code>的大小为3.9M, 在<code>./v1_cpu5_b100_lr1dir</code>中可看到量化前的参数<code>'emb'</code>的大小为16M。</p>
              
            </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="../nas_demo/" class="btn btn-neutral float-right" title="SA搜索">Next <span class="icon icon-circle-arrow-right"></span></a>
      
      
        <a href="../quant_aware_demo/" class="btn btn-neutral" title="量化训练"><span class="icon icon-circle-arrow-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <!-- Copyright etc -->
    
  </div>

  Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
      
        </div>
      </div>

    </section>

  </div>

  <div class="rst-versions" role="note" style="cursor: pointer">
    <span class="rst-current-version" data-toggle="rst-current-version">
      
          <a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
      
      
        <span><a href="../quant_aware_demo/" style="color: #fcfcfc;">&laquo; Previous</a></span>
      
      
        <span style="margin-left: 15px"><a href="../nas_demo/" style="color: #fcfcfc">Next &raquo;</a></span>
      
    </span>
</div>
    <script>var base_url = '../..';</script>
    <script src="../../js/theme.js" defer></script>
      <script src="../../mathjax-config.js" defer></script>
406
      <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
407 408 409 410
      <script src="../../search/main.js" defer></script>

</body>
</html>