Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eae4bf5b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eae4bf5b
编写于
9月 07, 2021
作者:
N
niuliling123
提交者:
GitHub
9月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify the elementwise op according to the kernel primitive API (#34456)
上级
b211f02b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
180 addition
and
408 deletion
+180
-408
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+116
-273
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+64
-135
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
eae4bf5b
...
@@ -15,10 +15,14 @@
...
@@ -15,10 +15,14 @@
#pragma once
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
struct
DimensionsTransform
{
struct
DimensionsTransform
{
using
DimVector
=
std
::
vector
<
int64_t
>
;
using
DimVector
=
std
::
vector
<
int64_t
>
;
typedef
void
(
*
MergeFunctor
)(
bool
&
,
std
::
vector
<
DimVector
>
&
,
DimVector
&
,
typedef
void
(
*
MergeFunctor
)(
bool
&
,
std
::
vector
<
DimVector
>
&
,
DimVector
&
,
...
@@ -161,202 +165,113 @@ struct DimensionsTransform {
...
@@ -161,202 +165,113 @@ struct DimensionsTransform {
}
}
};
};
struct
StridesCalculation
{
template
<
typename
T
,
int
VecSize
,
int
ShapeSize
,
bool
IsBoundary
=
false
>
std
::
vector
<
std
::
vector
<
uint32_t
>>
strides
;
__device__
__forceinline__
void
LoadData
(
std
::
vector
<
platform
::
FastDivMod
>
divmoders
;
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
const
kps
::
details
::
BroadcastConfig
<
ShapeSize
>
&
config
,
int
numel
,
int
num
,
private:
bool
need_broadcast
)
{
// To calculate the strides of each input_tensor.
// numel : whole num of output
__inline__
void
CalculateStrides
(
// num: how many data will be deal with in this time
int
N
,
int
dim_size
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
in_dims
)
{
if
(
need_broadcast
)
{
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
kps
::
ReadDataBc
<
T
,
VecSize
,
1
,
1
,
ShapeSize
,
IsBoundary
>
(
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
dst
,
src
,
block_offset
,
config
,
numel
,
1
,
1
);
strides
[
j
][
i
]
=
in_dims
[
j
][
i
]
==
1
?
0
:
strides
[
j
][
i
];
}
else
{
strides
[
j
][
i
]
=
kps
::
ReadData
<
T
,
VecSize
,
1
,
1
,
IsBoundary
>
(
dst
,
src
+
block_offset
,
num
);
(
i
!=
0
&&
strides
[
j
][
i
]
!=
0
)
?
std
::
accumulate
(
in_dims
[
j
].
begin
(),
in_dims
[
j
].
begin
()
+
i
,
1
,
std
::
multiplies
<
int64_t
>
())
:
strides
[
j
][
i
];
}
}
}
public:
explicit
StridesCalculation
(
const
int64_t
&
dim_size
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
in_dims
,
const
std
::
vector
<
int64_t
>
&
out_dims
)
{
const
auto
N
=
in_dims
.
size
();
divmoders
.
resize
(
dim_size
);
strides
.
resize
(
N
,
std
::
vector
<
uint32_t
>
(
dim_size
,
1
));
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
divmoders
[
i
]
=
platform
::
FastDivMod
(
out_dims
[
i
]);
}
CalculateStrides
(
N
,
dim_size
,
in_dims
);
}
};
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
struct
BroadcastArgsWrapper
{
using
InVecType
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
OutT
*
out_data
;
OutVecType
*
vec_out_data
;
const
InT
*
__restrict__
in_data
[
ET
];
const
InVecType
*
__restrict__
vec_in_data
[
ET
];
bool
no_broadcast
[
ET
];
platform
::
FastDivMod
divmoders
[
kDims
];
uint32_t
strides
[
ET
][
framework
::
DDim
::
kMaxRank
];
uint32_t
scalar_cal_offset
;
Functor
func
;
HOSTDEVICE
BroadcastArgsWrapper
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
int
scalar_cal_offset
,
Functor
func
,
const
StridesCalculation
&
offset_calculator
)
:
scalar_cal_offset
(
scalar_cal_offset
),
func
(
func
)
{
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
in_data
[
j
]
=
ins
[
j
]
->
data
<
InT
>
();
vec_in_data
[
j
]
=
reinterpret_cast
<
const
InVecType
*>
(
in_data
[
j
]);
no_broadcast
[
j
]
=
ins
[
j
]
->
dims
()
==
out
->
dims
()
?
true
:
false
;
memcpy
(
strides
[
j
],
offset_calculator
.
strides
[
j
].
data
(),
kDims
*
sizeof
(
uint32_t
));
}
out_data
=
out
->
data
<
OutT
>
();
vec_out_data
=
reinterpret_cast
<
OutVecType
*>
(
out_data
);
memcpy
(
divmoders
,
offset_calculator
.
divmoders
.
data
(),
kDims
*
sizeof
(
platform
::
FastDivMod
));
}
__device__
__forceinline__
uint32_t
GetOffsetByDivmod
(
int
idx
,
int
in_idx
)
{
uint32_t
offset
=
0
;
#pragma unroll(kDims)
for
(
int
i
=
0
;
i
<
kDims
;
++
i
)
{
auto
fast_divmoder
=
divmoders
[
i
].
Divmod
(
idx
);
idx
=
fast_divmoder
.
val
[
0
];
offset
+=
fast_divmoder
.
val
[
1
]
*
strides
[
in_idx
][
i
];
}
return
offset
;
}
__device__
__forceinline__
void
LoadVectorizedDataCommon
(
InVecType
*
vector_args
,
int
tid
,
int
idx
)
{
*
vector_args
=
vec_in_data
[
idx
][
tid
];
}
__device__
__forceinline__
void
LoadVectorizedDataByDivmod
(
InT
*
scalar_args
,
int
tid
,
int
idx
)
{
int
index
=
tid
*
VecSize
;
#pragma unroll(VecSize)
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
uint32_t
offset
=
GetOffsetByDivmod
(
index
+
i
,
idx
);
scalar_args
[
i
]
=
in_data
[
idx
][
offset
];
}
}
__device__
__forceinline__
void
LoadScalarizedDataCommon
(
InT
args
[],
int
tid
,
int
idx
)
{
args
[
idx
]
=
in_data
[
idx
][
tid
+
scalar_cal_offset
];
}
__device__
__forceinline__
void
LoadScalarizedDataByDivmod
(
InT
args
[],
int
tid
,
int
idx
)
{
auto
offset
=
GetOffsetByDivmod
(
tid
+
scalar_cal_offset
,
idx
);
args
[
idx
]
=
in_data
[
idx
][
offset
];
}
__device__
__forceinline__
void
LoadVectorizedData
(
InT
(
*
args
)[
VecSize
],
int
tid
)
{
#pragma unroll(ET)
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
if
(
no_broadcast
[
j
])
{
InVecType
*
vector_args
=
reinterpret_cast
<
InVecType
*>
(
args
[
j
]);
LoadVectorizedDataCommon
(
vector_args
,
tid
,
j
);
}
else
{
LoadVectorizedDataByDivmod
(
args
[
j
],
tid
,
j
);
}
}
}
}
}
__device__
__forceinline__
void
LoadScalarizedData
(
InT
args
[],
int
tid
)
{
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
int
ShapeSize
,
#pragma unroll(ET)
int
VecSize
,
typename
Functor
,
bool
IsBoundary
=
false
>
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
__device__
void
DealSegment
(
if
(
no_broadcast
[
j
])
{
const
framework
::
Array
<
const
InT
*
__restrict__
,
ET
>
&
in
,
OutT
*
out
,
LoadScalarizedDataCommon
(
args
,
tid
,
j
);
const
framework
::
Array
<
bool
,
MAX_INPUT_NUM
>
&
use_broadcast
,
uint32_t
numel
,
}
else
{
const
framework
::
Array
<
kps
::
details
::
BroadcastConfig
<
ShapeSize
>
,
LoadScalarizedDataByDivmod
(
args
,
tid
,
j
);
MAX_INPUT_NUM
>
&
configlists
,
}
int
num
,
Functor
func
)
{
}
InT
args
[
ET
][
VecSize
];
OutT
result
[
VecSize
];
int
block_offset
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
;
// load
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
LoadData
<
InT
,
VecSize
,
ShapeSize
,
IsBoundary
>
(
args
[
i
],
in
[
i
],
block_offset
,
configlists
[
i
],
numel
,
num
,
use_broadcast
[
i
]);
}
}
// compute
__device__
__forceinline__
void
StoreVectorizedData
(
OutVecType
vec_args_out
,
if
(
ET
==
kUnary
)
{
int
tid
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
vec_out_data
[
tid
]
=
vec_args_out
;
func
);
}
else
if
(
ET
==
kBinary
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
}
else
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
}
// compute
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
out
+
block_offset
,
result
,
num
);
}
__device__
__forceinline__
void
StoreScalarizedData
(
OutT
args_out
,
int
tid
)
{
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
int
ShapeSize
,
out_data
[
scalar_cal_offset
+
tid
]
=
args_out
;
int
VecSize
,
typename
Functor
>
__global__
void
BroadcastKernel
(
framework
::
Array
<
const
InT
*
__restrict__
,
ET
>
in
,
OutT
*
out
,
framework
::
Array
<
bool
,
MAX_INPUT_NUM
>
use_broadcast
,
uint32_t
numel
,
framework
::
Array
<
kps
::
details
::
BroadcastConfig
<
ShapeSize
>
,
MAX_INPUT_NUM
>
configlists
,
int
main_tid
,
int
tail_tid
,
Functor
func
)
{
int
block_offset
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
;
// data offset of this block
if
(
blockIdx
.
x
<
main_tid
)
{
int
num
=
blockDim
.
x
*
VecSize
;
// blockIdx.x < main_tid
DealSegment
<
ET
,
InT
,
OutT
,
ShapeSize
,
VecSize
,
Functor
,
false
>
(
in
,
out
,
use_broadcast
,
numel
,
configlists
,
num
,
func
);
}
else
{
// reminder
int
num
=
tail_tid
;
DealSegment
<
ET
,
InT
,
OutT
,
ShapeSize
,
VecSize
,
Functor
,
true
>
(
in
,
out
,
use_broadcast
,
numel
,
configlists
,
num
,
func
);
}
}
};
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsWrapper
,
ElementwiseType
ET
>
__device__
inline
void
ScalarizedBroadcastKernelImpl
(
BroadcastArgsWrapper
broadcast_wrapper
,
int
tid
)
{
InT
args
[
ET
];
OutT
args_out
;
broadcast_wrapper
.
LoadScalarizedData
(
args
,
tid
);
// Calcualtion of the in_tensor data.
args_out
=
broadcast_wrapper
.
func
(
args
);
broadcast_wrapper
.
StoreScalarizedData
(
args_out
,
tid
);
}
}
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsWrapper
,
template
<
typename
InT
,
typename
OutT
,
ElementwiseType
ET
,
int
VecSize
,
ElementwiseType
ET
,
int
VecSize
>
int
Size
,
typename
Functor
>
__device__
inline
void
VectorizedBroadcastKernelImpl
(
void
LaunchKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
BroadcastArgsWrapper
broadcast_wrapper
,
int
tid
)
{
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
framework
::
Tensor
*
out
,
Functor
func
,
OutVecType
args_out
;
DimensionsTransform
merge_dims
)
{
InT
ins
[
ET
]
;
int
numel
=
out
->
numel
()
;
InT
args
[
ET
][
VecSize
]
;
const
int
threads
=
256
;
broadcast_wrapper
.
LoadVectorizedData
(
args
,
tid
)
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
#pragma unroll(VecSize)
int
main_tid
=
numel
/
(
VecSize
*
threads
);
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
#pragma unroll(ET)
auto
stream
=
ctx
.
stream
();
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
OutT
*
out_data
=
out
->
data
<
OutT
>
();
ins
[
j
]
=
args
[
j
][
i
];
framework
::
Array
<
kps
::
details
::
BroadcastConfig
<
Size
>
,
MAX_INPUT_NUM
>
configlists
;
framework
::
Array
<
bool
,
MAX_INPUT_NUM
>
use_broadcast
;
framework
::
Array
<
const
InT
*
__restrict__
,
ET
>
ins_data
;
for
(
int
i
=
0
;
i
<
ET
;
i
++
)
{
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
ins_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
if
(
use_broadcast
[
i
])
{
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configlists
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Size
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
}
}
args_out
.
val
[
i
]
=
broadcast_wrapper
.
func
(
ins
);
}
}
broadcast_wrapper
.
StoreVectorizedData
(
args_out
,
tid
);
}
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsWrapper
,
BroadcastKernel
<
ET
,
InT
,
OutT
,
Size
,
VecSize
,
ElementwiseType
ET
,
int
VecSize
>
Functor
><<<
blocks
,
threads
,
0
,
stream
>>>
(
__global__
void
ElementwiseBroadcastKernel
(
ins_data
,
out_data
,
use_broadcast
,
numel
,
configlists
,
main_tid
,
tail_tid
,
BroadcastArgsWrapper
broadcast_wrapper
,
int
main_tid
,
int
tail_tid
)
{
func
);
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if
(
tid
<
main_tid
)
{
VectorizedBroadcastKernelImpl
<
InT
,
OutT
,
BroadcastArgsWrapper
,
ET
,
VecSize
>
(
broadcast_wrapper
,
tid
);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if
(
tid
<
tail_tid
)
{
ScalarizedBroadcastKernelImpl
<
InT
,
OutT
,
BroadcastArgsWrapper
,
ET
>
(
broadcast_wrapper
,
tid
);
}
}
}
template
<
typename
InT
,
typename
OutT
,
ElementwiseType
ET
,
int
VecSize
,
template
<
typename
InT
,
typename
OutT
,
ElementwiseType
ET
,
int
VecSize
,
...
@@ -365,98 +280,24 @@ void LaunchBroadcastKernelForDifferentDimSize(
...
@@ -365,98 +280,24 @@ void LaunchBroadcastKernelForDifferentDimSize(
const
platform
::
CUDADeviceContext
&
ctx
,
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
int
axis
,
Functor
func
)
{
int
axis
,
Functor
func
)
{
int
numel
=
out
->
numel
();
int
threads
=
GetThreadsConfig
(
ctx
,
numel
,
VecSize
);
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
main_tid
=
numel
/
VecSize
;
int
tail_tid
=
numel
%
VecSize
;
int
vec_len
=
main_tid
*
VecSize
;
auto
stream
=
ctx
.
stream
();
const
auto
merge_dims
=
DimensionsTransform
(
ins
,
out
->
dims
(),
axis
);
const
auto
merge_dims
=
DimensionsTransform
(
ins
,
out
->
dims
(),
axis
);
const
auto
offset_calculator
=
StridesCalculation
(
#define DIM_SIZE(size) \
merge_dims
.
dim_size
,
merge_dims
.
in_dims
,
merge_dims
.
out_dims
);
case size: { \
LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \
merge_dims); \
} break;
switch
(
merge_dims
.
dim_size
)
{
switch
(
merge_dims
.
dim_size
)
{
case
1
:
{
DIM_SIZE
(
1
);
auto
broadcast_wrapper
=
DIM_SIZE
(
2
);
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
1
>
(
DIM_SIZE
(
3
);
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
DIM_SIZE
(
4
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
DIM_SIZE
(
5
);
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
DIM_SIZE
(
6
);
broadcast_wrapper
,
main_tid
,
tail_tid
);
DIM_SIZE
(
7
);
break
;
DIM_SIZE
(
8
);
}
case
2
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
2
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
case
3
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
3
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
case
4
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
4
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
case
5
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
5
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
case
6
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
6
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
case
7
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
7
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
case
8
:
{
auto
broadcast_wrapper
=
BroadcastArgsWrapper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
8
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
InT
,
OutT
,
decltype
(
broadcast_wrapper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_wrapper
,
main_tid
,
tail_tid
);
break
;
}
default:
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.
\n
"
,
merge_dims
.
dim_size
,
framework
::
DDim
::
kMaxRank
));
}
}
}
#undef DIM_SIZE
}
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
...
@@ -528,5 +369,7 @@ void LaunchElementwiseCudaKernel(
...
@@ -528,5 +369,7 @@ void LaunchElementwiseCudaKernel(
}
}
}
}
#undef MAX_INPUT_NUM
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
eae4bf5b
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/fast_divmod.h"
...
@@ -26,6 +27,7 @@ limitations under the License. */
...
@@ -26,6 +27,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
,
kTernary
=
3
};
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
,
kTernary
=
3
};
/*
/*
...
@@ -67,121 +69,74 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
...
@@ -67,121 +69,74 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
return
vec_size
;
return
vec_size
;
}
}
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
>
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
,
struct
ElementwiseDataWrapper
{
typename
Functor
,
bool
IsBoundary
>
using
InVecType
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
__device__
void
DealSegment
(
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
const
framework
::
Array
<
const
InT
*
__restrict__
,
ET
>
&
in
,
OutT
*
out
,
int
num
,
Functor
func
)
{
const
InT
*
__restrict__
in_data
[
ET
];
int
data_offset
=
VecSize
*
blockIdx
.
x
*
blockDim
.
x
;
OutT
*
out_data
;
InT
args
[
ET
][
VecSize
];
uint32_t
scalar_cal_offset
;
OutT
result
[
VecSize
];
// load data
HOSTDEVICE
ElementwiseDataWrapper
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
uint32_t
scalar_cal_offset
)
:
scalar_cal_offset
(
scalar_cal_offset
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
in_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
}
out_data
=
(
*
outs
)[
0
]
->
data
<
OutT
>
();
}
inline
__device__
void
LoadVectorizedData
(
InVecType
vec_args
[],
int
tid
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
const
InVecType
*
in_vec_data
=
reinterpret_cast
<
const
InVecType
*>
(
in_data
[
i
]);
vec_args
[
i
]
=
in_vec_data
[
tid
];
}
}
inline
__device__
void
LoadScalarizedData
(
InT
args
[],
int
tid
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ET
;
i
++
)
{
args
[
i
]
=
in_data
[
i
][
tid
+
scalar_cal_offset
];
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
}
kps
::
ReadData
<
InT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
args
[
i
],
in
[
i
]
+
data_offset
,
}
num
);
inline
__device__
void
StoreVectorizedData
(
OutVecType
res
,
int
tid
)
{
OutVecType
*
out_vec
=
reinterpret_cast
<
OutVecType
*>
(
out_data
);
out_vec
[
tid
]
=
res
;
}
inline
__device__
void
StoreScalarizedData
(
OutT
res
,
int
tid
)
{
out_data
[
tid
+
scalar_cal_offset
]
=
res
;
}
}
};
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
ElementwiseWrapper
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
VectorizedKernelImpl
(
ElementwiseWrapper
data
,
Functor
func
,
int
tid
)
{
using
InVecType
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
InVecType
ins_vec
[
ET
];
OutVecType
out_vec
;
InT
*
ins_ptr
[
ET
];
InT
ins
[
ET
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ET
;
++
i
)
{
ins_ptr
[
i
]
=
reinterpret_cast
<
InT
*>
(
&
(
ins_vec
[
i
]));
}
// load
data
.
LoadVectorizedData
(
ins_vec
,
tid
);
// compute
// compute
#pragma unroll
if
(
ET
==
kUnary
)
{
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
#pragma unroll
func
);
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
}
else
if
(
ET
==
kBinary
)
{
ins
[
j
]
=
ins_ptr
[
j
][
i
];
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
}
args
[
1
],
func
);
out_vec
.
val
[
i
]
=
func
(
ins
);
}
else
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
}
// store
data
.
StoreVectorizedData
(
out_vec
,
tid
);
}
template
<
ElementwiseType
ET
,
typename
ElementwiseWrapper
,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
ScalarKernelImpl
(
ElementwiseWrapper
data
,
Functor
func
,
int
tid
)
{
InT
ins
[
ET
];
OutT
out
;
// load
data
.
LoadScalarizedData
(
ins
,
tid
);
// compute
out
=
func
(
ins
);
// store
// store
data
.
StoreScalarizedData
(
out
,
tid
);
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
out
+
data_offset
,
result
,
num
);
}
}
template
<
ElementwiseType
ET
,
typename
ElementwiseWrapper
,
typename
InT
,
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
typename
Functor
>
__global__
void
VectorizedKernel
(
ElementwiseWrapper
data
,
int
main_tid
,
__global__
void
ElementVectorizeKernel
(
int
tail_tid
,
Functor
func
)
{
framework
::
Array
<
const
InT
*
__restrict__
,
ET
>
in
,
OutT
*
out
,
int
size
,
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
Functor
func
)
{
int
data_offset
=
VecSize
*
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
<
main_tid
)
{
int
num
=
size
-
data_offset
;
VectorizedKernelImpl
<
ET
,
VecSize
,
ElementwiseWrapper
,
InT
,
OutT
,
Functor
>
(
// the num this time have to deal with
data
,
func
,
tid
);
if
(
VecSize
*
blockDim
.
x
>
num
)
{
// reminder segment
}
DealSegment
<
ET
,
VecSize
,
InT
,
OutT
,
Functor
,
true
>
(
in
,
out
,
num
,
func
);
if
(
tid
<
tail_tid
)
{
}
else
{
// complete segment
ScalarKernelImpl
<
ET
,
ElementwiseWrapper
,
InT
,
OutT
,
Functor
>
(
data
,
func
,
DealSegment
<
ET
,
VecSize
,
InT
,
OutT
,
Functor
,
false
>
(
in
,
out
,
num
,
func
);
tid
);
}
}
}
}
template
<
ElementwiseType
ET
,
typename
ElementwiseWrapper
,
typename
InT
,
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
typename
OutT
,
typename
Functor
>
int
VecSize
>
__global__
void
ScalarKernel
(
ElementwiseWrapper
data
,
int
numel
,
Functor
func
)
{
void
ElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
if
(
tid
<
numel
)
{
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
ScalarKernelImpl
<
ET
,
ElementwiseWrapper
,
InT
,
OutT
,
Functor
>
(
data
,
func
,
Functor
func
)
{
tid
);
auto
numel
=
ins
[
0
]
->
numel
();
int
block_size
=
GetThreadsConfig
(
ctx
,
numel
,
VecSize
);
int
grid_size
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
block_size
-
1
)
/
block_size
;
auto
stream
=
ctx
.
stream
();
OutT
*
out
=
(
*
outs
)[
0
]
->
data
<
OutT
>
();
framework
::
Array
<
const
InT
*
__restrict__
,
ET
>
in
;
for
(
int
i
=
0
;
i
<
ET
;
i
++
)
{
in
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
}
}
ElementVectorizeKernel
<
ET
,
VecSize
,
InT
,
OutT
,
Functor
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
in
,
out
,
numel
,
func
);
}
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
...
@@ -190,43 +145,17 @@ void LaunchSameDimsElementwiseCudaKernel(
...
@@ -190,43 +145,17 @@ void LaunchSameDimsElementwiseCudaKernel(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
// calculate the max vec_size for all ins and outs
// calculate the max vec_size for all ins and outs
auto
numel
=
ins
[
0
]
->
numel
();
int
vec_size
=
GetVectorizedSizeForIO
<
InT
,
OutT
>
(
ins
,
*
outs
);
int
vec_size
=
GetVectorizedSizeForIO
<
InT
,
OutT
>
(
ins
,
*
outs
);
int
block_size
=
GetThreadsConfig
(
ctx
,
numel
,
vec_size
);
int
grid_size
=
((
numel
+
vec_size
-
1
)
/
vec_size
+
block_size
-
1
)
/
block_size
;
int
main_tid
=
numel
/
vec_size
;
int
tail_tid
=
numel
%
vec_size
;
uint32_t
vec_len
=
main_tid
*
vec_size
;
// cuda kernel
auto
stream
=
ctx
.
stream
();
switch
(
vec_size
)
{
switch
(
vec_size
)
{
case
4
:
{
case
4
:
auto
data_wrapper
=
ElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
4
>
(
ctx
,
ins
,
outs
,
func
);
ElementwiseDataWrapper
<
ET
,
4
,
InT
,
OutT
>
(
ins
,
outs
,
vec_len
);
VectorizedKernel
<
ET
,
decltype
(
data_wrapper
),
InT
,
OutT
,
4
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
data_wrapper
,
main_tid
,
tail_tid
,
func
);
break
;
break
;
}
case
2
:
case
2
:
{
ElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
2
>
(
ctx
,
ins
,
outs
,
func
);
auto
data_wrapper
=
ElementwiseDataWrapper
<
ET
,
2
,
InT
,
OutT
>
(
ins
,
outs
,
vec_len
);
VectorizedKernel
<
ET
,
decltype
(
data_wrapper
),
InT
,
OutT
,
2
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
data_wrapper
,
main_tid
,
tail_tid
,
func
);
break
;
break
;
}
case
1
:
case
1
:
{
ElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
1
>
(
ctx
,
ins
,
outs
,
func
);
auto
data_wrapper
=
ElementwiseDataWrapper
<
ET
,
1
,
InT
,
OutT
>
(
ins
,
outs
,
0
);
ScalarKernel
<
ET
,
decltype
(
data_wrapper
),
InT
,
OutT
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
data_wrapper
,
numel
,
func
);
break
;
break
;
}
default:
{
default:
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported vectorized size: %d !"
,
vec_size
));
"Unsupported vectorized size: %d !"
,
vec_size
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录