Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !27112

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

[Feature] Enhance inplace addto strategy for gradient accumulation in static graph !27112

  • Report abuse
!27112 已合并 9月 07, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007f7e10532ee8>
  • 概览 0
  • 提交 13
  • 变更 25

Created by: zhiqiu

PR types

New features

PR changes

Others

Describe

New feature, support inplace addto strategy for gradient accumulation, which can improve the performance of gradient accumulation.

Background

In back-propagation,if the gradients of a tensor are generated by more than one operation, these gradients should be accumulated before back-propagating to the next layer. For example,

If the forward network contains the following part,

y = conv2d(x, w)
z = add(x, y)

Then in backward,

x_grad_0 = add_grad(z_grad)
x_grad_1 = conv2d_grad(y_grad, ...)
x_grad = sum(x_grad_0, x_grad_1)

Traditionally, if the gradients of a tensor are generated by n operation, then after these gradients generated, a sum operation is used to sum up these gradients. x_grad = sum(x_grad_0, x_grag_1, ..., x_grad_n)

However, we can improve the performance by adds up these gradients one by one in their backward operations.

For example, the backward API of conv in cudnn cudnnConvolutionBackwardData contains a beta arguments, image

which means, dstValue = alpha[0]*result + beta[0]*priorDstValue

So, in the above case, if the gradient x_grad_1 is set to x_grad_0 before conv2d_grad's execution, then we can set beta = 1 and the result x_grad_1 will be exactly x_grad_0 + x_grad_1, which is the gradient accumulation we want to do.

Implementation

  • In the backward stage, when appending grad operators automatically, use several grad_add ops instead of sum op.
    g = sum(g_0, g_1,..., g_n)
 ==>
    g_sum_0 = g_0
    g_sum_1 = grad_add(g_sum_0, g_1)
    g_sum_2 = grad_add(g_sum_1, g_2)
     ...
    g_sum_n = grad_add(g_sum_n-1, g_n)
  • For each grad_add op with the form out = grad_add(left, right)
    • Make the two input tensor and the output tensor share the same memory address.
    • Make the right input of grad_add do addto calculation.
    • Make the grad_add skip execution (since its work is done by addto)

Usage

To train with inplace addto strategy, there are two steps.

  • Set FLAGS_max_inplace_grad_add to a positive number, for example, 8. It means if the number gradients that need to sum up is less than 8, the grad_add will be used.
  • Set build_strategy.enable_addto=True, it enables the inplace addto strategy.

Performance

The performance of ResNet50 with amp enabled and batch_size=128 on V100 single card.

  • before: 1014 img/s
  • after: 1078 img/s, 6.3% imporved.
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!27112
Source branch: github/fork/zhiqiu/feat/imporve_gradient_accumulation
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7