From 3ff9f74886bf24da86b05503a216b6eef0d7fbad Mon Sep 17 00:00:00 2001 From: MaoXianxin Date: Thu, 3 Jun 2021 17:17:56 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8E=E5=90=91=E9=87=8D=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E5=9C=A8OneFlow=E4=B8=AD=E7=9A=84=E5=AE=9E=E7=8E=B0=EF=BC=9A?= =?UTF-8?q?=E4=BB=A5=E6=97=B6=E9=97=B4=E6=8D=A2=E7=A9=BA=E9=97=B4=EF=BC=8C?= =?UTF-8?q?=E5=A4=A7=E5=B9=85=E9=99=8D=E4=BD=8E=E6=98=BE=E5=AD=98=E5=8D=A0?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...76\345\255\230\345\215\240\347\224\250.md" | 368 ++++++++++++++++++ 1 file changed, 368 insertions(+) create mode 100644 "\345\220\216\345\220\221\351\207\215\350\256\241\347\256\227\345\234\250OneFlow\344\270\255\347\232\204\345\256\236\347\216\260\357\274\232\344\273\245\346\227\266\351\227\264\346\215\242\347\251\272\351\227\264\357\274\214\345\244\247\345\271\205\351\231\215\344\275\216\346\230\276\345\255\230\345\215\240\347\224\250.md" diff --git "a/\345\220\216\345\220\221\351\207\215\350\256\241\347\256\227\345\234\250OneFlow\344\270\255\347\232\204\345\256\236\347\216\260\357\274\232\344\273\245\346\227\266\351\227\264\346\215\242\347\251\272\351\227\264\357\274\214\345\244\247\345\271\205\351\231\215\344\275\216\346\230\276\345\255\230\345\215\240\347\224\250.md" "b/\345\220\216\345\220\221\351\207\215\350\256\241\347\256\227\345\234\250OneFlow\344\270\255\347\232\204\345\256\236\347\216\260\357\274\232\344\273\245\346\227\266\351\227\264\346\215\242\347\251\272\351\227\264\357\274\214\345\244\247\345\271\205\351\231\215\344\275\216\346\230\276\345\255\230\345\215\240\347\224\250.md" new file mode 100644 index 0000000..f4835e0 --- /dev/null +++ "b/\345\220\216\345\220\221\351\207\215\350\256\241\347\256\227\345\234\250OneFlow\344\270\255\347\232\204\345\256\236\347\216\260\357\274\232\344\273\245\346\227\266\351\227\264\346\215\242\347\251\272\351\227\264\357\274\214\345\244\247\345\271\205\351\231\215\344\275\216\346\230\276\345\255\230\345\215\240\347\224\250.md" @@ -0,0 +1,368 @@ +# 后向重计算在OneFlow中的实现:以时间换空间,大幅降低显存占用 + +2016年,陈天奇团队提出了亚线性内存优化相关的“gradient/activation checkpointing(后向重计算)”等技术[1],旨在降低深度学习训练过程中的中间激活(activation)带来的显存占用。Checkpointing技术属于亚线性内存优化的一种,除此之外还有CPU offload等技术(CPU offload在微软Deepspeed框架中被广泛使用)。 + +![](https://maoxianxin1996.oss-accelerate.aliyuncs.com/codechina/20210603170530.png) + +CPU offload将暂时用不到的GPU内存换入到CPU内存中存储,待需要时再取出,主要开销是来自于CPU和GPU之间的拷贝,会占用传输带宽(PCIE带宽),属于以传输换空间。而Checkpointing核心在于以时间换空间:**通过计算图分析技术来实施Inplace操作以及内存共享优化(Memory sharing),在每个mini-batch的前向过程中删除一些暂时用不到的中间激活特征以降低内存占用,并在后向过程中需要时借助额外的前向计算恢复它们。** + +> 在OneFlow中,Checkpointing的实现主要是通过静态内存复用的方式,前向Tensor的生命周期结束后,其余Tensor可以复用这块内存,从而起到内存复用、节省内存的效果。 + +OneFlow目前支持了“gradient/activation checkpointing”(后向重计算)以实现亚线性内存优化,且对算法开发者非常友好,使用方式很简单:**针对需要优化的网络部分,用一行代码将其包裹在“Checkpointing”的scope范围内即可,系统内部会针对此scope区域内的网络做分析并在训练过程中自动进行Checkpointing内存优化。** + +本文主要内容为以下3点: + +- 1.亚线性内存优化的用法 +- 2.亚线性内存优化的设计 +- 3.代码解读 + +其中:1.将介绍如何在OneFlow中开启试用亚线性内存优化;2.将介绍OneFlow中亚线性内存优化是如何设计的及其原理;3.将从代码入手,剖析具体实现过程。 + +## 亚线性内存优化的用法 + +OneFlow中开启亚线性内存优化的方式如下: + +``` +# 用法: +with flow.experimental.scope.config(checkpointing=True): + # your net work, such as : + # input layernorm + norm1 = layernorm("layernorm_1", h) + # attention + h = h + self.attn(norm1) + # output layernorm + norm2 = layernorm("layernorm_2", h) + # mlp + h = h + self.mlp(norm2) +``` + +用上述代码包裹后,此scope区域内的网络,在整个前向过程中只会保存一份input tensor的内存,从input到最后输出h,这之间所有中间特征tensor的内存都不会被保存,后向过程需要时从input开始进行(前向的)重计算。 + +我们在多个网络上进行了开启/关闭checkpointing的显存占用测试,以GPT-2为例,具体是在每个Transformer Layer内都使用`checkpointing = True` scope标记重计算的部分。 + +![](https://maoxianxin1996.oss-accelerate.aliyuncs.com/codechina/20210603170637.png) + +可以看见,开启checkpointing后会大幅降低GPT-2训练时的显存占用,在batch size = 4 时,内存节省超过50+%。 + +## 亚线性内存优化的设计 + +在系列文章《深度解析:让你掌握OneFlow框架的系统设计(上篇、中篇、下篇)》中,我们介绍了OneFlow中的OpNode/OpGragh抽象以及建立在这之上的Actor、SBP抽象等系统设计,正是这些良好的系统设计和抽象使得OneFlow在多种任务下都有着优秀的表现。 + +OneFlow的Job任务在逻辑图编译期会基于由OpNode构成的Job逻辑图(OpGragh),进行一系列pass的系统优化过程,每个pass对逻辑图进行了一次图修改/重写(对逻辑图中的节点和连边进行了增删操作),这些优化对性能的提升至关重要。 + +Activation Checkpointing 在OneFlow中的实现,也是通过一个Checkpointing的pass对Job逻辑图实现修改/重写来实现的(见[https://github.com/Oneflow-Inc/oneflow/pull/3976](https://github.com/Oneflow-Inc/oneflow/pull/3976))。 + +![](https://maoxianxin1996.oss-accelerate.aliyuncs.com/codechina/20210603170743.png) + +
+主要原理 +
+ +如图所示: + +1.上半部分为正常情况下的逻辑子图。T1、T2为Transformer Layer的前向计算部分、子图中每个op计算完成后得到的中间激活特征将持续占用内存,当计算进行到反向时(T1_grad、T2_grad),再利用这些中间激活进行反向的计算; + +2.下半部分为开启Activation Checkpointing后的逻辑子图。可以看到,中间部分增加了虚线框住,用于重计算的fake子图,由于fake子图的存在,正常forward子图在进行前向时,就无须保存中间激活了,当backward计算需要用到时,再临时根据fake子图进行前向的重计算。 + +在OneFlow中,Activation Checkpointing的细节流程如下: + +1.收集checkpointing作用域包裹下的所有前向pass下的ops + +2.收集ops下所有的子图subgraphs + +3.遍历子图subgraphs,并对所有需要做后向的subgraph做如下操作: + +- 生成fake子图,并将其作为后向消费者的输入(而不是真实子图) +- 在fake子图中增加由end op连向所有源节点source nodes的控制边 +- 将fake子图添加至job builder(被其管理) + +4.在job builder中更新所有后向消费者ops + +**代码实现:** + +``` +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job/checkpointing_config_def.cpp +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job/job_build_and_infer_ctx.cpp#L989 +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp +``` + +## 代码解读 + +### 收集所有前向pass下的ops + +由于activation checkpointing的设计主要是节省前向计算过程中的内存,即在checkpointing作用范围内,将前向计算过程中op nodes产生的activation显存释放掉,后向backward时,重新进行此部分的前向计算,得到所需的activation。 + +由于我们所有的操作是在逻辑图层面,操作的对象为每个op node节点,所以首先需要标记、筛选出checkpointing作用范围内所有前向的op nodes。此部分主要通过CollectAllCheckpointingOpsInForwardPass()方法实现: + +``` +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L65 +``` + +``` +void CollectAllCheckpointingOpsInForwardPass( + // 收集所有属于前向pass下,且符合条件的op nodes,存放至HashMap中 + const OpGraph& op_graph, HashMap* checkpointing_op_name2op_node) { + // NOTE(chengcheng): + // ignore batch_norm ops because of recompute bn will repeat the calculation of 'm' and 'v'. + // in the future, we need to support the recomputation version of batch_norm which do NOT + // update forward variables. + HashSet ignore_op_type_names = {"normalization", "normalization_add_relu", + "cudnn_fused_normalization_add_relu"}; + op_graph.ForEachNode([&](const OpNode* op_node) { + const OperatorConf& op_conf = op_node->op().op_conf(); + // 跳过不包含user_conf以及ignore_op_type_names指定的op_node + if (!op_conf.has_user_conf()) { return; } + if (ignore_op_type_names.find(op_conf.user_conf().op_type_name()) + != ignore_op_type_names.end()) { + return; + } + // 对scope范围内开启了checkpointing且标记为ForwardPass的op_node,则为目标node,将其插入HashMap中 + if (IsForwardPass7CheckpointingScope(Scope4OpNode(op_node))) { + CHECK(checkpointing_op_name2op_node->emplace(op_conf.name(), op_node).second); + } + }); +} +``` + +其中,主要通过`IsForwardPass7CheckpointingScope()`方法,来对符合条件的op node进行筛选: + +``` +bool IsForwardPassScope(const Scope& scope) { + // scope中,calculation_pass_name属性为kForwardPass的node,则为参与前向计算的目标node + return scope.scope_proto().calculation_pass_name() == kForwardPass; +} + +bool IsForwardPass7CheckpointingScope(const Scope& scope) { + // True if 属性为kForwardPass的node且scope开启了checkpointing + return IsForwardPassScope(scope) && scope.Bool("checkpointing"); +} +``` + +`IsForwardPass7CheckpointingScope()`方法通过node的scope来判断该op node是否属于直接参与前向计算的node(scope中包含kForwardPass),且是否开启了“checkpointing”,同时满足则为目标node,将其插入hashmap(checkpointing_op_name2op_node)中。 + +### 收集ops下所有的subgraphs + +筛选出checkpointing作用区域内所有的op nodes后,需要根据这些nodes生成所有子图subgraghs,这些子图有些是和后向重计算无关、有些则是后向重计算所需的目标子图,它们的输出作为后向op node的输入被消费,这些子图是实现activation checkpointing设计中前向重计算的最小单位。 + +生成子图的代码如下: + +``` +// 根据ops生成所有subgraphs子图,并将其存放在vector中 +// step 2. get all connected subgraphs in checkpointing ops. +std::vector> checkpointing_subgraphs; +GenConnectedCheckpointingSubgraphs(checkpointing_op_name2op_node, &checkpointing_subgraphs); +``` + +其中,主要通过GenConnectedCheckpointingSubgraphs()方法生成subgraphs: + +``` +void GenConnectedCheckpointingSubgraphs( + // 生成Subgraphs子图 + const HashMap& checkpointing_op_name2op_node, + std::vector>* checkpointing_subgraphs) { + HashSet visited_nodes; + for (const auto& pair : checkpointing_op_name2op_node) { + const OpNode* node = pair.second; + if (visited_nodes.find(node) != visited_nodes.end()) { continue; } + + // new subgraph + checkpointing_subgraphs->push_back(HashSet()); + CHECK(!checkpointing_subgraphs->empty()); + auto& subgraph = checkpointing_subgraphs->back(); + CHECK(subgraph.empty()); + + // bfs search all node in checkpointing ops + CHECK(visited_nodes.insert(node).second); + std::queue queued_nodes; + queued_nodes.push(node); + while (!queued_nodes.empty()) { + const OpNode* cur_node = queued_nodes.front(); + queued_nodes.pop(); + + CHECK(subgraph.insert(cur_node).second); + cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { + const std::string& next_op_name = next_node->op().op_name(); + if (checkpointing_op_name2op_node.find(next_op_name) != checkpointing_op_name2op_node.end() + && cur_node->parallel_desc() == next_node->parallel_desc() + && visited_nodes.find(next_node) == visited_nodes.end()) { + queued_nodes.push(next_node); + CHECK(visited_nodes.insert(next_node).second); + } + }); + } + } +} +``` + +根据当前节点(cur_node),找到下一个子图节点(next_node),采用的是BFS搜索,搜索逻辑为:以cur_node为起点,遍历其输入/输入边上有消费关系的节点next_node;对于不属于checkpointing op && 没有被当作子图node访问过 && 并行方式和cur_node一致的node;作为subgraph中的目标node(next_node),插入subgraph队列中,并将该node标记为已访问,放置到visited_nodes Set中。 + +### 遍历子图subgraphs + +经历上述过程后生成了子图vector(),我们需要对其进行遍历,筛选出和activation checkpointing相关的子图subgraghs,并做如下几件事: + +- 生成fake子图,并将其作为后向消费者的输入(而不是真实子图) +- 在fake子图中增加由end op连向所有源节点source nodes的控制边 +- 将fake子图添加至job builder(被其管理) + +对子图subgraghs的遍历主要在: + +``` +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L148-L290 +``` + +**过滤和activation checkpointing无关的subgragh子图** + +在`[for (auto& subgraph : checkpointing_subgraphs) {}]()`遍历循环的一开始,就会跳过不符合activation checkpointing条件的subgragh + +``` +for (auto& subgraph : checkpointing_subgraphs) { + // step 3.1 ignore this subgraph if there is no direct edge to backward pass op. + HashSet bw_consumers; + for (const OpNode* node : subgraph) { + node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + if (!IsForwardPassScope(Scope4OpNode(out_node))) { + bw_consumers.insert(out_node); + CHECK(subgraph.find(out_node) == subgraph.end()); + } + }); + } + if (bw_consumers.empty()) { continue; } +``` + +具体条件,即遍历subgraph子图中的所有node节点、判断node节点的所有出边out_edges是否有出边连到后向backward消费者op,如果subgraph中所有节点均没有连到后向backward消费者,则跳过该子图(表明该子图只有只和forward有关而和backward无关,即不是activation checkpointing优化的目标子图。 + +**生成fake子图,并将其作为后向消费者的输入(而不是真实子图)** + +过滤掉无效子图后,对于和activation checkpointing直接相关的目标子图,我们需要生成fake子图,其中的每个节点由fake op构成。 + +**fake子图,即重计算的最小单位,其作用即用于取代原有真实的子图、并在后面替换这些真实子图,用于被后向op nodes消费。通过将fake子图中fake op的scope属性从kForwardPass变为kBackwardPass,实现当计算进行到该fake op时,重新运行前向计算以产生backward所需的activation数据。** + +生成fake 子图的主要代码在:L168-L222 + +``` +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L168-L222 +``` + +生成fake子图后,将其设置为backward消费者的输入,主要代码如下: + +``` +const OpNode* first_bw_consumer = nullptr; + int32_t first_bw_order = std::numeric_limits::max(); + // 将backward消费者的input更改为fake子图op(而不是真实子图) + // step 3.3 change bw consumers input from subgraph to fake subgraph + for (const OpNode* node : bw_consumers) { + std::string bw_consumer_name = node->op().op_name(); + OperatorConf bw_consumer_op_conf; + // NOTE(chengcheng): + // reuse bw conumer op conf if it has been existed in map. + if (total_bw_consumers_op_name2conf.find(bw_consumer_name) + != total_bw_consumers_op_name2conf.end()) { + bw_consumer_op_conf = total_bw_consumers_op_name2conf.at(bw_consumer_name); + } else { + bw_consumer_op_conf = node->op().op_conf(); + } + CHECK_EQ(bw_consumer_name, bw_consumer_op_conf.name()); + + auto* user_conf = bw_consumer_op_conf.mutable_user_conf(); + // 修改和subgragh相关的backward op输入的blob name + // change input lbns if in subgraph + for (auto& pair : *(user_conf->mutable_input())) { + auto& list_s = pair.second; + for (int i = 0; i < list_s.s_size(); ++i) { + std::string old_lbn = list_s.s(i); + LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); + + std::string old_input_op_name = old_lbi.op_name(); + if (subgraph_op_name2op_node.find(old_input_op_name) != subgraph_op_name2op_node.end()) { + list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn); + } + } + } + + // NOTE(chengcheng): + // emplace maybe repeated, so do not check the return value + total_bw_consumers_op_name2conf.emplace(bw_consumer_name, bw_consumer_op_conf); + + CHECK(op_node2order.find(node) != op_node2order.end()); + int32_t this_order = op_node2order.at(node); + if (this_order < first_bw_order) { + first_bw_consumer = node; + first_bw_order = this_order; + } + } +``` + +**在fake子图中为所有source node—end node添加控制边** + +这一步操作的目的主要将子图subgraph中所有和backward op相连的node(source node),添加一条控制边。控制边的添加是人为控制node间执行的时序,控制边保证了fake子图的计算尽可能晚的发生,这样才能缩短生命周期,保证内存复用的效率。 + +添加控制边相关的代码在L267-L284: + +``` +https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/checkpointing_pass.cpp#L267-L284 +``` + +``` +// step 3.4 add control edge from End Op to all source node in fake subgraph + CHECK(first_bw_consumer != nullptr); + std::string end_op_name = kCheckpointingBadOpName; + int32_t end_order = -1; + first_bw_consumer->ForEachNodeOnInEdge([&](const OpNode* end_node) { + CHECK(op_node2order.find(end_node) != op_node2order.end()); + int32_t this_order = op_node2order.at(end_node); + if (this_order > end_order) { + end_order = this_order; + end_op_name = end_node->op().op_name(); + } + }); + CHECK_NE(end_order, -1); + CHECK_NE(end_op_name, kCheckpointingBadOpName); + CHECK_LT(end_order, first_bw_order); + for (const auto& source_op_name : source_node_in_fake_subgraph) { + fake_op_name2conf.at(source_op_name).add_ctrl_in_op_name(end_op_name); + } +``` + +**将fake子图添加至job build(被其管理)** + +fake subgraphs的生成、以及控制边的添加,实际上对原有job逻辑图产生了改动。改动后,需要将fake subgraphs中这些新生成的fake op nodes添加至job builder管理,正式完成了逻辑图的图改写。 + +主要代码如下: + +``` +// 将fake subgraph所包含的ops加入至job_builder管理(图改写) +// step 3.5 add fake subgraph ops to job builder +std::vector fake_op_confs; +for (auto& pair : fake_op_name2conf) { fake_op_confs.push_back(pair.second); } +job_builder->AddOps(parallel_conf, fake_op_confs); +``` + +### 更新所有后向消费者ops + +最后,由于fake op nodes更新了backward op nodes的输入输出等属性,需要将更新后的backward op nodes同步至job_builder管理: + +``` +// 在job builder中更新所有backward ops +// step 4. update bw consumers in job builder only once +std::vector total_bw_consumer_op_confs; +for (auto& pair : total_bw_consumers_op_name2conf) { + total_bw_consumer_op_confs.push_back(pair.second); +} +job_builder->MutOpsOnlyOnce(total_bw_consumer_op_confs); + +return Maybe::Ok(); +``` + +至此,通过这些fake subgraphs的插入、输入输出连边的变动等完成了整个job逻辑图的改写,改写后的逻辑图执行时即自动支持了activation checkpointing。 + +OneFlow最近复现了GPT-3相关的工作,其中就使用了activation checkpointing的技术,**代码在OneFlow-Benchmark已开源,欢迎在GitHub下载试用:** + +``` +https://github.com/Oneflow-Inc/OneFlow-Benchmark/tree/master/LanguageModeling/GPT +``` + +注:题图源自insspirito,pixabay + +参考文献 + +[1] Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training Deep Nets with Sublinear Memory Cost. arXiv preprint arXiv:1604.06174, 2016. \ No newline at end of file -- GitLab