register_grad_op.html 16.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30


<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>Design Doc: Gradient Operators Registration &mdash; PaddlePaddle  文档</title>
  

  
  

  

  
  
    

  

  
  
    <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  

  
31

32 33 34 35 36 37 38 39 40 41 42 43 44
  
        <link rel="index" title="索引"
              href="../genindex.html"/>
        <link rel="search" title="搜索" href="../search.html"/>
    <link rel="top" title="PaddlePaddle  文档" href="../index.html"/> 

  
  <script src="../_static/js/modernizr.min.js"></script>

</head>

<body class="wy-body-for-nav" role="document">

45 46 47 48 49 50 51 52 53 54 55 56 57
  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search">
          

          
            <a href="../index_cn.html" class="icon icon-home"> PaddlePaddle
          

          
58 59
          </a>

60 61 62 63 64 65
          
            
            
          

          
66 67 68 69 70 71
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
72
</div>
73 74

          
75 76 77 78 79 80 81 82 83 84 85 86
        </div>

        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
                <ul>
<li class="toctree-l1"><a class="reference internal" href="../getstarted/index_cn.html">新手入门</a></li>
<li class="toctree-l1"><a class="reference internal" href="../build_and_install/index_cn.html">安装与编译</a></li>
<li class="toctree-l1"><a class="reference internal" href="../howto/index_cn.html">进阶使用</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/index_cn.html">开发标准</a></li>
<li class="toctree-l1"><a class="reference internal" href="../faq/index_cn.html">FAQ</a></li>
87 88
</ul>

89 90 91 92
            
          
        </div>
      </div>
93 94
    </nav>

95
    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
96

97 98 99 100 101
      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../index_cn.html">PaddlePaddle</a>
      </nav>
102 103


104 105 106 107
      
      <div class="wy-nav-content">
        <div class="rst-content">
          
108

109
 
110 111 112 113 114



<div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
115
    <li><a href="../index_cn.html">Docs</a> &raquo;</li>
116 117
      
    <li>Design Doc: Gradient Operators Registration</li>
118 119 120 121 122 123 124
      <li class="wy-breadcrumbs-aside">
        
          
            <a href="../_sources/design/register_grad_op.md.txt" rel="nofollow"> View page source</a>
          
        
      </li>
125
  </ul>
126
  <hr/>
127 128 129 130 131 132 133 134
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="section" id="design-doc-gradient-operators-registration">
<span id="design-doc-gradient-operators-registration"></span><h1>Design Doc: Gradient Operators Registration<a class="headerlink" href="#design-doc-gradient-operators-registration" title="永久链接至标题"></a></h1>
<div class="section" id="the-problem-posed">
<span id="the-problem-posed"></span><h2>The Problem Posed<a class="headerlink" href="#the-problem-posed" title="永久链接至标题"></a></h2>
135 136
<p>Currently, for each C++ operator class definition, a <em>gradient operator creator</em> function is registered, which takes as input a C++ operator instance and returns the corresponding gradient operator instance.</p>
<p>However, we noticed two problems with the current design:</p>
137
<ol class="simple">
138 139
<li>As we decided to separate the <em>compilation</em> and the <em>execution</em> phases, we need to change the creator to take an <code class="docutils literal"><span class="pre">OpDesc</span></code> protobuf message in a <code class="docutils literal"><span class="pre">ProgramDesc</span></code> and inserts corresponding <code class="docutils literal"><span class="pre">OpDesc</span></code> messages into the <code class="docutils literal"><span class="pre">ProgramDesc</span></code> message.</li>
<li>For some operators, the gradient computation can be written in terms of existing operators.  For example, the gradient of <em>minus</em> operator consists of two operators &#8211; an <em>identity</em> operator followed by a <em>scale</em> operator.  Hence the registration mechanism needs to support mapping from an operator to a set of operators for the gradient computation.</li>
140
</ol>
141
</div>
142 143
<div class="section" id="the-current-implementation">
<span id="the-current-implementation"></span><h2>The Current Implementation<a class="headerlink" href="#the-current-implementation" title="永久链接至标题"></a></h2>
144
<p>Instances of the C++ class <code class="docutils literal"><span class="pre">OpInfo</span></code> are stored an associative map whose key is the operator type. The <code class="docutils literal"><span class="pre">grad_op_type</span></code> indicates the associated gradient operator type. An operator can create the gradient operator by invoking <code class="docutils literal"><span class="pre">OpInfo::creator_</span></code> of the gradient operator. The pseudo code is as follows</p>
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">struct</span> <span class="n">OpInfo</span> <span class="p">{</span>
  <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="n">OperatorBase</span><span class="o">*</span><span class="p">(...)</span><span class="o">&gt;</span> <span class="n">creator_</span><span class="p">;</span>
  <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">grad_op_type_</span><span class="p">;</span>
  <span class="p">...</span>
<span class="p">};</span>

<span class="n">map</span><span class="o">&lt;</span><span class="n">string</span><span class="p">,</span> <span class="n">OpInfo</span><span class="o">&gt;</span> <span class="n">OpInfoMap</span><span class="p">;</span>

<span class="n">OperatorBase</span><span class="o">*</span> <span class="nf">CreateGradientOperator</span><span class="p">(</span><span class="k">const</span> <span class="n">OperatorBase</span><span class="o">&amp;</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">return</span> <span class="n">OpInfoMap</span><span class="p">.</span><span class="n">at</span><span class="p">(</span><span class="n">op</span><span class="p">.</span><span class="n">Type</span><span class="p">()).</span><span class="n">creator_</span><span class="p">(...);</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="proposed-solution">
<span id="proposed-solution"></span><h2>Proposed Solution<a class="headerlink" href="#proposed-solution" title="永久链接至标题"></a></h2>
161
<p>The mapping relationship between an operator and its gradient operators is a function. The interface of this function is:</p>
162
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="c1">// (OpDesc) --&gt; vector&lt;OpDesc&gt;</span>
163
<span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpDescBind</span><span class="o">&gt;</span><span class="p">(</span><span class="k">const</span> <span class="n">OpDescBind</span><span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span><span class="p">;</span>
164 165
</pre></div>
</div>
166 167
<p>The function takes an <code class="docutils literal"><span class="pre">OpDescBind</span></code> of the forward operator and returns one or many gradient operator descriptions. <code class="docutils literal"><span class="pre">OpDescBind</span></code> is a C++ wrapper for  the protobuf message <code class="docutils literal"><span class="pre">OpDesc</span></code> for rapid manipulation of <code class="docutils literal"><span class="pre">OpDesc</span></code>.</p>
<p>The <code class="docutils literal"><span class="pre">GradOpDescMaker</span></code> will be registered in <code class="docutils literal"><span class="pre">OpInfo</span></code> and will replace the <code class="docutils literal"><span class="pre">grad_op_type_</span></code> field. The <code class="docutils literal"><span class="pre">OpInfo</span></code> should look like</p>
168
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">struct</span> <span class="n">OpInfo</span> <span class="p">{</span>
169
  <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">OpDescBind</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="k">const</span> <span class="n">OpDescBind</span><span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span>  <span class="n">grad_op_maker_</span><span class="p">;</span>
170 171 172 173
  <span class="p">...</span>
<span class="p">};</span>
</pre></div>
</div>
174
<p>The <code class="docutils literal"><span class="pre">grad_op_maker_</span></code> is a <code class="docutils literal"><span class="pre">nullptr</span></code> if the operator does not have any associated gradient operators.</p>
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
<p>We propose a base class called <code class="docutils literal"><span class="pre">GradOpDescMakerBase</span></code> to let operator developers generate <code class="docutils literal"><span class="pre">Gradient</span> <span class="pre">Operators</span></code> easily. The public interface of that class is</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">GradOpDescMakerBase</span> <span class="p">{</span>
<span class="k">public</span><span class="o">:</span>
  <span class="n">GradOpDescMakerBase</span><span class="p">(</span><span class="k">const</span> <span class="n">OpDescBind</span><span class="o">&amp;</span> <span class="p">);</span>
  <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">OpDescBind</span><span class="o">&gt;&gt;</span> <span class="k">operator</span><span class="p">()()</span><span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
<p>We can convert <code class="docutils literal"><span class="pre">GradOpDescMakerBase</span></code> to <code class="docutils literal"><span class="pre">std::function&lt;std::vector&lt;std::unique_ptr&lt;OpDescBind&gt;&gt;(const</span> <span class="pre">OpDescBind&amp;)&gt;</span></code> by</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">using</span> <span class="n">GradOpMaker</span> <span class="o">=</span> <span class="p">...;</span>
<span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpDescBind</span><span class="o">&gt;</span><span class="p">(</span><span class="k">const</span> <span class="n">OpDescBind</span><span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="n">func</span><span class="p">;</span>
<span class="n">func</span> <span class="o">=</span> <span class="p">[]</span> <span class="p">(</span><span class="k">const</span> <span class="n">OpDescBind</span><span class="o">&amp;</span> <span class="n">fwd_op</span><span class="p">)</span> <span class="p">{</span>
  <span class="n">GradOpMaker</span> <span class="n">maker</span><span class="p">(</span><span class="n">fwd_op</span><span class="p">);</span>
  <span class="k">return</span> <span class="nf">maker</span><span class="p">();</span>
<span class="p">};</span>
</pre></div>
</div>
<p>We can write many helper functions since the <code class="docutils literal"><span class="pre">GradOpDescMakerBase</span></code> is a class now. The basic helper functions get the variables of <code class="docutils literal"><span class="pre">Input</span></code>, <code class="docutils literal"><span class="pre">Output</span></code>, <code class="docutils literal"><span class="pre">InputGradient</span></code> and <code class="docutils literal"><span class="pre">OutputGradient</span></code> in the forwarding operator.</p>
193
<p>We should change register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So <code class="docutils literal"><span class="pre">REGISTER_OP</span></code> just register one operator. If the <code class="docutils literal"><span class="pre">REGISTER_OPERATOR</span></code> contains <code class="docutils literal"><span class="pre">OpProtoAndCheckerMaker</span></code> and <code class="docutils literal"><span class="pre">GradOpDescMaker</span></code>, we just list them in the same macro. It can be done by a macro contains <code class="docutils literal"><span class="pre">__VA_ARGS__</span></code>.</p>
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
<p>The user interface should be</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpDesc</span><span class="o">&gt;</span> <span class="n">MinusOpGradMaker</span><span class="p">(</span><span class="n">OpDesc</span><span class="p">)</span> <span class="p">{...}</span>
<span class="n">REGISTER_OPERATOR</span><span class="p">(</span><span class="n">minus</span><span class="p">,</span> <span class="n">MinusOp</span><span class="p">,</span> <span class="n">MinusOpProtoAndCheckerMaker</span><span class="p">,</span> <span class="n">SumOpGradMaker</span><span class="p">);</span>
<span class="c1">// Developers can still manually implement gradient operator.</span>
<span class="n">REGISTER_OPERATOR</span><span class="p">(</span><span class="n">minus_grad</span><span class="p">,</span> <span class="n">MinusGradOp</span><span class="p">);</span>
</pre></div>
</div>
<p>The interface of current <code class="docutils literal"><span class="pre">REGISTER_OP</span></code> macro could not be changed. In <code class="docutils literal"><span class="pre">REGISTER_OP</span></code>, it will invoke <code class="docutils literal"><span class="pre">REGISTER_OPERATOR</span></code> two times and generate GradOpDescMaker inside.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">REGISTER_OP</span><span class="p">(</span><span class="n">minus</span><span class="p">,</span> <span class="n">MinusOp</span><span class="p">,</span> <span class="n">MinusOpProtoAndCheckerMaker</span><span class="p">,</span> <span class="n">minus_grad</span><span class="p">,</span> <span class="n">MinusGradOp</span><span class="p">);</span>
</pre></div>
</div>
</div>
</div>


           </div>
          </div>
          <footer>
  

  <hr/>

  <div role="contentinfo">
    <p>
        &copy; Copyright 2016, PaddlePaddle developers.

    </p>
  </div>
  Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  


  

    <script type="text/javascript">
        var DOCUMENTATION_OPTIONS = {
            URL_ROOT:'../',
            VERSION:'',
            COLLAPSE_INDEX:false,
            FILE_SUFFIX:'.html',
243
            HAS_SOURCE:  true
244 245 246 247 248 249 250
        };
    </script>
      <script type="text/javascript" src="../_static/jquery.js"></script>
      <script type="text/javascript" src="../_static/underscore.js"></script>
      <script type="text/javascript" src="../_static/doctools.js"></script>
      <script type="text/javascript" src="../_static/translations.js"></script>
      <script type="text/javascript" src="https://cdn.bootcss.com/mathjax/2.7.0/MathJax.js"></script>
251

252 253 254 255 256 257
  

  
  
    <script type="text/javascript" src="../_static/js/theme.js"></script>
  
258

259
  
260 261 262 263 264 265 266
  
  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.StickyNav.enable();
      });
  </script>
   
267 268 269

</body>
</html>