Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cf799a6a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cf799a6a
编写于
8月 10, 2018
作者:
S
sneaxiy
提交者:
GitHub
8月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12553 from sneaxiy/refine_softmax_with_cross_entropy
Refine softmax_with_cross_entropy op
上级
772ceee3
1b4515f6
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
209 addition
and
9 deletion
+209
-9
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+209
-9
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
cf799a6a
/* Copyright (c) 201
6
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -14,6 +14,8 @@ limitations under the License. */
...
@@ -14,6 +14,8 @@ limitations under the License. */
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
...
@@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
logit_grad
[
ids
]
=
loss_grad
[
row_ids
]
*
(
logit_grad
[
ids
]
-
labels
[
ids
]);
logit_grad
[
ids
]
=
loss_grad
[
row_ids
]
*
(
logit_grad
[
ids
]
-
labels
[
ids
]);
}
}
}
}
}
// namespace
}
// namespace
static
__device__
__forceinline__
float
real_exp
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
real_exp
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
float
real_log
(
float
x
)
{
return
math
::
TolerableValue
<
float
>
()(
logf
(
x
));
}
static
__device__
__forceinline__
double
real_log
(
double
x
)
{
return
math
::
TolerableValue
<
double
>
()(
log
(
x
));
}
/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
* and loss **/
/*
Supposing the x is `logits` and y is `labels`, the equations are as
followings:
cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
= \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
= \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
= \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
= \sum_{j}(-y_i_j * tmp_i_j)
softmax_i_j = e^{tmp_i_j}
where:
max_i = \max_{j}{x_i_j}
logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/
// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template
<
typename
T
,
int
BlockDim
>
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
/*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/
>
;
template
<
typename
T
,
int
BlockDim
>
using
BlockReduceTempStorage
=
typename
BlockReduce
<
T
,
BlockDim
>::
TempStorage
;
// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
end_idx
=
feature_size
*
(
blockIdx
.
x
+
1
);
T
cur_max
=
logits_data
[
beg_idx
];
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
if
(
cur_max
<
logits_data
[
beg_idx
])
{
cur_max
=
logits_data
[
beg_idx
];
}
beg_idx
+=
BlockDim
;
}
cur_max
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
cur_max
,
cub
::
Max
());
if
(
threadIdx
.
x
==
0
)
{
max_data
[
blockIdx
.
x
]
=
cur_max
<
-
64
?
-
64
:
cur_max
;
}
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
end_idx
=
feature_size
*
(
blockIdx
.
x
+
1
);
auto
block_max
=
max_data
[
blockIdx
.
x
];
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
}
diff_max_sum
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
diff_max_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
real_log
(
diff_max_sum
);
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
end_idx
=
feature_size
*
(
blockIdx
.
x
+
1
);
// log_diff_max_sum shares memory with loss
auto
block_log_diff_max_sum
=
loss_data
[
blockIdx
.
x
];
auto
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
softmax
[
beg_idx
]
=
real_exp
(
tmp
);
auto
loss
=
-
labels_data
[
beg_idx
]
*
tmp
;
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
softmax
[
beg_idx
]
=
real_exp
(
tmp
);
loss
-=
(
labels_data
[
beg_idx
]
*
tmp
);
beg_idx
+=
BlockDim
;
}
loss
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
loss
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
loss_data
[
blockIdx
.
x
]
=
loss
;
}
template
<
typename
T
>
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
auto
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
batch_size
)
out
[
idx
]
=
static_cast
<
T
>
(
1
);
}
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
softmax_data
,
T
*
loss_data
,
int
batch_size
,
int
feature_size
,
cudaStream_t
stream
)
{
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
feature_size
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
feature_size
)));
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, \
BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
RowReductionForSoftmaxAndCrossEntropy< \
T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, softmax_data, feature_size); \
break
switch
(
block_dim
)
{
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
256
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
128
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
64
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
32
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
16
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
batch_size
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
,
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
}
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template
<
typename
T
>
template
<
typename
T
>
class
SoftmaxWithCrossEntropyCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SoftmaxWithCrossEntropyCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
...
@@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
Tensor
*
softmax
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
softmax
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
softmax_data
=
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
loss_data
=
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
auto
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
context
.
cuda_device_context
(),
logits
,
softmax
);
if
(
soft_label
)
{
int
batch_size
=
logits
->
dims
()[
0
];
int
feature_size
=
logits
->
dims
()[
1
];
auto
*
logits_data
=
logits
->
data
<
T
>
();
auto
*
labels_data
=
labels
->
data
<
T
>
();
SoftmaxWithCrossEntropyFusedKernel
(
logits_data
,
labels_data
,
softmax_data
,
loss_data
,
batch_size
,
feature_size
,
context
.
cuda_device_context
().
stream
());
}
else
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
);
context
.
Attr
<
bool
>
(
"soft_label"
));
}
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录