Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #10328

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 5月 02, 2018 by saxon_zh@saxon_zhGuest

Protential memcpy and malloc error on CUDA

Created by: reyoung

According to the manual of CUDA, cudaMallocHost is the suggested API for allocating CPU memory to data exchange between GPU and CPU.

Allocates size bytes of host memory that is page-locked and accessible to the device. The driver tracks the virtual memory ranges allocated with this function and automatically accelerates calls to functions such as cudaMemcpy*(). Since the memory can be accessed directly by the device, it can be read or written with much higher bandwidth than pageable memory obtained with functions such as malloc(). Allocating excessive amounts of memory with cudaMallocHost() may degrade system performance, since it reduces the amount of memory available to the system for paging. As a result, this function is best used sparingly to allocate staging areas for data exchange between host and device.

We should use cudaMallcHost for CPUAllocator, because

  1. We use BuddyAllocator to wrap device allocator in memory. Even cudaMallocHost is much slower than malloc, we just invoke cudaMallocHost once to alloc memory buffer and use buddy allocator to manage the rest allocating.
  2. All CPU memory of tensor mlock to physical memory. This is our current memory strategy. Introducing cudaMallocHost will not change this strategy. https://github.com/PaddlePaddle/Paddle/blob/9a8be9dacac4e1286482e3e0289985a19382beb9/paddle/fluid/memory/detail/system_allocator.cc#L59-L60

However, when we change the malloc to cudaMallocHost, we found several bugs in Fluid. The PR is #10136.

  1. there are many places in our code, which invoke cudaMemcpyAsync, are not actually wait or synchronized.

    • When using malloc and cudaMemcpyAsync, it will copy the CPU memory into a staging area of CUDA driver and then invoke cudaMemcpyAsync to that staging area. It is a synchronized method in CPU and it is slow. If we want to MemcpyAsync, we should not use malloc for data exchanging.
  2. In our experiment, cudaMemcpyAsync on a tiny data buffer and invoke cudaSynchronizeStream() immediately, the memcpy may not be synchronized in multi-threads, multi-streams configuration. It is strange that it should be able to synchronize. However, since it is not the best practice to use malloc memory for data exchange and the result is correct when using cudaMallocHost, we will use cudaMallocHost for CPUAllocator to avoid this bug.

指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#10328
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7