support_new_device.html 30.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10


<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
11
  <title>Design Doc: Supporting new Device/Library &mdash; PaddlePaddle  documentation</title>
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  

  
  

  

  
  
    

  

  
  
    <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  

  
  
        <link rel="index" title="Index"
              href="../genindex.html"/>
        <link rel="search" title="Search" href="../search.html"/>
    <link rel="top" title="PaddlePaddle  documentation" href="../index.html"/> 

  <link rel="stylesheet" href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" type="text/css" />
  <link rel="stylesheet" href="../_static/css/override.css" type="text/css" />
  <script>
  var _hmt = _hmt || [];
  (function() {
    var hm = document.createElement("script");
    hm.src = "//hm.baidu.com/hm.js?b9a314ab40d04d805655aab1deee08ba";
    var s = document.getElementsByTagName("script")[0]; 
    s.parentNode.insertBefore(hm, s);
  })();
  </script>

  

  
  <script src="../_static/js/modernizr.min.js"></script>

</head>

<body class="wy-body-for-nav" role="document">

  
  <header class="site-header">
    <div class="site-logo">
      <a href="/"><img src="../_static/images/PP_w.png"></a>
    </div>
    <div class="site-nav-links">
      <div class="site-menu">
        <a class="fork-on-github" href="https://github.com/PaddlePaddle/Paddle" target="_blank"><i class="fa fa-github"></i>Fork me on Github</a>
        <div class="language-switcher dropdown">
          <a type="button" data-toggle="dropdown">
            <span>English</span>
            <i class="fa fa-angle-up"></i>
            <i class="fa fa-angle-down"></i>
          </a>
          <ul class="dropdown-menu">
            <li><a href="/doc_cn">中文</a></li>
            <li><a href="/doc">English</a></li>
          </ul>
        </div>
        <ul class="site-page-links">
          <li><a href="/">Home</a></li>
        </ul>
      </div>
      <div class="doc-module">
        
        <ul>
<li class="toctree-l1"><a class="reference internal" href="../getstarted/index_en.html">GET STARTED</a></li>
85
<li class="toctree-l1"><a class="reference internal" href="../build_and_install/index_en.html">Install and Build</a></li>
86
<li class="toctree-l1"><a class="reference internal" href="../howto/index_en.html">HOW TO</a></li>
87
<li class="toctree-l1"><a class="reference internal" href="../dev/index_en.html">Development</a></li>
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
</ul>

        
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>        
      </div>
    </div>
  </header>
  
  <div class="main-content-wrap">

    
    <nav class="doc-menu-vertical" role="navigation">
        
          
          <ul>
<li class="toctree-l1"><a class="reference internal" href="../getstarted/index_en.html">GET STARTED</a><ul>
110
<li class="toctree-l2"><a class="reference internal" href="../getstarted/quickstart_en.html">Quick Start</a></li>
111 112
</ul>
</li>
113 114 115 116
<li class="toctree-l1"><a class="reference internal" href="../build_and_install/index_en.html">Install and Build</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../build_and_install/pip_install_en.html">Install Using pip</a></li>
<li class="toctree-l2"><a class="reference internal" href="../build_and_install/docker_install_en.html">Run in Docker Containers</a></li>
<li class="toctree-l2"><a class="reference internal" href="../build_and_install/build_from_source_en.html">Build from Sources</a></li>
117 118 119
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../howto/index_en.html">HOW TO</a><ul>
120 121 122 123
<li class="toctree-l2"><a class="reference internal" href="../howto/cmd_parameter/index_en.html">Set Command-line Parameters</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/cmd_parameter/use_case_en.html">Use Case</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/cmd_parameter/arguments_en.html">Argument Outline</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/cmd_parameter/detail_introduction_en.html">Detail Description</a></li>
124 125
</ul>
</li>
126 127 128 129 130 131 132 133
<li class="toctree-l2"><a class="reference internal" href="../howto/cluster/index_en.html">Distributed Training</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/cluster/preparations_en.html">Preparations</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/cluster/cmd_argument_en.html">Command-line arguments</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/cluster/multi_cluster/index_en.html">Use different clusters</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../howto/cluster/multi_cluster/fabric_en.html">Cluster Training Using Fabric</a></li>
<li class="toctree-l4"><a class="reference internal" href="../howto/cluster/multi_cluster/openmpi_en.html">Cluster Training Using OpenMPI</a></li>
<li class="toctree-l4"><a class="reference internal" href="../howto/cluster/multi_cluster/k8s_en.html">PaddlePaddle On Kubernetes</a></li>
<li class="toctree-l4"><a class="reference internal" href="../howto/cluster/multi_cluster/k8s_aws_en.html">Distributed PaddlePaddle Training on AWS with Kubernetes</a></li>
134 135
</ul>
</li>
136 137 138 139
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../howto/rnn/index_en.html">RNN Models</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/rnn/rnn_config_en.html">RNN Configuration</a></li>
140 141 142 143 144
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../howto/optimization/gpu_profiling_en.html">Tune GPU Performance</a></li>
</ul>
</li>
145 146 147 148 149 150
<li class="toctree-l1"><a class="reference internal" href="../dev/index_en.html">Development</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../dev/new_layer_en.html">Write New Layers</a></li>
<li class="toctree-l2"><a class="reference internal" href="../dev/contribute_to_paddle_en.html">Contribute Code</a></li>
<li class="toctree-l2"><a class="reference internal" href="../dev/write_docs_en.html">Contribute Documentation</a></li>
</ul>
</li>
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
</ul>

        
    </nav>
    
    <section class="doc-content-wrap">

      

 







<div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
      
171
    <li>Design Doc: Supporting new Device/Library</li>
172 173 174 175 176 177 178 179
  </ul>
</div>
      
      <div class="wy-nav-content" id="doc-content">
        <div class="rst-content">
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
180 181
  <div class="section" id="design-doc-supporting-new-device-library">
<span id="design-doc-supporting-new-device-library"></span><h1>Design Doc: Supporting new Device/Library<a class="headerlink" href="#design-doc-supporting-new-device-library" title="Permalink to this headline"></a></h1>
182 183
<div class="section" id="background">
<span id="background"></span><h2>Background<a class="headerlink" href="#background" title="Permalink to this headline"></a></h2>
184 185
<p>Deep learning has a high demand for computing resources. New high-performance devices and computing libraries are appearing very frequently. Deep learning frameworks have to integrate these high-performance devices and computing libraries in a flexible and efficient manner.</p>
<p>On one hand, hardware and computing libraries usually do not have a one-to-one correspondence. For example, Intel CPUs support Eigen and MKL computing libraries while Nvidia GPUs support Eigen and cuDNN computing libraries. We have to implement operator specific kernels for each computing library.</p>
186
<p>On the other hand, users usually do not want to care about the low-level hardware and computing libraries when writing a neural network configuration. In Fluid, <code class="docutils literal"><span class="pre">Layer</span></code> is exposed in <code class="docutils literal"><span class="pre">Python</span></code>, and <code class="docutils literal"><span class="pre">Operator</span></code> is exposed in <code class="docutils literal"><span class="pre">C++</span></code>. Both <code class="docutils literal"><span class="pre">Layer</span></code> and <code class="docutils literal"><span class="pre">Operator</span></code> are hardware independent.</p>
187 188 189 190
<p>So, how to support a new Device/Library in Fluid becomes a challenge.</p>
</div>
<div class="section" id="basic-integrate-a-new-device-library">
<span id="basic-integrate-a-new-device-library"></span><h2>Basic: Integrate A New Device/Library<a class="headerlink" href="#basic-integrate-a-new-device-library" title="Permalink to this headline"></a></h2>
191 192
<p>For a general overview of fluid, please refer to the <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/read_source.md">overview doc</a>.</p>
<p>There are mainly three parts that we have to consider while integrating a new device/library:</p>
193
<ul class="simple">
194
<li>Place and DeviceContext: indicate the device id and manage hardware resources</li>
195
<li>Memory and Tensor: malloc/free data on certain device</li>
196
<li>Math Functor and OpKernel: implement computing unit on certain devices/libraries</li>
197 198 199
</ul>
<div class="section" id="place-and-devicecontext">
<span id="place-and-devicecontext"></span><h3>Place and DeviceContext<a class="headerlink" href="#place-and-devicecontext" title="Permalink to this headline"></a></h3>
200
<p>Please note that device and computing library are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices.</p>
201 202
<div class="section" id="place">
<span id="place"></span><h4>Place<a class="headerlink" href="#place" title="Permalink to this headline"></a></h4>
203
<p>Fluid uses class <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55">Place</a> to represent the device memory where data is located. If we add another device, we have to add the corresponding <code class="docutils literal"><span class="pre">DevicePlace</span></code>.</p>
204 205
<div class="highlight-default"><div class="highlight"><pre><span></span>        <span class="o">|</span>   <span class="n">CPUPlace</span>
<span class="n">Place</span> <span class="o">--|</span>   <span class="n">CUDAPlace</span>
206 207 208 209 210 211 212 213 214 215
        <span class="o">|</span>   <span class="n">FPGAPlace</span>
</pre></div>
</div>
<p>And <code class="docutils literal"><span class="pre">Place</span></code> is defined as follows:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">typedef</span> <span class="n">boost</span><span class="p">::</span><span class="n">variant</span><span class="o">&lt;</span><span class="n">CUDAPlace</span><span class="p">,</span> <span class="n">CPUPlace</span><span class="p">,</span> <span class="n">FPGAPlace</span><span class="o">&gt;</span> <span class="n">Place</span><span class="p">;</span>
</pre></div>
</div>
</div>
<div class="section" id="devicecontext">
<span id="devicecontext"></span><h4>DeviceContext<a class="headerlink" href="#devicecontext" title="Permalink to this headline"></a></h4>
216
<p>Fluid uses class <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30">DeviceContext</a> to manage the resources in different libraries, such as CUDA stream in <code class="docutils literal"><span class="pre">CDUADeviceContext</span></code>. There are also inheritance relationships between different kinds of <code class="docutils literal"><span class="pre">DeviceContext</span></code>.</p>
217 218
<div class="highlight-default"><div class="highlight"><pre><span></span>                <span class="o">/-&gt;</span>  <span class="n">CPUDeviceContext</span>   
<span class="n">DeviceContext</span> <span class="o">----&gt;</span>  <span class="n">CUDADeviceContext</span>  
219 220 221
                \<span class="o">-&gt;</span>  <span class="n">FPGADeviceContext</span>
</pre></div>
</div>
222
<p>An example of Nvidia GPU is as follows:</p>
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
<ul class="simple">
<li>DeviceContext</li>
</ul>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">DeviceContext</span> <span class="p">{</span>
  <span class="n">virtual</span> <span class="n">Place</span> <span class="n">GetPlace</span><span class="p">()</span> <span class="n">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="p">};</span>  
</pre></div>
</div>
<ul class="simple">
<li>CUDADeviceContext</li>
</ul>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">CUDADeviceContext</span> <span class="p">:</span> <span class="n">public</span> <span class="n">DeviceContext</span> <span class="p">{</span>
  <span class="n">Place</span> <span class="n">GetPlace</span><span class="p">()</span> <span class="n">const</span> <span class="n">override</span> <span class="p">{</span> <span class="k">return</span> <span class="n">place_</span><span class="p">;</span> <span class="p">}</span>
<span class="n">private</span><span class="p">:</span>
  <span class="n">CUDAPlace</span> <span class="n">place_</span><span class="p">;</span>
  <span class="n">cudaStream_t</span> <span class="n">stream_</span><span class="p">;</span> 
  <span class="n">cublasHandle_t</span> <span class="n">cublas_handle_</span><span class="p">;</span>
  <span class="n">std</span><span class="p">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">Eigen</span><span class="p">::</span><span class="n">GpuDevice</span><span class="o">&gt;</span> <span class="n">eigen_device_</span><span class="p">;</span>  <span class="o">//</span> <span class="n">binds</span> <span class="k">with</span> <span class="n">stream_</span>
<span class="p">};</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="memory-and-tensor">
<span id="memory-and-tensor"></span><h3>Memory and Tensor<a class="headerlink" href="#memory-and-tensor" title="Permalink to this headline"></a></h3>
<div class="section" id="memory-module">
<span id="memory-module"></span><h4>memory module<a class="headerlink" href="#memory-module" title="Permalink to this headline"></a></h4>
250
<p>Fluid provides the following <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/memory/memory.h#L36">memory interfaces</a>:</p>
251 252 253 254 255 256 257 258 259 260
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">template</span> <span class="o">&lt;</span><span class="n">typename</span> <span class="n">Place</span><span class="o">&gt;</span>
<span class="n">void</span><span class="o">*</span> <span class="n">Alloc</span><span class="p">(</span><span class="n">Place</span> <span class="n">place</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">);</span>

<span class="n">template</span> <span class="o">&lt;</span><span class="n">typename</span> <span class="n">Place</span><span class="o">&gt;</span>
<span class="n">void</span> <span class="n">Free</span><span class="p">(</span><span class="n">Place</span> <span class="n">place</span><span class="p">,</span> <span class="n">void</span><span class="o">*</span> <span class="n">ptr</span><span class="p">);</span>

<span class="n">template</span> <span class="o">&lt;</span><span class="n">typename</span> <span class="n">Place</span><span class="o">&gt;</span>
<span class="n">size_t</span> <span class="n">Used</span><span class="p">(</span><span class="n">Place</span> <span class="n">place</span><span class="p">);</span>
</pre></div>
</div>
261
<p>To implement these interfaces, we have to implement MemoryAllocator for different Devices.</p>
262 263 264
</div>
<div class="section" id="tensor">
<span id="tensor"></span><h4>Tensor<a class="headerlink" href="#tensor" title="Permalink to this headline"></a></h4>
265
<p><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.h#L36">Tensor</a> holds data with some shape in a specific Place.</p>
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Tensor</span> <span class="p">{</span>
 <span class="k">public</span><span class="o">:</span>
  <span class="cm">/*! Return a pointer to mutable memory block. */</span>
  <span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
  <span class="kr">inline</span> <span class="n">T</span><span class="o">*</span> <span class="n">data</span><span class="p">();</span>

  <span class="cm">/**</span>
<span class="cm">   * @brief   Return a pointer to mutable memory block.</span>
<span class="cm">   * @note    If not exist, then allocation.</span>
<span class="cm">   */</span>
  <span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
  <span class="kr">inline</span> <span class="n">T</span><span class="o">*</span> <span class="n">mutable_data</span><span class="p">(</span><span class="n">platform</span><span class="o">::</span><span class="n">Place</span> <span class="n">place</span><span class="p">);</span>

  <span class="cm">/**</span>
<span class="cm">   * @brief     Return a pointer to mutable memory block.</span>
<span class="cm">   *</span>
<span class="cm">   * @param[in] dims    The dimensions of the memory block.</span>
<span class="cm">   * @param[in] place   The place of the memory block.</span>
<span class="cm">   *</span>
<span class="cm">   * @note      If not exist, then allocation.</span>
<span class="cm">   */</span>
  <span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
  <span class="kr">inline</span> <span class="n">T</span><span class="o">*</span> <span class="n">mutable_data</span><span class="p">(</span><span class="n">DDim</span> <span class="n">dims</span><span class="p">,</span> <span class="n">platform</span><span class="o">::</span><span class="n">Place</span> <span class="n">place</span><span class="p">);</span>

  <span class="cm">/*! Resize the dimensions of the memory block. */</span>
  <span class="kr">inline</span> <span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">Resize</span><span class="p">(</span><span class="k">const</span> <span class="n">DDim</span><span class="o">&amp;</span> <span class="n">dims</span><span class="p">);</span>

  <span class="cm">/*! Return the dimensions of the memory block. */</span>
  <span class="kr">inline</span> <span class="k">const</span> <span class="n">DDim</span><span class="o">&amp;</span> <span class="n">dims</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>

 <span class="k">private</span><span class="o">:</span>
  <span class="cm">/*! holds the memory block if allocated. */</span>
  <span class="n">std</span><span class="o">::</span><span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">Placeholder</span><span class="o">&gt;</span> <span class="n">holder_</span><span class="p">;</span>

  <span class="cm">/*! points to dimensions of memory block. */</span>
  <span class="n">DDim</span> <span class="n">dim_</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
305
<p><code class="docutils literal"><span class="pre">Placeholder</span></code> is used to delay memory allocation; that is, we can first define a tensor, using <code class="docutils literal"><span class="pre">Resize</span></code> to configurate its shape, and then call <code class="docutils literal"><span class="pre">mutuable_data</span></code> to allocate the actual memory.</p>
306 307 308 309 310 311 312 313 314 315 316 317
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">paddle</span><span class="o">::</span><span class="n">framework</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">t</span><span class="p">;</span>
<span class="n">paddle</span><span class="o">::</span><span class="n">platform</span><span class="o">::</span><span class="n">CPUPlace</span> <span class="n">place</span><span class="p">;</span>
<span class="c1">// set size first</span>
<span class="n">t</span><span class="p">.</span><span class="n">Resize</span><span class="p">({</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">});</span>
<span class="c1">// allocate memory on CPU later</span>
<span class="n">t</span><span class="p">.</span><span class="n">mutable_data</span><span class="p">(</span><span class="n">place</span><span class="p">);</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="math-functor-and-opkernel">
<span id="math-functor-and-opkernel"></span><h3>Math Functor and OpKernel<a class="headerlink" href="#math-functor-and-opkernel" title="Permalink to this headline"></a></h3>
318
<p>Fluid implements computing units based on different DeviceContexts. Some computing units are shared between operators. This common part will be put in operators/math directory as basic Functors.</p>
319
<p>Let&#8217;s take <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/math/maxouting.h#L27">MaxOutFunctor</a> as an example:</p>
320
<p>The interface is defined in the header file.</p>
321 322 323 324 325 326 327 328
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">template</span> <span class="o">&lt;</span><span class="n">typename</span> <span class="n">DeviceContext</span><span class="p">,</span> <span class="n">typename</span> <span class="n">T</span><span class="o">&gt;</span>
<span class="k">class</span> <span class="nc">MaxOutFunctor</span> <span class="p">{</span>
 <span class="n">public</span><span class="p">:</span>
  <span class="n">void</span> <span class="n">operator</span><span class="p">()(</span><span class="n">const</span> <span class="n">DeviceContext</span><span class="o">&amp;</span> <span class="n">context</span><span class="p">,</span> <span class="n">const</span> <span class="n">framework</span><span class="p">::</span><span class="n">Tensor</span><span class="o">&amp;</span> <span class="nb">input</span><span class="p">,</span>
                  <span class="n">framework</span><span class="p">::</span><span class="n">Tensor</span><span class="o">*</span> <span class="n">output</span><span class="p">,</span> <span class="nb">int</span> <span class="n">groups</span><span class="p">);</span>
<span class="p">};</span>
</pre></div>
</div>
329
<p>CPU implementation is in .cc file</p>
330 331 332 333 334 335 336 337 338 339 340
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">template</span> <span class="o">&lt;</span><span class="n">typename</span> <span class="n">T</span><span class="o">&gt;</span>
<span class="k">class</span> <span class="nc">MaxOutFunctor</span><span class="o">&lt;</span><span class="n">platform</span><span class="p">::</span><span class="n">CPUDeviceContext</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;</span> <span class="p">{</span>
  <span class="n">public</span><span class="p">:</span>
  <span class="n">void</span> <span class="n">operator</span><span class="p">()(</span><span class="n">const</span> <span class="n">platform</span><span class="p">::</span><span class="n">CPUDeviceContext</span><span class="o">&amp;</span> <span class="n">context</span><span class="p">,</span>
                  <span class="n">const</span> <span class="n">framework</span><span class="p">::</span><span class="n">Tensor</span><span class="o">&amp;</span> <span class="nb">input</span><span class="p">,</span> <span class="n">framework</span><span class="p">::</span><span class="n">Tensor</span><span class="o">*</span> <span class="n">output</span><span class="p">,</span>
                  <span class="nb">int</span> <span class="n">groups</span><span class="p">)</span> <span class="p">{</span>
                  <span class="o">...</span>
                  <span class="p">}</span>
<span class="p">};</span>
</pre></div>
</div>
341
<p>CUDA implementation is in .cu file</p>
342 343 344 345 346 347 348 349 350 351 352
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">template</span> <span class="o">&lt;</span><span class="n">typename</span> <span class="n">T</span><span class="o">&gt;</span>
<span class="k">class</span> <span class="nc">MaxOutFunctor</span><span class="o">&lt;</span><span class="n">platform</span><span class="p">::</span><span class="n">CUDADeviceContext</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;</span> <span class="p">{</span>
 <span class="n">public</span><span class="p">:</span>
  <span class="n">void</span> <span class="n">operator</span><span class="p">()(</span><span class="n">const</span> <span class="n">platform</span><span class="p">::</span><span class="n">CUDADeviceContext</span><span class="o">&amp;</span> <span class="n">context</span><span class="p">,</span>
                  <span class="n">const</span> <span class="n">framework</span><span class="p">::</span><span class="n">Tensor</span><span class="o">&amp;</span> <span class="nb">input</span><span class="p">,</span> <span class="n">framework</span><span class="p">::</span><span class="n">Tensor</span><span class="o">*</span> <span class="n">output</span><span class="p">,</span>
                  <span class="nb">int</span> <span class="n">groups</span><span class="p">)</span> <span class="p">{</span>
                  <span class="o">...</span>
                  <span class="p">}</span>
<span class="p">};</span>                  
</pre></div>
</div>
353 354
<p>We first obtain the computing handle from a concrete DeviceContext and then compute on tensors.</p>
<p>The implementation of <code class="docutils literal"><span class="pre">OpKernel</span></code> is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map.</p>
355
<p>Fluid provides different register interfaces in op_registry.h</p>
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
<p>Let&#8217;s take <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/crop_op.cc#L134">Crop</a> operator as an example:</p>
<p>In .cc file:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">REGISTER_OP_CPU_KERNEL</span><span class="p">(</span><span class="n">crop</span><span class="p">,</span> <span class="n">ops</span><span class="p">::</span><span class="n">CropKernel</span><span class="o">&lt;</span><span class="nb">float</span><span class="o">&gt;</span><span class="p">);</span>
<span class="n">REGISTER_OP_CPU_KERNEL</span><span class="p">(</span>
    <span class="n">crop_grad</span><span class="p">,</span> <span class="n">ops</span><span class="p">::</span><span class="n">CropGradKernel</span><span class="o">&lt;</span><span class="n">paddle</span><span class="p">::</span><span class="n">platform</span><span class="p">::</span><span class="n">CPUDeviceContext</span><span class="p">,</span> <span class="nb">float</span><span class="o">&gt;</span><span class="p">);</span>
</pre></div>
</div>
<p>In .cu file:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">REGISTER_OP_CUDA_KERNEL</span><span class="p">(</span><span class="n">crop</span><span class="p">,</span> <span class="n">ops</span><span class="p">::</span><span class="n">CropKernel</span><span class="o">&lt;</span><span class="nb">float</span><span class="o">&gt;</span><span class="p">);</span>
<span class="n">REGISTER_OP_CUDA_KERNEL</span><span class="p">(</span>
    <span class="n">crop_grad</span><span class="p">,</span> <span class="n">ops</span><span class="p">::</span><span class="n">CropGradKernel</span><span class="o">&lt;</span><span class="n">paddle</span><span class="p">::</span><span class="n">platform</span><span class="p">::</span><span class="n">CUDADeviceContext</span><span class="p">,</span> <span class="nb">float</span><span class="o">&gt;</span><span class="p">);</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="advanced-topics-how-to-switch-between-different-device-library">
<span id="advanced-topics-how-to-switch-between-different-device-library"></span><h2>Advanced topics: How to switch between different Device/Library<a class="headerlink" href="#advanced-topics-how-to-switch-between-different-device-library" title="Permalink to this headline"></a></h2>
373
<p>Generally, we will implement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not suitable on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run on GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library.</p>
374
<p>For more details, please refer to following docs:</p>
375
<ul class="simple">
376 377
<li>operator kernel type <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md">doc</a></li>
<li>switch kernel <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md">doc</a></li>
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
</ul>
</div>
</div>


           </div>
          </div>
          <footer>
  

  <hr/>

  <div role="contentinfo">
    <p>
        &copy; Copyright 2016, PaddlePaddle developers.

    </p>
  </div>
  Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  


  

    <script type="text/javascript">
        var DOCUMENTATION_OPTIONS = {
            URL_ROOT:'../',
            VERSION:'',
            COLLAPSE_INDEX:false,
            FILE_SUFFIX:'.html',
            HAS_SOURCE:  true,
            SOURCELINK_SUFFIX: ".txt",
        };
    </script>
      <script type="text/javascript" src="../_static/jquery.js"></script>
      <script type="text/javascript" src="../_static/underscore.js"></script>
      <script type="text/javascript" src="../_static/doctools.js"></script>
      <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
       
  

  
  
    <script type="text/javascript" src="../_static/js/theme.js"></script>
  
  
  <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha384-Tc5IQib027qvyjSMfHjOMaLkfuWVxZxUPnCJA7l2mCWNIpG9mGCD8wGNIcPD7Txa" crossorigin="anonymous"></script>
  <script src="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/js/perfect-scrollbar.jquery.min.js"></script>
  <script src="../_static/js/paddle_doc_init.js"></script> 

</body>
</html>