Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OPTHREE
Paddle
提交
63abd500
P
Paddle
项目概览
OPTHREE
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
63abd500
编写于
4月 14, 2021
作者:
X
xingfeng01
提交者:
GitHub
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
softmax reconstruction and optimization (#31821)
上级
8552a182
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
504 addition
and
354 deletion
+504
-354
paddle/fluid/operators/softmax_cudnn_op.cu
paddle/fluid/operators/softmax_cudnn_op.cu
+449
-354
paddle/fluid/operators/softmax_impl.cuh
paddle/fluid/operators/softmax_impl.cuh
+47
-0
paddle/fluid/operators/softmax_op.h
paddle/fluid/operators/softmax_op.h
+8
-0
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu
浏览文件 @
63abd500
...
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_impl.cuh"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#ifdef PADDLE_WITH_HIP
...
...
@@ -21,7 +23,6 @@ limitations under the License. */
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -37,288 +38,414 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using
DataLayout
=
platform
::
DataLayout
;
using
Tensor
=
framework
::
Tensor
;
#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \
case Log2Elements: \
WarpSoftmaxForward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
out_data, x->data<T>(), N, dim, dim); \
break;
#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \
case Log2Elements: \
softmax_warp_backward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dx_data, mul_grad.data<T>(), out->data<T>(), N, dim, dim); \
break;
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
}
return
size
;
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
,
int
VLEN
>
union
vec_t
{
static_assert
(
sizeof
(
T
)
==
-
1
,
"vec_t is only available by specialization."
);
// Vectorization trait 4 * sizeof(T)
template
<
typename
T
>
class
VecT4
{};
template
<
>
class
VecT4
<
double
>
{
public:
using
Type
=
long4
;
};
template
<
>
union
vec_t
<
float
,
4
>
{
float4
s
;
float
v
[
4
];
class
VecT4
<
float
>
{
public:
using
Type
=
int4
;
};
template
<
>
class
VecT4
<
platform
::
float16
>
{
public:
using
Type
=
int2
;
};
// Vectorization trait 2 * sizeof(T)
template
<
typename
T
>
class
VecT2
{};
template
<
>
union
vec_t
<
platform
::
float16
,
4
>
{
int2
s
;
platform
::
float16
v
[
4
];
class
VecT2
<
double
>
{
public:
using
Type
=
int4
;
};
template
<
>
class
VecT2
<
float
>
{
public:
using
Type
=
int2
;
};
template
<
>
class
VecT2
<
platform
::
float16
>
{
public:
using
Type
=
int
;
};
template
<
typename
T
,
typename
VECT
,
int
VPT
,
int
WARP_PER_BLOCK
>
__global__
void
VecSoftmaxForward
(
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
softmax_ele
)
{
int
offset
=
blockIdx
.
x
*
softmax_ele
*
WARP_PER_BLOCK
;
int
idx
=
threadIdx
.
x
*
VPT
;
VECT
buf
=
reinterpret_cast
<
const
VECT
*>
(
&
src
[
offset
+
idx
])[
0
];
T
*
bufp
=
reinterpret_cast
<
T
*>
(
&
buf
);
float4
val4
;
float
*
val4p
=
reinterpret_cast
<
float
*>
(
&
val4
);
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
val4p
[
i
]
=
static_cast
<
float
>
(
bufp
[
i
]);
}
float
val
=
val4
.
x
+
val4
.
y
+
val4
.
z
+
val4
.
w
;
float
max_val
=
math
::
warpReduceMax
<
float
>
(
max
(
max
(
val4
.
x
,
val4
.
y
),
max
(
val4
.
z
,
val4
.
w
)),
0xffffffff
);
float4
tmp4
=
make_float4
(
__expf
(
val4
.
x
-
max_val
),
__expf
(
val4
.
y
-
max_val
),
__expf
(
val4
.
z
-
max_val
),
__expf
(
val4
.
w
-
max_val
));
float
*
tmp4p
=
reinterpret_cast
<
float
*>
(
&
tmp4
);
float
invsum
=
1.
f
/
(
math
::
warpReduceSum
<
float
>
(
tmp4
.
x
+
tmp4
.
y
+
tmp4
.
z
+
tmp4
.
w
,
0xffffffff
)
+
1e-6
f
);
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
bufp
[
i
]
=
static_cast
<
T
>
(
tmp4p
[
i
]
*
invsum
);
}
reinterpret_cast
<
VECT
*>
(
&
dst
[
offset
+
idx
])[
0
]
=
buf
;
int
static
inline
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
,
int
WARP_BATCH
,
int
WARP_SIZE_SOFTMAX
>
__device__
__forceinline__
void
warp_reduce_sum
(
T
*
sum
)
{
/*
Core function of computing softmax forward for axis=-1.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxForward
(
T
*
softmax
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
// max index to read
int
idx_max_v
[
kBatchSize
];
#pragma unroll
for
(
int
offset
=
WARP_SIZE_SOFTMAX
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
int
idx_max
=
((
i
+
first_batch
)
<
batch_size
)
?
element_count
:
0
;
idx_max_v
[
i
]
=
idx_max
/
kVSize
;
}
}
template
<
typename
T
,
int
WARP_BATCH
,
int
WARP_SIZE_SOFTMAX
>
__device__
__forceinline__
void
warp_reduce_max
(
T
*
sum
)
{
// read data from global memory
AccT
srcdata
[
kBatchSize
][
kIterationsV
][
kVSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// read data
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
src_idx
<
idx_max_v
[
i
])
{
srcdata
[
i
][
it
][
0
]
=
static_cast
<
AccT
>
(
src
[(
first_batch
+
i
)
*
stride
+
src_idx
]);
}
else
{
srcdata
[
i
][
it
][
0
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
else
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
if
(
src_idx
<
idx_max_v
[
i
])
{
VecT
srctmp
=
src_v
[
src_idx
];
const
T
*
srcinptr
=
reinterpret_cast
<
const
T
*>
(
&
srctmp
);
#pragma unroll
for
(
int
offset
=
WARP_SIZE_SOFTMAX
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
static_cast
<
AccT
>
(
srcinptr
[
s
]);
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
}
}
}
}
template
<
typename
T
,
typename
AccT
,
int
Log2Elements
>
__global__
void
WarpSoftmaxForward
(
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
next_power_of_two
=
1
<<
Log2Elements
;
constexpr
int
warp_size_softmax
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
warp_size_softmax
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
{
local_batches
=
WARP_BATCH
;
}
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_id
x
;
// compute max value
AccT
max_value
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
AccT
valmax
=
srcdata
[
i
][
0
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
0
][
s
])
?
valmax
:
srcdata
[
i
][
0
][
s
];
}
max_value
[
i
]
=
valma
x
;
// load data from global memory
AccT
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
static_cast
<
float
>
(
src
[
i
*
element_count
+
it
*
warp_size_softmax
]);
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
AccT
valmax
=
srcdata
[
i
][
it
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
it
][
s
])
?
valmax
:
srcdata
[
i
][
it
][
s
];
}
max_value
[
i
]
=
(
max_value
[
i
]
>
valmax
)
?
max_value
[
i
]
:
valmax
;
}
}
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max_value
);
// compute
max_value
AccT
max_value
[
WARP_BATCH
];
// compute
sum
AccT
sum
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
if
(
LogMode
)
{
sum
[
i
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
0
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
sum
[
i
]
=
srcdata
[
i
][
0
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
s
]
=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
0
][
s
];
}
}
}
warp_reduce_max
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
max_value
);
AccT
sum
[
WARP_BATCH
]{
0.0
f
};
// it = 1, 2, ...
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
(
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
])));
sum
[
i
]
+=
elements
[
i
][
it
];
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
it
][
s
]
=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
it
][
s
];
}
}
}
}
warp_reduce_sum
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
sum
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
//
store result
//
write result to global memory
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
LogMode
)
{
sum
[
i
]
=
std
::
log
(
sum
[
i
]);
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
warp_size_softmax
]
=
elements
[
i
][
it
]
/
sum
[
i
];
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
idx
<
idx_max_v
[
i
])
{
if
(
LogMode
)
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
/
sum
[
i
];
}
}
else
{
break
;
}
}
else
{
break
;
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[(
first_batch
+
i
)
*
stride
]);
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
/
sum
[
i
];
}
}
if
(
idx
<
idx_max_v
[
i
])
{
softmax_v
[
idx
]
=
tmpdata
;
}
else
{
break
;
}
}
}
}
}
template
<
typename
T
,
typename
AccT
,
int
Log2Elements
>
__global__
void
softmax_warp_backward
(
T
*
gradInput
,
const
T
*
grad
,
const
T
*
output
,
int
batch_size
,
int
stride
,
int
element_count
)
{
constexpr
int
next_power_of_two
=
1
<<
Log2Elements
;
constexpr
int
warp_size_softmax
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
warp_size_softmax
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
/*
Core function of computing softmax backward for axis=-1.
The computation includes
- Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
- Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
int
batch_size
,
int
stride
,
int
element_count
)
{
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
int
element_count_v
=
element_count
/
kVSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
{
local_batches
=
WARP_BATCH
;
if
(
local_batches
>
kBatchSize
)
{
local_batches
=
kBatchSize
;
}
int
local_idx
=
threadIdx
.
x
%
warp_size_softmax
;
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
AccT
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
AccT
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
static_cast
<
AccT
>
(
grad
[
i
*
element_count
+
it
*
warp_size_softmax
]);
output_reg
[
i
][
it
]
=
static_cast
<
AccT
>
(
output
[
i
*
element_count
+
it
*
warp_size_softmax
]);
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kIterationsV
];
VecT
grad_reg
[
kBatchSize
][
kIterationsV
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[(
first_batch
+
i
)
*
stride
]);
// max index to read
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
// read data
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
src_idx
<
idx_max_v
)
{
src_reg
[
i
][
it
]
=
src_v
[
src_idx
];
grad_reg
[
i
][
it
]
=
grad_v
[
src_idx
];
}
else
{
grad_reg
[
i
][
it
]
=
AccT
(
0
);
output_reg
[
i
][
it
]
=
AccT
(
0
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
])[
s
]
=
0.0
;
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
])[
s
]
=
0.0
;
}
}
}
}
AccT
sum
[
WARP_BATCH
];
// compute sum
AccT
sum
[
kBatchSize
]{
0.0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]);
}
else
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]
*
srcptr
[
s
]);
}
}
}
}
warp_reduce_sum
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
sum
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
//
stor
e result
//
writ
e result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
// max index to write
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
+
it
*
warp_size_softmax
]
=
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
gradptr
[
s
])
-
std
::
exp
(
static_cast
<
AccT
>
(
srcptr
[
s
]))
*
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
srcptr
[
s
])
*
(
static_cast
<
AccT
>
(
gradptr
[
s
])
-
sum
[
i
]);
}
}
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
idx
<
idx_max_v
)
{
dst_v
[
idx
]
=
tmpdata
;
}
}
}
}
template
<
typename
T
>
__global__
void
MultiplyCUDAKernel
(
T
*
C
,
const
T
*
A
,
const
T
*
B
,
int
N
)
{
CUDA_KERNEL_LOOP
(
i
,
N
)
{
C
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
A
[
i
])
*
static_cast
<
float
>
(
B
[
i
]));
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxForward< \
T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dst, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax formward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxForward
(
const
int
blocks
,
const
dim3
threads
,
const
framework
::
ExecutionContext
&
ctx
,
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_FORWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
softmax_ele
)
{
const
int
offset
=
blockIdx
.
x
*
softmax_ele
*
WARP_PER_BLOCK
+
threadIdx
.
x
*
VPT
;
float
local_sum_gy
=
0.
f
;
vec_t
<
T
,
VPT
>
local_grad
;
vec_t
<
T
,
VPT
>
local_src
;
local_grad
.
s
=
reinterpret_cast
<
const
decltype
(
local_grad
.
s
)
*>
(
&
grad
[
offset
])[
0
];
local_src
.
s
=
reinterpret_cast
<
const
decltype
(
local_src
.
s
)
*>
(
&
src
[
offset
])[
0
];
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
local_sum_gy
+=
static_cast
<
float
>
(
local_grad
.
v
[
i
])
*
static_cast
<
float
>
(
local_src
.
v
[
i
]);
}
float
sum_gy
=
math
::
warpReduceSum
<
float
>
(
local_sum_gy
,
0xffffffff
);
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxBackward< \
T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dst, grad, src, batch_size, stride, element_count); \
break;
vec_t
<
T
,
VPT
>
local_dst
;
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
local_dst
.
v
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
local_src
.
v
[
i
])
*
(
static_cast
<
float
>
(
local_grad
.
v
[
i
])
-
sum_gy
));
/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxBackward
(
const
int
blocks
,
const
dim3
threads
,
const
framework
::
ExecutionContext
&
ctx
,
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_BACKWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
9
,
AccT
);
default:
break
;
}
reinterpret_cast
<
decltype
(
local_dst
.
s
)
*>
(
&
dst
[
offset
])[
0
]
=
local_dst
.
s
;
}
template
<
typename
T
>
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE
template
<
typename
T
,
bool
LogMode
=
false
>
class
SoftmaxCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -335,60 +462,39 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
bool
optimize
=
false
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
if
(
dim
==
128
&&
N
%
warps_per_block
==
0
)
{
optimize
=
true
;
// a warp for a batch, 4 elements for a thread, only support the softmax
// dim size = 128 currently
if
(
sizeof
(
T
)
==
2
)
{
VecSoftmaxForward
<
T
,
int2
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
}
else
if
(
sizeof
(
T
)
==
4
)
{
VecSoftmaxForward
<
T
,
int4
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
}
else
{
assert
(
false
&&
"not support"
);
}
}
else
if
(
dim
<
max_dim
)
{
optimize
=
true
;
int
log2_elements
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
log2_elements
)
{
LAUNCH_SOFTMAX_WARP_FORWARD
(
0
);
// 1
LAUNCH_SOFTMAX_WARP_FORWARD
(
1
);
// 2
LAUNCH_SOFTMAX_WARP_FORWARD
(
2
);
// 4
LAUNCH_SOFTMAX_WARP_FORWARD
(
3
);
// 8
LAUNCH_SOFTMAX_WARP_FORWARD
(
4
);
// 16
LAUNCH_SOFTMAX_WARP_FORWARD
(
5
);
// 32
LAUNCH_SOFTMAX_WARP_FORWARD
(
6
);
// 64
LAUNCH_SOFTMAX_WARP_FORWARD
(
7
);
// 128
LAUNCH_SOFTMAX_WARP_FORWARD
(
8
);
// 256
LAUNCH_SOFTMAX_WARP_FORWARD
(
9
);
// 512
default:
break
;
}
const
int
kDimLog2
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
32
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
{
SwitchWarpSoftmaxForward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
}
if
(
!
optimize
)
{
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
...
...
@@ -405,22 +511,37 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
#endif
}
}
};
template
<
typename
T
>
template
<
typename
T
,
bool
LogMode
=
false
>
class
SoftmaxGradCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -437,78 +558,38 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
warps_per_block
=
4
;
constexpr
bool
warp_softmax_available
=
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
platform
::
float16
>::
value
;
bool
optimize
=
false
;
if
(
D
==
1
&&
warp_softmax_available
)
{
if
(
dim
==
128
&&
N
%
warps_per_block
==
0
)
{
optimize
=
true
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
VecSoftmaxBackward
<
float
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dx
->
data
<
float
>
(),
dout
->
data
<
float
>
(),
out
->
data
<
float
>
(),
N
,
dim
);
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
VecSoftmaxBackward
<
platform
::
float16
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dx
->
data
<
platform
::
float16
>
(),
dout
->
data
<
platform
::
float16
>
(),
out
->
data
<
platform
::
float16
>
(),
N
,
dim
);
}
else
{
PADDLE_ENFORCE_EQ
(
warp_softmax_available
,
true
,
platform
::
errors
::
Unimplemented
(
"Warp softmax backward is only available for fp32 and fp16"
));
}
}
else
if
(
dim
<
40
&&
dim
%
32
!=
0
)
{
optimize
=
true
;
Tensor
mul_grad
;
int
numel
=
N
*
dim
;
mul_grad
.
mutable_data
<
T
>
({
numel
},
ctx
.
GetPlace
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
config
=
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
MultiplyCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
mul_grad
.
data
<
T
>
(),
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
numel
);
int
log2_elements
=
log2_ceil
(
dim
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
log2_elements
)
{
LAUNCH_SOFTMAX_WARP_BACKWARD
(
0
);
// 1
LAUNCH_SOFTMAX_WARP_BACKWARD
(
1
);
// 2
LAUNCH_SOFTMAX_WARP_BACKWARD
(
2
);
// 4
LAUNCH_SOFTMAX_WARP_BACKWARD
(
3
);
// 8
LAUNCH_SOFTMAX_WARP_BACKWARD
(
4
);
// 16
LAUNCH_SOFTMAX_WARP_BACKWARD
(
5
);
// 32
LAUNCH_SOFTMAX_WARP_BACKWARD
(
6
);
// 64
LAUNCH_SOFTMAX_WARP_BACKWARD
(
7
);
// 128
LAUNCH_SOFTMAX_WARP_BACKWARD
(
8
);
// 256
LAUNCH_SOFTMAX_WARP_BACKWARD
(
9
);
// 512
default:
break
;
}
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
log2_ceil
(
dim
);
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
else
{
SwitchWarpSoftmaxBackward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
}
}
if
(
!
optimize
)
{
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
...
...
@@ -525,18 +606,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
if
(
LogMode
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
#endif
}
}
...
...
paddle/fluid/operators/softmax_impl.cuh
0 → 100755
浏览文件 @
63abd500
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/cuda_device_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceSum
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceMax
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
}
// namespace operators
}
// namespace paddle
\ No newline at end of file
paddle/fluid/operators/softmax_op.h
浏览文件 @
63abd500
...
...
@@ -45,6 +45,14 @@ static inline int SizeFromAxis(const int axis, DDim dims) {
return
size
;
}
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
}
return
size
;
}
template
<
typename
DeviceContext
,
typename
T
>
class
SoftmaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录