Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
081e4307
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
081e4307
编写于
3月 18, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
3月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize perf of softmax_with_cross_entropy_bwd (#40643)
* Optimize perf of softmax_with_cross_entropy_bwd * fix * fix
上级
1904572a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
19 addition
and
12 deletion
+19
-12
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+19
-12
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
081e4307
...
@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
...
@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
*/
*/
template
<
typename
T
,
typename
LabelT
>
template
<
typename
T
,
typename
LabelT
>
__global__
void
SoftmaxWithCrossEntropyGradHardLabel
(
__global__
void
SoftmaxWithCrossEntropyGradHardLabel
(
T
*
logits_grad
,
const
T
*
loss_grad
,
const
LabelT
*
labels
,
const
int64_t
n
,
T
*
logits_grad
,
const
T
*
loss_grad
,
const
T
*
softmax
,
const
LabelT
*
labels
,
const
int64_t
dim
,
const
int64_t
d
,
const
int
ignore_index
)
{
const
int64_t
n
,
const
int64_t
dim
,
const
int64_t
d
,
const
int
ignore_index
)
{
int64_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int64_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int64_t
idx_n
=
idx
/
(
d
*
dim
);
int64_t
idx_n
=
idx
/
(
d
*
dim
);
int64_t
idx_dim
=
(
idx
/
d
)
%
dim
;
int64_t
idx_dim
=
(
idx
/
d
)
%
dim
;
...
@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
...
@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
if
(
lbl
==
ignore_index
)
{
if
(
lbl
==
ignore_index
)
{
logits_grad
[
idx
]
=
static_cast
<
T
>
(
0.0
);
logits_grad
[
idx
]
=
static_cast
<
T
>
(
0.0
);
}
else
if
(
lbl
==
idx_dim
)
{
}
else
if
(
lbl
==
idx_dim
)
{
logits_grad
[
idx
]
=
logits_grad
[
idx
]
=
(
softmax
[
idx
]
-
static_cast
<
T
>
(
1.0
))
*
loss_grad
[
ids
];
(
logits_grad
[
idx
]
-
static_cast
<
T
>
(
1.0
))
*
loss_grad
[
ids
];
}
else
{
}
else
{
logits_grad
[
idx
]
*=
loss_grad
[
ids
];
logits_grad
[
idx
]
=
softmax
[
idx
]
*
loss_grad
[
ids
];
}
}
}
}
}
}
...
@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
Tensor
*
logit_grad
=
Tensor
*
logit_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
if
(
logit_grad
!=
softmax
)
{
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
auto
use_softmax
=
context
.
Attr
<
bool
>
(
"use_softmax"
);
T
*
logit_grad_data
=
nullptr
;
bool
copy_flag
=
(
logit_grad
!=
softmax
&&
(
!
use_softmax
||
soft_label
));
if
(
copy_flag
)
{
framework
::
TensorCopy
(
*
softmax
,
context
.
GetPlace
(),
framework
::
TensorCopy
(
*
softmax
,
context
.
GetPlace
(),
context
.
device_context
(),
logit_grad
);
context
.
device_context
(),
logit_grad
);
logit_grad_data
=
logit_grad
->
template
data
<
T
>();
}
else
{
logit_grad_data
=
logit_grad
->
template
mutable_data
<
T
>(
context
.
GetPlace
());
}
}
T
*
logit_grad_data
=
logit_grad
->
template
data
<
T
>();
const
int
rank
=
logit_grad
->
dims
().
size
();
const
int
rank
=
logit_grad
->
dims
().
size
();
const
int
axis
=
phi
::
funcs
::
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
axis
=
phi
::
funcs
::
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
...
@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
#else
#else
int
block
=
512
;
int
block
=
512
;
#endif
#endif
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
auto
use_softmax
=
context
.
Attr
<
bool
>
(
"use_softmax"
);
// do not with softmax op, and input is softmax
// do not with softmax op, and input is softmax
if
(
!
use_softmax
)
{
if
(
!
use_softmax
)
{
...
@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
loss_grad_data
,
label_data
,
n
,
d
,
remain
);
logit_grad_data
,
loss_grad_data
,
label_data
,
n
,
d
,
remain
);
}
else
{
}
else
{
const
T
*
softmax_data
=
softmax
->
template
data
<
T
>();
const
auto
*
label_data
=
labels
.
template
data
<
LabelT
>();
const
auto
*
label_data
=
labels
.
template
data
<
LabelT
>();
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
SoftmaxWithCrossEntropyGradHardLabel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
SoftmaxWithCrossEntropyGradHardLabel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
loss_grad_data
,
label_data
,
n
,
d
/
remain
,
remai
n
,
logit_grad_data
,
loss_grad_data
,
softmax_data
,
label_data
,
n
,
ignore_index
);
d
/
remain
,
remain
,
ignore_index
);
}
}
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录