Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #23518

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 4月 07, 2020 by saxon_zh@saxon_zhGuest

bert_encoder_functor在V100上编译错误

Created by: luotao1

develop分支,V100机器,开发镜像docker.io/paddlepaddle/paddle_manylinux_devel:cuda9.0_cudnn7,出现编译错误

/Paddle/paddle/fluid/operators/math/bert_encoder_functor.cu(258): error: calling a __device__ function("__half") from a __host__ function("operator()") is not allowed
          detected during instantiation of "void paddle::operators::math::MultiHeadGPUComputeFunctor<T>::operator()(const paddle::platform::CUDADeviceContext &, int, int, int, int, T *, const T *, T *, T, T) [with T=half]"
(264): here

/Paddle/paddle/fluid/operators/math/bert_encoder_functor.cu(317): error: calling a __device__ function("operator float") from a __host__ function("operator()") is not allowed
          detected during instantiation of "void paddle::operators::math::SkipLayerNormFunctor<T>::operator()(int, int, const T *, const T *, const float *, const float *, T *, T, cudaStream_t) [with T=half]"
(336): here

/Paddle/paddle/fluid/operators/math/bert_encoder_functor.cu(321): error: calling a __device__ function("operator float") from a __host__ function("operator()") is not allowed
          detected during instantiation of "void paddle::operators::math::SkipLayerNormFunctor<T>::operator()(int, int, const T *, const T *, const float *, const float *, T *, T, cudaStream_t) [with T=half]"
(336): here

/Paddle/paddle/fluid/operators/math/bert_encoder_functor.cu(325): error: calling a __device__ function("operator float") from a __host__ function("operator()") is not allowed
          detected during instantiation of "void paddle::operators::math::SkipLayerNormFunctor<T>::operator()(int, int, const T *, const T *, const float *, const float *, T *, T, cudaStream_t) [with T=half]"
(336): here

/Paddle/paddle/fluid/operators/math/bert_encoder_functor.cu(329): error: calling a __device__ function("operator float") from a __host__ function("operator()") is not allowed
          detected during instantiation of "void paddle::operators::math::SkipLayerNormFunctor<T>::operator()(int, int, const T *, const T *, const float *, const float *, T *, T, cudaStream_t) [with T=half]"
(336): here

5 errors detected in the compilation of "/tmp/tmpxft_0000be09_00000000-4_bert_encoder_functor.cpp4.ii".
CMake Error at bert_encoder_functor_generated_bert_encoder_functor.cu.o.cmake:262 (message):
  Error generating file
  /Paddle/build/paddle/fluid/operators/math/CMakeFiles/bert_encoder_functor.dir//./bert_encoder_functor_generated_bert_encoder_functor.cu.o


make[2]: *** [paddle/fluid/operators/math/CMakeFiles/bert_encoder_functor.dir/bert_encoder_functor_generated_bert_encoder_functor.cu.o] Error 1
make[1]: *** [paddle/fluid/operators/math/CMakeFiles/bert_encoder_functor.dir/all] Error 2
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#23518
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7