Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
75af5464
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
75af5464
编写于
8月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4954 Fix GPU non-sparse cross-entropy op returning all zeros
Merge pull request !4954 from tom_chen/cross_entropy
上级
af45133a
8fa4422d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
15 addition
and
7 deletion
+15
-7
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu
...ckend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu
+15
-7
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu
浏览文件 @
75af5464
...
...
@@ -52,12 +52,18 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
}
template
<
typename
T
,
typename
S
>
__global__
void
CrossEntropyKernel
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
class_num
,
T
*
losses
,
T
*
dlogits
)
{
losses
[
threadIdx
.
x
]
=
0
;
T
epsilon
=
1e-6
;
for
(
int
i
=
threadIdx
.
x
*
class_num
;
i
<
(
threadIdx
.
x
+
1
)
*
class_num
;
++
i
)
{
losses
[
threadIdx
.
x
]
-=
logf
((
logits
[
i
]
<=
0
?
epsilon
:
logits
[
i
]))
*
labels
[
i
];
dlogits
[
i
]
=
logits
[
i
]
-
labels
[
i
];
__global__
void
CrossEntropyKernel
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
epsilon
,
T
*
losses
,
T
*
dlogits
)
{
for
(
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
batch_size
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
losses
[
index
]
=
0
;
const
int
start
=
index
*
class_num
;
const
int
end
=
(
index
+
1
)
*
class_num
;
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
losses
[
index
]
-=
logf
((
logits
[
i
]
<=
0
?
epsilon
:
logits
[
i
]))
*
labels
[
i
];
dlogits
[
i
]
=
logits
[
i
]
-
labels
[
i
];
}
}
}
...
...
@@ -79,7 +85,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
template
<
typename
T
,
typename
S
>
void
CrossEntropy
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
losses
,
T
*
dlogits
,
cudaStream_t
cuda_stream
)
{
CrossEntropyKernel
<<<
1
,
batch_size
,
0
,
cuda_stream
>>>
(
logits
,
labels
,
class_num
,
losses
,
dlogits
);
T
epsilon
=
1e-6
;
CrossEntropyKernel
<<<
GET_BLOCKS
(
batch_size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
logits
,
labels
,
batch_size
,
class_num
,
epsilon
,
losses
,
dlogits
);
}
template
void
CrossEntropyWithSparse
<
float
,
int
>(
const
float
*
logits
,
const
int
*
labels
,
const
size_t
batch_size
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录