Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4cc0337f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4cc0337f
编写于
1月 14, 2021
作者:
G
GaoWei8
提交者:
GitHub
1月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Softmax backward optimize (#30249) (#30400)
* softmax backward optimize
上级
9fb5a3e5
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
152 addition
and
21 deletion
+152
-21
paddle/fluid/operators/softmax_cudnn_op.cu
paddle/fluid/operators/softmax_cudnn_op.cu
+152
-21
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu
浏览文件 @
4cc0337f
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -39,6 +40,13 @@ using Tensor = framework::Tensor;
...
@@ -39,6 +40,13 @@ using Tensor = framework::Tensor;
out_data, x->data<T>(), N, dim, dim); \
out_data, x->data<T>(), N, dim, dim); \
break;
break;
#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \
case Log2Elements: \
softmax_warp_backward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dx_data, mul_grad.data<T>(), out->data<T>(), N, dim, dim); \
break;
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
int
size
=
1
;
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
...
@@ -199,6 +207,83 @@ __global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size,
...
@@ -199,6 +207,83 @@ __global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size,
}
}
}
}
template
<
typename
T
,
typename
AccT
,
int
Log2Elements
>
__global__
void
softmax_warp_backward
(
T
*
gradInput
,
const
T
*
grad
,
const
T
*
output
,
int
batch_size
,
int
stride
,
int
element_count
)
{
constexpr
int
next_power_of_two
=
1
<<
Log2Elements
;
constexpr
int
warp_size_softmax
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
warp_size_softmax
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
{
local_batches
=
WARP_BATCH
;
}
int
local_idx
=
threadIdx
.
x
%
warp_size_softmax
;
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
AccT
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
AccT
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
static_cast
<
AccT
>
(
grad
[
i
*
element_count
+
it
*
warp_size_softmax
]);
output_reg
[
i
][
it
]
=
static_cast
<
AccT
>
(
output
[
i
*
element_count
+
it
*
warp_size_softmax
]);
}
else
{
grad_reg
[
i
][
it
]
=
AccT
(
0
);
output_reg
[
i
][
it
]
=
AccT
(
0
);
}
}
}
AccT
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce_sum
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
+
it
*
warp_size_softmax
]
=
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
}
}
}
}
template
<
typename
T
>
__global__
void
MultiplyCUDAKernel
(
T
*
C
,
const
T
*
A
,
const
T
*
B
,
int
N
)
{
CUDA_KERNEL_LOOP
(
i
,
N
)
{
C
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
A
[
i
])
*
static_cast
<
float
>
(
B
[
i
]));
}
}
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
batch_size
,
...
@@ -340,28 +425,74 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
...
@@ -340,28 +425,74 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
constexpr
bool
warp_softmax_available
=
constexpr
bool
warp_softmax_available
=
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
platform
::
float16
>::
value
;
std
::
is_same
<
T
,
platform
::
float16
>::
value
;
if
(
D
==
1
&&
dim
==
128
&&
N
%
warps_per_block
==
0
&&
bool
optimize
=
false
;
warp_softmax_available
)
{
if
(
D
==
1
&&
warp_softmax_available
)
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
if
(
dim
==
128
&&
N
%
warps_per_block
==
0
)
{
VecSoftmaxBackward
<
optimize
=
true
;
float
,
4
,
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
VecSoftmaxBackward
<
float
,
4
,
warps_per_block
><<<
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
dx
->
data
<
float
>
(),
dout
->
data
<
float
>
(),
out
->
data
<
float
>
(),
N
,
dim
);
ctx
.
cuda_device_context
().
stream
()
>>>
(
dx
->
data
<
float
>
(),
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
dout
->
data
<
float
>
(),
VecSoftmaxBackward
<
out
->
data
<
float
>
(),
N
,
dim
);
platform
::
float16
,
4
,
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
VecSoftmaxBackward
<
platform
::
float16
,
4
,
warps_per_block
><<<
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
dx
->
data
<
platform
::
float16
>
(),
dout
->
data
<
platform
::
float16
>
(),
ctx
.
cuda_device_context
().
stream
()
>>>
(
out
->
data
<
platform
::
float16
>
(),
N
,
dim
);
dx
->
data
<
platform
::
float16
>
(),
dout
->
data
<
platform
::
float16
>
(),
}
else
{
out
->
data
<
platform
::
float16
>
(),
N
,
dim
);
PADDLE_ENFORCE_EQ
(
}
else
{
warp_softmax_available
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
Unimplemented
(
warp_softmax_available
,
true
,
"Warp softmax backward is only available for fp32 and fp16"
));
platform
::
errors
::
Unimplemented
(
"Warp softmax backward is only available for fp32 and fp16"
));
}
}
else
if
(
dim
<
40
&&
dim
%
32
!=
0
)
{
optimize
=
true
;
Tensor
mul_grad
;
int
numel
=
N
*
dim
;
mul_grad
.
mutable_data
<
T
>
({
numel
},
ctx
.
GetPlace
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
config
=
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
MultiplyCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
mul_grad
.
data
<
T
>
(),
dout
->
data
<
T
>
(),
out
->
data
<
T
>
(),
numel
);
int
log2_elements
=
log2_ceil
(
dim
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
log2_elements
)
{
LAUNCH_SOFTMAX_WARP_BACKWARD
(
0
);
// 1
LAUNCH_SOFTMAX_WARP_BACKWARD
(
1
);
// 2
LAUNCH_SOFTMAX_WARP_BACKWARD
(
2
);
// 4
LAUNCH_SOFTMAX_WARP_BACKWARD
(
3
);
// 8
LAUNCH_SOFTMAX_WARP_BACKWARD
(
4
);
// 16
LAUNCH_SOFTMAX_WARP_BACKWARD
(
5
);
// 32
LAUNCH_SOFTMAX_WARP_BACKWARD
(
6
);
// 64
LAUNCH_SOFTMAX_WARP_BACKWARD
(
7
);
// 128
LAUNCH_SOFTMAX_WARP_BACKWARD
(
8
);
// 256
LAUNCH_SOFTMAX_WARP_BACKWARD
(
9
);
// 512
default:
break
;
}
}
}
}
else
{
}
if
(
!
optimize
)
{
ScopedTensorDescriptor
desc
;
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
DataLayout
layout
=
DataLayout
::
kNCHW
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录