提交 53a059d6 编写于 作者: B baiyfbupt

Deployed 0f9c3ceb with MkDocs version: 1.0.4

上级 1f1df4df
......@@ -66,6 +66,10 @@
<a class="" href="/tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="/tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -79,15 +83,15 @@
</li>
<li class="">
<a class="" href="/api/prune_api/">剪枝</a>
<a class="" href="/api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="/api/analysis_api/">敏感度分析</a>
<a class="" href="/api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="/api/single_distiller_api/">蒸馏</a>
<a class="" href="/api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -97,9 +101,18 @@
<a class="" href="/api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="/table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="/algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -170,7 +183,7 @@
<script>var base_url = '/';</script>
<script src="/js/theme.js" defer></script>
<script src="/mathjax-config.js" defer></script>
<script src="/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="/search/main.js" defer></script>
</body>
......
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="../../img/favicon.ico">
<title>算法原理 - PaddleSlim Docs</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<script>
// Current page data
var mkdocs_page_name = "\u7b97\u6cd5\u539f\u7406";
var mkdocs_page_input_path = "algo/algo.md";
var mkdocs_page_url = null;
</script>
<script src="../../js/jquery-2.1.1.min.js" defer></script>
<script src="../../js/modernizr-2.8.3.min.js" defer></script>
<script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
<script>hljs.initHighlightingOnLoad();</script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-nav-search">
<a href="../.." class="icon icon-home"> PaddleSlim Docs</a>
<div role="search">
<form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" title="Type search term here" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul class="current">
<li class="toctree-l1">
<a class="" href="../..">Home</a>
</li>
<li class="toctree-l1">
<span class="caption-text">教程</span>
<ul class="subnav">
<li class="">
<a class="" href="../../tutorials/quant_post_demo/">离线量化</a>
</li>
<li class="">
<a class="" href="../../tutorials/quant_aware_demo/">量化训练</a>
</li>
<li class="">
<a class="" href="../../tutorials/quant_embedding_demo/">Embedding量化</a>
</li>
<li class="">
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<span class="caption-text">API</span>
<ul class="subnav">
<li class="">
<a class="" href="../../api/quantization_api/">量化</a>
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
<a class="" href="../../api/nas_api/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1 current">
<a class="current" href="./">算法原理</a>
<ul class="subnav">
<li class="toctree-l2"><a href="#_1">目录</a></li>
<li class="toctree-l2"><a href="#1-quantization-aware-training">1. Quantization Aware Training量化介绍</a></li>
<ul>
<li><a class="toctree-l3" href="#11">1.1 背景</a></li>
<li><a class="toctree-l3" href="#12">1.2 量化原理</a></li>
</ul>
<li class="toctree-l2"><a href="#2">2. 卷积核剪裁原理</a></li>
<ul>
<li><a class="toctree-l3" href="#21">2.1 剪裁卷积核</a></li>
<li><a class="toctree-l3" href="#22-uniform">2.2 Uniform剪裁卷积网络</a></li>
<li><a class="toctree-l3" href="#23">2.3 基于敏感度剪裁卷积网络</a></li>
</ul>
<li class="toctree-l2"><a href="#3">3. 蒸馏</a></li>
<li class="toctree-l2"><a href="#4">4. 轻量级模型结构搜索</a></li>
<ul>
<li><a class="toctree-l3" href="#41">4.1 搜索策略</a></li>
<li><a class="toctree-l3" href="#42">4.2 搜索空间</a></li>
<li><a class="toctree-l3" href="#43">4.3 模型延时评估</a></li>
</ul>
<li class="toctree-l2"><a href="#5">5. 参考文献</a></li>
</ul>
</li>
</ul>
</div>
&nbsp;
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../..">PaddleSlim Docs</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../..">Docs</a> &raquo;</li>
<li>算法原理</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/algo/algo.md"
class="icon icon-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main">
<div class="section">
<h2 id="_1">目录<a class="headerlink" href="#_1" title="Permanent link">#</a></h2>
<ul>
<li><a href="#1-quantization-aware-training量化介绍">量化原理介绍</a></li>
<li><a href="#2-卷积核剪裁原理">剪裁原理介绍</a></li>
<li><a href="#3-蒸馏">蒸馏原理介绍</a></li>
<li><a href="#4-轻量级模型结构搜索">轻量级模型结构搜索原理介绍</a></li>
</ul>
<h2 id="1-quantization-aware-training">1. Quantization Aware Training量化介绍<a class="headerlink" href="#1-quantization-aware-training" title="Permanent link">#</a></h2>
<h3 id="11">1.1 背景<a class="headerlink" href="#11" title="Permanent link">#</a></h3>
<p>近年来,定点量化使用更少的比特数(如8-bit、3-bit、2-bit等)表示神经网络的权重和激活已被验证是有效的。定点量化的优点包括低内存带宽、低功耗、低计算资源占用以及低模型存储需求等。</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/quan_table_0.png" height=258 width=600 hspace='10'/> <br />
<strong>表1: 不同类型操作的开销对比</strong>
</p>
<p>由表1可知,低精度定点数操作的硬件面积大小及能耗比高精度浮点数要少几个数量级。 使用定点量化可带来4倍的模型压缩、4倍的内存带宽提升,以及更高效的cache利用(很多硬件设备,内存访问是主要能耗)。除此之外,计算速度也会更快(通常具有2x-3x的性能提升)。由表2可知,在很多场景下,定点量化操作对精度并不会造成损失。另外,定点量化对神经网络于嵌入式设备上的推断来说是极其重要的。</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/quan_table_1.png" height=155 width=500 hspace='10'/> <br />
<strong>表2:模型量化前后精度对比</strong>
</p>
<p>目前,学术界主要将量化分为两大类:<code>Post Training Quantization</code><code>Quantization Aware Training</code><code>Post Training Quantization</code>是指使用KL散度、滑动平均等方法确定量化参数且不需要重新训练的定点量化方法。<code>Quantization Aware Training</code>是在训练过程中对量化进行建模以确定量化参数,它与<code>Post Training Quantization</code>模式相比可以提供更高的预测精度。</p>
<h3 id="12">1.2 量化原理<a class="headerlink" href="#12" title="Permanent link">#</a></h3>
<h4 id="121">1.2.1 量化方式<a class="headerlink" href="#121" title="Permanent link">#</a></h4>
<p>目前,存在着许多方法可以将浮点数量化成定点数。例如:
<script type="math/tex; mode=display"> r = min(max(x, a), b)</script>
<script type="math/tex; mode=display"> s = \frac{b - a}{n - 1} </script>
<script type="math/tex; mode=display"> q = \left \lfloor \frac{r - a}{s} \right \rceil </script>
式中,<span><span class="MathJax_Preview">x</span><script type="math/tex">x</script></span>是待量化的浮点值,<span><span class="MathJax_Preview">[a, b]</span><script type="math/tex">[a, b]</script></span>是量化范围,<span><span class="MathJax_Preview">a</span><script type="math/tex">a</script></span>是待量化浮点数中的最小值, <span><span class="MathJax_Preview">b</span><script type="math/tex">b</script></span> 是待量化浮点数中的最大值。<span><span class="MathJax_Preview">\left \lfloor \right \rceil</span><script type="math/tex">\left \lfloor \right \rceil</script></span> 表示将结果四舍五入到最近的整数。如果量化级别为<span><span class="MathJax_Preview">k</span><script type="math/tex">k</script></span>,则<span><span class="MathJax_Preview">n</span><script type="math/tex">n</script></span><span><span class="MathJax_Preview">2^k</span><script type="math/tex">2^k</script></span>。例如,若<span><span class="MathJax_Preview">k</span><script type="math/tex">k</script></span>为8,则<span><span class="MathJax_Preview">n</span><script type="math/tex">n</script></span>为256。<span><span class="MathJax_Preview">q</span><script type="math/tex">q</script></span>是量化得到的整数。
PaddleSlim框架中选择的量化方法为最大绝对值量化(<code>max-abs</code>),具体描述如下:
<script type="math/tex; mode=display"> M = max(abs(x)) </script>
<script type="math/tex; mode=display"> q = \left \lfloor \frac{x}{M} * (n - 1) \right \rceil </script>
式中,<span><span class="MathJax_Preview">x</span><script type="math/tex">x</script></span>是待被量化的浮点值,<span><span class="MathJax_Preview">M</span><script type="math/tex">M</script></span>是待量化浮点数中的绝对值最大值。<span><span class="MathJax_Preview">\left \lfloor \right \rceil</span><script type="math/tex">\left \lfloor \right \rceil</script></span>表示将结果四舍五入到最近的整数。对于8bit量化,PaddleSlim采用<code>int8_t</code>,即<span><span class="MathJax_Preview">n=2^7=128</span><script type="math/tex">n=2^7=128</script></span><span><span class="MathJax_Preview">q</span><script type="math/tex">q</script></span>是量化得到的整数。
无论是<code>min-max量化</code>还是<code>max-abs量化</code>,他们都可以表示为如下形式:
<span><span class="MathJax_Preview">q = scale * r + b</span><script type="math/tex">q = scale * r + b</script></span>
其中<code>min-max</code><code>max-abs</code>被称为量化参数或者量化比例或者量化范围。</p>
<h4 id="122">1.2.2 量化训练<a class="headerlink" href="#122" title="Permanent link">#</a></h4>
<h5 id="1221">1.2.2.1 前向传播<a class="headerlink" href="#1221" title="Permanent link">#</a></h5>
<p>前向传播过程采用模拟量化的方式,具体描述如下:</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/quan_forward.png" height=433 width=335 hspace='10'/> <br />
<strong>图1:基于模拟量化训练的前向过程</strong>
</p>
<p>由图1可知,基于模拟量化训练的前向过程可被描述为以下四个部分:
1) 输入和权重均被量化成8-bit整数。
2) 在8-bit整数上执行矩阵乘法或者卷积操作。
3) 反量化矩阵乘法或者卷积操作的输出结果为32-bit浮点型数据。
4) 在32-bit浮点型数据上执行偏置加法操作。此处,偏置并未被量化。
对于通用矩阵乘法(<code>GEMM</code>),输入<span><span class="MathJax_Preview">X</span><script type="math/tex">X</script></span>和权重<span><span class="MathJax_Preview">W</span><script type="math/tex">W</script></span>的量化操作可被表述为如下过程:
<script type="math/tex; mode=display"> X_q = \left \lfloor \frac{X}{X_m} * (n - 1) \right \rceil </script>
<script type="math/tex; mode=display"> W_q = \left \lfloor \frac{W}{W_m} * (n - 1) \right \rceil </script>
执行通用矩阵乘法:
<script type="math/tex; mode=display"> Y_q = X_q * W_q </script>
对量化乘积结果<span><span class="MathJax_Preview">Yq</span><script type="math/tex">Yq</script></span>进行反量化:
<script type="math/tex; mode=display">
\begin{align}
Y_{dq} = \frac{Y_q}{(n - 1) * (n - 1)} * X_m * W_m \
=\frac{X_q * W_q}{(n - 1) * (n - 1)} * X_m * W_m \
=(\frac{X_q}{n - 1} * X_m) * (\frac{W_q}{n - 1} * W_m) \
\end{align}
</script>
上述公式表明反量化操作可以被移动到<code>GEMM</code>之前,即先对<span><span class="MathJax_Preview">Xq</span><script type="math/tex">Xq</script></span><span><span class="MathJax_Preview">Wq</span><script type="math/tex">Wq</script></span>执行反量化操作再做<code>GEMM</code>操作。因此,前向传播的工作流亦可表示为如下方式:</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/quan_fwd_1.png" height=435 width=341 hspace='10'/> <br />
<strong>图2:基于模拟量化训练前向过程的等价工作流</strong>
</p>
<p>训练过程中,PaddleSlim使用图2中所示的等价工作流。在设计中,量化Pass在IrGraph中插入量化op和反量化op。因为在连续的量化、反量化操作之后输入仍然为32-bit浮点型数据。因此,PaddleSlim量化训练框架所采用的量化方式被称为模拟量化。</p>
<h5 id="1222">1.2.2.2 反向传播<a class="headerlink" href="#1222" title="Permanent link">#</a></h5>
<p>由图3可知,权重更新所需的梯度值可以由量化后的权重和量化后的激活求得。反向传播过程中的所有输入和输出均为32-bit浮点型数据。注意,梯度更新操作需要在原始权重上进行,即计算出的梯度将被加到原始权重上而非量化后或反量化后的权重上。</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/quan_bwd.png" height=300 width=650 hspace='10'/> <br />
<strong>图3:基于模拟量化训练的反向传播和权重更新过程</strong>
</p>
<p>因此,量化Pass也会改变相应反向算子的某些输入。</p>
<h5 id="1223">1.2.2.3 确定量化比例系数<a class="headerlink" href="#1223" title="Permanent link">#</a></h5>
<p>存在着两种策略可以计算求取量化比例系数,即动态策略和静态策略。动态策略会在每次迭代过程中计算量化比例系数的值。静态策略则对不同的输入采用相同的量化比例系数。
对于权重而言,在训练过程中采用动态策略。换句话说,在每次迭代过程中量化比例系数均会被重新计算得到直至训练过程结束。
对于激活而言,可以选择动态策略也可以选择静态策略。若选择使用静态策略,则量化比例系数会在训练过程中被评估求得,且在推断过程中被使用(不同的输入均保持不变)。静态策略中的量化比例系数可于训练过程中通过如下三种方式进行评估:</p>
<ol>
<li>
<p>在一个窗口中计算激活最大绝对值的平均值。</p>
</li>
<li>
<p>在一个窗口中计算激活最大绝对值的最大值。</p>
</li>
<li>
<p>在一个窗口中计算激活最大绝对值的滑动平均值,计算公式如下:</p>
</li>
</ol>
<div>
<div class="MathJax_Preview"> Vt = (1 - k) * V + k * V_{t-1} </div>
<script type="math/tex; mode=display"> Vt = (1 - k) * V + k * V_{t-1} </script>
</div>
<p>式中,<span><span class="MathJax_Preview">V</span><script type="math/tex">V</script></span> 是当前batch的最大绝对值, <span><span class="MathJax_Preview">Vt</span><script type="math/tex">Vt</script></span>是滑动平均值。<span><span class="MathJax_Preview">k</span><script type="math/tex">k</script></span>是一个因子,例如其值可取为0.9。</p>
<h4 id="124">1.2.4 训练后量化<a class="headerlink" href="#124" title="Permanent link">#</a></h4>
<p>训练后量化是基于采样数据,采用KL散度等方法计算量化比例因子的方法。相比量化训练,训练后量化不需要重新训练,可以快速得到量化模型。</p>
<p>训练后量化的目标是求取量化比例因子,主要有两种方法:非饱和量化方法 ( No Saturation) 和饱和量化方法 (Saturation)。非饱和量化方法计算FP32类型Tensor中绝对值的最大值<code>abs_max</code>,将其映射为127,则量化比例因子等于<code>abs_max/127</code>。饱和量化方法使用KL散度计算一个合适的阈值<code>T</code> (<code>0&lt;T&lt;mab_max</code>),将其映射为127,则量化比例因子等于<code>T/127</code>。一般而言,对于待量化op的权重Tensor,采用非饱和量化方法,对于待量化op的激活Tensor(包括输入和输出),采用饱和量化方法 。</p>
<p>训练后量化的实现步骤如下:
* 加载预训练的FP32模型,配置<code>DataLoader</code>
* 读取样本数据,执行模型的前向推理,保存待量化op激活Tensor的数值;
* 基于激活Tensor的采样数据,使用饱和量化方法计算它的量化比例因子;
* 模型权重Tensor数据一直保持不变,使用非饱和方法计算它每个通道的绝对值最大值,作为每个通道的量化比例因子;
* 将FP32模型转成INT8模型,进行保存。</p>
<h2 id="2">2. 卷积核剪裁原理<a class="headerlink" href="#2" title="Permanent link">#</a></h2>
<p>该策略参考paper: <a href="https://arxiv.org/pdf/1608.08710.pdf">Pruning Filters for Efficient ConvNets</a></p>
<p>该策略通过减少卷积层中卷积核的数量,来减小模型大小和降低模型计算复杂度。</p>
<h3 id="21">2.1 剪裁卷积核<a class="headerlink" href="#21" title="Permanent link">#</a></h3>
<p><strong>剪裁注意事项1</strong>
剪裁一个conv layer的filter,需要修改后续conv layer的filter. 如**图4**所示,剪掉Xi的一个filter,会导致<span><span class="MathJax_Preview">X_{i+1}</span><script type="math/tex">X_{i+1}</script></span>少一个channel, <span><span class="MathJax_Preview">X_{i+1}</span><script type="math/tex">X_{i+1}</script></span>对应的filter在input_channel纬度上也要减1.</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/pruning_0.png" height=200 width=600 hspace='10'/> <br />
<strong>图4</strong>
</p>
<p><strong>剪裁注意事项2</strong></p>
<p>如**图5**所示,剪裁完<span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>之后,根据注意事项1我们从<span><span class="MathJax_Preview">X_{i+1}</span><script type="math/tex">X_{i+1}</script></span>的filter中删除了一行(图中蓝色行),在计算<span><span class="MathJax_Preview">X_{i+1}</span><script type="math/tex">X_{i+1}</script></span>的filters的l1_norm(图中绿色一列)的时候,有两种选择:
算上被删除的一行:independent pruning
减去被删除的一行:greedy pruning</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/pruning_1.png" height=200 width=450 hspace='10'/> <br />
<strong>图5</strong>
</p>
<p><strong>剪裁注意事项3</strong>
在对ResNet等复杂网络剪裁的时候,还要考虑到后当前卷积层的修改对上一层卷积层的影响。
如**图6**所示,在对residual block剪裁时,<span><span class="MathJax_Preview">X_{i+1}</span><script type="math/tex">X_{i+1}</script></span>层如何剪裁取决于project shortcut的剪裁结果,因为我们要保证project shortcut的output和<span><span class="MathJax_Preview">X_{i+1}</span><script type="math/tex">X_{i+1}</script></span>的output能被正确的concat.</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/pruning_2.png" height=240 width=600 hspace='10'/> <br />
<strong>图6</strong>
</p>
<h3 id="22-uniform">2.2 Uniform剪裁卷积网络<a class="headerlink" href="#22-uniform" title="Permanent link">#</a></h3>
<p>每层剪裁一样比例的卷积核。
在剪裁一个卷积核之前,按l1_norm对filter从高到低排序,越靠后的filter越不重要,优先剪掉靠后的filter.</p>
<h3 id="23">2.3 基于敏感度剪裁卷积网络<a class="headerlink" href="#23" title="Permanent link">#</a></h3>
<p>根据每个卷积层敏感度的不同,剪掉不同比例的卷积核。</p>
<h4 id="_2">两个假设<a class="headerlink" href="#_2" title="Permanent link">#</a></h4>
<ul>
<li>在一个conv layer的parameter内部,按l1_norm对filter从高到低排序,越靠后的filter越不重要。</li>
<li>两个layer剪裁相同的比例的filters,我们称对模型精度影响更大的layer的敏感度相对高。</li>
</ul>
<h4 id="filter">剪裁filter的指导原则<a class="headerlink" href="#filter" title="Permanent link">#</a></h4>
<ul>
<li>layer的剪裁比例与其敏感度成反比</li>
<li>优先剪裁layer内l1_norm相对低的filter</li>
</ul>
<h4 id="_3">敏感度的理解<a class="headerlink" href="#_3" title="Permanent link">#</a></h4>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/pruning_3.png" height=200 width=400 hspace='10'/> <br />
<strong>图7</strong>
</p>
<p>如**图7**所示,横坐标是将filter剪裁掉的比例,竖坐标是精度的损失,每条彩色虚线表示的是网络中的一个卷积层。
以不同的剪裁比例**单独**剪裁一个卷积层,并观察其在验证数据集上的精度损失,并绘出**图7**中的虚线。虚线上升较慢的,对应的卷积层相对不敏感,我们优先剪不敏感的卷积层的filter.</p>
<h4 id="_4">选择最优的剪裁率组合<a class="headerlink" href="#_4" title="Permanent link">#</a></h4>
<p>我们将**图7**中的折线拟合为**图8**中的曲线,每在竖坐标轴上选取一个精度损失值,就在横坐标轴上对应着一组剪裁率,如**图8**中黑色实线所示。
用户给定一个模型整体的剪裁率,我们通过移动**图5**中的黑色实线来找到一组满足条件的且合法的剪裁率。</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/pruning_4.png" height=200 width=400 hspace='10'/> <br />
<strong>图8</strong>
</p>
<h4 id="_5">迭代剪裁<a class="headerlink" href="#_5" title="Permanent link">#</a></h4>
<p>考虑到多个卷积层间的相关性,一个卷积层的修改可能会影响其它卷积层的敏感度,我们采取了多次剪裁的策略,步骤如下:</p>
<ul>
<li>step1: 统计各卷积层的敏感度信息</li>
<li>step2: 根据当前统计的敏感度信息,对每个卷积层剪掉少量filter, 并统计FLOPS,如果FLOPS已满足要求,进入step4,否则进行step3。</li>
<li>step3: 对网络进行简单的fine-tune,进入step1</li>
<li>step4: fine-tune训练至收敛</li>
</ul>
<h2 id="3">3. 蒸馏<a class="headerlink" href="#3" title="Permanent link">#</a></h2>
<p>一般情况下,模型参数量越多,结构越复杂,其性能越好,但参数也越允余,运算量和资源消耗也越大;模型蒸馏是将复杂网络中的有用信息将复杂网络中的有用信息提取出来提取出来,迁移到一个更小的网络中去,在我们的工具包中,支持两种蒸馏的方法。
第一种是传统的蒸馏方法(参考论文:<a href="https://arxiv.org/pdf/1503.02531.pdf">Distilling the Knowledge in a Neural Network</a>
使用复杂的网络作为teacher模型去监督训练一个参数量和运算量更少的student模型。teacher模型可以是一个或者多个提前训练好的高性能模型。student模型的训练有两个目标:一个是原始的目标函数,为student模型输出的类别概率和label的交叉熵,记为hard-target;另一个是student模型输出的类别概率和teacher模型输出的类别概率的交叉熵,记为soft target,这两个loss加权后得到最终的训练loss,共同监督studuent模型的训练。
第二种是基于FSP的蒸馏方法(参考论文:<a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf">A Gift from Knowledge Distillation:
Fast Optimization, Network Minimization and Transfer Learning</a>
相比传统的蒸馏方法直接用小模型去拟合大模型的输出,该方法用小模型去拟合大模型不同层特征之间的转换关系,其用一个FSP矩阵(特征的内积)来表示不同层特征之间的关系,大模型和小模型不同层之间分别获得多个FSP矩阵,然后使用L2 loss让小模型的对应层FSP矩阵和大模型对应层的FSP矩阵尽量一致,具体如下图所示。这种方法的优势,通俗的解释是,比如将蒸馏类比成teacher(大模型)教student(小模型)解决一个问题,传统的蒸馏是直接告诉小模型问题的答案,让小模型学习,而学习FSP矩阵是让小模型学习解决问题的中间过程和方法,因此其学到的信息更多。</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/distillation_0.png" height=300 width=600 hspace='10'/> <br />
<strong>图9</strong>
</p>
<p>由于小模型和大模型之间通过L2 loss进行监督,必须保证两个FSP矩阵的维度必须相同,而FSP矩阵的维度为M*N,其中M、N分别为输入和输出特征的channel数,因此大模型和小模型的FSP矩阵需要一一对应。</p>
<h2 id="4">4. 轻量级模型结构搜索<a class="headerlink" href="#4" title="Permanent link">#</a></h2>
<p>深度学习模型在很多任务上都取得了不错的效果,网络结构的好坏对最终模型的效果有非常重要的影响。手工设计网络需要非常丰富的经验和众多尝试,并且众多的超参数和网络结构参数会产生爆炸性的组合,常规的random search几乎不可行,因此最近几年自动模型搜索技术(Neural Architecture Search)成为研究热点。区别于传统NAS,我们专注在搜索精度高并且速度快的模型结构,我们将该功能统称为Light-NAS.</p>
<h3 id="41">4.1 搜索策略<a class="headerlink" href="#41" title="Permanent link">#</a></h3>
<p>搜索策略定义了使用怎样的算法可以快速、准确找到最优的网络结构参数配置。常见的搜索方法包括:强化学习、贝叶斯优化、进化算法、基于梯度的算法。我们当前的实现以模拟退火算法为主。</p>
<h4 id="411">4.1.1 模拟退火<a class="headerlink" href="#411" title="Permanent link">#</a></h4>
<p>模拟退火算法来源于固体退火原理,将固体加温至充分高,再让其徐徐冷却,加温时,固体内部粒子随温升变为无序状,内能增大,而徐徐冷却时粒子渐趋有序,在每个温度都达到平衡态,最后在常温时达到基态,内能减为最小。</p>
<p>鉴于物理中固体物质的退火过程与一般组合优化问题之间的相似性,我们将其用于网络结构的搜索。</p>
<p>使用模拟退火算法搜索模型的过程如下:</p>
<div>
<div class="MathJax_Preview">
T_k = T_0*\theta^k
</div>
<script type="math/tex; mode=display">
T_k = T_0*\theta^k
</script>
</div>
<div>
<div class="MathJax_Preview">\begin{equation}
P(r_k) =
\begin{cases}
e^{\frac{(r_k-r)}{T_k}} &amp; r_k &lt; r\\
1 &amp; r_k&gt;=r
\end{cases}
\end{equation}</div>
<script type="math/tex; mode=display">\begin{equation}
P(r_k) =
\begin{cases}
e^{\frac{(r_k-r)}{T_k}} & r_k < r\\
1 & r_k>=r
\end{cases}
\end{equation}</script>
</div>
<p>在第k次迭代,搜到的网络为<span><span class="MathJax_Preview">N_k</span><script type="math/tex">N_k</script></span>, 对<span><span class="MathJax_Preview">N_k</span><script type="math/tex">N_k</script></span>训练若干epoch后,在测试集上得到reward为<span><span class="MathJax_Preview">r_k</span><script type="math/tex">r_k</script></span>, 以概率<span><span class="MathJax_Preview">P(r_k)</span><script type="math/tex">P(r_k)</script></span>接受<span><span class="MathJax_Preview">r_k</span><script type="math/tex">r_k</script></span>,即执行<span><span class="MathJax_Preview">r=r_k</span><script type="math/tex">r=r_k</script></span><span><span class="MathJax_Preview">r</span><script type="math/tex">r</script></span>在搜索过程起始时被初始化为0. <span><span class="MathJax_Preview">T_0</span><script type="math/tex">T_0</script></span>为初始化温度,<span><span class="MathJax_Preview">\theta</span><script type="math/tex">\theta</script></span>为温度衰减系数,<span><span class="MathJax_Preview">T_k</span><script type="math/tex">T_k</script></span>为第k次迭代的温度。</p>
<p>在我们的NAS任务中,区别于RL每次重新生成一个完整的网络,我们将网络结构映射成一段编码,第一次随机初始化,然后每次随机修改编码中的一部分(对应于网络结构的一部分)生成一个新的编码,然后将这个编码再映射回网络结构,通过在训练集上训练一定的epochs后的精度以及网络延时融合获得reward,来指导退火算法的收敛。</p>
<h3 id="42">4.2 搜索空间<a class="headerlink" href="#42" title="Permanent link">#</a></h3>
<p>搜索空间定义了优化问题的变量,变量规模决定了搜索算法的难度和搜索时间。因此为了加快搜索速度,定义一个合理的搜索空间至关重要。在Light-NAS中,为了加速搜索速度,我们将一个网络划分为多个block,先手动按链状层级结构堆叠c,再 使用搜索算法自动搜索每个block内部的结构。</p>
<p>因为要搜索出在移动端运行速度快的模型,我们参考了MobileNetV2中的Linear Bottlenecks和Inverted residuals结构,搜索每一个Inverted residuals中的具体参数,包括kernelsize、channel扩张倍数、重复次数、channels number。如图10所示:</p>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSlim/develop/docs/docs/images/algo/light-nas-block.png" height=300 width=600 hspace='10'/> <br />
<strong>图10</strong>
</p>
<h3 id="43">4.3 模型延时评估<a class="headerlink" href="#43" title="Permanent link">#</a></h3>
<p>搜索过程支持 FLOPS 约束和模型延时约束。而基于 Android/iOS 移动端、开发板等硬件平台,迭代搜索过程中不断测试模型的延时不仅消耗时间而且非常不方便,因此我们开发了模型延时评估器来评估搜索得到模型的延时。通过延时评估器评估得到的延时与模型实际测试的延时波动偏差小于 10%。</p>
<p>延时评估器分为配置硬件延时评估器和评估模型延时两个阶段,配置硬件延时评估器只需要执行一次,而评估模型延时则在搜索过程中不断评估搜索得到的模型延时。</p>
<ul>
<li>
<p>配置硬件延时评估器</p>
<ol>
<li>获取搜索空间中所有不重复的 op 及其参数</li>
<li>获取每组 op 及其参数的延时</li>
</ol>
</li>
<li>
<p>评估模型延时</p>
<ol>
<li>获取给定模型的所有 op 及其参数</li>
<li>根据给定模型的所有 op 及参数,利用延时评估器去估计模型的延时</li>
</ol>
</li>
</ul>
<h2 id="5">5. 参考文献<a class="headerlink" href="#5" title="Permanent link">#</a></h2>
<ol>
<li>
<p><a href="https://media.nips.cc/Conferences/2015/tutorialslides/Dally-NIPS-Tutorial-2015.pdf">High-Performance Hardware for Machine Learning</a></p>
</li>
<li>
<p><a href="https://arxiv.org/pdf/1806.08342.pdf">Quantizing deep convolutional networks for efficient inference: A whitepaper</a></p>
</li>
<li>
<p><a href="https://arxiv.org/pdf/1608.08710.pdf">Pruning Filters for Efficient ConvNets</a></p>
</li>
<li>
<p><a href="https://arxiv.org/pdf/1503.02531.pdf">Distilling the Knowledge in a Neural Network</a></p>
</li>
<li>
<p><a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf">A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning</a></p>
</li>
</ol>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../table_latency/" class="btn btn-neutral" title="硬件延时评估表"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" style="cursor: pointer">
<span class="rst-current-version" data-toggle="rst-current-version">
<a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
<span><a href="../../table_latency/" style="color: #fcfcfc;">&laquo; Previous</a></span>
</span>
</div>
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
</html>
......@@ -8,7 +8,7 @@
<link rel="shortcut icon" href="../../img/favicon.ico">
<title>敏感度分析 - PaddleSlim Docs</title>
<title>模型分析 - PaddleSlim Docs</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
......@@ -17,7 +17,7 @@
<script>
// Current page data
var mkdocs_page_name = "\u654f\u611f\u5ea6\u5206\u6790";
var mkdocs_page_name = "\u6a21\u578b\u5206\u6790";
var mkdocs_page_input_path = "api/analysis_api.md";
var mkdocs_page_url = null;
</script>
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,31 +90,27 @@
</li>
<li class="">
<a class="" href="../prune_api/">剪枝</a>
<a class="" href="../prune_api/">剪枝与敏感度</a>
</li>
<li class=" current">
<a class="current" href="./">敏感度分析</a>
<a class="current" href="./">模型分析</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#api">模型分析API文档</a></li>
<li class="toctree-l3"><a href="#flops">FLOPs</a></li>
<ul>
<li><a class="toctree-l4" href="#flops">flops</a></li>
<li class="toctree-l3"><a href="#model_size">model_size</a></li>
<li><a class="toctree-l4" href="#model_size">model_size</a></li>
<li><a class="toctree-l4" href="#tablelatencyevaluator">TableLatencyEvaluator</a></li>
</ul>
<li class="toctree-l3"><a href="#tablelatencyevaluator">TableLatencyEvaluator</a></li>
</ul>
</li>
<li class="">
<a class="" href="../single_distiller_api/">蒸馏</a>
<a class="" href="../single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -120,9 +120,18 @@
<a class="" href="../search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -149,7 +158,7 @@
<li>敏感度分析</li>
<li>模型分析</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/api/analysis_api.md"
......@@ -162,28 +171,32 @@
<div role="main">
<div class="section">
<h1 id="api">模型分析API文档<a class="headerlink" href="#api" title="Permanent link">#</a></h1>
<h2 id="flops">flops<a class="headerlink" href="#flops" title="Permanent link">#</a></h2>
<blockquote>
<p>paddleslim.analysis.flops(program, detail=False) <a href="">源代码</a></p>
</blockquote>
<p>获得指定网络的每秒浮点运算次数(FLOPS)。</p>
<h2 id="flops">FLOPs<a class="headerlink" href="#flops" title="Permanent link">#</a></h2>
<dl>
<dt>paddleslim.analysis.flops(program, detail=False) <a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/analysis/flops.py">源代码</a></dt>
<dd>
<p>获得指定网络的浮点运算次数(FLOPs)。</p>
</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li>
<p><strong>program(paddle.fluid.Program):</strong> 待分析的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></p>
<p><strong>program(paddle.fluid.Program)</strong> - 待分析的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></p>
</li>
<li>
<p><strong>detail(bool)</strong> - 是否返回每个卷积层的FLOPs。默认为False。</p>
</li>
<li>
<p><strong>detail(bool):</strong> 是否返回每个卷积层的FLOPS。默认为False</p>
<p><strong>only_conv(bool)</strong> - 如果设置为True,则仅计算卷积层和全连接层的FLOPs,即浮点数的乘加(multiplication-adds)操作次数。如果设置为False,则也会计算卷积和全连接层之外的操作的FLOPs</p>
</li>
</ul>
<p><strong>返回值:</strong></p>
<ul>
<li>
<p><strong>flops(float):</strong> 整个网络的FLOPS</p>
<p><strong>flops(float)</strong> - 整个网络的FLOPs</p>
</li>
<li>
<p><strong>params2flops(dict):</strong> 每层卷积对应的FLOPS,其中key为卷积层参数名称,value为FLOPS值。</p>
<p><strong>params2flops(dict)</strong> - 每层卷积对应的FLOPs,其中key为卷积层参数名称,value为FLOPs值。</p>
</li>
</ul>
<p><strong>示例:</strong></p>
......@@ -291,22 +304,20 @@
<span class="n">conv5</span> <span class="o">=</span> <span class="n">conv_bn_layer</span><span class="p">(</span><span class="n">sum2</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">&quot;conv5&quot;</span><span class="p">)</span>
<span class="n">conv6</span> <span class="o">=</span> <span class="n">conv_bn_layer</span><span class="p">(</span><span class="n">conv5</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">&quot;conv6&quot;</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;FLOPS: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">flops</span><span class="p">(</span><span class="n">main_program</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;FLOPs: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">flops</span><span class="p">(</span><span class="n">main_program</span><span class="p">)))</span>
</pre></div>
</td></tr></table>
<h2 id="model_size">model_size<a class="headerlink" href="#model_size" title="Permanent link">#</a></h2>
<blockquote>
<p>paddleslim.analysis.model_size(program) <a href="">源代码</a></p>
</blockquote>
<p>paddleslim.analysis.model_size(program) <a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/analysis/model_size.py">源代码</a></p>
<p>获得指定网络的参数数量。</p>
<p><strong>参数:</strong></p>
<ul>
<li><strong>program(paddle.fluid.Program):</strong> 待分析的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></li>
<li><strong>program(paddle.fluid.Program)</strong> - 待分析的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></li>
</ul>
<p><strong>返回值:</strong></p>
<ul>
<li><strong>model_size(int):</strong> 整个网络的参数数量。</li>
<li><strong>model_size(int)</strong> - 整个网络的参数数量。</li>
</ul>
<p><strong>示例:</strong></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
......@@ -397,39 +408,43 @@
<span class="n">conv5</span> <span class="o">=</span> <span class="n">conv_layer</span><span class="p">(</span><span class="n">sum2</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">&quot;conv5&quot;</span><span class="p">)</span>
<span class="n">conv6</span> <span class="o">=</span> <span class="n">conv_layer</span><span class="p">(</span><span class="n">conv5</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">&quot;conv6&quot;</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;FLOPS: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">model_size</span><span class="p">(</span><span class="n">main_program</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;FLOPs: {}&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">model_size</span><span class="p">(</span><span class="n">main_program</span><span class="p">)))</span>
</pre></div>
</td></tr></table>
<h2 id="tablelatencyevaluator">TableLatencyEvaluator<a class="headerlink" href="#tablelatencyevaluator" title="Permanent link">#</a></h2>
<blockquote>
<p>paddleslim.analysis.TableLatencyEvaluator(table_file, delimiter=",") <a href="">源代码</a></p>
</blockquote>
<dl>
<dt>paddleslim.analysis.TableLatencyEvaluator(table_file, delimiter=",") <a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/analysis/latency.py">源代码</a></dt>
<dd>
<p>基于硬件延时表的模型延时评估器。</p>
</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li>
<p><strong>table_file(str):</strong> 所使用的延时评估表的绝对路径。关于演示评估表格式请参考:<a href="../paddleslim/analysis/table_latency.md">PaddleSlim硬件延时评估表格式</a></p>
<p><strong>table_file(str)</strong> - 所使用的延时评估表的绝对路径。关于演示评估表格式请参考:<a href="../paddleslim/analysis/table_latency.md">PaddleSlim硬件延时评估表格式</a></p>
</li>
<li>
<p><strong>delimiter(str):</strong> 硬件延时评估表中,操作信息之前所使用的分割符,默认为英文字符逗号。</p>
<p><strong>delimiter(str)</strong> - 硬件延时评估表中,操作信息之前所使用的分割符,默认为英文字符逗号。</p>
</li>
</ul>
<p><strong>返回值:</strong></p>
<ul>
<li><strong>Evaluator:</strong> 硬件延时评估器的实例。</li>
<li><strong>Evaluator</strong> - 硬件延时评估器的实例。</li>
</ul>
<blockquote>
<p>paddleslim.analysis.TableLatencyEvaluator.latency(graph) <a href="">源代码</a></p>
</blockquote>
<dl>
<dt>paddleslim.analysis.TableLatencyEvaluator.latency(graph) <a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/analysis/latency.py">源代码</a></dt>
<dd>
<p>获得指定网络的预估延时。</p>
</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li><strong>graph(Program):</strong> 待预估的目标网络。</li>
<li><strong>graph(Program)</strong> - 待预估的目标网络。</li>
</ul>
<p><strong>返回值:</strong></p>
<ul>
<li><strong>latency:</strong> 目标网络的预估延时。</li>
<li><strong>latency</strong> - 目标网络的预估延时。</li>
</ul>
</div>
......@@ -438,10 +453,10 @@
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../single_distiller_api/" class="btn btn-neutral float-right" title="蒸馏">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../single_distiller_api/" class="btn btn-neutral float-right" title="知识蒸馏">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../prune_api/" class="btn btn-neutral" title="剪枝"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../prune_api/" class="btn btn-neutral" title="剪枝与敏感度"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
......@@ -479,7 +494,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,15 +90,15 @@
</li>
<li class="">
<a class="" href="../prune_api/">剪枝</a>
<a class="" href="../prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../analysis_api/">敏感度分析</a>
<a class="" href="../analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../single_distiller_api/">蒸馏</a>
<a class="" href="../single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -104,9 +108,18 @@
<a class="" href="../search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -206,7 +219,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,15 +90,15 @@
</li>
<li class="">
<a class="" href="../prune_api/">剪枝</a>
<a class="" href="../prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../analysis_api/">敏感度分析</a>
<a class="" href="../analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../single_distiller_api/">蒸馏</a>
<a class="" href="../single_distiller_api/">知识蒸馏</a>
</li>
<li class=" current">
......@@ -118,9 +122,18 @@
<a class="" href="../search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -446,7 +459,7 @@
<a href="../search_space/" class="btn btn-neutral float-right" title="搜索空间">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../single_distiller_api/" class="btn btn-neutral" title="蒸馏"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../single_distiller_api/" class="btn btn-neutral" title="知识蒸馏"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
......@@ -484,7 +497,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -8,7 +8,7 @@
<link rel="shortcut icon" href="../../img/favicon.ico">
<title>剪枝 - PaddleSlim Docs</title>
<title>剪枝与敏感度 - PaddleSlim Docs</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
......@@ -17,7 +17,7 @@
<script>
// Current page data
var mkdocs_page_name = "\u526a\u679d";
var mkdocs_page_name = "\u526a\u679d\u4e0e\u654f\u611f\u5ea6";
var mkdocs_page_input_path = "api/prune_api.md";
var mkdocs_page_url = null;
</script>
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,35 +90,33 @@
</li>
<li class=" current">
<a class="current" href="./">剪枝</a>
<a class="current" href="./">剪枝与敏感度</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#api">卷积通道剪裁API文档</a></li>
<li class="toctree-l3"><a href="#pruner">Pruner</a></li>
<ul>
<li><a class="toctree-l4" href="#class-pruner">class Pruner</a></li>
<li class="toctree-l3"><a href="#sensitivity">sensitivity</a></li>
<li><a class="toctree-l4" href="#sensitivity">sensitivity</a></li>
<li><a class="toctree-l4" href="#merge_sensitive">merge_sensitive</a></li>
<li class="toctree-l3"><a href="#merge_sensitive">merge_sensitive</a></li>
<li><a class="toctree-l4" href="#load_sensitivities">load_sensitivities</a></li>
<li><a class="toctree-l4" href="#get_ratios_by_losssensitivities-loss">get_ratios_by_loss(sensitivities, loss)</a></li>
<li class="toctree-l3"><a href="#load_sensitivities">load_sensitivities</a></li>
</ul>
<li class="toctree-l3"><a href="#get_ratios_by_loss">get_ratios_by_loss</a></li>
</ul>
</li>
<li class="">
<a class="" href="../analysis_api/">敏感度分析</a>
<a class="" href="../analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../single_distiller_api/">蒸馏</a>
<a class="" href="../single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -124,9 +126,18 @@
<a class="" href="../search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -153,7 +164,7 @@
<li>剪枝</li>
<li>剪枝与敏感度</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/api/prune_api.md"
......@@ -166,16 +177,16 @@
<div role="main">
<div class="section">
<h1 id="api">卷积通道剪裁API文档<a class="headerlink" href="#api" title="Permanent link">#</a></h1>
<h2 id="class-pruner">class Pruner<a class="headerlink" href="#class-pruner" title="Permanent link">#</a></h2>
<hr />
<blockquote>
<p>paddleslim.prune.Pruner(criterion="l1_norm")<a href="">源代码</a></p>
</blockquote>
<h2 id="pruner">Pruner<a class="headerlink" href="#pruner" title="Permanent link">#</a></h2>
<dl>
<dt>paddleslim.prune.Pruner(criterion="l1_norm")<a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/pruner.py#L28">源代码</a></dt>
<dd>
<p>对卷积网络的通道进行一次剪裁。剪裁一个卷积层的通道,是指剪裁该卷积层输出的通道。卷积层的权重形状为<code>[output_channel, input_channel, kernel_size, kernel_size]</code>,通过剪裁该权重的第一纬度达到剪裁输出通道数的目的。</p>
</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li><strong>criterion:</strong> 评估一个卷积层内通道重要性所参考的指标。目前仅支持<code>l1_norm</code>。默认为<code>l1_norm</code></li>
<li><strong>criterion</strong> - 评估一个卷积层内通道重要性所参考的指标。目前仅支持<code>l1_norm</code>。默认为<code>l1_norm</code></li>
</ul>
<p><strong>返回:</strong> 一个Pruner类的实例</p>
<p><strong>示例代码:</strong></p>
......@@ -185,21 +196,22 @@
</pre></div>
</td></tr></table>
<hr />
<blockquote>
<p>prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=False, param_shape_backup=False)</p>
</blockquote>
<dl>
<dt>paddleslim.prune.Pruner.prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=False, param_shape_backup=False)<a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/pruner.py#L36">源代码</a></dt>
<dd>
<p>对目标网络的一组卷积层的权重进行裁剪。</p>
</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li>
<p><strong>program(paddle.fluid.Program):</strong> 要裁剪的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></p>
<p><strong>program(paddle.fluid.Program)</strong> - 要裁剪的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></p>
</li>
<li>
<p><strong>scope(paddle.fluid.Scope):</strong> 要裁剪的权重所在的<code>scope</code>,Paddle中用<code>scope</code>实例存放模型参数和运行时变量的值。Scope中的参数值会被<code>inplace</code>的裁剪。更多介绍请参考<a href="">Scope概念介绍</a></p>
<p><strong>scope(paddle.fluid.Scope)</strong> - 要裁剪的权重所在的<code>scope</code>,Paddle中用<code>scope</code>实例存放模型参数和运行时变量的值。Scope中的参数值会被<code>inplace</code>的裁剪。更多介绍请参考<a href="">Scope概念介绍</a></p>
</li>
<li>
<p><strong>params(list<str>):</strong> 需要被裁剪的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:
<p><strong>params(list<str>)</strong> - 需要被裁剪的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="k">for</span> <span class="nv">block</span> <span class="nv">in</span> <span class="nv">program</span>.<span class="nv">blocks</span>:
......@@ -209,34 +221,34 @@
</td></tr></table></p>
</li>
<li>
<p><strong>ratios(list<float>):</strong> 用于裁剪<code>params</code>的剪切率,类型为列表。该列表长度必须与<code>params</code>的长度一致。</p>
<p><strong>ratios(list<float>)</strong> - 用于裁剪<code>params</code>的剪切率,类型为列表。该列表长度必须与<code>params</code>的长度一致。</p>
</li>
<li>
<p><strong>place(paddle.fluid.Place):</strong> 待裁剪参数所在的设备位置,可以是<code>CUDAPlace</code><code>CPUPlace</code><a href="">Place概念介绍</a></p>
<p><strong>place(paddle.fluid.Place)</strong> - 待裁剪参数所在的设备位置,可以是<code>CUDAPlace</code><code>CPUPlace</code><a href="">Place概念介绍</a></p>
</li>
<li>
<p><strong>lazy(bool):</strong> <code>lazy</code>为True时,通过将指定通道的参数置零达到裁剪的目的,参数的<code>shape保持不变</code><code>lazy</code>为False时,直接将要裁的通道的参数删除,参数的<code>shape</code>会发生变化。</p>
<p><strong>lazy(bool)</strong> - <code>lazy</code>为True时,通过将指定通道的参数置零达到裁剪的目的,参数的<code>shape保持不变</code><code>lazy</code>为False时,直接将要裁的通道的参数删除,参数的<code>shape</code>会发生变化。</p>
</li>
<li>
<p><strong>only_graph(bool):</strong> 是否只裁剪网络结构。在Paddle中,Program定义了网络结构,Scope存储参数的数值。一个Scope实例可以被多个Program使用,比如定义了训练网络的Program和定义了测试网络的Program是使用同一个Scope实例的。<code>only_graph</code>为True时,只对Program中定义的卷积的通道进行剪裁;<code>only_graph</code>为false时,Scope中卷积参数的数值也会被剪裁。默认为False。</p>
<p><strong>only_graph(bool)</strong> - 是否只裁剪网络结构。在Paddle中,Program定义了网络结构,Scope存储参数的数值。一个Scope实例可以被多个Program使用,比如定义了训练网络的Program和定义了测试网络的Program是使用同一个Scope实例的。<code>only_graph</code>为True时,只对Program中定义的卷积的通道进行剪裁;<code>only_graph</code>为false时,Scope中卷积参数的数值也会被剪裁。默认为False。</p>
</li>
<li>
<p><strong>param_backup(bool):</strong> 是否返回对参数值的备份。默认为False。</p>
<p><strong>param_backup(bool)</strong> - 是否返回对参数值的备份。默认为False。</p>
</li>
<li>
<p><strong>param_shape_backup(bool):</strong> 是否返回对参数<code>shape</code>的备份。默认为False。</p>
<p><strong>param_shape_backup(bool)</strong> - 是否返回对参数<code>shape</code>的备份。默认为False。</p>
</li>
</ul>
<p><strong>返回:</strong></p>
<ul>
<li>
<p><strong>pruned_program(paddle.fluid.Program):</strong> 被裁剪后的Program。</p>
<p><strong>pruned_program(paddle.fluid.Program)</strong> - 被裁剪后的Program。</p>
</li>
<li>
<p><strong>param_backup(dict):</strong> 对参数数值的备份,用于恢复Scope中的参数数值。</p>
<p><strong>param_backup(dict)</strong> - 对参数数值的备份,用于恢复Scope中的参数数值。</p>
</li>
<li>
<p><strong>param_shape_backup(dict):</strong> 对参数形状的备份。</p>
<p><strong>param_shape_backup(dict)</strong> - 对参数形状的备份。</p>
</li>
</ul>
<p><strong>示例:</strong></p>
......@@ -386,20 +398,22 @@
</td></tr></table></p>
<hr />
<h2 id="sensitivity">sensitivity<a class="headerlink" href="#sensitivity" title="Permanent link">#</a></h2>
<blockquote>
<p>paddleslim.prune.sensitivity(program, place, param_names, eval_func, sensitivities_file=None, pruned_ratios=None) <a href="">源代码</a></p>
</blockquote>
<dl>
<dt>paddleslim.prune.sensitivity(program, place, param_names, eval_func, sensitivities_file=None, pruned_ratios=None) <a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L34">源代码</a></dt>
<dd>
<p>计算网络中每个卷积层的敏感度。每个卷积层的敏感度信息统计方法为:依次剪掉当前卷积层不同比例的输出通道数,在测试集上计算剪裁后的精度损失。得到敏感度信息后,可以通过观察或其它方式确定每层卷积的剪裁率。</p>
</dd>
</dl>
<p><strong>参数:</strong></p>
<ul>
<li>
<p><strong>program(paddle.fluid.Program):</strong> 待评估的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></p>
<p><strong>program(paddle.fluid.Program)</strong> - 待评估的目标网络。更多关于Program的介绍请参考:<a href="https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program">Program概念介绍</a></p>
</li>
<li>
<p><strong>place(paddle.fluid.Place):</strong> 待分析的参数所在的设备位置,可以是<code>CUDAPlace</code><code>CPUPlace</code><a href="">Place概念介绍</a></p>
<p><strong>place(paddle.fluid.Place)</strong> - 待分析的参数所在的设备位置,可以是<code>CUDAPlace</code><code>CPUPlace</code><a href="">Place概念介绍</a></p>
</li>
<li>
<p><strong>param_names(list<str>):</strong> 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:</p>
<p><strong>param_names(list<str>)</strong> - 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:</p>
</li>
</ul>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
......@@ -412,18 +426,18 @@
<ul>
<li>
<p><strong>eval_func(function):</strong> 用于评估裁剪后模型效果的回调函数。该回调函数接受被裁剪后的<code>program</code>为参数,返回一个表示当前program的精度,用以计算当前裁剪带来的精度损失。</p>
<p><strong>eval_func(function)</strong> - 用于评估裁剪后模型效果的回调函数。该回调函数接受被裁剪后的<code>program</code>为参数,返回一个表示当前program的精度,用以计算当前裁剪带来的精度损失。</p>
</li>
<li>
<p><strong>sensitivities_file(str):</strong> 保存敏感度信息的本地文件系统的文件。在敏感度计算过程中,会持续将新计算出的敏感度信息追加到该文件中。重启任务后,文件中已有敏感度信息不会被重复计算。该文件可以用<code>pickle</code>加载。</p>
<p><strong>sensitivities_file(str)</strong> - 保存敏感度信息的本地文件系统的文件。在敏感度计算过程中,会持续将新计算出的敏感度信息追加到该文件中。重启任务后,文件中已有敏感度信息不会被重复计算。该文件可以用<code>pickle</code>加载。</p>
</li>
<li>
<p><strong>pruned_ratios(list<float>):</strong> 计算卷积层敏感度信息时,依次剪掉的通道数比例。默认为[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]。</p>
<p><strong>pruned_ratios(list<float>)</strong> - 计算卷积层敏感度信息时,依次剪掉的通道数比例。默认为[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]。</p>
</li>
</ul>
<p><strong>返回:</strong></p>
<ul>
<li><strong>sensitivities(dict):</strong> 存放敏感度信息的dict,其格式为:</li>
<li><strong>sensitivities(dict)</strong> - 存放敏感度信息的dict,其格式为:</li>
</ul>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
......@@ -633,17 +647,19 @@
</td></tr></table>
<h2 id="merge_sensitive">merge_sensitive<a class="headerlink" href="#merge_sensitive" title="Permanent link">#</a></h2>
<blockquote>
<p>merge_sensitive(sensitivities)</p>
</blockquote>
<dl>
<dt>paddleslim.prune.merge_sensitive(sensitivities)<a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L161">源代码</a></dt>
<dd>
<p>合并多个敏感度信息。</p>
</dd>
</dl>
<p>参数:</p>
<ul>
<li><strong>sensitivities(list<dict> | list<str>):</strong> 待合并的敏感度信息,可以是字典的列表,或者是存放敏感度信息的文件的路径列表。</li>
<li><strong>sensitivities(list<dict> | list<str>)</strong> - 待合并的敏感度信息,可以是字典的列表,或者是存放敏感度信息的文件的路径列表。</li>
</ul>
<p>返回:</p>
<ul>
<li><strong>sensitivities(dict)</strong> 合并后的敏感度信息。其格式为:</li>
<li><strong>sensitivities(dict)</strong> - 合并后的敏感度信息。其格式为:</li>
</ul>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
......@@ -668,38 +684,41 @@
<p>其中,<code>weight_0</code>是卷积层参数的名称,sensitivities['weight_0']的<code>value</code>为剪裁比例,<code>value</code>为精度损失的比例。</p>
<p>示例:</p>
<h2 id="load_sensitivities">load_sensitivities<a class="headerlink" href="#load_sensitivities" title="Permanent link">#</a></h2>
<blockquote>
<p>load_sensitivities(sensitivities_file)</p>
</blockquote>
<dl>
<dt>paddleslim.prune.load_sensitivities(sensitivities_file)<a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L184">源代码</a></dt>
<dd>
<p>从文件中加载敏感度信息。</p>
</dd>
</dl>
<p>参数:</p>
<ul>
<li><strong>sensitivities_file(str):</strong> 存放敏感度信息的本地文件.</li>
<li><strong>sensitivities_file(str)</strong> - 存放敏感度信息的本地文件.</li>
</ul>
<p>返回:</p>
<ul>
<li>**sensitivities(dict)**敏感度信息。</li>
<li><strong>sensitivities(dict)</strong> - 敏感度信息。</li>
</ul>
<p>示例:</p>
<h2 id="get_ratios_by_losssensitivities-loss">get_ratios_by_loss(sensitivities, loss)<a class="headerlink" href="#get_ratios_by_losssensitivities-loss" title="Permanent link">#</a></h2>
<h2 id="get_ratios_by_loss">get_ratios_by_loss<a class="headerlink" href="#get_ratios_by_loss" title="Permanent link">#</a></h2>
<dl>
<dt>paddleslim.prune.get_ratios_by_loss(sensitivities, loss)<a href="https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L206">源代码</a></dt>
<dd>
<p>根据敏感度和精度损失阈值计算出一组剪切率。对于参数<code>w</code>, 其剪裁率为使精度损失低于<code>loss</code>的最大剪裁率。</p>
</dd>
</dl>
<p>参数:</p>
<ul>
<li>
<p><strong>sensitivities(dict):</strong> 敏感度信息。</p>
<p><strong>sensitivities(dict)</strong> - 敏感度信息。</p>
</li>
<li>
<p><strong>loss:</strong> 精度损失阈值。</p>
<p><strong>loss</strong> - 精度损失阈值。</p>
</li>
</ul>
<p>返回:</p>
<ul>
<li>ratios(dict): 一组剪切率。<code>key</code>是待剪裁参数的名称。<code>value</code>是对应参数的剪裁率。</li>
<li><strong>ratios(dict)</strong> - 一组剪切率。<code>key</code>是待剪裁参数的名称。<code>value</code>是对应参数的剪裁率。</li>
</ul>
<p>示例:</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span>
</pre></div>
</td></tr></table>
</div>
</div>
......@@ -707,7 +726,7 @@
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../analysis_api/" class="btn btn-neutral float-right" title="敏感度分析">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../analysis_api/" class="btn btn-neutral float-right" title="模型分析">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../quantization_api/" class="btn btn-neutral" title="量化"><span class="icon icon-circle-arrow-left"></span> Previous</a>
......@@ -748,7 +767,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -102,15 +106,15 @@
</li>
<li class="">
<a class="" href="../prune_api/">剪枝</a>
<a class="" href="../prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../analysis_api/">敏感度分析</a>
<a class="" href="../analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../single_distiller_api/">蒸馏</a>
<a class="" href="../single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -120,9 +124,18 @@
<a class="" href="../search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -491,10 +504,10 @@
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../prune_api/" class="btn btn-neutral float-right" title="剪枝">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../prune_api/" class="btn btn-neutral float-right" title="剪枝与敏感度">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../../tutorials/nas_demo/" class="btn btn-neutral" title="SA搜索"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../../tutorials/distillation_demo/" class="btn btn-neutral" title="知识蒸馏"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
......@@ -522,7 +535,7 @@
<a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
<span><a href="../../tutorials/nas_demo/" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span><a href="../../tutorials/distillation_demo/" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../prune_api/" style="color: #fcfcfc">Next &raquo;</a></span>
......@@ -532,7 +545,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,15 +90,15 @@
</li>
<li class="">
<a class="" href="../prune_api/">剪枝</a>
<a class="" href="../prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../analysis_api/">敏感度分析</a>
<a class="" href="../analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../single_distiller_api/">蒸馏</a>
<a class="" href="../single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -121,9 +125,18 @@
</ul>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -329,6 +342,8 @@
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../table_latency/" class="btn btn-neutral float-right" title="硬件延时评估表">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../nas_api/" class="btn btn-neutral" title="SA搜索"><span class="icon icon-circle-arrow-left"></span> Previous</a>
......@@ -361,12 +376,14 @@
<span><a href="../nas_api/" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../../table_latency/" style="color: #fcfcfc">Next &raquo;</a></span>
</span>
</div>
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -8,7 +8,7 @@
<link rel="shortcut icon" href="../../img/favicon.ico">
<title>蒸馏 - PaddleSlim Docs</title>
<title>知识蒸馏 - PaddleSlim Docs</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
......@@ -17,7 +17,7 @@
<script>
// Current page data
var mkdocs_page_name = "\u84b8\u998f";
var mkdocs_page_name = "\u77e5\u8bc6\u84b8\u998f";
var mkdocs_page_input_path = "api/single_distiller_api.md";
var mkdocs_page_url = null;
</script>
......@@ -73,6 +73,10 @@
<a class="" href="../../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,15 +90,15 @@
</li>
<li class="">
<a class="" href="../prune_api/">剪枝</a>
<a class="" href="../prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../analysis_api/">敏感度分析</a>
<a class="" href="../analysis_api/">模型分析</a>
</li>
<li class=" current">
<a class="current" href="./">蒸馏</a>
<a class="current" href="./">知识蒸馏</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#merge">merge</a></li>
......@@ -122,9 +126,18 @@
<a class="" href="../search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -151,7 +164,7 @@
<li>蒸馏</li>
<li>知识蒸馏</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/api/single_distiller_api.md"
......@@ -484,7 +497,7 @@
<a href="../nas_api/" class="btn btn-neutral float-right" title="SA搜索">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../analysis_api/" class="btn btn-neutral" title="敏感度分析"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../analysis_api/" class="btn btn-neutral" title="模型分析"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
......@@ -522,7 +535,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -91,6 +91,10 @@
<a class="" href="tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -104,15 +108,15 @@
</li>
<li class="">
<a class="" href="api/prune_api/">剪枝</a>
<a class="" href="api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="api/analysis_api/">敏感度分析</a>
<a class="" href="api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="api/single_distiller_api/">蒸馏</a>
<a class="" href="api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -122,9 +126,18 @@
<a class="" href="api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -267,7 +280,7 @@
<script>var base_url = '.';</script>
<script src="js/theme.js" defer></script>
<script src="mathjax-config.js" defer></script>
<script src="MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="search/main.js" defer></script>
</body>
......@@ -275,5 +288,5 @@
<!--
MkDocs version : 1.0.4
Build Date UTC : 2019-12-23 12:36:00
Build Date UTC : 2019-12-30 07:49:13
-->
......@@ -66,6 +66,10 @@
<a class="" href="./tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="./tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -79,15 +83,15 @@
</li>
<li class="">
<a class="" href="./api/prune_api/">剪枝</a>
<a class="" href="./api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="./api/analysis_api/">敏感度分析</a>
<a class="" href="./api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="./api/single_distiller_api/">蒸馏</a>
<a class="" href="./api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -97,9 +101,18 @@
<a class="" href="./api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="./table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="./algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -177,7 +190,7 @@
<script>var base_url = '.';</script>
<script src="./js/theme.js" defer></script>
<script src="./mathjax-config.js" defer></script>
<script src="./MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="./search/main.js" defer></script>
</body>
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -2,57 +2,72 @@
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-23</lastmod>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>None</loc>
<lastmod>2019-12-30</lastmod>
<changefreq>daily</changefreq>
</url>
</urlset>
\ No newline at end of file
无法预览此类型文件
......@@ -73,6 +73,10 @@
<a class="" href="../tutorials/nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../tutorials/distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,15 +90,15 @@
</li>
<li class="">
<a class="" href="../api/prune_api/">剪枝</a>
<a class="" href="../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../api/analysis_api/">敏感度分析</a>
<a class="" href="../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../api/single_distiller_api/">蒸馏</a>
<a class="" href="../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -104,9 +108,36 @@
<a class="" href="../api/search_space/">搜索空间</a>
</li>
<li class=" current">
<a class="current" href="./">硬件延时评估表</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#_1">硬件延时评估表</a></li>
<ul>
<li><a class="toctree-l4" href="#_2">概述</a></li>
<li><a class="toctree-l4" href="#_3">整体格式</a></li>
<li><a class="toctree-l4" href="#_4">版本信息</a></li>
<li><a class="toctree-l4" href="#_5">操作信息</a></li>
</ul>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -129,6 +160,10 @@
<li>API &raquo;</li>
<li>硬件延时评估表</li>
<li class="wy-breadcrumbs-aside">
......@@ -213,7 +248,7 @@
<p><strong>字段解释</strong></p>
<ul>
<li><strong>op_type(str)</strong> - 当前op类型。</li>
<li><strong>active_type (string)</strong> - 激活函数类型,包含:relu, prelu, sigmoid, relu6, tanh。</li>
<li><strong>active_type (string|None)</strong> - 激活函数类型,包含:relu, prelu, sigmoid, relu6, tanh。</li>
<li><strong>n_in (int)</strong> - 输入 Tensor 的批尺寸 (batch size)。</li>
<li><strong>c_in (int)</strong> - 输入 Tensor 的通道 (channel) 数。</li>
<li><strong>h_in (int)</strong> - 输入 Tensor 的特征高度。</li>
......@@ -277,6 +312,15 @@
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../algo/algo/" class="btn btn-neutral float-right" title="算法原理">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../api/search_space/" class="btn btn-neutral" title="搜索空间"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
<hr/>
......@@ -301,13 +345,17 @@
<a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
<span><a href="../api/search_space/" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../algo/algo/" style="color: #fcfcfc">Next &raquo;</a></span>
</span>
</div>
<script>var base_url = '..';</script>
<script src="../js/theme.js" defer></script>
<script src="../mathjax-config.js" defer></script>
<script src="../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../search/main.js" defer></script>
</body>
......
......@@ -73,6 +73,10 @@
<a class="" href="../nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -86,15 +90,15 @@
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝</a>
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">敏感度分析</a>
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">蒸馏</a>
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -104,9 +108,18 @@
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -187,7 +200,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="../../img/favicon.ico">
<title>知识蒸馏 - PaddleSlim Docs</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<script>
// Current page data
var mkdocs_page_name = "\u77e5\u8bc6\u84b8\u998f";
var mkdocs_page_input_path = "tutorials/distillation_demo.md";
var mkdocs_page_url = null;
</script>
<script src="../../js/jquery-2.1.1.min.js" defer></script>
<script src="../../js/modernizr-2.8.3.min.js" defer></script>
<script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
<script>hljs.initHighlightingOnLoad();</script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-nav-search">
<a href="../.." class="icon icon-home"> PaddleSlim Docs</a>
<div role="search">
<form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" title="Type search term here" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul class="current">
<li class="toctree-l1">
<a class="" href="../..">Home</a>
</li>
<li class="toctree-l1">
<span class="caption-text">教程</span>
<ul class="subnav">
<li class="">
<a class="" href="../quant_post_demo/">离线量化</a>
</li>
<li class="">
<a class="" href="../quant_aware_demo/">量化训练</a>
</li>
<li class="">
<a class="" href="../quant_embedding_demo/">Embedding量化</a>
</li>
<li class="">
<a class="" href="../nas_demo/">SA搜索</a>
</li>
<li class=" current">
<a class="current" href="./">知识蒸馏</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#_1">接口介绍</a></li>
<li class="toctree-l3"><a href="#paddleslim">PaddleSlim蒸馏训练流程</a></li>
<ul>
<li><a class="toctree-l4" href="#1-student_program">1. 定义student_program</a></li>
<li><a class="toctree-l4" href="#2-teacher_program">2. 定义teacher_program</a></li>
<li><a class="toctree-l4" href="#3">3.选择特征图</a></li>
<li><a class="toctree-l4" href="#4-merge">4. 合并特征图(merge)</a></li>
<li><a class="toctree-l4" href="#5loss">5.添加蒸馏loss</a></li>
</ul>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1">
<span class="caption-text">API</span>
<ul class="subnav">
<li class="">
<a class="" href="../../api/quantization_api/">量化</a>
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
<a class="" href="../../api/nas_api/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../..">PaddleSlim Docs</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../..">Docs</a> &raquo;</li>
<li>教程 &raquo;</li>
<li>知识蒸馏</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/tutorials/distillation_demo.md"
class="icon icon-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main">
<div class="section">
<p>本示例将介绍如何使用PaddleSlim蒸馏接口来对模型进行蒸馏训练</p>
<h2 id="_1">接口介绍<a class="headerlink" href="#_1" title="Permanent link">#</a></h2>
<p>请参考<a href="https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/">蒸馏API文档</a></p>
<h2 id="paddleslim">PaddleSlim蒸馏训练流程<a class="headerlink" href="#paddleslim" title="Permanent link">#</a></h2>
<p>一般情况下,模型参数量越多,结构越复杂,其性能越好,但运算量和资源消耗也越大。<strong>知识蒸馏</strong> 就是一种将大模型学习到的有用信息(Dark Knowledge)压缩进更小更快的模型,而获得可以匹敌大模型结果的方法。</p>
<p>在本示例中精度较高的大模型被称为teacher,精度稍逊但速度更快的小模型被称为student。</p>
<h3 id="1-student_program">1. 定义student_program<a class="headerlink" href="#1-student_program" title="Permanent link">#</a></h3>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">student_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">student_startup</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">student_program</span><span class="p">,</span> <span class="n">student_startup</span><span class="p">):</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s1">&#39;image&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;label&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;int64&#39;</span><span class="p">)</span>
<span class="c1"># student model definition</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MobileNet</span><span class="p">()</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">image</span><span class="p">,</span> <span class="n">class_dim</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
<span class="n">cost</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">out</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">)</span>
<span class="n">avg_cost</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">cost</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="2-teacher_program">2. 定义teacher_program<a class="headerlink" href="#2-teacher_program" title="Permanent link">#</a></h3>
<p>在定义好teacher_program后,可以一并加载训练好的pretrained_model</p>
<p>在teacher_program内需要加上<code>with fluid.unique_name.guard():</code>,保证teacher的变量命名不被student_program影响,从而跟能够正确地加载预训练参数</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">teacher_program</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="n">teacher_startup</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">Program</span><span class="p">()</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">teacher_program</span><span class="p">,</span> <span class="n">teacher_startup</span><span class="p">):</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">unique_name</span><span class="o">.</span><span class="n">guard</span><span class="p">():</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">fluid</span><span class="o">.</span><span class="n">data</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s1">&#39;data&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">None</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="c1"># teacher model definition</span>
<span class="n">teacher_model</span> <span class="o">=</span> <span class="n">ResNet</span><span class="p">()</span>
<span class="n">predict</span> <span class="o">=</span> <span class="n">teacher_model</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">class_dim</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">teacher_startup</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">if_exist</span><span class="p">(</span><span class="n">var</span><span class="p">):</span>
<span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">&quot;./pretrained&quot;</span><span class="p">,</span> <span class="n">var</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="n">fluid</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">load_vars</span><span class="p">(</span>
<span class="n">exe</span><span class="p">,</span>
<span class="s2">&quot;./pretrained&quot;</span><span class="p">,</span>
<span class="n">main_program</span><span class="o">=</span><span class="n">teacher_program</span><span class="p">,</span>
<span class="n">predicate</span><span class="o">=</span><span class="n">if_exist</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="3">3.选择特征图<a class="headerlink" href="#3" title="Permanent link">#</a></h3>
<p>定义好student_program和teacher_program后,我们需要从中两两对应地挑选出若干个特征图,留待后续为其添加知识蒸馏损失函数</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span> 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="c1"># get all student variables</span>
<span class="n">student_vars</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">student_program</span><span class="o">.</span><span class="n">list_vars</span><span class="p">():</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">student_vars</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">except</span><span class="p">:</span>
<span class="k">pass</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="o">+</span><span class="s2">&quot;student_model_vars&quot;</span><span class="o">+</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">student_vars</span><span class="p">)</span>
<span class="c1"># get all teacher variables</span>
<span class="n">teacher_vars</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">teacher_program</span><span class="o">.</span><span class="n">list_vars</span><span class="p">():</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">teacher_vars</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">except</span><span class="p">:</span>
<span class="k">pass</span>
<span class="k">print</span><span class="p">(</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="o">+</span><span class="s2">&quot;teacher_model_vars&quot;</span><span class="o">+</span><span class="s2">&quot;=&quot;</span><span class="o">*</span><span class="mi">50</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">teacher_vars</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="4-merge">4. 合并特征图(merge)<a class="headerlink" href="#4-merge" title="Permanent link">#</a></h3>
<p>PaddlePaddle使用Program来描述计算图,为了同时计算student和teacher两个Program,这里需要将其两者合并(merge)为一个Program。</p>
<p>merge过程操作较多,具体细节请参考<a href="https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge">merge API文档</a></p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">data_name_map</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;data&#39;</span><span class="p">:</span> <span class="s1">&#39;image&#39;</span><span class="p">}</span>
<span class="n">student_program</span> <span class="o">=</span> <span class="n">merge</span><span class="p">(</span><span class="n">teacher_program</span><span class="p">,</span> <span class="n">student_program</span><span class="p">,</span> <span class="n">data_name_map</span><span class="p">,</span> <span class="n">place</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<h3 id="5loss">5.添加蒸馏loss<a class="headerlink" href="#5loss" title="Permanent link">#</a></h3>
<p>在添加蒸馏loss的过程中,可能还会引入部分变量(Variable),为了避免命名重复这里可以使用<code>with fluid.name_scope("distill"):</code>为新引入的变量加一个命名作用域</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">program_guard</span><span class="p">(</span><span class="n">student_program</span><span class="p">,</span> <span class="n">student_startup</span><span class="p">):</span>
<span class="k">with</span> <span class="n">fluid</span><span class="o">.</span><span class="n">name_scope</span><span class="p">(</span><span class="s2">&quot;distill&quot;</span><span class="p">):</span>
<span class="n">distill_loss</span> <span class="o">=</span> <span class="n">l2_loss</span><span class="p">(</span><span class="s1">&#39;teacher_bn5c_branch2b.output.1.tmp_3&#39;</span><span class="p">,</span> <span class="s1">&#39;depthwise_conv2d_11.tmp_0&#39;</span><span class="p">,</span> <span class="n">main</span><span class="p">)</span>
<span class="n">distill_weight</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">avg_cost</span> <span class="o">+</span> <span class="n">distill_loss</span> <span class="o">*</span> <span class="n">distill_weight</span>
<span class="n">opt</span> <span class="o">=</span> <span class="n">create_optimizer</span><span class="p">()</span>
<span class="n">opt</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">student_startup</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<p>至此,我们就得到了用于蒸馏训练的student_program,后面就可以使用一个普通program一样对其开始训练和评估</p>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../api/quantization_api/" class="btn btn-neutral float-right" title="量化">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../nas_demo/" class="btn btn-neutral" title="SA搜索"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" style="cursor: pointer">
<span class="rst-current-version" data-toggle="rst-current-version">
<a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
<span><a href="../nas_demo/" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../../api/quantization_api/" style="color: #fcfcfc">Next &raquo;</a></span>
</span>
</div>
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
</html>
......@@ -85,6 +85,10 @@
</ul>
</li>
<li class="">
<a class="" href="../distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -98,15 +102,15 @@
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝</a>
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">敏感度分析</a>
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">蒸馏</a>
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -116,9 +120,18 @@
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -258,7 +271,7 @@
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../api/quantization_api/" class="btn btn-neutral float-right" title="量化">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../distillation_demo/" class="btn btn-neutral float-right" title="知识蒸馏">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../quant_embedding_demo/" class="btn btn-neutral" title="Embedding量化"><span class="icon icon-circle-arrow-left"></span> Previous</a>
......@@ -292,14 +305,14 @@
<span><a href="../quant_embedding_demo/" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../../api/quantization_api/" style="color: #fcfcfc">Next &raquo;</a></span>
<span style="margin-left: 15px"><a href="../distillation_demo/" style="color: #fcfcfc">Next &raquo;</a></span>
</span>
</div>
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -87,6 +87,10 @@
<a class="" href="../nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -100,15 +104,15 @@
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝</a>
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">敏感度分析</a>
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">蒸馏</a>
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -118,9 +122,18 @@
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -313,7 +326,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -87,6 +87,10 @@
<a class="" href="../nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -100,15 +104,15 @@
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝</a>
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">敏感度分析</a>
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">蒸馏</a>
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -118,9 +122,18 @@
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -486,7 +499,7 @@ wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
......@@ -87,6 +87,10 @@
<a class="" href="../nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
......@@ -100,15 +104,15 @@
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝</a>
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">敏感度分析</a>
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">蒸馏</a>
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
......@@ -118,9 +122,18 @@
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
......@@ -264,7 +277,7 @@
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="../../MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
......
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="../../img/favicon.ico">
<title>Sensitivity demo - PaddleSlim Docs</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<script>
// Current page data
var mkdocs_page_name = "Sensitivity demo";
var mkdocs_page_input_path = "tutorials/sensitivity_demo.md";
var mkdocs_page_url = null;
</script>
<script src="../../js/jquery-2.1.1.min.js" defer></script>
<script src="../../js/modernizr-2.8.3.min.js" defer></script>
<script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
<script>hljs.initHighlightingOnLoad();</script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-nav-search">
<a href="../.." class="icon icon-home"> PaddleSlim Docs</a>
<div role="search">
<form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" title="Type search term here" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul class="current">
<li class="toctree-l1">
<a class="" href="../..">Home</a>
</li>
<li class="toctree-l1">
<span class="caption-text">教程</span>
<ul class="subnav">
<li class="">
<a class="" href="../quant_post_demo/">离线量化</a>
</li>
<li class="">
<a class="" href="../quant_aware_demo/">量化训练</a>
</li>
<li class="">
<a class="" href="../quant_embedding_demo/">Embedding量化</a>
</li>
<li class="">
<a class="" href="../nas_demo/">SA搜索</a>
</li>
<li class="">
<a class="" href="../distillation_demo/">知识蒸馏</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<span class="caption-text">API</span>
<ul class="subnav">
<li class="">
<a class="" href="../../api/quantization_api/">量化</a>
</li>
<li class="">
<a class="" href="../../api/prune_api/">剪枝与敏感度</a>
</li>
<li class="">
<a class="" href="../../api/analysis_api/">模型分析</a>
</li>
<li class="">
<a class="" href="../../api/single_distiller_api/">知识蒸馏</a>
</li>
<li class="">
<a class="" href="../../api/nas_api/">SA搜索</a>
</li>
<li class="">
<a class="" href="../../api/search_space/">搜索空间</a>
</li>
<li class="">
<a class="" href="../../table_latency/">硬件延时评估表</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../../algo/algo/">算法原理</a>
</li>
</ul>
</div>
&nbsp;
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../..">PaddleSlim Docs</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../..">Docs</a> &raquo;</li>
<li>Sensitivity demo</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/PaddlePaddle/PaddleSlim/edit/master/docs/tutorials/sensitivity_demo.md"
class="icon icon-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main">
<div class="section">
<p>该示例介绍如何分析卷积网络中各卷积层的敏感度,以及如何根据计算出的敏感度选择一组合适的剪裁率。
该示例默认会自动下载并使用MNIST数据。支持以下模型:</p>
<ul>
<li>MobileNetV1</li>
<li>MobileNetV2</li>
<li>ResNet50</li>
</ul>
<h2 id="1">1. 接口介绍<a class="headerlink" href="#1" title="Permanent link">#</a></h2>
<p>该示例涉及以下接口:</p>
<ul>
<li><a href="https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#sensitivity">paddleslim.prune.sensitivity</a></li>
<li><a href="https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#merge_sensitive">paddleslim.prune.merge_sensitive</a></li>
<li><a href="https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#get_ratios_by_losssensitivities-loss">paddleslim.prune.get_ratios_by_loss</a></li>
</ul>
<h2 id="2">2. 运行示例<a class="headerlink" href="#2" title="Permanent link">#</a></h2>
<p>在路径<code>PaddleSlim/demo/sensitive</code>下执行以下代码运行示例:</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">export</span> <span class="n">CUDA_VISIBLE_DEVICES</span><span class="o">=</span><span class="mi">0</span>
<span class="n">python</span> <span class="n">train</span><span class="p">.</span><span class="n">py</span> <span class="c1">--model &quot;MobileNetV1&quot;</span>
</pre></div>
</td></tr></table>
<p>通过<code>python train.py --help</code>查看更多选项。</p>
<h2 id="3">3. 重要步骤说明<a class="headerlink" href="#3" title="Permanent link">#</a></h2>
<h3 id="31">3.1 计算敏感度<a class="headerlink" href="#31" title="Permanent link">#</a></h3>
<p>计算敏感度之前,用户需要搭建好用于测试的网络,以及实现评估模型精度的回调函数。</p>
<p>调用<code>paddleslim.prune.sensitivity</code>接口计算敏感度。敏感度信息会追加到<code>sensitivities_file</code>选项所指定的文件中,如果需要重新计算敏感度,需要先删除<code>sensitivities_file</code>文件。</p>
<p>如果模型评估速度较慢,可以通过多进程的方式加速敏感度计算过程。比如在进程1中设置<code>pruned_ratios=[0.1, 0.2, 0.3, 0.4]</code>,并将敏感度信息存放在文件<code>sensitivities_0.data</code>中,然后在进程2中设置<code>pruned_ratios=[0.5, 0.6, 0.7]</code>,并将敏感度信息存储在文件<code>sensitivities_1.data</code>中。这样每个进程只会计算指定剪切率下的敏感度信息。多进程可以运行在单机多卡,或多机多卡。</p>
<p>代码如下:</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="o">#</span> <span class="err">进程</span><span class="mi">1</span>
<span class="n">sensitivity</span><span class="p">(</span>
<span class="n">val_program</span><span class="p">,</span>
<span class="n">place</span><span class="p">,</span>
<span class="n">params</span><span class="p">,</span>
<span class="n">test</span><span class="p">,</span>
<span class="n">sensitivities_file</span><span class="o">=</span><span class="ss">&quot;sensitivities_0.data&quot;</span><span class="p">,</span>
<span class="n">pruned_ratios</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">.</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">.</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">.</span><span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">.</span><span class="mi">4</span><span class="p">])</span>
</pre></div>
</td></tr></table>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1
2
3
4
5
6
7
8</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="o">#</span> <span class="err">进程</span><span class="mi">2</span>
<span class="n">sensitivity</span><span class="p">(</span>
<span class="n">val_program</span><span class="p">,</span>
<span class="n">place</span><span class="p">,</span>
<span class="n">params</span><span class="p">,</span>
<span class="n">test</span><span class="p">,</span>
<span class="n">sensitivities_file</span><span class="o">=</span><span class="ss">&quot;sensitivities_1.data&quot;</span><span class="p">,</span>
<span class="n">pruned_ratios</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">.</span><span class="mi">5</span><span class="p">,</span> <span class="mi">0</span><span class="p">.</span><span class="mi">6</span><span class="p">,</span> <span class="mi">0</span><span class="p">.</span><span class="mi">7</span><span class="p">])</span>
</pre></div>
</td></tr></table>
<h3 id="32">3.2 合并敏感度<a class="headerlink" href="#32" title="Permanent link">#</a></h3>
<p>如果用户通过上一节多进程的方式生成了多个存储敏感度信息的文件,可以通过<code>paddleslim.prune.merge_sensitive</code>将其合并,合并后的敏感度信息存储在一个<code>dict</code>中。代码如下:</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">sens</span> <span class="o">=</span> <span class="n">merge_sensitive</span><span class="p">([</span><span class="ss">&quot;./sensitivities_0.data&quot;</span><span class="p">,</span> <span class="ss">&quot;./sensitivities_1.data&quot;</span><span class="p">])</span>
</pre></div>
</td></tr></table>
<h3 id="33">3.3 计算剪裁率<a class="headerlink" href="#33" title="Permanent link">#</a></h3>
<p>调用<code>paddleslim.prune.get_ratios_by_loss</code>接口计算一组剪裁率。</p>
<table class="codehilitetable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span>1</pre></div></td><td class="code"><div class="codehilite"><pre><span></span><span class="n">ratios</span> <span class="o">=</span> <span class="n">get_ratios_by_loss</span><span class="p">(</span><span class="n">sens</span><span class="p">,</span> <span class="mi">0</span><span class="p">.</span><span class="mi">01</span><span class="p">)</span>
</pre></div>
</td></tr></table>
<p>其中,<code>0.01</code>为一个阈值,对于任意卷积层,其剪裁率为使精度损失低于阈值<code>0.01</code>的最大剪裁率。</p>
<p>用户在计算出一组剪裁率之后可以通过接口<code>paddleslim.prune.Pruner</code>剪裁网络,并用接口<code>paddleslim.analysis.flops</code>计算<code>FLOPs</code>。如果<code>FLOPs</code>不满足要求,调整阈值重新计算出一组剪裁率。</p>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" style="cursor: pointer">
<span class="rst-current-version" data-toggle="rst-current-version">
<a href="https://github.com/PaddlePaddle/PaddleSlim/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
</span>
</div>
<script>var base_url = '../..';</script>
<script src="../../js/theme.js" defer></script>
<script src="../../mathjax-config.js" defer></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML" defer></script>
<script src="../../search/main.js" defer></script>
</body>
</html>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册