Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1d7b75dd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1d7b75dd
编写于
8月 05, 2021
作者:
L
limingshu
提交者:
GitHub
8月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support Ternary ops in elmentwise and broadcast (#33976)
上级
a68709d8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
212 addition
and
195 deletion
+212
-195
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+0
-1
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+62
-63
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+109
-125
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+5
-4
paddle/fluid/platform/fast_divmod.h
paddle/fluid/platform/fast_divmod.h
+36
-2
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
1d7b75dd
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
...
...
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
1d7b75dd
...
...
@@ -163,7 +163,7 @@ struct DimensionsTransform {
struct
StridesCalculation
{
std
::
vector
<
std
::
vector
<
uint32_t
>>
strides
;
std
::
vector
<
FastDivMod
>
divmoders
;
std
::
vector
<
platform
::
FastDivMod
>
divmoders
;
private:
// To calculate the strides of each input_tensor.
...
...
@@ -190,7 +190,7 @@ struct StridesCalculation {
strides
.
resize
(
N
,
std
::
vector
<
uint32_t
>
(
dim_size
,
1
));
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
divmoders
[
i
]
=
FastDivMod
(
out_dims
[
i
]);
divmoders
[
i
]
=
platform
::
FastDivMod
(
out_dims
[
i
]);
}
CalculateStrides
(
N
,
dim_size
,
in_dims
);
}
...
...
@@ -198,21 +198,21 @@ struct StridesCalculation {
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
struct
BroadcastArgsW
ar
pper
{
using
InVecType
=
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
CudaAlignedVector
<
OutT
,
VecSize
>
;
struct
BroadcastArgsW
ra
pper
{
using
InVecType
=
platform
::
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
CudaAlignedVector
<
OutT
,
VecSize
>
;
OutT
*
out_data
;
OutVecType
*
vec_out_data
;
const
InT
*
__restrict__
in_data
[
ET
];
const
InVecType
*
__restrict__
vec_in_data
[
ET
];
bool
no_broadcast
[
ET
];
FastDivMod
divmoders
[
kDims
];
platform
::
FastDivMod
divmoders
[
kDims
];
uint32_t
strides
[
ET
][
framework
::
DDim
::
kMaxRank
];
uint32_t
scalar_cal_offset
;
Functor
func
;
HOSTDEVICE
BroadcastArgsW
ar
pper
(
HOSTDEVICE
BroadcastArgsW
ra
pper
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
int
scalar_cal_offset
,
Functor
func
,
const
StridesCalculation
&
offset_calculator
)
...
...
@@ -227,7 +227,7 @@ struct BroadcastArgsWarpper {
out_data
=
out
->
data
<
OutT
>
();
vec_out_data
=
reinterpret_cast
<
OutVecType
*>
(
out_data
);
memcpy
(
divmoders
,
offset_calculator
.
divmoders
.
data
(),
kDims
*
sizeof
(
FastDivMod
));
kDims
*
sizeof
(
platform
::
FastDivMod
));
}
__device__
__forceinline__
uint32_t
GetOffsetByDivmod
(
int
idx
,
int
in_idx
)
{
...
...
@@ -302,30 +302,29 @@ struct BroadcastArgsWarpper {
}
};
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsW
ar
pper
,
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsW
ra
pper
,
ElementwiseType
ET
>
__device__
inline
void
ScalarizedBroadcastKernelImpl
(
BroadcastArgsW
arpper
broadcast_war
pper
,
int
tid
)
{
BroadcastArgsW
rapper
broadcast_wra
pper
,
int
tid
)
{
InT
args
[
ET
];
OutT
args_out
;
broadcast_w
ar
pper
.
LoadScalarizedData
(
args
,
tid
);
broadcast_w
ra
pper
.
LoadScalarizedData
(
args
,
tid
);
#pragma unroll(ET)
for
(
int
j
=
1
;
j
<
ET
;
++
j
)
{
args_out
=
broadcast_warpper
.
func
(
args
);
}
broadcast_warpper
.
StoreScalarizedData
(
args_out
,
tid
);
// Calcualtion of the in_tensor data.
args_out
=
broadcast_wrapper
.
func
(
args
);
broadcast_wrapper
.
StoreScalarizedData
(
args_out
,
tid
);
}
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsW
ar
pper
,
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsW
ra
pper
,
ElementwiseType
ET
,
int
VecSize
>
__device__
inline
void
VectorizedBroadcastKernelImpl
(
BroadcastArgsW
arpper
broadcast_war
pper
,
int
tid
)
{
using
OutVecType
=
CudaAlignedVector
<
OutT
,
VecSize
>
;
BroadcastArgsW
rapper
broadcast_wra
pper
,
int
tid
)
{
using
OutVecType
=
platform
::
CudaAlignedVector
<
OutT
,
VecSize
>
;
OutVecType
args_out
;
InT
ins
[
ET
];
InT
args
[
ET
][
VecSize
];
broadcast_w
ar
pper
.
LoadVectorizedData
(
args
,
tid
);
broadcast_w
ra
pper
.
LoadVectorizedData
(
args
,
tid
);
#pragma unroll(VecSize)
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
...
...
@@ -333,30 +332,30 @@ __device__ inline void VectorizedBroadcastKernelImpl(
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
ins
[
j
]
=
args
[
j
][
i
];
}
args_out
.
val
[
i
]
=
broadcast_w
ar
pper
.
func
(
ins
);
args_out
.
val
[
i
]
=
broadcast_w
ra
pper
.
func
(
ins
);
}
broadcast_w
ar
pper
.
StoreVectorizedData
(
args_out
,
tid
);
broadcast_w
ra
pper
.
StoreVectorizedData
(
args_out
,
tid
);
}
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsW
ar
pper
,
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsW
ra
pper
,
ElementwiseType
ET
,
int
VecSize
>
__global__
void
ElementwiseBroadcastKernel
(
BroadcastArgsW
arpper
broadcast_war
pper
,
int
main_tid
,
int
tail_tid
)
{
BroadcastArgsW
rapper
broadcast_wra
pper
,
int
main_tid
,
int
tail_tid
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if
(
tid
<
main_tid
)
{
VectorizedBroadcastKernelImpl
<
InT
,
OutT
,
BroadcastArgsW
ar
pper
,
ET
,
VecSize
>
(
broadcast_w
ar
pper
,
tid
);
VectorizedBroadcastKernelImpl
<
InT
,
OutT
,
BroadcastArgsW
ra
pper
,
ET
,
VecSize
>
(
broadcast_w
ra
pper
,
tid
);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if
(
tid
<
tail_tid
)
{
ScalarizedBroadcastKernelImpl
<
InT
,
OutT
,
BroadcastArgsW
ar
pper
,
ET
>
(
broadcast_w
ar
pper
,
tid
);
ScalarizedBroadcastKernelImpl
<
InT
,
OutT
,
BroadcastArgsW
ra
pper
,
ET
>
(
broadcast_w
ra
pper
,
tid
);
}
}
...
...
@@ -367,7 +366,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
int
axis
,
Functor
func
)
{
int
numel
=
out
->
numel
();
const
int
threads
=
256
;
int
threads
=
GetThreadsConfig
(
ctx
,
numel
,
VecSize
)
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
main_tid
=
numel
/
VecSize
;
int
tail_tid
=
numel
%
VecSize
;
...
...
@@ -380,75 +379,75 @@ void LaunchBroadcastKernelForDifferentDimSize(
switch
(
merge_dims
.
dim_size
)
{
case
1
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
1
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
1
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
2
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
2
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
2
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
3
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
3
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
3
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
4
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
4
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
4
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
5
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
5
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
5
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
6
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
6
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
6
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
7
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
7
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
7
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
case
8
:
{
auto
broadcast_w
ar
pper
=
BroadcastArgsW
ar
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
8
>
(
auto
broadcast_w
ra
pper
=
BroadcastArgsW
ra
pper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
8
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ar
pper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_w
ra
pper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_w
ar
pper
,
main_tid
,
tail_tid
);
broadcast_w
ra
pper
,
main_tid
,
tail_tid
);
break
;
}
default:
{
...
...
@@ -473,11 +472,11 @@ void LaunchBroadcastElementwiseCudaKernel(
int
in_vec_size
=
4
;
framework
::
Tensor
*
out
=
(
*
outs
)[
0
];
for
(
auto
*
in
:
ins
)
{
auto
temp_size
=
GetVectorizedSizeImpl
<
InT
>
(
in
->
data
<
InT
>
());
auto
temp_size
=
platform
::
GetVectorizedSize
<
InT
>
(
in
->
data
<
InT
>
());
in_vec_size
=
in
->
dims
()
==
out
->
dims
()
?
std
::
min
(
temp_size
,
in_vec_size
)
:
in_vec_size
;
}
int
out_vec_size
=
GetVectorizedSizeImpl
<
OutT
>
(
out
->
data
<
OutT
>
());
int
out_vec_size
=
platform
::
GetVectorizedSize
<
OutT
>
(
out
->
data
<
OutT
>
());
int
vec_size
=
std
::
min
(
out_vec_size
,
in_vec_size
);
switch
(
vec_size
)
{
...
...
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
1d7b75dd
...
...
@@ -26,7 +26,7 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
};
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
,
kTernary
=
3
};
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
...
...
@@ -52,98 +52,73 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
return
std
::
max
(
64
,
threads
);
}
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template
<
typename
T
>
int
GetVectorizedSizeImpl
(
const
T
*
pointer
)
{
constexpr
int
max_load_bits
=
128
;
int
valid_vec_size
=
max_load_bits
/
CHAR_BIT
/
sizeof
(
T
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec8
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
8
>>::
value
;
// NOLINT
constexpr
int
vec4
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
4
>>::
value
;
// NOLINT
constexpr
int
vec2
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
2
>>::
value
;
// NOLINT
if
(
address
%
vec8
==
0
)
{
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec4
==
0
)
{
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec2
==
0
)
{
return
std
::
min
(
2
,
valid_vec_size
);
}
else
{
return
1
;
}
}
template
<
typename
InT
,
typename
OutT
>
int
GetVectorizedSize
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outs
)
{
int
GetVectorizedSize
ForIO
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outs
)
{
int
vec_size
=
4
;
for
(
auto
iter
=
ins
.
begin
();
iter
!=
ins
.
end
();
++
iter
)
{
vec_size
=
std
::
min
<
int
>
(
vec_size
,
GetVectorizedSizeImpl
((
*
iter
)
->
data
<
InT
>
()));
vec_size
=
std
::
min
<
int
>
(
vec_size
,
platform
::
GetVectorizedSize
((
*
iter
)
->
data
<
InT
>
()));
}
for
(
auto
iter
=
outs
.
begin
();
iter
!=
outs
.
end
();
++
iter
)
{
vec_size
=
std
::
min
<
int
>
(
vec_size
,
GetVectorizedSizeImpl
((
*
iter
)
->
data
<
OutT
>
()));
vec_size
=
std
::
min
<
int
>
(
vec_size
,
platform
::
GetVectorizedSize
((
*
iter
)
->
data
<
OutT
>
()));
}
return
vec_size
;
}
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
>
struct
ElementwiseDataWrapper
{
OutT
*
out
;
const
InT
*
in0
;
const
InT
*
in1
;
__device__
ElementwiseDataWrapper
(
OutT
*
out
,
const
InT
*
in0
,
const
InT
*
in1
=
nullptr
)
:
out
(
out
),
in0
(
in0
),
in1
(
in1
)
{}
using
InVecType
=
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
CudaAlignedVector
<
OutT
,
VecSize
>
;
inline
__device__
void
load_vector
(
InVecType
args
[],
int
idx
)
{
const
InVecType
*
x_vec
=
reinterpret_cast
<
const
InVecType
*>
(
in0
);
args
[
0
]
=
x_vec
[
idx
];
if
(
ET
==
ElementwiseType
::
kBinary
)
{
const
InVecType
*
y_vec
=
reinterpret_cast
<
const
InVecType
*>
(
in1
);
args
[
1
]
=
y_vec
[
idx
];
using
InVecType
=
platform
::
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
CudaAlignedVector
<
OutT
,
VecSize
>
;
const
InT
*
__restrict__
in_data
[
ET
];
OutT
*
out_data
;
uint32_t
scalar_cal_offset
;
HOSTDEVICE
ElementwiseDataWrapper
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
uint32_t
scalar_cal_offset
)
:
scalar_cal_offset
(
scalar_cal_offset
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
in_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
}
out_data
=
(
*
outs
)[
0
]
->
data
<
OutT
>
();
}
inline
__device__
void
LoadVectorizedData
(
InVecType
vec_args
[],
int
tid
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
const
InVecType
*
in_vec_data
=
reinterpret_cast
<
const
InVecType
*>
(
in_data
[
i
]);
vec_args
[
i
]
=
in_vec_data
[
tid
];
}
}
inline
__device__
void
load_scalar
(
InT
args
[],
int
idx
)
{
args
[
0
]
=
in0
[
idx
];
if
(
ET
==
ElementwiseType
::
kBinary
)
{
args
[
1
]
=
in1
[
idx
];
inline
__device__
void
LoadScalarizedData
(
InT
args
[],
int
tid
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
args
[
i
]
=
in_data
[
i
][
tid
+
scalar_cal_offset
];
}
}
inline
__device__
void
store_vector
(
OutVecType
res
,
int
idx
)
{
OutVecType
*
out_vec
=
reinterpret_cast
<
OutVecType
*>
(
out
);
out_vec
[
idx
]
=
res
;
inline
__device__
void
StoreVectorizedData
(
OutVecType
res
,
int
tid
)
{
OutVecType
*
out_vec
=
reinterpret_cast
<
OutVecType
*>
(
out
_data
);
out_vec
[
tid
]
=
res
;
}
inline
__device__
void
store_scalar
(
OutT
res
,
int
idx
)
{
out
[
idx
]
=
res
;
}
inline
__device__
void
StoreScalarizedData
(
OutT
res
,
int
tid
)
{
out_data
[
tid
+
scalar_cal_offset
]
=
res
;
}
};
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
VectorizedKernelImpl
(
ElementwiseDataWrapper
<
ET
,
VecSize
,
InT
,
OutT
>
data
,
Functor
func
,
int
tid
)
{
using
InVecType
=
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
CudaAlignedVector
<
OutT
,
VecSize
>
;
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
ElementwiseWrapper
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
VectorizedKernelImpl
(
ElementwiseWrapper
data
,
Functor
func
,
int
tid
)
{
using
InVecType
=
platform
::
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
CudaAlignedVector
<
OutT
,
VecSize
>
;
InVecType
ins_vec
[
ET
];
OutVecType
out_vec
;
InT
*
ins_ptr
[
ET
];
...
...
@@ -153,7 +128,7 @@ __device__ inline void VectorizedKernelImpl(
ins_ptr
[
i
]
=
reinterpret_cast
<
InT
*>
(
&
(
ins_vec
[
i
]));
}
// load
data
.
load_vector
(
ins_vec
,
tid
);
data
.
LoadVectorizedData
(
ins_vec
,
tid
);
// compute
#pragma unroll
...
...
@@ -165,52 +140,48 @@ __device__ inline void VectorizedKernelImpl(
out_vec
.
val
[
i
]
=
func
(
ins
);
}
// store
data
.
store_vector
(
out_vec
,
tid
);
data
.
StoreVectorizedData
(
out_vec
,
tid
);
}
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
ScalarKernelImpl
(
ElementwiseDataWrapper
<
ET
,
VecSize
,
InT
,
OutT
>
data
,
Functor
func
,
int
start
,
int
remain
)
{
template
<
ElementwiseType
ET
,
typename
ElementwiseWrapper
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
ScalarKernelImpl
(
ElementwiseWrapper
data
,
Functor
func
,
int
tid
)
{
InT
ins
[
ET
];
OutT
out
;
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
int
idx
=
start
+
i
;
// load
data
.
load_scalar
(
ins
,
idx
);
// compute
out
=
func
(
ins
);
// store
data
.
store_scalar
(
out
,
idx
);
}
// load
data
.
LoadScalarizedData
(
ins
,
tid
);
// compute
out
=
func
(
ins
);
// store
data
.
StoreScalarizedData
(
out
,
tid
);
}
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__global__
void
VectorizedKernel
(
const
InT
*
__restrict__
in0
,
const
InT
*
__restrict__
in1
,
OutT
*
out
,
int
size
,
Functor
func
)
{
template
<
ElementwiseType
ET
,
typename
ElementwiseWrapper
,
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
__global__
void
VectorizedKernel
(
ElementwiseWrapper
data
,
int
main_tid
,
int
tail_tid
,
Functor
func
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
remain
=
size
-
VecSize
*
tid
;
remain
=
remain
>
0
?
remain
:
0
;
auto
data
=
ElementwiseDataWrapper
<
ET
,
VecSize
,
InT
,
OutT
>
(
out
,
in0
,
in1
);
if
(
remain
>=
VecSize
)
{
VectorizedKernelImpl
(
data
,
func
,
tid
);
}
else
{
ScalarKernelImpl
(
data
,
func
,
tid
*
VecSize
,
remain
);
if
(
tid
<
main_tid
)
{
VectorizedKernelImpl
<
ET
,
VecSize
,
ElementwiseWrapper
,
InT
,
OutT
,
Functor
>
(
data
,
func
,
tid
);
}
if
(
tid
<
tail_tid
)
{
ScalarKernelImpl
<
ET
,
ElementwiseWrapper
,
InT
,
OutT
,
Functor
>
(
data
,
func
,
tid
);
}
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__global__
void
ScalarKernel
(
const
InT
*
__restrict__
in0
,
const
InT
*
__restrict__
in1
,
OutT
*
out
,
int
size
,
Functor
func
)
{
auto
data
=
ElementwiseDataWrapper
<
ET
,
1
,
InT
,
OutT
>
(
out
,
in0
,
in1
);
template
<
ElementwiseType
ET
,
typename
ElementwiseWrapper
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__global__
void
ScalarKernel
(
ElementwiseWrapper
data
,
int
numel
,
Functor
func
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
remain
=
tid
<
size
?
1
:
0
;
ScalarKernelImpl
(
data
,
func
,
tid
,
remain
);
if
(
tid
<
numel
)
{
ScalarKernelImpl
<
ET
,
ElementwiseWrapper
,
InT
,
OutT
,
Functor
>
(
data
,
func
,
tid
);
}
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
...
...
@@ -219,35 +190,48 @@ void LaunchSameDimsElementwiseCudaKernel(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
// calculate the max vec_size for all ins and outs
auto
size
=
ins
[
0
]
->
numel
();
int
vec_size
=
GetVectorizedSize
<
InT
,
OutT
>
(
ins
,
*
outs
);
int
block_size
=
GetThreadsConfig
(
ctx
,
size
,
vec_size
);
auto
numel
=
ins
[
0
]
->
numel
();
int
vec_size
=
GetVectorizedSize
ForIO
<
InT
,
OutT
>
(
ins
,
*
outs
);
int
block_size
=
GetThreadsConfig
(
ctx
,
numel
,
vec_size
);
int
grid_size
=
((
size
+
vec_size
-
1
)
/
vec_size
+
block_size
-
1
)
/
block_size
;
const
InT
*
in0
=
ins
[
0
]
->
data
<
InT
>
()
;
const
InT
*
in1
=
(
ET
==
ElementwiseType
::
kBinary
)
?
ins
[
1
]
->
data
<
InT
>
()
:
nullptr
;
OutT
*
out
=
(
*
outs
)[
0
]
->
data
<
OutT
>
();
((
numel
+
vec_size
-
1
)
/
vec_size
+
block_size
-
1
)
/
block_size
;
int
main_tid
=
numel
/
vec_size
;
int
tail_tid
=
numel
%
vec_size
;
uint32_t
vec_len
=
main_tid
*
vec_size
;
// cuda kernel
auto
stream
=
ctx
.
stream
();
switch
(
vec_size
)
{
case
4
:
VectorizedKernel
<
ET
,
4
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
in0
,
in1
,
out
,
size
,
func
);
case
4
:
{
auto
data_wrapper
=
ElementwiseDataWrapper
<
ET
,
4
,
InT
,
OutT
>
(
ins
,
outs
,
vec_len
);
VectorizedKernel
<
ET
,
decltype
(
data_wrapper
),
InT
,
OutT
,
4
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
data_wrapper
,
main_tid
,
tail_tid
,
func
);
break
;
case
2
:
VectorizedKernel
<
ET
,
2
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
in0
,
in1
,
out
,
size
,
func
);
}
case
2
:
{
auto
data_wrapper
=
ElementwiseDataWrapper
<
ET
,
2
,
InT
,
OutT
>
(
ins
,
outs
,
vec_len
);
VectorizedKernel
<
ET
,
decltype
(
data_wrapper
),
InT
,
OutT
,
2
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
data_wrapper
,
main_tid
,
tail_tid
,
func
);
break
;
case
1
:
ScalarKernel
<
ET
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
in0
,
in1
,
out
,
size
,
func
);
}
case
1
:
{
auto
data_wrapper
=
ElementwiseDataWrapper
<
ET
,
1
,
InT
,
OutT
>
(
ins
,
outs
,
0
);
ScalarKernel
<
ET
,
decltype
(
data_wrapper
),
InT
,
OutT
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
data_wrapper
,
numel
,
func
);
break
;
default:
}
default:
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported vectorized size: %d !"
,
vec_size
));
break
;
}
}
}
...
...
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
1d7b75dd
...
...
@@ -158,12 +158,13 @@ struct IndexCalculator {
:
dim
(
dim
)
{
dims
=
detail
::
VectorToArray
<
int
,
kMaxRank
>
(
cal_dims
);
strides
=
detail
::
VectorToArray
<
int
,
kMaxRank
>
(
full_strides
);
std
::
vector
<
FastDivMod
>
cal_divmoders
;
std
::
vector
<
platform
::
FastDivMod
>
cal_divmoders
;
// fast divmod
for
(
auto
i
:
cal_strides
)
{
cal_divmoders
.
push_back
(
FastDivMod
(
i
));
cal_divmoders
.
push_back
(
platform
::
FastDivMod
(
i
));
}
divmoders
=
detail
::
VectorToArray
<
FastDivMod
,
kMaxRank
>
(
cal_divmoders
);
divmoders
=
detail
::
VectorToArray
<
platform
::
FastDivMod
,
kMaxRank
>
(
cal_divmoders
);
}
__device__
inline
int
Get
(
int
offset
)
const
{
...
...
@@ -183,7 +184,7 @@ struct IndexCalculator {
int
dim
;
framework
::
Array
<
int
,
kMaxRank
>
dims
;
framework
::
Array
<
int
,
kMaxRank
>
strides
;
framework
::
Array
<
FastDivMod
,
kMaxRank
>
divmoders
;
framework
::
Array
<
platform
::
FastDivMod
,
kMaxRank
>
divmoders
;
};
// reduce config
...
...
paddle/fluid/platform/fast_divmod.h
浏览文件 @
1d7b75dd
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#define INT_BITS 32
namespace
paddle
{
namespace
operators
{
namespace
platform
{
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
CudaAlignedVector
{
...
...
@@ -65,5 +65,39 @@ struct FastDivMod {
uint32_t
multiplier
;
};
}
// namespace operators
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template
<
typename
T
>
int
GetVectorizedSize
(
const
T
*
pointer
)
{
constexpr
int
max_load_bits
=
128
;
int
valid_vec_size
=
max_load_bits
/
CHAR_BIT
/
sizeof
(
T
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec8
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
8
>>::
value
;
// NOLINT
constexpr
int
vec4
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
4
>>::
value
;
// NOLINT
constexpr
int
vec2
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
2
>>::
value
;
// NOLINT
if
(
address
%
vec8
==
0
)
{
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec4
==
0
)
{
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec2
==
0
)
{
return
std
::
min
(
2
,
valid_vec_size
);
}
else
{
return
1
;
}
}
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录