Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
95e33481
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95e33481
编写于
12月 09, 2020
作者:
Z
zlsh80826
提交者:
GitHub
12月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Softmax vectorization (#29404)
* vec softmax fw * vec softmax bw * add a message argument for compiler compatibility
上级
a136c9cd
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
247 addition
and
0 deletion
+247
-0
paddle/fluid/operators/softmax_cudnn_op.cu
paddle/fluid/operators/softmax_cudnn_op.cu
+247
-0
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu
.cc
→
paddle/fluid/operators/softmax_cudnn_op.cu
浏览文件 @
95e33481
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
...
...
@@ -38,6 +39,81 @@ static inline int SizeOutAxis(const int axis, DDim dims) {
return
size
;
}
template
<
typename
T
,
int
VLEN
>
union
vec_t
{
static_assert
(
sizeof
(
T
)
==
-
1
,
"vec_t is only available by specialization."
);
};
template
<
>
union
vec_t
<
float
,
4
>
{
float4
s
;
float
v
[
4
];
};
template
<
>
union
vec_t
<
platform
::
float16
,
4
>
{
int2
s
;
platform
::
float16
v
[
4
];
};
template
<
typename
T
,
typename
VECT
,
int
VPT
,
int
WARP_PER_BLOCK
>
__global__
void
VecSoftmaxForward
(
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
softmax_ele
)
{
int
offset
=
blockIdx
.
x
*
softmax_ele
*
WARP_PER_BLOCK
;
int
idx
=
threadIdx
.
x
*
VPT
;
VECT
buf
=
reinterpret_cast
<
const
VECT
*>
(
&
src
[
offset
+
idx
])[
0
];
T
*
bufp
=
reinterpret_cast
<
T
*>
(
&
buf
);
float4
val4
;
float
*
val4p
=
reinterpret_cast
<
float
*>
(
&
val4
);
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
val4p
[
i
]
=
static_cast
<
float
>
(
bufp
[
i
]);
}
float
val
=
val4
.
x
+
val4
.
y
+
val4
.
z
+
val4
.
w
;
float
max_val
=
math
::
warpReduceMax
<
float
>
(
max
(
max
(
val4
.
x
,
val4
.
y
),
max
(
val4
.
z
,
val4
.
w
)),
0xffffffff
);
float4
tmp4
=
make_float4
(
__expf
(
val4
.
x
-
max_val
),
__expf
(
val4
.
y
-
max_val
),
__expf
(
val4
.
z
-
max_val
),
__expf
(
val4
.
w
-
max_val
));
float
*
tmp4p
=
reinterpret_cast
<
float
*>
(
&
tmp4
);
float
invsum
=
1.
f
/
(
math
::
warpReduceSum
<
float
>
(
tmp4
.
x
+
tmp4
.
y
+
tmp4
.
z
+
tmp4
.
w
,
0xffffffff
)
+
1e-6
f
);
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
bufp
[
i
]
=
static_cast
<
T
>
(
tmp4p
[
i
]
*
invsum
);
}
reinterpret_cast
<
VECT
*>
(
&
dst
[
offset
+
idx
])[
0
]
=
buf
;
}
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
softmax_ele
)
{
const
int
offset
=
blockIdx
.
x
*
softmax_ele
*
WARP_PER_BLOCK
+
threadIdx
.
x
*
VPT
;
float
local_sum_gy
=
0.
f
;
vec_t
<
T
,
VPT
>
local_grad
;
vec_t
<
T
,
VPT
>
local_src
;
local_grad
.
s
=
reinterpret_cast
<
const
decltype
(
local_grad
.
s
)
*>
(
&
grad
[
offset
])[
0
];
local_src
.
s
=
reinterpret_cast
<
const
decltype
(
local_src
.
s
)
*>
(
&
src
[
offset
])[
0
];
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
local_sum_gy
+=
static_cast
<
float
>
(
local_grad
.
v
[
i
])
*
static_cast
<
float
>
(
local_src
.
v
[
i
]);
}
float
sum_gy
=
math
::
warpReduceSum
<
float
>
(
local_sum_gy
,
0xffffffff
);
vec_t
<
T
,
VPT
>
local_dst
;
for
(
int
i
=
0
;
i
<
VPT
;
++
i
)
{
local_dst
.
v
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
local_src
.
v
[
i
])
*
(
static_cast
<
float
>
(
local_grad
.
v
[
i
])
-
sum_gy
));
}
reinterpret_cast
<
decltype
(
local_dst
.
s
)
*>
(
&
dst
[
offset
])[
0
]
=
local_dst
.
s
;
}
template
<
typename
T
>
class
SoftmaxCUDNNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -54,20 +130,42 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
==
128
&&
N
%
warps_per_block
==
0
&&
sizeof
(
T
)
<=
4
)
{
// a warp for a batch, 4 elements for a thread, only support the softmax
// dim size = 128 currently
if
(
sizeof
(
T
)
==
2
)
{
VecSoftmaxForward
<
T
,
int2
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
}
else
if
(
sizeof
(
T
)
==
4
)
{
VecSoftmaxForward
<
T
,
int4
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
}
else
{
assert
(
false
&&
"not support"
);
}
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
}
};
...
...
@@ -88,20 +186,49 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
constexpr
int
warps_per_block
=
4
;
constexpr
bool
warp_softmax_available
=
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
platform
::
float16
>::
value
;
if
(
D
==
1
&&
dim
==
128
&&
N
%
warps_per_block
==
0
&&
warp_softmax_available
)
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
VecSoftmaxBackward
<
float
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dx
->
data
<
float
>
(),
dout
->
data
<
float
>
(),
out
->
data
<
float
>
(),
N
,
dim
);
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
VecSoftmaxBackward
<
platform
::
float16
,
4
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dx
->
data
<
platform
::
float16
>
(),
dout
->
data
<
platform
::
float16
>
(),
out
->
data
<
platform
::
float16
>
(),
N
,
dim
);
}
else
{
PADDLE_ENFORCE_EQ
(
warp_softmax_available
,
true
,
platform
::
errors
::
Unimplemented
(
"Warp softmax backward is only available for fp32 and fp16"
));
}
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
->
data
<
T
>
(),
desc_
,
dout
->
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录