Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
30a2e7f0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30a2e7f0
编写于
2月 23, 2021
作者:
Z
Zhong Hui
提交者:
GitHub
2月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] Fix softmax cross entropy integer overflow. (#30590) (#31134)
[BUG FIX] Fix softmax cross entropy overflow problem.
上级
3a72408f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
96 addition
and
93 deletion
+96
-93
paddle/fluid/operators/log_softmax_op.h
paddle/fluid/operators/log_softmax_op.h
+4
-4
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+82
-82
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+5
-2
paddle/fluid/platform/for_range.h
paddle/fluid/platform/for_range.h
+5
-5
未找到文件。
paddle/fluid/operators/log_softmax_op.h
浏览文件 @
30a2e7f0
...
@@ -29,16 +29,16 @@ static inline int CanonicalAxis(const int axis, const int rank) {
...
@@ -29,16 +29,16 @@ static inline int CanonicalAxis(const int axis, const int rank) {
return
axis
;
return
axis
;
}
}
static
inline
in
t
SizeToAxis
(
const
int
axis
,
const
framework
::
DDim
dims
)
{
static
inline
size_
t
SizeToAxis
(
const
int
axis
,
const
framework
::
DDim
dims
)
{
in
t
size
=
1
;
size_
t
size
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
size
*=
dims
[
i
];
size
*=
dims
[
i
];
}
}
return
size
;
return
size
;
}
}
static
inline
in
t
SizeFromAxis
(
const
int
axis
,
const
framework
::
DDim
dims
)
{
static
inline
size_
t
SizeFromAxis
(
const
int
axis
,
const
framework
::
DDim
dims
)
{
in
t
size
=
1
;
size_
t
size
=
1
;
for
(
int
i
=
axis
;
i
<
dims
.
size
();
i
++
)
{
for
(
int
i
=
axis
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
size
*=
dims
[
i
];
}
}
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
30a2e7f0
...
@@ -22,27 +22,27 @@ using Tensor = framework::Tensor;
...
@@ -22,27 +22,27 @@ using Tensor = framework::Tensor;
namespace
{
namespace
{
template
<
typename
T
>
template
<
typename
T
>
__global__
void
CrossEntropyGrad
(
T
*
logit_grad
,
const
int64_t
*
labels
,
__global__
void
CrossEntropyGrad
(
T
*
logit_grad
,
const
int64_t
*
labels
,
const
int
n
,
const
int
d
,
const
int
remain
,
const
int
64_t
n
,
const
int64_t
d
,
const
int
ignore_index
)
{
const
int
64_t
remain
,
const
int
ignore_index
)
{
CUDA_KERNEL_LOOP
(
index
,
n
*
remain
)
{
CUDA_KERNEL_LOOP
_TYPE
(
index
,
n
*
remain
,
int64_t
)
{
int
idx_n
=
index
/
remain
;
int
64_t
idx_n
=
index
/
remain
;
int
idx_remain
=
index
%
remain
;
int
64_t
idx_remain
=
index
%
remain
;
int
tmp
=
labels
[
index
];
int
64_t
tmp
=
labels
[
index
];
if
(
ignore_index
!=
tmp
)
{
if
(
ignore_index
!=
tmp
)
{
int
idx
=
idx_n
*
d
+
tmp
*
remain
+
idx_remain
;
int
64_t
idx
=
idx_n
*
d
+
tmp
*
remain
+
idx_remain
;
logit_grad
[
idx
]
-=
static_cast
<
T
>
(
1.
);
logit_grad
[
idx
]
-=
static_cast
<
T
>
(
1.
);
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
Scale
(
T
*
logit_grad
,
const
T
*
loss_grad
,
const
int
num
,
__global__
void
Scale
(
T
*
logit_grad
,
const
T
*
loss_grad
,
const
int
64_t
num
,
const
int
d
,
const
int
remain
,
const
int64_t
*
labels
,
const
int
64_t
d
,
const
int64_t
remain
,
const
int
ignore_index
)
{
const
int
64_t
*
labels
,
const
int
ignore_index
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
CUDA_KERNEL_LOOP
_TYPE
(
index
,
num
,
int64_t
)
{
int
idx_n
=
index
/
d
;
int
64_t
idx_n
=
index
/
d
;
int
idx_remain
=
index
%
remain
;
int
64_t
idx_remain
=
index
%
remain
;
int
idx_lbl
=
idx_n
*
remain
+
idx_remain
;
int
64_t
idx_lbl
=
idx_n
*
remain
+
idx_remain
;
if
(
labels
[
idx_lbl
]
==
ignore_index
)
{
if
(
labels
[
idx_lbl
]
==
ignore_index
)
{
logit_grad
[
index
]
=
static_cast
<
T
>
(
0.
);
logit_grad
[
index
]
=
static_cast
<
T
>
(
0.
);
}
else
{
}
else
{
...
@@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
...
@@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
logit_grad
,
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
logit_grad
,
const
T
*
loss_grad
,
const
T
*
loss_grad
,
const
T
*
labels
,
const
int
n
,
const
T
*
labels
,
const
int64_t
n
,
const
int
d
,
const
int
remain
)
{
const
int64_t
d
,
int
ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
remain
)
{
int64_t
ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
ids
<
n
*
d
)
{
if
(
ids
<
n
*
d
)
{
int
idx_n
=
ids
/
d
;
int
64_t
idx_n
=
ids
/
d
;
int
idx_remain
=
ids
%
remain
;
int
64_t
idx_remain
=
ids
%
remain
;
int
idx_loss
=
idx_n
*
remain
+
idx_remain
;
int
64_t
idx_loss
=
idx_n
*
remain
+
idx_remain
;
logit_grad
[
ids
]
=
loss_grad
[
idx_loss
]
*
(
logit_grad
[
ids
]
-
labels
[
ids
]);
logit_grad
[
ids
]
=
loss_grad
[
idx_loss
]
*
(
logit_grad
[
ids
]
-
labels
[
ids
]);
}
}
}
}
...
@@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
...
@@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// This kernel is used to calculate the max element of each row
// This kernel is used to calculate the max element of each row
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
static
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
d
,
int
axis_dim
)
{
int
64_t
d
,
int
axis_dim
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
// logits_data view as [n, axis_dim, remain]
// logits_data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain]
// max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int
remain
=
d
/
axis_dim
;
int
64_t
remain
=
d
/
axis_dim
;
int
idx_n
=
blockIdx
.
x
/
remain
;
int
64_t
idx_n
=
blockIdx
.
x
/
remain
;
int
idx_remain
=
blockIdx
.
x
%
remain
;
int
64_t
idx_remain
=
blockIdx
.
x
%
remain
;
int
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
64_t
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
end_idx
=
(
idx_n
+
1
)
*
d
;
int
64_t
end_idx
=
(
idx_n
+
1
)
*
d
;
int
step
=
BlockDim
*
remain
;
int
64_t
step
=
BlockDim
*
remain
;
T
cur_max
=
logits_data
[
beg_idx
];
T
cur_max
=
logits_data
[
beg_idx
];
beg_idx
+=
step
;
beg_idx
+=
step
;
while
(
beg_idx
<
end_idx
)
{
while
(
beg_idx
<
end_idx
)
{
...
@@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
...
@@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
// Make sure that BlockDim <= axis_dim
// Make sure that BlockDim <= axis_dim
template
<
typename
T
,
int
BlockDim
,
bool
CalculateLogSoftmax
=
false
>
template
<
typename
T
,
int
BlockDim
,
bool
CalculateLogSoftmax
=
false
>
static
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
static
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
d
,
T
*
max_data
,
T
*
softmax
,
int
axis_dim
)
{
int
64_t
d
,
int
axis_dim
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
// logits, softmax data view as [n, axis_dim, remain]
// logits, softmax data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain]
// max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int
remain
=
d
/
axis_dim
;
int
64_t
remain
=
d
/
axis_dim
;
int
idx_n
=
blockIdx
.
x
/
remain
;
int
64_t
idx_n
=
blockIdx
.
x
/
remain
;
int
idx_remain
=
blockIdx
.
x
%
remain
;
int
64_t
idx_remain
=
blockIdx
.
x
%
remain
;
int
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
64_t
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
end_idx
=
(
idx_n
+
1
)
*
d
;
int
64_t
end_idx
=
(
idx_n
+
1
)
*
d
;
auto
block_max
=
max_data
[
blockIdx
.
x
];
auto
block_max
=
max_data
[
blockIdx
.
x
];
int
step
=
BlockDim
*
remain
;
int
64_t
step
=
BlockDim
*
remain
;
// In numeric stable mode softmax_with_loss, we calc loss with
// In numeric stable mode softmax_with_loss, we calc loss with
// tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
// tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
...
@@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
...
@@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
// Make sure that BlockDim <= axis_dim
// Make sure that BlockDim <= axis_dim
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
static
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
d
,
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
axis_dim
)
{
int
64_t
d
,
int
axis_dim
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
// logits, softmax, labels data view as [n, axis_dim, remain]
// logits, softmax, labels data view as [n, axis_dim, remain]
// loss_data view as [n, 1, remain]
// loss_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int
remain
=
d
/
axis_dim
;
int
64_t
remain
=
d
/
axis_dim
;
int
idx_n
=
blockIdx
.
x
/
remain
;
int
64_t
idx_n
=
blockIdx
.
x
/
remain
;
int
idx_remain
=
blockIdx
.
x
%
remain
;
int
64_t
idx_remain
=
blockIdx
.
x
%
remain
;
int
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
64_t
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
end_idx
=
(
idx_n
+
1
)
*
d
;
int
64_t
end_idx
=
(
idx_n
+
1
)
*
d
;
// log_diff_max_sum shares memory with loss
// log_diff_max_sum shares memory with loss
auto
block_log_diff_max_sum
=
loss_data
[
blockIdx
.
x
];
auto
block_log_diff_max_sum
=
loss_data
[
blockIdx
.
x
];
auto
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
auto
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
softmax
[
beg_idx
]
=
exp_on_device
(
tmp
);
softmax
[
beg_idx
]
=
exp_on_device
(
tmp
);
auto
loss
=
-
labels_data
[
beg_idx
]
*
tmp
;
auto
loss
=
-
labels_data
[
beg_idx
]
*
tmp
;
int
step
=
BlockDim
*
remain
;
int
64_t
step
=
BlockDim
*
remain
;
beg_idx
+=
step
;
beg_idx
+=
step
;
while
(
beg_idx
<
end_idx
)
{
while
(
beg_idx
<
end_idx
)
{
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
...
@@ -251,21 +252,22 @@ template <typename T>
...
@@ -251,21 +252,22 @@ template <typename T>
struct
HardLabelSoftmaxWithCrossEntropyFunctor
{
struct
HardLabelSoftmaxWithCrossEntropyFunctor
{
public:
public:
HardLabelSoftmaxWithCrossEntropyFunctor
(
const
int64_t
*
labels
,
T
*
loss
,
HardLabelSoftmaxWithCrossEntropyFunctor
(
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
d
,
int
axis_dim
)
T
*
log_softmax
,
int64_t
d
,
int
axis_dim
)
:
labels_
(
labels
),
:
labels_
(
labels
),
loss_
(
loss
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
log_softmax_
(
log_softmax
),
d_
(
d
),
d_
(
d
),
axis_dim_
(
axis_dim
)
{}
axis_dim_
(
axis_dim
)
{}
__device__
void
operator
()(
int
idx
)
const
{
__device__
void
operator
()(
int
64_t
idx
)
const
{
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int
remain
=
d_
/
axis_dim_
;
int
64_t
remain
=
d_
/
axis_dim_
;
int
idx_n
=
idx
/
d_
;
int
64_t
idx_n
=
idx
/
d_
;
int
idx_axis
=
(
idx
%
d_
)
/
remain
;
int
64_t
idx_axis
=
(
idx
%
d_
)
/
remain
;
int
idx_remain
=
idx
%
remain
;
int
64_t
idx_remain
=
idx
%
remain
;
// labels, loss view as [n, remain]
// labels, loss view as [n, remain]
int
idx_lbl
=
idx_n
*
remain
+
idx_remain
;
int
64_t
idx_lbl
=
idx_n
*
remain
+
idx_remain
;
// It also would ignore labels not in range(class_num).
// It also would ignore labels not in range(class_num).
if
(
idx_axis
!=
labels_
[
idx_lbl
])
{
if
(
idx_axis
!=
labels_
[
idx_lbl
])
{
log_softmax_
[
idx
]
=
exp_on_device
(
log_softmax_
[
idx
]);
log_softmax_
[
idx
]
=
exp_on_device
(
log_softmax_
[
idx
]);
...
@@ -280,7 +282,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
...
@@ -280,7 +282,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
const
int64_t
*
labels_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
loss_
;
T
*
log_softmax_
;
T
*
log_softmax_
;
int
d_
;
int
64_t
d_
;
int
axis_dim_
;
int
axis_dim_
;
};
};
...
@@ -289,7 +291,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
...
@@ -289,7 +291,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
public:
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
(
const
int64_t
*
labels
,
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
(
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
T
*
loss
,
T
*
log_softmax
,
int
d
,
int
axis_dim
,
int
64_t
d
,
int
axis_dim
,
int
ignore_idx
)
int
ignore_idx
)
:
labels_
(
labels
),
:
labels_
(
labels
),
loss_
(
loss
),
loss_
(
loss
),
...
@@ -298,14 +300,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
...
@@ -298,14 +300,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
axis_dim_
(
axis_dim
),
axis_dim_
(
axis_dim
),
ignore_idx_
(
ignore_idx
)
{}
ignore_idx_
(
ignore_idx
)
{}
__device__
void
operator
()(
int
idx
)
const
{
__device__
void
operator
()(
int
64_t
idx
)
const
{
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int
remain
=
d_
/
axis_dim_
;
int
64_t
remain
=
d_
/
axis_dim_
;
int
idx_n
=
idx
/
d_
;
int
64_t
idx_n
=
idx
/
d_
;
int
idx_axis
=
(
idx
%
d_
)
/
remain
;
int
64_t
idx_axis
=
(
idx
%
d_
)
/
remain
;
int
idx_remain
=
idx
%
remain
;
int
64_t
idx_remain
=
idx
%
remain
;
// labels, loss view as [n, remain]
// labels, loss view as [n, remain]
int
idx_lbl
=
idx_n
*
remain
+
idx_remain
;
int
64_t
idx_lbl
=
idx_n
*
remain
+
idx_remain
;
if
(
idx_axis
!=
labels_
[
idx_lbl
]
||
idx_axis
==
ignore_idx_
)
{
if
(
idx_axis
!=
labels_
[
idx_lbl
]
||
idx_axis
==
ignore_idx_
)
{
log_softmax_
[
idx
]
=
exp_on_device
(
log_softmax_
[
idx
]);
log_softmax_
[
idx
]
=
exp_on_device
(
log_softmax_
[
idx
]);
}
else
{
}
else
{
...
@@ -319,7 +321,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
...
@@ -319,7 +321,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
const
int64_t
*
labels_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
loss_
;
T
*
log_softmax_
;
T
*
log_softmax_
;
int
d_
;
int
64_t
d_
;
int
axis_dim_
;
int
axis_dim_
;
int
ignore_idx_
;
int
ignore_idx_
;
};
};
...
@@ -327,13 +329,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
...
@@ -327,13 +329,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
template
<
typename
T
>
template
<
typename
T
>
static
void
HardLabelSoftmaxWithCrossEntropy
(
static
void
HardLabelSoftmaxWithCrossEntropy
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
logits_data
,
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
logits_data
,
const
int64_t
*
labels_data
,
T
*
loss_data
,
T
*
softmax_data
,
int
n
,
int
d
,
const
int64_t
*
labels_data
,
T
*
loss_data
,
T
*
softmax_data
,
int
64_t
n
,
int
axis_dim
,
int
ignore_idx
)
{
int
64_t
d
,
int
axis_dim
,
int
ignore_idx
)
{
constexpr
int
kMaxBlockDim
=
512
;
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
axis_dim
>=
kMaxBlockDim
int
64_t
block_dim
=
axis_dim
>=
kMaxBlockDim
?
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
axis_dim
)));
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
axis_dim
)));
int
grid_dim
=
n
*
d
/
axis_dim
;
int
64_t
grid_dim
=
n
*
d
/
axis_dim
;
auto
stream
=
ctx
.
stream
();
auto
stream
=
ctx
.
stream
();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
...
@@ -372,16 +374,14 @@ static void HardLabelSoftmaxWithCrossEntropy(
...
@@ -372,16 +374,14 @@ static void HardLabelSoftmaxWithCrossEntropy(
}
}
template
<
typename
T
>
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
labels_data
,
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
softmax_data
,
T
*
loss_data
,
T
*
softmax_data
,
T
*
loss_data
,
int64_t
n
,
int64_t
d
,
int
axis_dim
,
cudaStream_t
stream
)
{
int
n
,
int
d
,
int
axis_dim
,
cudaStream_t
stream
)
{
constexpr
int
kMaxBlockDim
=
512
;
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
axis_dim
>=
kMaxBlockDim
int
64_t
block_dim
=
axis_dim
>=
kMaxBlockDim
?
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
axis_dim
)));
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
axis_dim
)));
int
grid_dim
=
n
*
d
/
axis_dim
;
int
64_t
grid_dim
=
n
*
d
/
axis_dim
;
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
case BlockDim: \
...
@@ -430,8 +430,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
...
@@ -430,8 +430,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
int
axis_dim
=
logits
->
dims
()[
axis
];
int
axis_dim
=
logits
->
dims
()[
axis
];
const
int
n
=
SizeToAxis
(
axis
,
logits
->
dims
());
const
int
64_t
n
=
SizeToAxis
(
axis
,
logits
->
dims
());
const
int
d
=
SizeFromAxis
(
axis
,
logits
->
dims
());
const
int
64_t
d
=
SizeFromAxis
(
axis
,
logits
->
dims
());
auto
*
softmax_data
=
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
softmax_data
=
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
loss_data
=
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
loss_data
=
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
@@ -500,24 +500,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -500,24 +500,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
int
axis_dim
=
logit_grad
->
dims
()[
axis
];
int
axis_dim
=
logit_grad
->
dims
()[
axis
];
const
int
n
=
SizeToAxis
(
axis
,
logit_grad
->
dims
());
const
int
64_t
n
=
SizeToAxis
(
axis
,
logit_grad
->
dims
());
const
int
d
=
SizeFromAxis
(
axis
,
logit_grad
->
dims
());
const
int
64_t
d
=
SizeFromAxis
(
axis
,
logit_grad
->
dims
());
const
int
remain
=
d
/
axis_dim
;
const
int
64_t
remain
=
d
/
axis_dim
;
int
block
=
512
;
int
block
=
512
;
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
if
(
context
.
Attr
<
bool
>
(
"soft_label"
))
{
if
(
context
.
Attr
<
bool
>
(
"soft_label"
))
{
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
int
64_t
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
const
T
*
label_data
=
labels
->
data
<
T
>
();
const
T
*
label_data
=
labels
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
loss_grad_data
,
label_data
,
n
,
d
,
remain
);
logit_grad_data
,
loss_grad_data
,
label_data
,
n
,
d
,
remain
);
}
else
{
}
else
{
int
grid
=
(
n
*
remain
+
block
-
1
)
/
block
;
int
64_t
grid
=
(
n
*
remain
+
block
-
1
)
/
block
;
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
CrossEntropyGrad
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
CrossEntropyGrad
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
label_data
,
n
,
d
,
remain
,
ignore_index
);
logit_grad_data
,
label_data
,
n
,
d
,
remain
,
ignore_index
);
int
num
=
n
*
d
;
int
64_t
num
=
n
*
d
;
grid
=
(
num
+
block
-
1
)
/
block
;
grid
=
(
num
+
block
-
1
)
/
block
;
Scale
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
loss_grad_data
,
num
,
Scale
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
loss_grad_data
,
num
,
d
,
remain
,
label_data
,
ignore_index
);
d
,
remain
,
label_data
,
ignore_index
);
...
...
paddle/fluid/platform/cuda_helper.h
浏览文件 @
30a2e7f0
...
@@ -70,11 +70,14 @@ namespace platform {
...
@@ -70,11 +70,14 @@ namespace platform {
* }
* }
*
*
*/
*/
#define CUDA_KERNEL_LOOP(i, num) \
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \
int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \
for (in
t i = __index__; __index__ < (num);
\
for (in
dex_type i = __index__; __index__ < (num);
\
__index__ += blockDim.x * gridDim.x, i = __index__)
__index__ += blockDim.x * gridDim.x, i = __index__)
#define CUDA_KERNEL_LOOP(i, num) CUDA_KERNEL_LOOP_TYPE(i, num, int)
class
CublasHandleHolder
{
class
CublasHandleHolder
{
public:
public:
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
...
...
paddle/fluid/platform/for_range.h
浏览文件 @
30a2e7f0
...
@@ -48,7 +48,7 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
...
@@ -48,7 +48,7 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
}
}
template
<
typename
Function
>
template
<
typename
Function
>
__global__
static
void
ForRangeElemwiseOp
(
Function
func
,
in
t
limit
)
{
__global__
static
void
ForRangeElemwiseOp
(
Function
func
,
size_
t
limit
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
if
(
idx
<
limit
)
{
if
(
idx
<
limit
)
{
func
(
idx
);
func
(
idx
);
...
@@ -58,13 +58,13 @@ __global__ static void ForRangeElemwiseOp(Function func, int limit) {
...
@@ -58,13 +58,13 @@ __global__ static void ForRangeElemwiseOp(Function func, int limit) {
template
<
>
template
<
>
struct
ForRange
<
CUDADeviceContext
>
{
struct
ForRange
<
CUDADeviceContext
>
{
ForRange
(
const
CUDADeviceContext
&
dev_ctx
,
size_t
limit
)
ForRange
(
const
CUDADeviceContext
&
dev_ctx
,
size_t
limit
)
:
dev_ctx_
(
dev_ctx
),
limit_
(
static_cast
<
in
t
>
(
limit
))
{}
:
dev_ctx_
(
dev_ctx
),
limit_
(
static_cast
<
size_
t
>
(
limit
))
{}
template
<
typename
Function
>
template
<
typename
Function
>
inline
void
operator
()(
Function
func
)
const
{
inline
void
operator
()(
Function
func
)
const
{
constexpr
int
num_threads
=
1024
;
constexpr
int
num_threads
=
1024
;
in
t
block_size
=
limit_
<=
num_threads
?
limit_
:
num_threads
;
size_
t
block_size
=
limit_
<=
num_threads
?
limit_
:
num_threads
;
in
t
grid_size
=
(
limit_
+
num_threads
-
1
)
/
num_threads
;
size_
t
grid_size
=
(
limit_
+
num_threads
-
1
)
/
num_threads
;
if
(
grid_size
==
1
)
{
if
(
grid_size
==
1
)
{
ForRangeElemwiseOpGridIsOne
<<<
1
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
ForRangeElemwiseOpGridIsOne
<<<
1
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
...
@@ -76,7 +76,7 @@ struct ForRange<CUDADeviceContext> {
...
@@ -76,7 +76,7 @@ struct ForRange<CUDADeviceContext> {
}
}
const
CUDADeviceContext
&
dev_ctx_
;
const
CUDADeviceContext
&
dev_ctx_
;
in
t
limit_
;
size_
t
limit_
;
};
};
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录