Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #24731

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
接近 2 年 前同步成功

通知 2320
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 5月 25, 2020 by saxon_zh@saxon_zhGuest

[Paddle-TRT] Different behaviors between blockReduce and warpReduce

Created by: zlsh80826

  1. Paddle Version: develop
  2. GPU/cuda-10.2/cudnn7.6.5
  3. Ubuntu16.04

Hello! There are different behaviors between blockReduceSum and warpReduceSum. warpReduceSum makes each thread in the warp has a copy of summation. However, blockReduceSum only ensures the first warp has the summation, which results in one more block synchronization like here.

I suggest two solutions,

  1. Change the behavior of blockReduceSum to AllReduce, which computes the summation and ensures each thread in the block have the summation copy. We only need to change one line to do this. Change the code val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f); to val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);. By doing this, each warp has a copy of first round result. Thus, all threads in the block have the reduced values after the second warpReduce.
  2. In blockReduceSum, add a condition which lets only the first warp in the block does the second reduction. Because we only need to ensure threadIdx.x has the correct reduced value (SoftmaxKernelWithEltadd is the only caller), so only the first warp needs to do the second reduction. i.e.
if (threadIdx.x < warpSize) {  
    val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
    val = warpReduceSum<T>(val, mask);
}

In my experiments, the first one has higher performance, because we can delete the unnecessary shared memory copy and __syncthreads after calling the blockReduceXXX function.

BTW, the same issue exists in the blockReduceMax too.

指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#24731
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7