Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ef61da86
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef61da86
编写于
9月 08, 2021
作者:
L
Li Min
提交者:
GitHub
9月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor softmax_cudnn kernel impl for code reuse. (#35350)
上级
1159f753
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
649 addition
and
644 deletion
+649
-644
paddle/fluid/operators/softmax_cudnn_op.cu
paddle/fluid/operators/softmax_cudnn_op.cu
+7
-597
paddle/fluid/operators/softmax_cudnn_op.cu.h
paddle/fluid/operators/softmax_cudnn_op.cu.h
+642
-0
paddle/fluid/operators/softmax_impl.cuh
paddle/fluid/operators/softmax_impl.cuh
+0
-47
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu
浏览文件 @
ef61da86
...
...
@@ -13,438 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_impl.cuh"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
platform
{
struct
CUDAPlace
;
struct
float16
;
}
// namespace platform
}
// namespace paddle
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
namespace
paddle
{
namespace
operators
{
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
DataLayout
=
platform
::
DataLayout
;
using
Tensor
=
framework
::
Tensor
;
// Vectorization trait 4 * sizeof(T)
template
<
typename
T
>
class
VecT4
{};
template
<
>
class
VecT4
<
double
>
{
public:
using
Type
=
long4
;
};
template
<
>
class
VecT4
<
float
>
{
public:
using
Type
=
int4
;
};
template
<
>
class
VecT4
<
platform
::
float16
>
{
public:
using
Type
=
int2
;
};
// Vectorization trait 2 * sizeof(T)
template
<
typename
T
>
class
VecT2
{};
template
<
>
class
VecT2
<
double
>
{
public:
using
Type
=
int4
;
};
template
<
>
class
VecT2
<
float
>
{
public:
using
Type
=
int2
;
};
template
<
>
class
VecT2
<
platform
::
float16
>
{
public:
using
Type
=
int
;
};
int
static
inline
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
/*
Core function of computing softmax forward for axis=-1.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxForward
(
T
*
softmax
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
// max index to read
int
idx_max_v
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
int
idx_max
=
((
i
+
first_batch
)
<
batch_size
)
?
element_count
:
0
;
idx_max_v
[
i
]
=
idx_max
/
kVSize
;
}
// read data from global memory
AccT
srcdata
[
kBatchSize
][
kIterationsV
][
kVSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// read data
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
src_idx
<
idx_max_v
[
i
])
{
srcdata
[
i
][
it
][
0
]
=
static_cast
<
AccT
>
(
src
[(
first_batch
+
i
)
*
stride
+
src_idx
]);
}
else
{
srcdata
[
i
][
it
][
0
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
else
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
if
(
src_idx
<
idx_max_v
[
i
])
{
VecT
srctmp
=
src_v
[
src_idx
];
const
T
*
srcinptr
=
reinterpret_cast
<
const
T
*>
(
&
srctmp
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
static_cast
<
AccT
>
(
srcinptr
[
s
]);
}
}
else
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
}
}
}
// compute max value
AccT
max_value
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
AccT
valmax
=
srcdata
[
i
][
0
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
0
][
s
])
?
valmax
:
srcdata
[
i
][
0
][
s
];
}
max_value
[
i
]
=
valmax
;
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
AccT
valmax
=
srcdata
[
i
][
it
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
it
][
s
])
?
valmax
:
srcdata
[
i
][
it
][
s
];
}
max_value
[
i
]
=
(
max_value
[
i
]
>
valmax
)
?
max_value
[
i
]
:
valmax
;
}
}
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max_value
);
// compute sum
AccT
sum
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
if
(
LogMode
)
{
sum
[
i
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
0
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
sum
[
i
]
=
srcdata
[
i
][
0
][
0
];
}
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
s
]
=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
0
][
s
];
}
}
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
it
][
s
]
=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
it
][
s
];
}
}
}
}
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result to global memory
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
LogMode
)
{
sum
[
i
]
=
std
::
log
(
sum
[
i
]);
}
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
idx
<
idx_max_v
[
i
])
{
if
(
LogMode
)
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
/
sum
[
i
];
}
}
else
{
break
;
}
}
else
{
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[(
first_batch
+
i
)
*
stride
]);
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
/
sum
[
i
];
}
}
if
(
idx
<
idx_max_v
[
i
])
{
softmax_v
[
idx
]
=
tmpdata
;
}
else
{
break
;
}
}
}
}
}
/*
Core function of computing softmax backward for axis=-1.
The computation includes
- Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
- Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
int
batch_size
,
int
stride
,
int
element_count
)
{
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
int
element_count_v
=
element_count
/
kVSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
kBatchSize
)
{
local_batches
=
kBatchSize
;
}
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kIterationsV
];
VecT
grad_reg
[
kBatchSize
][
kIterationsV
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[(
first_batch
+
i
)
*
stride
]);
// max index to read
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
// read data
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
src_idx
<
idx_max_v
)
{
src_reg
[
i
][
it
]
=
src_v
[
src_idx
];
grad_reg
[
i
][
it
]
=
grad_v
[
src_idx
];
}
else
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
])[
s
]
=
0.0
;
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
])[
s
]
=
0.0
;
}
}
}
}
// compute sum
AccT
sum
[
kBatchSize
]{
0.0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]);
}
else
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]
*
srcptr
[
s
]);
}
}
}
}
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
// max index to write
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
gradptr
[
s
])
-
std
::
exp
(
static_cast
<
AccT
>
(
srcptr
[
s
]))
*
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
srcptr
[
s
])
*
(
static_cast
<
AccT
>
(
gradptr
[
s
])
-
sum
[
i
]);
}
}
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
idx
<
idx_max_v
)
{
dst_v
[
idx
]
=
tmpdata
;
}
}
}
}
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxForward< \
T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dst, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax formward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxForward
(
const
int
blocks
,
const
dim3
threads
,
const
framework
::
ExecutionContext
&
ctx
,
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_FORWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxBackward< \
T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dst, grad, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxBackward
(
const
int
blocks
,
const
dim3
threads
,
const
framework
::
ExecutionContext
&
ctx
,
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_BACKWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE
template
<
typename
T
,
bool
LogMode
=
false
>
class
SoftmaxCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -452,92 +25,10 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
out_data
=
out
->
data
<
T
>
();
auto
dims
=
x
->
dims
();
const
int
rank
=
dims
.
size
();
const
int
axis
=
CanonicalAxis
(
ctx
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
dim
=
dims
[
axis
];
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
32
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
{
SwitchWarpSoftmaxForward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#else
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#endif
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
#endif
}
int
input_axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx
,
*
x
,
input_axis
,
out
);
}
};
...
...
@@ -549,91 +40,10 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dx_data
=
dx
->
data
<
T
>
();
auto
dims
=
out
->
dims
();
const
int
rank
=
dims
.
size
();
const
int
axis
=
CanonicalAxis
(
ctx
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
dim
=
dims
[
axis
];
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
log2_ceil
(
dim
);
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
{
SwitchWarpSoftmaxBackward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#else
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#endif
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
#endif
}
int
input_axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
SoftmaxBackwardCUDAKernelDriver
<
T
>
(
dev_ctx
,
*
out
,
*
dout
,
input_axis
,
dx
);
}
};
...
...
paddle/fluid/operators/softmax_cudnn_op.cu.h
0 → 100644
浏览文件 @
ef61da86
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
DataLayout
=
platform
::
DataLayout
;
using
Tensor
=
framework
::
Tensor
;
// Vectorization trait 4 * sizeof(T)
template
<
typename
T
>
class
VecT4
{};
template
<
>
class
VecT4
<
double
>
{
public:
using
Type
=
long4
;
};
template
<
>
class
VecT4
<
float
>
{
public:
using
Type
=
int4
;
};
template
<
>
class
VecT4
<
platform
::
float16
>
{
public:
using
Type
=
int2
;
};
// Vectorization trait 2 * sizeof(T)
template
<
typename
T
>
class
VecT2
{};
template
<
>
class
VecT2
<
double
>
{
public:
using
Type
=
int4
;
};
template
<
>
class
VecT2
<
float
>
{
public:
using
Type
=
int2
;
};
template
<
>
class
VecT2
<
platform
::
float16
>
{
public:
using
Type
=
int
;
};
static
inline
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceSum
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceMax
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
/*
Core function of computing softmax forward for axis=-1.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxForward
(
T
*
softmax
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
// max index to read
int
idx_max_v
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
int
idx_max
=
((
i
+
first_batch
)
<
batch_size
)
?
element_count
:
0
;
idx_max_v
[
i
]
=
idx_max
/
kVSize
;
}
// read data from global memory
AccT
srcdata
[
kBatchSize
][
kIterationsV
][
kVSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// read data
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
src_idx
<
idx_max_v
[
i
])
{
srcdata
[
i
][
it
][
0
]
=
static_cast
<
AccT
>
(
src
[(
first_batch
+
i
)
*
stride
+
src_idx
]);
}
else
{
srcdata
[
i
][
it
][
0
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
else
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
if
(
src_idx
<
idx_max_v
[
i
])
{
VecT
srctmp
=
src_v
[
src_idx
];
const
T
*
srcinptr
=
reinterpret_cast
<
const
T
*>
(
&
srctmp
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
static_cast
<
AccT
>
(
srcinptr
[
s
]);
}
}
else
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
}
}
}
// compute max value
AccT
max_value
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
AccT
valmax
=
srcdata
[
i
][
0
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
0
][
s
])
?
valmax
:
srcdata
[
i
][
0
][
s
];
}
max_value
[
i
]
=
valmax
;
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
AccT
valmax
=
srcdata
[
i
][
it
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
it
][
s
])
?
valmax
:
srcdata
[
i
][
it
][
s
];
}
max_value
[
i
]
=
(
max_value
[
i
]
>
valmax
)
?
max_value
[
i
]
:
valmax
;
}
}
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max_value
);
// compute sum
AccT
sum
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
if
(
LogMode
)
{
sum
[
i
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
0
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
sum
[
i
]
=
srcdata
[
i
][
0
][
0
];
}
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
s
]
=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
0
][
s
];
}
}
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
it
][
s
]
=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
it
][
s
];
}
}
}
}
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result to global memory
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
LogMode
)
{
sum
[
i
]
=
std
::
log
(
sum
[
i
]);
}
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
idx
<
idx_max_v
[
i
])
{
if
(
LogMode
)
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
/
sum
[
i
];
}
}
else
{
break
;
}
}
else
{
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[(
first_batch
+
i
)
*
stride
]);
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
/
sum
[
i
];
}
}
if
(
idx
<
idx_max_v
[
i
])
{
softmax_v
[
idx
]
=
tmpdata
;
}
else
{
break
;
}
}
}
}
}
/*
Core function of computing softmax backward for axis=-1.
The computation includes
- Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
- Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
int
batch_size
,
int
stride
,
int
element_count
)
{
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
int
element_count_v
=
element_count
/
kVSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
kBatchSize
)
{
local_batches
=
kBatchSize
;
}
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kIterationsV
];
VecT
grad_reg
[
kBatchSize
][
kIterationsV
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[(
first_batch
+
i
)
*
stride
]);
// max index to read
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
// read data
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
src_idx
<
idx_max_v
)
{
src_reg
[
i
][
it
]
=
src_v
[
src_idx
];
grad_reg
[
i
][
it
]
=
grad_v
[
src_idx
];
}
else
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
])[
s
]
=
0.0
;
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
])[
s
]
=
0.0
;
}
}
}
}
// compute sum
AccT
sum
[
kBatchSize
]{
0.0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]);
}
else
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]
*
srcptr
[
s
]);
}
}
}
}
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
// max index to write
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
gradptr
[
s
])
-
std
::
exp
(
static_cast
<
AccT
>
(
srcptr
[
s
]))
*
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
srcptr
[
s
])
*
(
static_cast
<
AccT
>
(
gradptr
[
s
])
-
sum
[
i
]);
}
}
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
idx
<
idx_max_v
)
{
dst_v
[
idx
]
=
tmpdata
;
}
}
}
}
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax formward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxForward
(
const
int
blocks
,
const
dim3
threads
,
const
platform
::
CUDADeviceContext
&
dev_ctx
,
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_FORWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxBackward<T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, grad, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxBackward
(
const
int
blocks
,
const
dim3
threads
,
const
platform
::
CUDADeviceContext
&
dev_ctx
,
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_BACKWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE
template
<
typename
T
,
bool
LogMode
=
false
>
void
SoftmaxForwardCUDAKernelDriver
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
Tensor
&
x
,
const
int
input_axis
,
Tensor
*
out
)
{
auto
*
out_data
=
out
->
data
<
T
>
();
auto
dims
=
x
.
dims
();
const
int
rank
=
dims
.
size
();
const
int
axis
=
CanonicalAxis
(
input_axis
,
rank
);
const
int
dim
=
dims
[
axis
];
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
32
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
{
SwitchWarpSoftmaxForward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#else
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#endif
auto
handle
=
dev_ctx
.
cudnn_handle
();
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
#endif
}
}
template
<
typename
T
,
bool
LogMode
=
false
>
void
SoftmaxBackwardCUDAKernelDriver
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
Tensor
&
out
,
const
Tensor
&
dout
,
const
int
input_axis
,
Tensor
*
dx
)
{
auto
*
dx_data
=
dx
->
data
<
T
>
();
auto
dims
=
out
.
dims
();
const
int
rank
=
dims
.
size
();
const
int
axis
=
CanonicalAxis
(
input_axis
,
rank
);
const
int
dim
=
dims
[
axis
];
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
log2_ceil
(
dim
);
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
{
SwitchWarpSoftmaxBackward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#else
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#endif
auto
handle
=
dev_ctx
.
cudnn_handle
();
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
#endif
}
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/softmax_impl.cuh
已删除
100755 → 0
浏览文件 @
1159f753
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/cuda_device_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceSum
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceMax
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
}
// namespace operators
}
// namespace paddle
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录