Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fff270ea
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fff270ea
编写于
5月 09, 2019
作者:
Z
Zeng Jinle
提交者:
GitHub
5月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments,test=develop (#17273)
上级
7a3bb061
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
8 addition
and
20 deletion
+8
-20
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+8
-20
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
fff270ea
...
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h"
...
...
@@ -309,12 +310,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
int
ignore_idx_
;
};
template
<
typename
T
>
static
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
n
)
{
auto
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
n
)
out
[
idx
]
=
static_cast
<
T
>
(
1
);
}
template
<
typename
T
>
static
void
HardLabelSoftmaxWithCrossEntropy
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
logits_data
,
...
...
@@ -354,13 +349,6 @@ static void HardLabelSoftmaxWithCrossEntropy(
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
grid_dim
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
grid_dim
);
cudaMemsetAsync
(
loss_data
,
0
,
grid_dim
*
sizeof
(
T
),
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
...
...
@@ -401,13 +389,6 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
grid_dim
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
n
);
cudaMemsetAsync
(
loss_data
,
0
,
grid_dim
*
sizeof
(
T
),
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
...
...
@@ -431,6 +412,13 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
int
axis_dim
=
logits
->
dims
()[
axis
];
if
(
axis_dim
==
1
)
{
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_constant
;
set_constant
(
context
.
cuda_device_context
(),
softmax
,
static_cast
<
T
>
(
1
));
set_constant
(
context
.
cuda_device_context
(),
loss
,
static_cast
<
T
>
(
0
));
return
;
}
const
int
n
=
SizeToAxis
(
axis
,
logits
->
dims
());
const
int
d
=
SizeFromAxis
(
axis
,
logits
->
dims
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录