Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b3090ad4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b3090ad4
编写于
12月 03, 2019
作者:
L
Leo Chen
提交者:
GitHub
12月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix synchronization problem in softmax_with_cross_entropy_op, test=develop (#21480)
上级
01fa4ead
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
4 addition
and
0 deletion
+4
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+4
-0
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
b3090ad4
...
@@ -200,6 +200,10 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
...
@@ -200,6 +200,10 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
softmax
[
beg_idx
]
-=
diff_max_sum
;
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
step
;
beg_idx
+=
step
;
}
}
// Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
// calculate diff_max_sum, __syncthreads() is needed here.
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录