paddle_nccl.html 19.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126


<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>Design Doc: NCCL support in Paddle Fluid &mdash; PaddlePaddle  文档</title>
  

  
  

  

  
  
    

  

  
  
    <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  

  
  
        <link rel="index" title="索引"
              href="../genindex.html"/>
        <link rel="search" title="搜索" href="../search.html"/>
    <link rel="top" title="PaddlePaddle  文档" href="../index.html"/> 

  <link rel="stylesheet" href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" type="text/css" />
  <link rel="stylesheet" href="../_static/css/override.css" type="text/css" />
  <script>
  var _hmt = _hmt || [];
  (function() {
    var hm = document.createElement("script");
    hm.src = "//hm.baidu.com/hm.js?b9a314ab40d04d805655aab1deee08ba";
    var s = document.getElementsByTagName("script")[0]; 
    s.parentNode.insertBefore(hm, s);
  })();
  </script>

  

  
  <script src="../_static/js/modernizr.min.js"></script>

</head>

<body class="wy-body-for-nav" role="document">

  
  <header class="site-header">
    <div class="site-logo">
      <a href="/"><img src="../_static/images/PP_w.png"></a>
    </div>
    <div class="site-nav-links">
      <div class="site-menu">
        <a class="fork-on-github" href="https://github.com/PaddlePaddle/Paddle" target="_blank"><i class="fa fa-github"></i>Fork me on Github</a>
        <div class="language-switcher dropdown">
          <a type="button" data-toggle="dropdown">
            <span>English</span>
            <i class="fa fa-angle-up"></i>
            <i class="fa fa-angle-down"></i>
          </a>
          <ul class="dropdown-menu">
            <li><a href="/doc_cn">中文</a></li>
            <li><a href="/doc">English</a></li>
          </ul>
        </div>
        <ul class="site-page-links">
          <li><a href="/">Home</a></li>
        </ul>
      </div>
      <div class="doc-module">
        
        <ul>
<li class="toctree-l1"><a class="reference internal" href="../getstarted/index_cn.html">新手入门</a></li>
<li class="toctree-l1"><a class="reference internal" href="../howto/index_cn.html">进阶指南</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/index_cn.html">API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../faq/index_cn.html">FAQ</a></li>
</ul>

        
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>        
      </div>
    </div>
  </header>
  
  <div class="main-content-wrap">

    
    <nav class="doc-menu-vertical" role="navigation">
        
          
          <ul>
<li class="toctree-l1"><a class="reference internal" href="../getstarted/index_cn.html">新手入门</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../getstarted/build_and_install/index_cn.html">安装与编译</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../getstarted/build_and_install/pip_install_cn.html">使用pip安装</a></li>
<li class="toctree-l3"><a class="reference internal" href="../getstarted/build_and_install/docker_install_cn.html">使用Docker安装运行</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/dev/build_cn.html">用Docker编译和测试PaddlePaddle</a></li>
<li class="toctree-l3"><a class="reference internal" href="../getstarted/build_and_install/build_from_source_cn.html">从源码编译</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../getstarted/concepts/use_concepts_cn.html">基本使用概念</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../howto/index_cn.html">进阶指南</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../howto/usage/cmd_parameter/index_cn.html">设置命令行参数</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cmd_parameter/use_case_cn.html">使用案例</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cmd_parameter/arguments_cn.html">参数概述</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cmd_parameter/detail_introduction_cn.html">细节描述</a></li>
</ul>
</li>
127 128 129 130 131 132
<li class="toctree-l2"><a class="reference internal" href="../howto/usage/cluster/cluster_train_cn.html">分布式训练</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cluster/fabric_cn.html">fabric集群</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cluster/openmpi_cn.html">openmpi集群</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cluster/k8s_cn.html">kubernetes单机</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cluster/k8s_distributed_cn.html">kubernetes distributed分布式</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/cluster/k8s_aws_cn.html">AWS上运行kubernetes集群训练</a></li>
133 134
</ul>
</li>
135 136 137 138 139 140
<li class="toctree-l2"><a class="reference internal" href="../howto/usage/capi/index_cn.html">PaddlePaddle C-API</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/capi/compile_paddle_lib_cn.html">编译 PaddlePaddle 预测库</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/capi/organization_of_the_inputs_cn.html">输入/输出数据组织</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/usage/capi/workflow_of_capi_cn.html">C-API 使用流程</a></li>
</ul>
</li>
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
<li class="toctree-l2"><a class="reference internal" href="../howto/dev/contribute_to_paddle_cn.html">如何贡献代码</a></li>
<li class="toctree-l2"><a class="reference internal" href="../howto/dev/write_docs_cn.html">如何贡献/修改文档</a></li>
<li class="toctree-l2"><a class="reference internal" href="../howto/deep_model/rnn/index_cn.html">RNN相关模型</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../howto/deep_model/rnn/rnn_config_cn.html">RNN配置</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/deep_model/rnn/recurrent_group_cn.html">Recurrent Group教程</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/deep_model/rnn/hierarchical_layer_cn.html">支持双层序列作为输入的Layer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../howto/deep_model/rnn/hrnn_rnn_api_compare_cn.html">单双层RNN API对比介绍</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../howto/optimization/gpu_profiling_cn.html">GPU性能分析与调优</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../api/index_cn.html">API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../api/v2/model_configs.html">模型配置</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/activation.html">Activation</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/layer.html">Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/evaluators.html">Evaluators</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/optimizer.html">Optimizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/pooling.html">Pooling</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/networks.html">Networks</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/config/attr.html">Parameter Attribute</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../api/v2/data.html">数据访问</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/data/data_reader.html">Data Reader Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/data/image.html">Image Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/data/dataset.html">Dataset</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../api/v2/run_logic.html">训练与应用</a></li>
171
<li class="toctree-l2"><a class="reference internal" href="../api/v2/fluid.html">Fluid</a><ul>
172 173 174 175 176 177 178 179 180 181 182
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/layers.html">layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/data_feeder.html">data_feeder</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/executor.html">executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/initializer.html">initializer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/evaluator.html">evaluator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/nets.html">nets</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/optimizer.html">optimizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/param_attr.html">param_attr</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/profiler.html">profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/regularizer.html">regularizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../api/v2/fluid/io.html">io</a></li>
183 184
</ul>
</li>
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../faq/index_cn.html">FAQ</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../faq/build_and_install/index_cn.html">编译安装与单元测试</a></li>
<li class="toctree-l2"><a class="reference internal" href="../faq/model/index_cn.html">模型配置</a></li>
<li class="toctree-l2"><a class="reference internal" href="../faq/parameter/index_cn.html">参数设置</a></li>
<li class="toctree-l2"><a class="reference internal" href="../faq/local/index_cn.html">本地训练与预测</a></li>
<li class="toctree-l2"><a class="reference internal" href="../faq/cluster/index_cn.html">集群训练与预测</a></li>
</ul>
</li>
</ul>

        
    </nav>
    
    <section class="doc-content-wrap">

      

 







<div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
      
    <li>Design Doc: NCCL support in Paddle Fluid</li>
  </ul>
</div>
      
      <div class="wy-nav-content" id="doc-content">
        <div class="rst-content">
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="section" id="design-doc-nccl-support-in-paddle-fluid">
<span id="design-doc-nccl-support-in-paddle-fluid"></span><h1>Design Doc: NCCL support in Paddle Fluid<a class="headerlink" href="#design-doc-nccl-support-in-paddle-fluid" title="永久链接至标题"></a></h1>
<div class="section" id="abstract">
<span id="abstract"></span><h2>Abstract<a class="headerlink" href="#abstract" title="永久链接至标题"></a></h2>
<p>This Design Doc refers to the NCCL feature in  paddle.  We propose an approach to support NCCL library both on a single machine and multiple machines. We wrapper the NCCL primitives <code class="docutils literal"><span class="pre">Broadcast</span></code>, <code class="docutils literal"><span class="pre">Allreduce</span></code>, <code class="docutils literal"><span class="pre">Reduce</span></code> as operators to utilize Multi-GPU powers in one script.</p>
</div>
<div class="section" id="motivation">
<span id="motivation"></span><h2>Motivation<a class="headerlink" href="#motivation" title="永久链接至标题"></a></h2>
<p><a class="reference external" href="https://developer.nvidia.com/nccl">NCCL</a> is a NVIDIA library support Multi-GPU communicating and optimized for NVIDIA GPUs, it provides routines such as all-gather, all-reduce, broadcast, reduce, reduce-scatter, that can achieve high bandwidth over PCIe and NVLink high-speed interconnect. With NCCL library, we can easily accelerate the training in parallel.</p>
<ul class="simple">
<li>Pros</li>
</ul>
<ol class="simple">
<li>easily plug-in with <a class="reference external" href="https://developer.nvidia.com/nccl">NCCL2</a> library.</li>
<li>high performance in NVIDIA GPUs.</li>
<li>MPI like primitives, which have low learning cost for users.</li>
</ol>
<ul class="simple">
<li>Cons</li>
</ul>
<ol class="simple">
<li>Only design for NVIDIA GPUs, not a general multi-device solution.</li>
<li>Although NCCL1 is opensourced under BSD license, but NCCL2 is not opensourced anymore.</li>
</ol>
<p>At the beginning of training, the framework needs to distribute the same parameters to every GPU, and merge the gradients at any time user interests.</p>
<p>As a result, during training, we need the operations of peer to peer copy between different GPUs, aggregating gradients/parameters from GPUs, and broadcasting parameters to GPUs. Every GPU only need to run the operator with correct place information.</p>
<p>Besides, it needs interfaces to synchronize model update with each different GPU Cards.</p>
</div>
<div class="section" id="implementation">
<span id="implementation"></span><h2>Implementation<a class="headerlink" href="#implementation" title="永久链接至标题"></a></h2>
<p>As mentioned above, we wrap the NCCL routines as several kinds of operators. Need to note that NCCL need to create Communicator between gpu at the beginning, so there is a NCCLInit operator created.</p>
<div class="section" id="transpiler">
<span id="transpiler"></span><h3>Transpiler<a class="headerlink" href="#transpiler" title="永久链接至标题"></a></h3>
<p>To be compatible with <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/ops/dist_train.md">parameter server design doc</a>, the transpiler compiles the user defined operation graph into sub-graphs to be executed on different devices.</p>
<ol>
<li><p class="first">The user-defined model will be a single device program</p>
</li>
<li><p class="first">Broadcast/Reduce operators between GPUs will be inserted into the program, even for the multi-node, may insert the <code class="docutils literal"><span class="pre">Send</span></code>, <code class="docutils literal"><span class="pre">Recv</span></code> operator.</p>
<p><em>Broadcast, AllReduce in a single machine. And Broadcast, AllReduce, <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/ops/dist_train.md#graph-converter">Send, Recv</a> in multiple machines</em></p>
<p><img src="images/multigpu_before_convert.png" width="300"/></p>
</li>
</ol>
<p>After compiling, the graph as shows</p>
<p><img src="images/multigpu_allreduce.png" width="1000"/></p>
<p>Operators are added to the sub-graphs. Every GPU assigned a role of <code class="docutils literal"><span class="pre">rank0</span></code>, <code class="docutils literal"><span class="pre">rank1</span></code> etc.</p>
<ul class="simple">
<li><strong>Broadcast</strong>. Broadcast operator distribute initialized parameter to all the GPUs from the GPU who owns it. e.g. from<code class="docutils literal"><span class="pre">rank0</span></code> GPU.</li>
<li><strong>AllReduce</strong>. AllReduce operator synchronizes parameters/gradients between GPUs. AllReduce implemented in the Ring-Based  communicating method, avoid of the bottle neck in a single GPU.</li>
</ul>
<p>Need to notice that AllReduce operator force GPUs synchronized at that point. The whole training process in asynchronous or synchronous mode depends on the AllReduce point in the graph.</p>
<p>As it shown in the picture, when each GPU compute the gradient of <code class="docutils literal"><span class="pre">W</span></code>, followed with a <code class="docutils literal"><span class="pre">AllReduce</span></code> operator, accumulate the <code class="docutils literal"><span class="pre">dW</span></code> to full batch of data, then run the optimize process individually and apply the gradient to its <code class="docutils literal"><span class="pre">W</span></code>.</p>
<ul class="simple">
<li><strong>AllReduce</strong>
Need to note that our AllReduce operator is a ring-base AllReduce implementation. If we use the NCCL2 AllReduce primitive, every GPU optimized full batch of data, wasted (n-1) GPU compute resources. In addition, NCCL2 built-in AllReduce will only utilize the communicating resource during synchronization, then update the gradient will be a subsequent phase. In fact, we can amortize the update gradient time cost into the communicating phase. The process is</li>
</ul>
<ol class="simple">
<li>Every parameter has its root card. That card will responsible for aggregating the gradients from GPUs.</li>
<li>The whole model&#8217;s parameter will be hashed to different root card, ensure the load balance between GPUs.</li>
<li>Logically neighberhood card will start send parameter to the next one. After one round, the parameter main card will aggregate the full gradients.</li>
<li>Then the root card will optimize the parameter.</li>
<li>This parameter card will send its optimized result to its neighberhood, then the neighberhood will send parameter to its next one.</li>
<li>Finish the sychronization round.</li>
</ol>
<p>The total time cost will be 2 * (n-1) * per-parameter-send-time, we reach the goal of amortize the upgrade time into communicating phase.</p>
</div>
</div>
</div>


           </div>
          </div>
          <footer>
  

  <hr/>

  <div role="contentinfo">
    <p>
        &copy; Copyright 2016, PaddlePaddle developers.

    </p>
  </div>
  Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  


  

    <script type="text/javascript">
        var DOCUMENTATION_OPTIONS = {
            URL_ROOT:'../',
            VERSION:'',
            COLLAPSE_INDEX:false,
            FILE_SUFFIX:'.html',
            HAS_SOURCE:  true,
            SOURCELINK_SUFFIX: ".txt",
        };
    </script>
      <script type="text/javascript" src="../_static/jquery.js"></script>
      <script type="text/javascript" src="../_static/underscore.js"></script>
      <script type="text/javascript" src="../_static/doctools.js"></script>
      <script type="text/javascript" src="../_static/translations.js"></script>
      <script type="text/javascript" src="https://cdn.bootcss.com/mathjax/2.7.0/MathJax.js"></script>
       
  

  
  
    <script type="text/javascript" src="../_static/js/theme.js"></script>
  
  
  <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha384-Tc5IQib027qvyjSMfHjOMaLkfuWVxZxUPnCJA7l2mCWNIpG9mGCD8wGNIcPD7Txa" crossorigin="anonymous"></script>
  <script src="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/js/perfect-scrollbar.jquery.min.js"></script>
  <script src="../_static/js/paddle_doc_init.js"></script> 

</body>
</html>