Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d6aea4ac
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d6aea4ac
编写于
5月 24, 2021
作者:
L
limingshu
提交者:
GitHub
5月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support OutType tmeplate argument in elementwise_broadcast branch (#33060)
上级
a6dc68b7
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
89 addition
and
75 deletion
+89
-75
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+89
-75
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
d6aea4ac
...
...
@@ -196,15 +196,16 @@ struct StridesCalculation {
}
};
template
<
typename
T
,
typename
Functor
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
struct
BroadcastArgsWarpper
{
using
VecType
=
CudaAlignedVector
<
T
,
VecSize
>
;
using
InVecType
=
CudaAlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
CudaAlignedVector
<
OutT
,
VecSize
>
;
T
*
out_data
;
VecType
*
vec_out_data
;
const
T
*
__restrict__
in_data
[
ET
];
const
VecType
*
__restrict__
vec_in_data
[
ET
];
Out
T
*
out_data
;
Out
VecType
*
vec_out_data
;
const
In
T
*
__restrict__
in_data
[
ET
];
const
In
VecType
*
__restrict__
vec_in_data
[
ET
];
bool
no_broadcast
[
ET
];
FastDivMod
divmoders
[
kDims
];
uint32_t
strides
[
ET
][
framework
::
DDim
::
kMaxRank
];
...
...
@@ -217,14 +218,14 @@ struct BroadcastArgsWarpper {
const
StridesCalculation
&
offset_calculator
)
:
scalar_cal_offset
(
scalar_cal_offset
),
func
(
func
)
{
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
in_data
[
j
]
=
ins
[
j
]
->
data
<
T
>
();
vec_in_data
[
j
]
=
reinterpret_cast
<
const
VecType
*>
(
in_data
[
j
]);
in_data
[
j
]
=
ins
[
j
]
->
data
<
In
T
>
();
vec_in_data
[
j
]
=
reinterpret_cast
<
const
In
VecType
*>
(
in_data
[
j
]);
no_broadcast
[
j
]
=
ins
[
j
]
->
dims
()
==
out
->
dims
()
?
true
:
false
;
memcpy
(
strides
[
j
],
offset_calculator
.
strides
[
j
].
data
(),
kDims
*
sizeof
(
uint32_t
));
}
out_data
=
out
->
data
<
T
>
();
vec_out_data
=
reinterpret_cast
<
VecType
*>
(
out_data
);
out_data
=
out
->
data
<
Out
T
>
();
vec_out_data
=
reinterpret_cast
<
Out
VecType
*>
(
out_data
);
memcpy
(
divmoders
,
offset_calculator
.
divmoders
.
data
(),
kDims
*
sizeof
(
FastDivMod
));
}
...
...
@@ -241,12 +242,12 @@ struct BroadcastArgsWarpper {
return
offset
;
}
__device__
__forceinline__
void
LoadVectorizedDataCommon
(
VecType
*
vector_args
,
int
tid
,
int
idx
)
{
__device__
__forceinline__
void
LoadVectorizedDataCommon
(
InVecType
*
vector_args
,
int
tid
,
int
idx
)
{
*
vector_args
=
vec_in_data
[
idx
][
tid
];
}
__device__
__forceinline__
void
LoadVectorizedDataByDivmod
(
T
*
scalar_args
,
__device__
__forceinline__
void
LoadVectorizedDataByDivmod
(
In
T
*
scalar_args
,
int
tid
,
int
idx
)
{
int
index
=
tid
*
VecSize
;
#pragma unroll(VecSize)
...
...
@@ -256,23 +257,23 @@ struct BroadcastArgsWarpper {
}
}
__device__
__forceinline__
void
LoadScalarizedDataCommon
(
T
args
[],
int
tid
,
__device__
__forceinline__
void
LoadScalarizedDataCommon
(
In
T
args
[],
int
tid
,
int
idx
)
{
args
[
idx
]
=
in_data
[
idx
][
tid
+
scalar_cal_offset
];
}
__device__
__forceinline__
void
LoadScalarizedDataByDivmod
(
T
args
[],
int
tid
,
int
idx
)
{
__device__
__forceinline__
void
LoadScalarizedDataByDivmod
(
InT
args
[]
,
int
tid
,
int
idx
)
{
auto
offset
=
GetOffsetByDivmod
(
tid
+
scalar_cal_offset
,
idx
);
args
[
idx
]
=
in_data
[
idx
][
offset
];
}
__device__
__forceinline__
void
LoadVectorizedData
(
T
(
*
args
)[
VecSize
],
__device__
__forceinline__
void
LoadVectorizedData
(
In
T
(
*
args
)[
VecSize
],
int
tid
)
{
#pragma unroll(ET)
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
if
(
no_broadcast
[
j
])
{
VecType
*
vector_args
=
reinterpret_cast
<
VecType
*>
(
args
[
j
]);
InVecType
*
vector_args
=
reinterpret_cast
<
In
VecType
*>
(
args
[
j
]);
LoadVectorizedDataCommon
(
vector_args
,
tid
,
j
);
}
else
{
LoadVectorizedDataByDivmod
(
args
[
j
],
tid
,
j
);
...
...
@@ -280,7 +281,7 @@ struct BroadcastArgsWarpper {
}
}
__device__
__forceinline__
void
LoadScalarizedData
(
T
args
[],
int
tid
)
{
__device__
__forceinline__
void
LoadScalarizedData
(
In
T
args
[],
int
tid
)
{
#pragma unroll(ET)
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
if
(
no_broadcast
[
j
])
{
...
...
@@ -291,36 +292,39 @@ struct BroadcastArgsWarpper {
}
}
__device__
__forceinline__
void
StoreVectorizedData
(
T
(
*
args
)[
VecSize
]
,
__device__
__forceinline__
void
StoreVectorizedData
(
OutVecType
vec_args_out
,
int
tid
)
{
VecType
*
args_out
=
reinterpret_cast
<
VecType
*>
(
args
[
0
]);
vec_out_data
[
tid
]
=
*
args_out
;
vec_out_data
[
tid
]
=
vec_args_out
;
}
__device__
__forceinline__
void
StoreScalarizedData
(
T
args
[]
,
int
tid
)
{
out_data
[
scalar_cal_offset
+
tid
]
=
args
[
0
]
;
__device__
__forceinline__
void
StoreScalarizedData
(
OutT
args_out
,
int
tid
)
{
out_data
[
scalar_cal_offset
+
tid
]
=
args
_out
;
}
};
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
>
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
>
__device__
inline
void
ScalarizedBroadcastKernelImpl
(
BroadcastArgsWarpper
broadcast_warpper
,
int
tid
)
{
T
args
[
ET
];
InT
args
[
ET
];
OutT
args_out
;
broadcast_warpper
.
LoadScalarizedData
(
args
,
tid
);
#pragma unroll(ET)
for
(
int
j
=
1
;
j
<
ET
;
++
j
)
{
args
[
0
]
=
broadcast_warpper
.
func
(
args
);
args
_out
=
broadcast_warpper
.
func
(
args
);
}
broadcast_warpper
.
StoreScalarizedData
(
args
,
tid
);
broadcast_warpper
.
StoreScalarizedData
(
args
_out
,
tid
);
}
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
int
VecSize
>
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
int
VecSize
>
__device__
inline
void
VectorizedBroadcastKernelImpl
(
BroadcastArgsWarpper
broadcast_warpper
,
int
tid
)
{
T
ins
[
ET
];
T
args
[
ET
][
VecSize
];
using
OutVecType
=
CudaAlignedVector
<
OutT
,
VecSize
>
;
OutVecType
args_out
;
InT
ins
[
ET
];
InT
args
[
ET
][
VecSize
];
broadcast_warpper
.
LoadVectorizedData
(
args
,
tid
);
#pragma unroll(VecSize)
...
...
@@ -329,13 +333,13 @@ __device__ inline void VectorizedBroadcastKernelImpl(
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
ins
[
j
]
=
args
[
j
][
i
];
}
args
[
0
]
[
i
]
=
broadcast_warpper
.
func
(
ins
);
args
_out
.
val
[
i
]
=
broadcast_warpper
.
func
(
ins
);
}
broadcast_warpper
.
StoreVectorizedData
(
args
,
tid
);
broadcast_warpper
.
StoreVectorizedData
(
args
_out
,
tid
);
}
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
int
VecSize
>
template
<
typename
InT
,
typename
OutT
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
int
VecSize
>
__global__
void
ElementwiseBroadcastKernel
(
BroadcastArgsWarpper
broadcast_warpper
,
int
main_tid
,
int
tail_tid
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
...
...
@@ -345,19 +349,20 @@ __global__ void ElementwiseBroadcastKernel(
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if
(
tid
<
main_tid
)
{
VectorizedBroadcastKernelImpl
<
T
,
BroadcastArgsWarpper
,
ET
,
VecSize
>
(
VectorizedBroadcastKernelImpl
<
InT
,
Out
T
,
BroadcastArgsWarpper
,
ET
,
VecSize
>
(
broadcast_warpper
,
tid
);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if
(
tid
<
tail_tid
)
{
ScalarizedBroadcastKernelImpl
<
T
,
BroadcastArgsWarpper
,
ET
>
(
ScalarizedBroadcastKernelImpl
<
InT
,
Out
T
,
BroadcastArgsWarpper
,
ET
>
(
broadcast_warpper
,
tid
);
}
}
template
<
typename
T
,
ElementwiseType
ET
,
int
VecSize
,
typename
Functor
>
template
<
typename
InT
,
typename
OutT
,
ElementwiseType
ET
,
int
VecSize
,
typename
Functor
>
void
LaunchBroadcastKernelForDifferentDimSize
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
...
...
@@ -376,65 +381,73 @@ void LaunchBroadcastKernelForDifferentDimSize(
switch
(
merge_dims
.
dim_size
)
{
case
1
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
1
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
1
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
2
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
2
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
2
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
3
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
3
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
3
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
4
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
4
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
4
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
5
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
5
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
5
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
6
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
6
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
6
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
7
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
7
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
7
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
}
case
8
:
{
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
8
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
InT
,
OutT
,
Functor
,
ET
,
VecSize
,
8
>
(
ins
,
out
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpper
),
ET
,
ElementwiseBroadcastKernel
<
InT
,
Out
T
,
decltype
(
broadcast_warpper
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
broadcast_warpper
,
main_tid
,
tail_tid
);
break
;
...
...
@@ -448,7 +461,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
}
}
template
<
ElementwiseType
ET
,
typename
T
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
Out
T
,
typename
Functor
>
void
LaunchBroadcastElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
...
...
@@ -457,27 +470,27 @@ void LaunchBroadcastElementwiseCudaKernel(
int
in_vec_size
=
4
;
framework
::
Tensor
*
out
=
(
*
outs
)[
0
];
for
(
auto
*
in
:
ins
)
{
auto
temp_size
=
GetVectorizedSizeImpl
<
T
>
(
in
->
data
<
T
>
());
auto
temp_size
=
GetVectorizedSizeImpl
<
InT
>
(
in
->
data
<
In
T
>
());
in_vec_size
=
in
->
dims
()
==
out
->
dims
()
?
std
::
min
(
temp_size
,
in_vec_size
)
:
in_vec_size
;
}
int
out_vec_size
=
GetVectorizedSizeImpl
<
T
>
(
out
->
data
<
T
>
());
int
out_vec_size
=
GetVectorizedSizeImpl
<
OutT
>
(
out
->
data
<
Out
T
>
());
int
vec_size
=
std
::
min
(
out_vec_size
,
in_vec_size
);
switch
(
vec_size
)
{
case
4
:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
4
>
(
ctx
,
ins
,
out
,
axis
,
func
);
LaunchBroadcastKernelForDifferentDimSize
<
InT
,
OutT
,
ET
,
4
>
(
ctx
,
ins
,
out
,
axis
,
func
);
break
;
}
case
2
:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
2
>
(
ctx
,
ins
,
out
,
axis
,
func
);
LaunchBroadcastKernelForDifferentDimSize
<
InT
,
OutT
,
ET
,
2
>
(
ctx
,
ins
,
out
,
axis
,
func
);
break
;
}
case
1
:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
1
>
(
ctx
,
ins
,
out
,
axis
,
func
);
LaunchBroadcastKernelForDifferentDimSize
<
InT
,
OutT
,
ET
,
1
>
(
ctx
,
ins
,
out
,
axis
,
func
);
break
;
}
default:
{
...
...
@@ -502,8 +515,9 @@ void LaunchElementwiseCudaKernel(
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
InT
,
OutType
>
(
cuda_ctx
,
ins
,
outs
,
func
);
}
else
{
LaunchBroadcastElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
InT
>
(
cuda_ctx
,
ins
,
outs
,
axis
,
func
);
LaunchBroadcastElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
InT
,
OutType
>
(
cuda_ctx
,
ins
,
outs
,
axis
,
func
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录