From 440fabb66ddaba9558513f10e7df4ddcdd738e37 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Fri, 21 Feb 2020 15:03:19 +0800 Subject: [PATCH] add recompute document --- .../performance_improving/index_cn.rst | 1 + .../performance_improving/index_en.rst | 1 + .../gpu_training_with_recompute.rst | 160 ++++++++++++++ .../gpu_training_with_recompute_en.rst | 196 ++++++++++++++++++ 4 files changed, 358 insertions(+) create mode 100644 doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute.rst create mode 100644 doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute_en.rst diff --git a/doc/fluid/advanced_guide/performance_improving/index_cn.rst b/doc/fluid/advanced_guide/performance_improving/index_cn.rst index 286a2a561..a8080faf3 100644 --- a/doc/fluid/advanced_guide/performance_improving/index_cn.rst +++ b/doc/fluid/advanced_guide/performance_improving/index_cn.rst @@ -11,5 +11,6 @@ singlenode_training_improving/memory_optimize.rst multinode_training_improving/cpu_train_best_practice.rst multinode_training_improving/dist_training_gpu.rst + multinode_training_improving/gpu_training_with_recompute.rst inference_improving/paddle_tensorrt_infer.md analysis_tools/index_cn.rst diff --git a/doc/fluid/advanced_guide/performance_improving/index_en.rst b/doc/fluid/advanced_guide/performance_improving/index_en.rst index 22a966243..382d0d162 100644 --- a/doc/fluid/advanced_guide/performance_improving/index_en.rst +++ b/doc/fluid/advanced_guide/performance_improving/index_en.rst @@ -7,5 +7,6 @@ Practice Improving multinode_training_improving/cpu_train_best_practice_en.rst + multinode_training_improving/gpu_training_with_recompute_en.rst inference_improving/paddle_tensorrt_infer_en.md analysis_tools/index_en.rst diff --git a/doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute.rst b/doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute.rst new file mode 100644 index 000000000..07b880616 --- /dev/null +++ b/doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute.rst @@ -0,0 +1,160 @@ + +重计算:大Batch训练特性 +============= + +背景 +--------- + +随着训练数据规模的逐渐增加,训练更大、更深的深度学习模型成为一个主流趋势。目前的深度学习模型训练,通常要求保留前向计算的隐层结果,并且需要保存结果的数量会随着模型层数的增加线性增加,这对于目前能够使用的AI芯片的内存大小是个挑战。Forward Recomputation Backpropagation(FRB)可以在额外增加少量计算的情况下,显著增加模型的层数和宽度,同时也可以显著提升模型训练的batch大小。 + +原理 +--------- + +我们知道,深度学习网络的一次训练迭代包含三个步骤: + +- **前向计算**:运行前向算子(Operator) 来计算中间隐层(Variable)的值 +- **反向计算**:运行反向算子来计算参数(Parameter)的梯度 +- **优化**:应用优化算法以更新参数值 + +在前向计算过程中,前向算子会输出大量的中间计算结果,在Paddle中,使用 +Variable来存储这些隐层的中间结果。当模型层数加深时,其数量可达成千上万个, +占据大量的内存。Paddle的 `显存回收机制 `_ +会及时清除无用的中间结果,以节省存储。 +然而,有些中间结果是反向算子的输入,这些Variable必须存储在内存中,直到相应的反向算子计算完毕。 + +举个简单的例子, 我们定义一个由mul算子构成的网络,其前向计算为: + +.. math:: + + y = W_1 * x + + z = W_2 * y + +其中 :math:`x, y, z` 为向量, :math:`W_1, W_2` 为矩阵。容易知道,求 :math:`W_2` 梯度的反向计算为: + +.. math:: + W_{2}^{'} = z^{'} / y + +可以看到反向计算中用到了前向计算生成的变量 :math:`y` ,因此变量 :math:`y` 必须存储在内存中,直到这个反向算子计算完毕。当模型加深时,我们会有大量的“ :math:`y` ”,占据了大量的内存。 + +Forward Recomputation Backpropagation(FRB)的思想是将深度学习网络切分为k个部分(segments)。对每个segment而言:前向计算时,除了小部分必须存储在内存中的Variable外(我们后续会讨论这些特殊Variable),其他中间结果都将被删除;在反向计算中,首先重新计算一遍前向算子,以获得中间结果,再运行反向算子。简而言之,FRB和普通的网络迭代相比,多计算了一遍前向算子。 + +我们把切分网络的变量叫做checkpoints。 +那么问题来了,如何选择checkpoints呢?自从FRB方法提出以来 \ :sup:`[1], [2]`,大量学者在研究这一关键问题。 +我们知道深度学习网络通常是由一个个模块串联得到的,比如ResNet-50由16个block串联而成, +Bert-Large由24个transformer串联而成,以两个子模块中间的变量作为切分点就是一个很好的选择。 +对于非串联的网络(比如含有大量shortcut结构的网络),FRB也支持对其做切分, +只是可能多耗费一点内存(用于存储shortcut的Variable)。 +Mitsuru Kusumoto \ :sup:`[3]` 等提出了一种基于动态规划的算法, +可以根据指定的内存自动搜索合适的checkpoints,支持各种各样的网络结构。 + +下图是由4个fc Layer、3个relu Layer、1个sigmoid Layer和1个log-loss Layer串联而成的一个网络:最左侧为其前向计算流程、中间是普通的前向计算和反向计算流程、最右侧为添加FRB后的前向计算和反向计算流程。其中方框代表算子(Operator),红点代表前向计算的中间结果、蓝点代表checkpoints。 + +.. image:: images/recompute.png + +注:该例子完整代码位于 `source `_ + +添加FRB后,前向计算中需要存储的中间Variable从4个(红点)变为2个(蓝点), +从而节省了这部分内存。当然了,重计算的部分也产生了新的中间变量, +这就需要根据实际情况来做权衡了。这个例子里的网络比较浅,通常来讲, +对层数较深的网络,FRB节省的内存要远多于新增加的内存。 + +使用方法 +--------- + +我们实现了基于Paddle的FRB算法,叫做RecomputeOptimizer, +您可以根据其 `源码 `_ +与 +`文档 `_ +更深入地了解这一算法。我们为用户提供了两个使用RecomputeOptimizer的方法: +直接调用与Fleet API中使用。在单机单卡或者CPU训练中建议您直接调用RecomputeOptimizer, +在多卡训练或者多机训练任务上建议您在Fleet API中使用Recompute。 + +**1. 直接调用** + +直接调用RecomputeOptimizer非常简单,首先要定义一个经典的Optimizer,比如Adam; +然后在外面包一层RecomputeOptimizer;最后设置checkpoints即可。 + +.. code-block:: python + + import paddle.fluid as fluid + # 定义网络 + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + print(input_x) + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + # 定义RecomputeOptimizer + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + # 设置checkpoints + sgd._set_checkpoints([fc_1, pred]) + # 运行优化算法 + sgd.minimize(cost) + +Recompute原则上适用于所有Optimizer。 + +**2. 在Fleet API中使用Recompute** + +`Fleet API `_ +是基于Fluid的分布式计算高层API。在Fleet API中添加RecomputeOptimizer +仅需要2步: + +- 设置dist_strategy.forward_recompute为True; + +- 设置dist_strategy.recompute_checkpoints。 + +.. code-block:: python + + from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy + dist_strategy = DistributedStrategy() + dist_strategy.forward_recompute = True + dist_strategy.recompute_checkpoints=checkpoints + optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) + optimizer.minimize(loss) + +为了帮助您快速地用Fleet API使用Recompute任务,我们提供了一些例子, +并且给出了这些例子的计算速度、效果和显存节省情况: + +- 用Recompute做Bert Fine-tuning: `source `_ + +- 用Recompute做目标检测:开发中. + +Q&A +------- + +- **是否支持带有随机性的Op?** + + 目前Paddle中带随机性的Op有:dropout,Recompute支持 + dropout Operator,可以保证重计算与初次计算结果保持一致。 + +- **有没有更多Recompute的官方例子?** + + 更多Recompute的例子将更新在 `examples `_ + 和 `Fleet `_ 库下,欢迎关注。 + +- **有没有添加checkpoints的建议?** + + 我们建议将子网络连接部分的变量添加为checkpoints,即: + 如果一个变量能将网络完全分为前后两部分,那么建议将其加入checkpoints。 + checkpoints的数目会影响内存的消耗:如果checkpoints很少, + 那么Recompute起的作用有限;如果checkpoints数量过多, + 那么checkpoints本身占用的内存量就较大,内存消耗可能不降反升。 + + 我们后续会添加一个估算内存用量的工具, + 可以对每个Operator运算前后的显存用量做可视化, + 帮助用户定位问题。 + +[1] Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin . Training deep nets with sublinear memory cost. +arXiv preprint, arXiv:1604.06174, 2016. + +[2] Audrunas Gruslys , Rémi Munos , Ivo Danihelka , Marc Lanctot , and Alex Graves. Memory efficient +backpropagation through time. In Advances in Neural Information Processing Systems (NIPS), pages 4125 4133, +2016. + +[3] Kusumoto, Mitsuru, et al. "A Graph Theoretic Framework of Recomputation Algorithms for Memory-Efficient Backpropagation." arXiv preprint arXiv:1905.11722 (2019). diff --git a/doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute_en.rst b/doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute_en.rst new file mode 100644 index 000000000..708ab04ff --- /dev/null +++ b/doc/fluid/advanced_guide/performance_improving/multinode_training_improving/gpu_training_with_recompute_en.rst @@ -0,0 +1,196 @@ + +Recompute: Training with bigger batch size +============= + +Context +--------- + +As the amount of training data increases, training deeper neural network models becomes more and more popular. Current deep-learning training usually keeps the hidden layer outputs in memory during the forward propagation, +and the number of outputs increases linearly with +the increase of the number of model layers, +which becomes a challenge of the memory size +for common devices. + + +Theory +--------- + +As we know, a training process of a deep-learning network contains 3 steps: + +- **Forward Propagation**:Running forward operators and generate temporary variables as output +- **Backward Propagation**:Running backward operators to compute gradients of parameters +- **Optimization**:Applying optimization algorithm to update parameters + +When the model becomes deeper, the number of temporary variables +generated in the forward propagation process can reach tens +of thousands, occupying a large amount of memory. +The `Garbage Collection mechanism `_ +in Paddle can delete useless variables for the sake of saving memory. +However, some variables serve as inputs of backward operators, +they must be kept in memory until particular operator finish. + +Take a simple example, define a network contains two `mul` operators, +the forward propagation works as follows: + +.. math:: + + y = W_1 * x + + z = W_2 * y + +where :math:`x, y, z` are vectors, :math:`W_1, W_2` are matrix。It is easy to conduct that the gradient of :math:`W_2` is: + +.. math:: + W_{2}^{'} = z^{'} / y + +We can see that :math:`y` is used in the backward propagation process, +thus it must be kept in the memory during the whole forward propagation. +When network grows deeper, more 'y's need to be stored, +adding more requirements to the memory. + +Forward Recomputation Backpropagation(FRB) splits a deep network to k segments. +For each segment, in forward propagation, +most of the temporary variables are erased in time, +except for some special variables (we will talk about that later); +in backward propagation, the forward operators will be recomputed +to get these temporary variables before running backward operators. +In short, FBR runs forward operators twice. + +But how to split the network? A deep learning network usually consists +of connecting modules in series: +ResNet-50 contains 16 blocks and Bert-Large contains 24 transformers. +It is a good choice to treat such modules as segments. +The variables among segments are +called as checkpoints. + +The following picture is a network with 4 fc layers, 3 relu layers, +1 sigmoid layer and 1 log-loss layer in series. +The left column is the forward propagation, +the middle column is the normal backward propagation, +and the right column is the FRB. +Rectangular boxes represent the operators, red dots represent +the intermediate variables in forward computation, blue dots +represent checkpoints and arrows represent the dependencies between operators. + +.. image:: images/recompute.png + +Note: the complete source code of this example: `source `_ + +After applying FBR, the forward computation only needs to store +2 variables (the blue dots) instead of 4 variables (the red +dots), saving the corresponding memories. It is notable that +recomputing operators generate new intermediate variables at the same time, +a trade-off needs to be considered in this situation. +While according to our experiments, +FBR usually saves rather than increase the memory load. + +Usage +--------- + +We have implemented the FRB algorithm named "RecomputeOptimizer" +based on Paddle. More information about this algorithm can +be learned by the `source code `_ +and the +`document `_ +of RecomputeOptimizer. + +There are 2 methods to apply RecomputeOptimizer in your Paddle +program: call RecomputeOptimizer directly or use it with Fleet +API. For single-GPU card training or CPU training, we recommend +directly calling; For multi-GPU training, we +recommend using with Fleet API. + +**1. Directly calling** + +Calling RecomputeOptimizer is very easy: first, define a classic +optimizer, such as Adam; second, wrap it with RecomputeOptimizer; +third, set the checkpoints. + +.. code-block:: python + + import paddle.fluid as fluid + # Define the network + def mlp(input_x, input_y, hid_dim=128, label_dim=2): + print(input_x) + fc_1 = fluid.layers.fc(input=input_x, size=hid_dim) + prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return sum_cost, fc_1, prediction + input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + cost, fc_1, pred = mlp(input_x, input_y) + # define RecomputeOptimizer + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + # set checkpoints + sgd._set_checkpoints([fc_1, pred]) + # apply optimization + sgd.minimize(cost) + +In principle, recompute is for all kinds of optimizers in Paddle. + +**2. Using Recompute in Fleet API** + +`Fleet API `_ +is a high-level API for distributed training in Fluid. Adding +RecomputeOptimizer to Fluid takes two steps: + +- set dist_strategy.forward_recompute to True + +- set dist_strategy.recompute_checkpoints + +.. code-block:: python + + from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy + dist_strategy = DistributedStrategy() + dist_strategy.forward_recompute = True + dist_strategy.recompute_checkpoints=checkpoints + optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) + optimizer.minimize(loss) + +We supply some examples of using recompute in Fleet API for users. +We also post corresponding training speed, +test results and memory usages of these examples for reference. + + +- Fine-tuning Bert Large model with recomputing: `source `_ + +- Training object detection models with recomputing:developing. + +Q&A +------- + +- **Does RecomputeOptimizer support operators with random outputs?** + +We currently found that the dropout operator has random results +and RecomputeOptimizer is able to keep the outputs of +first-computation and recomputation consistent. + + +- **Are there more official examples of Recompute?** + + More examples will be updated at `examples `_ +and `Fleet `_ . Feel free to +raise issues if you get any problem with these examples. + +- **How should I set checkpoints?** + +The position of checkpoints is important: +we suggest setting the variable between the sub-model as checkpoints, +that is, set a variable as a checkpoint if it +can separate the network into two parts without short-cut connections. +The number of checkpoints is also important: +too few checkpoints will reduce the memory saved by recomputing while +too many checkpoints will occupy a lot of memory themselves. +We will add a tool to estimate the memory usage with specific checkpoints, +helping users to choose checkpointing variables. + +[1] Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin . Training deep nets with sublinear memory cost. +arXiv preprint, arXiv:1604.06174, 2016. + +[2] Audrunas Gruslys , Rémi Munos , Ivo Danihelka , Marc Lanctot , and Alex Graves. Memory efficient +backpropagation through time. In Advances in Neural Information Processing Systems (NIPS), pages 4125 4133, +2016. + +[3] Kusumoto, Mitsuru, et al. "A Graph Theoretic Framework of Recomputation Algorithms for Memory-Efficient Backpropagation." arXiv preprint arXiv:1905.11722 (2019). -- GitLab