Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
58488c5d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58488c5d
编写于
4月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!496 fix bug in cross entropy error
Merge pull request !496 from SanjayChan/cross_entropy
上级
d9e4dcc3
65a23763
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
17 addition
and
134 deletion
+17
-134
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu
...ore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu
+0
-47
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh
...re/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh
+0
-26
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu
+11
-48
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh
+2
-7
mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h
...nel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h
+4
-6
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu
已删除
100644 → 0
浏览文件 @
d9e4dcc3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include "cross_entropy_cuda_impl.cuh"
#include "include/cuda_runtime.h"
__global__
void
CalCrossEntropyWithGradKernel
(
const
float
*
softmax_logits
,
const
float
*
log_softmax_logits
,
const
float
*
labels
,
const
int
batch_size
,
const
int
num_classes
,
float
*
loss
,
float
*
dx
)
{
extern
__shared__
float
loss_shared
[];
const
float
mean_scale
=
1.0
f
/
static_cast
<
float
>
(
batch_size
);
loss_shared
[
threadIdx
.
x
]
=
0
;
for
(
int
i
=
threadIdx
.
x
*
num_classes
;
i
<
(
threadIdx
.
x
+
1
)
*
num_classes
;
++
i
)
{
loss_shared
[
threadIdx
.
x
]
-=
log_softmax_logits
[
i
]
*
labels
[
i
];
dx
[
i
]
=
(
softmax_logits
[
i
]
-
labels
[
i
])
*
mean_scale
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
*
loss
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
*
loss
+=
loss_shared
[
i
];
}
*
loss
*=
mean_scale
;
}
}
void
CalCrossEntropyWithGrad
(
const
float
*
softmax_logits
,
const
float
*
log_softmax_logits
,
const
float
*
labels
,
const
int
batch_size
,
const
int
num_classes
,
float
*
loss
,
float
*
dx
,
cudaStream_t
cuda_stream
)
{
CalCrossEntropyWithGradKernel
<<<
1
,
batch_size
,
batch_size
*
sizeof
(
float
),
cuda_stream
>>>
(
softmax_logits
,
log_softmax_logits
,
labels
,
batch_size
,
num_classes
,
loss
,
dx
);
}
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh
已删除
100644 → 0
浏览文件 @
d9e4dcc3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
#include "device/gpu/cuda_common.h"
void
CalCrossEntropyWithGrad
(
const
float
*
softmax_logits
,
const
float
*
log_softmax_logits
,
const
float
*
labels
,
const
int
batch_size
,
const
int
num_classes
,
float
*
loss
,
float
*
dx
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu
浏览文件 @
58488c5d
...
...
@@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
}
template
<
typename
T
,
typename
S
>
__global__
void
CrossEntropyWithoutSparseKernel
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
losses
)
{
T
epsilon
=
1e-6
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
logit
=
0.0
;
for
(
size_t
j
=
0
;
j
<
class_num
;
j
++
)
{
if
(
fabs
(
labels
[
i
*
class_num
+
j
]
-
1.0
)
<=
1e-8
)
{
logit
=
logits
[
i
*
class_num
+
j
];
break
;
}
}
if
(
logit
<=
0
)
{
logit
+=
epsilon
;
}
losses
[
i
]
=
-
logf
(
logit
);
__global__
void
CrossEntropyKernel
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
class_num
,
T
*
losses
,
T
*
dlogits
)
{
losses
[
threadIdx
.
x
]
=
0
;
for
(
int
i
=
threadIdx
.
x
*
class_num
;
i
<
(
threadIdx
.
x
+
1
)
*
class_num
;
++
i
)
{
losses
[
threadIdx
.
x
]
-=
logf
(
logits
[
i
])
*
labels
[
i
];
dlogits
[
i
]
=
logits
[
i
]
-
labels
[
i
];
}
return
;
}
template
<
typename
T
,
typename
S
>
__global__
void
CrossEntropyGradWithoutSparseKernel
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
grad
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
size_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
j
<
class_num
;
j
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
fabs
(
labels
[
i
*
class_num
+
j
]
-
1.0
)
<=
1e-8
)
{
grad
[
i
*
class_num
+
j
]
=
(
logits
[
i
*
class_num
+
j
]
-
1
)
/
batch_size
;
}
else
{
grad
[
i
*
class_num
+
j
]
=
logits
[
i
*
class_num
+
j
]
/
batch_size
;
}
}
}
return
;
}
template
<
typename
T
,
typename
S
>
...
...
@@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
}
template
<
typename
T
,
typename
S
>
void
CrossEntropyWithoutSparse
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
losses
,
cudaStream_t
cuda_stream
)
{
CrossEntropyWithoutSparseKernel
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
logits
,
labels
,
batch_size
,
class_num
,
losses
);
return
;
}
template
<
typename
T
,
typename
S
>
void
CrossEntropyGradWithoutSparse
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
grad
,
cudaStream_t
cuda_stream
)
{
CrossEntropyGradWithoutSparseKernel
<<<
GET_BLOCKS
(
class_num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
logits
,
labels
,
batch_size
,
class_num
,
grad
);
return
;
void
CrossEntropy
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
losses
,
T
*
dlogits
,
cudaStream_t
cuda_stream
)
{
CrossEntropyKernel
<<<
1
,
batch_size
,
0
,
cuda_stream
>>>
(
logits
,
labels
,
class_num
,
losses
,
dlogits
);
}
template
void
CrossEntropyWithSparse
<
float
,
int
>(
const
float
*
logits
,
const
int
*
labels
,
const
size_t
batch_size
,
...
...
@@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse<float, int>(const float *logits, const
template
void
CrossEntropyGradWithSparse
<
float
,
int64_t
>(
const
float
*
logits
,
const
int64_t
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
float
*
grad
,
cudaStream_t
cuda_stream
);
template
void
CrossEntropyWithoutSparse
<
float
,
float
>(
const
float
*
logits
,
const
float
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
float
*
losses
,
cudaStream_t
cuda_stream
);
template
void
CrossEntropyGradWithoutSparse
<
float
,
float
>(
const
float
*
logits
,
const
float
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
float
*
grad
,
cudaStream_t
cuda_stream
);
template
void
CrossEntropy
<
float
,
float
>(
const
float
*
logits
,
const
float
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
float
*
losses
,
float
*
dlogits
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh
浏览文件 @
58488c5d
...
...
@@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
T
*
grad
,
cudaStream_t
cuda_stream
);
template
<
typename
T
,
typename
S
>
void
CrossEntropyWithoutSparse
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
losses
,
cudaStream_t
cuda_stream
);
template
<
typename
T
,
typename
S
>
void
CrossEntropyGradWithoutSparse
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
grad
,
cudaStream_t
cuda_stream
);
void
CrossEntropy
(
const
T
*
logits
,
const
S
*
labels
,
const
size_t
batch_size
,
const
size_t
class_num
,
T
*
losses
,
T
*
dlogits
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_
mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h
浏览文件 @
58488c5d
...
...
@@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel {
}
T
*
logits_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
S
*
labels_addr
=
GetDeviceAddress
<
S
>
(
inputs
,
1
);
T
*
output1
_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
output2
_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
T
*
loss
_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
dlogits
_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
T
*
softmax_output_logits
=
GetDeviceAddress
<
T
>
(
workspace
,
0
);
const
float
alpha
=
1
;
...
...
@@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel {
softmax_output_descriptor_
,
softmax_output_logits
),
"cudnnSoftmaxForward failed."
);
CrossEntropyWithoutSparse
(
softmax_output_logits
,
labels_addr
,
batch_size_
,
channel_size_
,
output1_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CrossEntropyGradWithoutSparse
(
softmax_output_logits
,
labels_addr
,
batch_size_
,
channel_size_
,
output2_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CrossEntropy
(
softmax_output_logits
,
labels_addr
,
batch_size_
,
channel_size_
,
loss_addr
,
dlogits_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录