Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
63abd500
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
63abd500
编写于
4月 14, 2021
作者:
X
xingfeng01
提交者:
GitHub
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
softmax reconstruction and optimization (#31821)
上级
8552a182
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
504 addition
and
354 deletion
+504
-354
paddle/fluid/operators/softmax_cudnn_op.cu
paddle/fluid/operators/softmax_cudnn_op.cu
+449
-354
paddle/fluid/operators/softmax_impl.cuh
paddle/fluid/operators/softmax_impl.cuh
+47
-0
paddle/fluid/operators/softmax_op.h
paddle/fluid/operators/softmax_op.h
+8
-0
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu
浏览文件 @
63abd500
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_impl.cuh"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
...
@@ -21,7 +23,6 @@ limitations under the License. */
...
@@ -21,7 +23,6 @@ limitations under the License. */
#else
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#endif
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -37,288 +38,414 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
...
@@ -37,288 +38,414 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using
DataLayout
=
platform
::
DataLayout
;
using
DataLayout
=
platform
::
DataLayout
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \
// Vectorization trait 4 * sizeof(T)
case Log2Elements: \
template
<
typename
T
>
WarpSoftmaxForward<T, float, Log2Elements><<< \
class
VecT4
{};
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
template
<
>
out_data, x->data<T>(), N, dim, dim); \
class
VecT4
<
double
>
{
break;
public:
using
Type
=
long4
;
#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \
case Log2Elements: \
softmax_warp_backward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dx_data, mul_grad.data<T>(), out->data<T>(), N, dim, dim); \
break;
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
}
return
size
;
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
,
int
VLEN
>
union
vec_t
{
static_assert
(
sizeof
(
T
)
==
-
1
,
"vec_t is only available by specialization."
);
};
};
template
<
>
template
<
>
union
vec_t
<
float
,
4
>
{
class
VecT4
<
float
>
{
float4
s
;
public:
float
v
[
4
];
using
Type
=
int4
;
};
template
<
>
class
VecT4
<
platform
::
float16
>
{
public:
using
Type
=
int2
;
};
};
// Vectorization trait 2 * sizeof(T)
template
<
typename
T
>
class
VecT2
{};
template
<
>
template
<
>
union
vec_t
<
platform
::
float16
,
4
>
{
class
VecT2
<
double
>
{
int2
s
;
public:
platform
::
float16
v
[
4
];
using
Type
=
int4
;
};
template
<
>
class
VecT2
<
float
>
{
public:
using
Type
=
int2
;
};
template
<
>
class
VecT2
<
platform
::
float16
>
{
public:
using
Type
=
int
;
};
};
template
<
typename
T
,
typename
VECT
,
int
VPT
,
int
WARP_PER_BLOCK
>
int
static
inline
log2_ceil
(
int
value
)
{
__global__
void
VecSoftmaxForward
(
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
int
log2_value
=
0
;
const
int
softmax_ele
)
{
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
int
offset
=
blockIdx
.
x
*
softmax_ele
*
WARP_PER_BLOCK
;
return
log2_value
;
int
idx
=
threadIdx
.
x
*
VPT
;
VECT
buf
=
reinterpret_cast
<
const
VECT
*>
(
&
src
[
offset
+
idx
])[
0
];
T
*
bufp
=
reinterpret_cast
<
T
*>
(
&
buf
);
float4
val4
;
float
*
val4p
=
reinterpret_cast
<
float
*>
(
&
val4
);
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
val4p
[
i
]
=
static_cast
<
float
>
(
bufp
[
i
]);
}
float
val
=
val4
.
x
+
val4
.
y
+
val4
.
z
+
val4
.
w
;
float
max_val
=
math
::
warpReduceMax
<
float
>
(
max
(
max
(
val4
.
x
,
val4
.
y
),
max
(
val4
.
z
,
val4
.
w
)),
0xffffffff
);
float4
tmp4
=
make_float4
(
__expf
(
val4
.
x
-
max_val
),
__expf
(
val4
.
y
-
max_val
),
__expf
(
val4
.
z
-
max_val
),
__expf
(
val4
.
w
-
max_val
));
float
*
tmp4p
=
reinterpret_cast
<
float
*>
(
&
tmp4
);
float
invsum
=
1.
f
/
(
math
::
warpReduceSum
<
float
>
(
tmp4
.
x
+
tmp4
.
y
+
tmp4
.
z
+
tmp4
.
w
,
0xffffffff
)
+
1e-6
f
);
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
bufp
[
i
]
=
static_cast
<
T
>
(
tmp4p
[
i
]
*
invsum
);
}
reinterpret_cast
<
VECT
*>
(
&
dst
[
offset
+
idx
])[
0
]
=
buf
;
}
}
template
<
typename
T
,
int
WARP_BATCH
,
int
WARP_SIZE_SOFTMAX
>
/*
__device__
__forceinline__
void
warp_reduce_sum
(
T
*
sum
)
{
Core function of computing softmax forward for axis=-1.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxForward
(
T
*
softmax
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
// max index to read
int
idx_max_v
[
kBatchSize
];
#pragma unroll
#pragma unroll
for
(
int
offset
=
WARP_SIZE_SOFTMAX
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
#pragma unroll
int
idx_max
=
((
i
+
first_batch
)
<
batch_size
)
?
element_count
:
0
;
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
idx_max_v
[
i
]
=
idx_max
/
kVSize
;
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
}
template
<
typename
T
,
int
WARP_BATCH
,
int
WARP_SIZE_SOFTMAX
>
// read data from global memory
__device__
__forceinline__
void
warp_reduce_max
(
T
*
sum
)
{
AccT
srcdata
[
kBatchSize
][
kIterationsV
][
kVSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// read data
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
src_idx
<
idx_max_v
[
i
])
{
srcdata
[
i
][
it
][
0
]
=
static_cast
<
AccT
>
(
src
[(
first_batch
+
i
)
*
stride
+
src_idx
]);
}
else
{
srcdata
[
i
][
it
][
0
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
else
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
if
(
src_idx
<
idx_max_v
[
i
])
{
VecT
srctmp
=
src_v
[
src_idx
];
const
T
*
srcinptr
=
reinterpret_cast
<
const
T
*>
(
&
srctmp
);
#pragma unroll
#pragma unroll
for
(
int
offset
=
WARP_SIZE_SOFTMAX
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
static_cast
<
AccT
>
(
srcinptr
[
s
]);
}
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
srcdata
[
i
][
it
][
s
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
}
}
}
}
}
template
<
typename
T
,
typename
AccT
,
int
Log2Elements
>
__global__
void
WarpSoftmaxForward
(
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
next_power_of_two
=
1
<<
Log2Elements
;
constexpr
int
warp_size_softmax
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
warp_size_softmax
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// compute max value
AccT
max_value
[
kBatchSize
];
int
local_batches
=
batch_size
-
first_batch
;
#pragma unroll
if
(
local_batches
>
WARP_BATCH
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
local_batches
=
WARP_BATCH
;
// it = 0
}
AccT
valmax
=
srcdata
[
i
][
0
][
0
];
#pragma unroll
int
local_idx
=
threadIdx
.
x
;
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
0
][
s
])
?
valmax
:
srcdata
[
i
][
0
][
s
];
src
+=
first_batch
*
stride
+
local_idx
;
}
dst
+=
first_batch
*
stride
+
local_id
x
;
max_value
[
i
]
=
valma
x
;
// load data from global memory
// it = 1, 2, ...
AccT
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
AccT
valmax
=
srcdata
[
i
][
it
][
0
];
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
#pragma unroll
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
if
(
element_index
<
batch_element_count
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
it
][
s
])
?
valmax
:
srcdata
[
i
][
it
][
s
];
elements
[
i
][
it
]
=
static_cast
<
float
>
(
src
[
i
*
element_count
+
it
*
warp_size_softmax
]);
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
max_value
[
i
]
=
(
max_value
[
i
]
>
valmax
)
?
max_value
[
i
]
:
valmax
;
}
}
}
}
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max_value
);
// compute
max_value
// compute
sum
AccT
max_value
[
WARP_BATCH
];
AccT
sum
[
kBatchSize
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
// it = 0
if
(
LogMode
)
{
sum
[
i
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
0
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
sum
[
i
]
=
srcdata
[
i
][
0
][
0
];
}
#pragma unroll
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
max_value
[
i
]
=
if
(
LogMode
)
{
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
s
]
=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
0
][
s
];
}
}
}
}
warp_reduce_max
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
max_value
);
AccT
sum
[
WARP_BATCH
]{
0.0
f
};
// it = 1, 2, ...
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
elements
[
i
][
it
]
=
(
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
])));
if
(
LogMode
)
{
sum
[
i
]
+=
elements
[
i
][
it
];
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
it
][
s
]
=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
it
][
s
];
}
}
}
}
}
}
warp_reduce_sum
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
sum
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
//
store result
//
write result to global memory
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
if
(
LogMode
)
{
sum
[
i
]
=
std
::
log
(
sum
[
i
]);
}
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
element_index
<
element_count
)
{
if
(
kVSize
==
1
)
{
dst
[
i
*
element_count
+
it
*
warp_size_softmax
]
=
if
(
idx
<
idx_max_v
[
i
])
{
elements
[
i
][
it
]
/
sum
[
i
];
if
(
LogMode
)
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
/
sum
[
i
];
}
}
else
{
break
;
}
}
else
{
}
else
{
break
;
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[(
first_batch
+
i
)
*
stride
]);
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
/
sum
[
i
];
}
}
if
(
idx
<
idx_max_v
[
i
])
{
softmax_v
[
idx
]
=
tmpdata
;
}
else
{
break
;
}
}
}
}
}
}
}
}
}
template
<
typename
T
,
typename
AccT
,
int
Log2Elements
>
/*
__global__
void
softmax_warp_backward
(
T
*
gradInput
,
const
T
*
grad
,
Core function of computing softmax backward for axis=-1.
const
T
*
output
,
int
batch_size
,
The computation includes
int
stride
,
int
element_count
)
{
- Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
constexpr
int
next_power_of_two
=
1
<<
Log2Elements
;
- Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
constexpr
int
warp_size_softmax
=
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
warp_size_softmax
;
api to compute max (sum) in one warp.
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
*/
template
<
typename
T
,
typename
VecT
,
typename
AccT
,
int
Log2Elements
,
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
bool
LogMode
=
false
>
__global__
void
WarpSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
int
batch_size
,
int
stride
,
int
element_count
)
{
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
int
element_count_v
=
element_count
/
kVSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
local_batches
=
batch_size
-
first_batch
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
{
if
(
local_batches
>
kBatchSize
)
{
local_batches
=
WARP_BATCH
;
local_batches
=
kBatchSize
;
}
}
int
local_idx
=
threadIdx
.
x
%
warp_size_softmax
;
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kIterationsV
];
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
VecT
grad_reg
[
kBatchSize
][
kIterationsV
];
grad
+=
thread_offset
;
output
+=
thread_offset
;
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
gradInput
+=
thread_offset
;
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
// load data from global memory
const
VecT
*
grad_v
=
AccT
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
reinterpret_cast
<
const
VecT
*>
(
&
grad
[(
first_batch
+
i
)
*
stride
]);
AccT
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
// max index to read
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
idx_max_v
=
idx_max
/
kVSize
;
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
batch_element_count
)
{
// read data
grad_reg
[
i
][
it
]
=
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
static_cast
<
AccT
>
(
grad
[
i
*
element_count
+
it
*
warp_size_softmax
]);
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
output_reg
[
i
][
it
]
=
static_cast
<
AccT
>
(
if
(
src_idx
<
idx_max_v
)
{
output
[
i
*
element_count
+
it
*
warp_size_softmax
]);
src_reg
[
i
][
it
]
=
src_v
[
src_idx
];
grad_reg
[
i
][
it
]
=
grad_v
[
src_idx
];
}
else
{
}
else
{
grad_reg
[
i
][
it
]
=
AccT
(
0
);
#pragma unroll
output_reg
[
i
][
it
]
=
AccT
(
0
);
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
])[
s
]
=
0.0
;
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
])[
s
]
=
0.0
;
}
}
}
}
}
}
}
AccT
sum
[
WARP_BATCH
];
// compute sum
AccT
sum
[
kBatchSize
]{
0.0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
if
(
LogMode
)
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]);
}
else
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]
*
srcptr
[
s
]);
}
}
}
}
}
}
warp_reduce_sum
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
sum
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
//
stor
e result
//
writ
e result
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
if
(
i
>=
local_batches
)
break
;
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
// max index to write
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
VecT
tmpdata
;
if
(
element_index
<
element_count
)
{
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
// compute gradients
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
gradInput
[
i
*
element_count
+
it
*
warp_size_softmax
]
=
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
gradptr
[
s
])
-
std
::
exp
(
static_cast
<
AccT
>
(
srcptr
[
s
]))
*
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
srcptr
[
s
])
*
(
static_cast
<
AccT
>
(
gradptr
[
s
])
-
sum
[
i
]);
}
}
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
idx
<
idx_max_v
)
{
dst_v
[
idx
]
=
tmpdata
;
}
}
}
}
}
}
}
}
template
<
typename
T
>
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
__global__
void
MultiplyCUDAKernel
(
T
*
C
,
const
T
*
A
,
const
T
*
B
,
int
N
)
{
case Log2Elements: \
CUDA_KERNEL_LOOP
(
i
,
N
)
{
WarpSoftmaxForward< \
C
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
A
[
i
])
*
static_cast
<
float
>
(
B
[
i
]));
T, VecT, AccT, Log2Elements, \
LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dst, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax formward with template instantiation on size of input.
*/
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
void
SwitchWarpSoftmaxForward
(
const
int
blocks
,
const
dim3
threads
,
const
framework
::
ExecutionContext
&
ctx
,
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_FORWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_FORWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
}
}
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
case Log2Elements: \
const
int
batch_size
,
WarpSoftmaxBackward< \
const
int
softmax_ele
)
{
T, VecT, AccT, Log2Elements, \
const
int
offset
=
LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
blockIdx
.
x
*
softmax_ele
*
WARP_PER_BLOCK
+
threadIdx
.
x
*
VPT
;
dst, grad, src, batch_size, stride, element_count); \
break;
float
local_sum_gy
=
0.
f
;
vec_t
<
T
,
VPT
>
local_grad
;
vec_t
<
T
,
VPT
>
local_src
;
local_grad
.
s
=
reinterpret_cast
<
const
decltype
(
local_grad
.
s
)
*>
(
&
grad
[
offset
])[
0
];
local_src
.
s
=
reinterpret_cast
<
const
decltype
(
local_src
.
s
)
*>
(
&
src
[
offset
])[
0
];
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
local_sum_gy
+=
static_cast
<
float
>
(
local_grad
.
v
[
i
])
*
static_cast
<
float
>
(
local_src
.
v
[
i
]);
}
float
sum_gy
=
math
::
warpReduceSum
<
float
>
(
local_sum_gy
,
0xffffffff
);
vec_t
<
T
,
VPT
>
local_dst
;
/*
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
Wrapper of softmax backward with template instantiation on size of input.
local_dst
.
v
[
i
]
=
*/
static_cast
<
T
>
(
static_cast
<
float
>
(
local_src
.
v
[
i
])
*
template
<
typename
T
,
typename
VecT
,
bool
LogMode
>
(
static_cast
<
float
>
(
local_grad
.
v
[
i
])
-
sum_gy
));
void
SwitchWarpSoftmaxBackward
(
const
int
blocks
,
const
dim3
threads
,
const
framework
::
ExecutionContext
&
ctx
,
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
,
int
Log2Elements
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
switch
(
Log2Elements
)
{
SOFTMAX_WARP_BACKWARD_CASE
(
0
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
1
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
2
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
3
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
4
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
5
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
6
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
7
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
8
,
AccT
);
SOFTMAX_WARP_BACKWARD_CASE
(
9
,
AccT
);
default:
break
;
}
}
reinterpret_cast
<
decltype
(
local_dst
.
s
)
*>
(
&
dst
[
offset
])[
0
]
=
local_dst
.
s
;
}
}
template
<
typename
T
>
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE
template
<
typename
T
,
bool
LogMode
=
false
>
class
SoftmaxCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SoftmaxCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -335,60 +462,39 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
...
@@ -335,60 +462,39 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
max_dim
=
320
;
bool
optimize
=
false
;
constexpr
int
warps_per_block
=
4
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
if
(
dim
==
128
&&
N
%
warps_per_block
==
0
)
{
const
int
kDimLog2
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
optimize
=
true
;
const
int
kDimCeil
=
1
<<
kDimLog2
;
// a warp for a batch, 4 elements for a thread, only support the softmax
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
// dim size = 128 currently
int
batches_per_warp
=
(
kDimCeil
<=
32
)
?
2
:
1
;
if
(
sizeof
(
T
)
==
2
)
{
VecSoftmaxForward
<
T
,
int2
,
4
,
warps_per_block
><<<
// use 128 threads per block to maximimize gpu utilization
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
constexpr
int
threads_per_block
=
128
;
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
}
else
if
(
sizeof
(
T
)
==
4
)
{
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
VecSoftmaxForward
<
T
,
int4
,
4
,
warps_per_block
><<<
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
// vectorization read/write
}
else
{
using
T4
=
typename
VecT4
<
T
>::
Type
;
assert
(
false
&&
"not support"
);
using
T2
=
typename
VecT2
<
T
>::
Type
;
}
if
(
dim
%
4
==
0
)
{
}
else
if
(
dim
<
max_dim
)
{
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
optimize
=
true
;
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
int
log2_elements
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
kDimLog2
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
int
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
kDimLog2
);
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
}
else
{
SwitchWarpSoftmaxForward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
ctx
,
out_data
,
// use 128 threads per block to maximimize gpu utilization
x
->
data
<
T
>
(),
N
,
dim
,
dim
,
constexpr
int
threads_per_block
=
128
;
kDimLog2
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
log2_elements
)
{
LAUNCH_SOFTMAX_WARP_FORWARD
(
0
);
// 1
LAUNCH_SOFTMAX_WARP_FORWARD
(
1
);
// 2
LAUNCH_SOFTMAX_WARP_FORWARD
(
2
);
// 4
LAUNCH_SOFTMAX_WARP_FORWARD
(
3
);
// 8
LAUNCH_SOFTMAX_WARP_FORWARD
(
4
);
// 16
LAUNCH_SOFTMAX_WARP_FORWARD
(
5
);
// 32
LAUNCH_SOFTMAX_WARP_FORWARD
(
6
);
// 64
LAUNCH_SOFTMAX_WARP_FORWARD
(
7
);
// 128
LAUNCH_SOFTMAX_WARP_FORWARD
(
8
);
// 256
LAUNCH_SOFTMAX_WARP_FORWARD
(
9
);
// 512
default:
break
;
}
}
}
}
}
else
{
if
(
!
optimize
)
{
ScopedTensorDescriptor
desc
;
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
DataLayout
layout
=
DataLayout
::
kNCHW
;
...
@@ -405,22 +511,37 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
...
@@ -405,22 +511,37 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward
(
if
(
LogMode
)
{
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
if
(
LogMode
)
{
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
#endif
#endif
}
}
}
}
};
};
template
<
typename
T
>
template
<
typename
T
,
bool
LogMode
=
false
>
class
SoftmaxGradCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SoftmaxGradCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -437,78 +558,38 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
...
@@ -437,78 +558,38 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
constexpr
int
warps_per_block
=
4
;
constexpr
int
warps_per_block
=
4
;
constexpr
bool
warp_softmax_available
=
std
::
is_same
<
T
,
float
>::
value
||
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
std
::
is_same
<
T
,
platform
::
float16
>::
value
;
const
int
kDimLog2
=
log2_ceil
(
dim
);
bool
optimize
=
false
;
const
int
kDimCeil
=
1
<<
kDimLog2
;
if
(
D
==
1
&&
warp_softmax_available
)
{
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
if
(
dim
==
128
&&
N
%
warps_per_block
==
0
)
{
int
batches_per_warp
=
(
kDimCeil
<=
128
)
?
2
:
1
;
optimize
=
true
;
constexpr
int
threads_per_block
=
128
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
VecSoftmaxBackward
<
float
,
4
,
warps_per_block
><<<
int
warps_per_block
=
(
threads_per_block
/
kWarpSize
);
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
ctx
.
cuda_device_context
().
stream
()
>>>
(
dx
->
data
<
float
>
(),
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dout
->
data
<
float
>
(),
dim3
threads
(
kWarpSize
,
warps_per_block
,
1
);
out
->
data
<
float
>
(),
N
,
dim
);
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
// vectorization read/write
VecSoftmaxBackward
<
platform
::
float16
,
4
,
warps_per_block
><<<
using
T4
=
typename
VecT4
<
T
>::
Type
;
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
using
T2
=
typename
VecT2
<
T
>::
Type
;
ctx
.
cuda_device_context
().
stream
()
>>>
(
if
(
dim
%
4
==
0
)
{
dx
->
data
<
platform
::
float16
>
(),
dout
->
data
<
platform
::
float16
>
(),
SwitchWarpSoftmaxBackward
<
T
,
T4
,
LogMode
>
(
out
->
data
<
platform
::
float16
>
(),
N
,
dim
);
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
}
else
{
dim
,
dim
,
kDimLog2
);
PADDLE_ENFORCE_EQ
(
}
else
if
(
dim
%
2
==
0
)
{
warp_softmax_available
,
true
,
SwitchWarpSoftmaxBackward
<
T
,
T2
,
LogMode
>
(
platform
::
errors
::
Unimplemented
(
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
"Warp softmax backward is only available for fp32 and fp16"
));
dim
,
dim
,
kDimLog2
);
}
}
else
{
}
else
if
(
dim
<
40
&&
dim
%
32
!=
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T
,
LogMode
>
(
optimize
=
true
;
blocks
,
threads
,
ctx
,
dx_data
,
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
N
,
Tensor
mul_grad
;
dim
,
dim
,
kDimLog2
);
int
numel
=
N
*
dim
;
mul_grad
.
mutable_data
<
T
>
({
numel
},
ctx
.
GetPlace
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
config
=
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
MultiplyCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
mul_grad
.
data
<
T
>
(),
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
numel
);
int
log2_elements
=
log2_ceil
(
dim
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
log2_elements
)
{
LAUNCH_SOFTMAX_WARP_BACKWARD
(
0
);
// 1
LAUNCH_SOFTMAX_WARP_BACKWARD
(
1
);
// 2
LAUNCH_SOFTMAX_WARP_BACKWARD
(
2
);
// 4
LAUNCH_SOFTMAX_WARP_BACKWARD
(
3
);
// 8
LAUNCH_SOFTMAX_WARP_BACKWARD
(
4
);
// 16
LAUNCH_SOFTMAX_WARP_BACKWARD
(
5
);
// 32
LAUNCH_SOFTMAX_WARP_BACKWARD
(
6
);
// 64
LAUNCH_SOFTMAX_WARP_BACKWARD
(
7
);
// 128
LAUNCH_SOFTMAX_WARP_BACKWARD
(
8
);
// 256
LAUNCH_SOFTMAX_WARP_BACKWARD
(
9
);
// 512
default:
break
;
}
}
}
}
}
else
{
if
(
!
optimize
)
{
ScopedTensorDescriptor
desc
;
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
DataLayout
layout
=
DataLayout
::
kNCHW
;
...
@@ -525,18 +606,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
...
@@ -525,18 +606,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward
(
if
(
LogMode
)
{
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
dx_data
));
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
if
(
LogMode
)
{
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
dx_data
));
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
#endif
#endif
}
}
}
}
...
...
paddle/fluid/operators/softmax_impl.cuh
0 → 100755
浏览文件 @
63abd500
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/cuda_device_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceSum
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
template
<
typename
T
,
int
BatchSize
,
int
WarpSize
>
__device__
__forceinline__
void
WarpReduceMax
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WarpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BatchSize
;
++
i
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
}
// namespace operators
}
// namespace paddle
\ No newline at end of file
paddle/fluid/operators/softmax_op.h
浏览文件 @
63abd500
...
@@ -45,6 +45,14 @@ static inline int SizeFromAxis(const int axis, DDim dims) {
...
@@ -45,6 +45,14 @@ static inline int SizeFromAxis(const int axis, DDim dims) {
return
size
;
return
size
;
}
}
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
}
return
size
;
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
SoftmaxKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SoftmaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录