Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
48f061fb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
48f061fb
编写于
12月 28, 2021
作者:
L
limingshu
提交者:
GitHub
12月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support multi-output feature for elementwise (#38410)
* first commit * pass ctest of elementwise_div_grad
上级
85f5d264
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
107 addition
and
48 deletion
+107
-48
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+7
-5
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+4
-3
paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h
paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h
+7
-3
paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h
...ernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h
+13
-14
paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h
...n/kernels/hybird/cuda/elementwise/elementwise_common.cu.h
+2
-2
paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h
...els/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h
+74
-21
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
48f061fb
...
...
@@ -162,7 +162,8 @@ struct DimensionsTransform {
}
};
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchBroadcastElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
...
...
@@ -190,11 +191,12 @@ void LaunchBroadcastElementwiseCudaKernel(
for
(
int
i
=
0
;
i
<
pt_outputs_tmp
.
size
();
i
++
)
{
pt_outputs
.
push_back
(
pt_outputs_tmp
[
i
].
get
());
}
pten
::
LaunchBroadcastElementwiseCudaKernel
<
ET
,
InT
,
OutT
>
(
pten
::
LaunchBroadcastElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
ctx
,
pt_inputs
,
&
pt_outputs
,
axis
,
func
);
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
cuda_ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
...
...
@@ -222,8 +224,8 @@ void LaunchElementwiseCudaKernel(
for
(
int
i
=
0
;
i
<
pt_outputs_tmp
.
size
();
i
++
)
{
pt_outputs
.
push_back
(
pt_outputs_tmp
[
i
].
get
());
}
pten
::
LaunchElementwiseCudaKernel
<
ET
,
InT
,
OutT
>
(
cuda_ctx
,
pt_inputs
,
&
pt_outputs
,
axis
,
func
);
pten
::
LaunchElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
cuda_ctx
,
pt_inputs
,
&
pt_outputs
,
axis
,
func
);
}
}
// namespace operators
...
...
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
48f061fb
...
...
@@ -38,7 +38,8 @@ namespace kps = paddle::operators::kernel_primitives;
using
ElementwiseType
=
pten
::
ElementwiseType
;
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchSameDimsElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
...
...
@@ -66,8 +67,8 @@ void LaunchSameDimsElementwiseCudaKernel(
for
(
int
i
=
0
;
i
<
pt_outputs_tmp
.
size
();
i
++
)
{
pt_outputs
.
push_back
(
pt_outputs_tmp
[
i
].
get
());
}
pten
::
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
>
(
ctx
,
pt_inputs
,
&
pt_outputs
,
func
);
pten
::
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
ctx
,
pt_inputs
,
&
pt_outputs
,
func
);
}
}
// namespace operators
...
...
paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h
浏览文件 @
48f061fb
...
...
@@ -19,7 +19,11 @@ limitations under the License. */
namespace
pten
{
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
cuda_ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
...
...
@@ -33,14 +37,14 @@ void LaunchElementwiseCudaKernel(
dims_size
.
emplace_back
(
in
->
dims
().
size
());
}
if
(
no_broadcast_flag
)
{
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
>
(
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
cuda_ctx
,
ins
,
outs
,
func
);
}
else
{
axis
=
axis
==
-
1
?
*
std
::
max_element
(
dims_size
.
begin
(),
dims_size
.
end
())
-
*
std
::
min_element
(
dims_size
.
begin
(),
dims_size
.
end
())
:
axis
;
LaunchBroadcastElementwiseCudaKernel
<
ET
,
InT
,
OutT
>
(
LaunchBroadcastElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
cuda_ctx
,
ins
,
outs
,
axis
,
func
);
}
}
...
...
paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h
浏览文件 @
48f061fb
...
...
@@ -208,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
int
block_offset
,
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
OutType
<
OutT
,
NumOuts
>
result
[
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
...
...
@@ -224,7 +224,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
ElementwisePrimitiveCaller
<
InT
,
OutType
<
OutT
,
NumOuts
>
,
ConditionalT
<
OutT
,
NumOuts
>
,
VecSize
,
Functor
,
Arity
,
...
...
@@ -455,20 +455,19 @@ void LaunchBroadcastElementwiseCudaKernel(
"is %d, the arity of functor is %d."
,
ins
.
size
(),
kArity
));
PADDLE_ENFORCE_
EQ
(
kArity
,
2
,
PADDLE_ENFORCE_
LE
(
kArity
,
ElementwiseType
::
kTernary
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Currently only broadcast of
binary is supported an
d "
"verified, but received %d."
,
"Currently only broadcast of
ternary is supporte
d "
"
and
verified, but received %d."
,
kArity
));
PADDLE_ENFORCE_EQ
(
outs
->
size
(),
NumOuts
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, number of functions is %d."
,
outs
->
size
(),
NumOuts
));
PADDLE_ENFORCE_EQ
(
outs
->
size
(),
NumOuts
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, of functions is %d."
,
outs
->
size
(),
NumOuts
));
int
in_vec_size
=
4
;
int
out_vec_size
=
4
;
if
(
NumOuts
>
1
)
{
...
...
paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h
浏览文件 @
48f061fb
...
...
@@ -27,7 +27,7 @@ enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template
<
class
T
,
int
Num
>
using
OutType
=
using
ConditionalT
=
typename
std
::
conditional_t
<
Num
==
1
,
T
,
paddle
::
framework
::
Array
<
T
,
Num
>>
;
template
<
typename
InT
,
...
...
@@ -86,7 +86,7 @@ template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct
ElementwiseWriteDataCaller
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
OutType
<
OutT
,
NumOuts
>
src
[
VecSize
],
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
)
{
OutT
dst
[
NumOuts
][
VecSize
];
...
...
paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h
浏览文件 @
48f061fb
...
...
@@ -55,16 +55,17 @@ template <typename InT,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
,
bool
IsBoundary
>
__device__
void
VectorizedElementwiseKernelImpl
(
const
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
&
in
,
OutT
*
out
,
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
int
num
,
int
data_offset
,
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
OutT
result
[
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
...
...
@@ -73,36 +74,53 @@ __device__ void VectorizedElementwiseKernelImpl(
args
[
i
],
in
[
i
]
+
data_offset
,
num
);
}
const
bool
kCallElementwiseAny
=
const
expr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
ElementwisePrimitiveCaller
<
InT
,
OutT
,
ConditionalT
<
OutT
,
NumOuts
>
,
VecSize
,
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
out
+
data_offset
,
result
,
num
);
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
data_offset
,
num
);
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
VecSize
>
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
>
__global__
void
VectorizedElementwiseKernel
(
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
ins
,
OutT
*
out
,
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
int
size
,
int
main_offset
,
Functor
func
)
{
int
data_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
for
(;
data_offset
<
main_offset
;
data_offset
+=
stride
)
{
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
VecSize
,
false
>
(
ins
,
out
,
VecSize
*
BLOCK_NUM_X
,
data_offset
,
func
);
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
false
>
(
ins
,
outs
,
VecSize
*
BLOCK_NUM_X
,
data_offset
,
func
);
}
int
num
=
size
-
data_offset
;
if
(
num
>
0
)
{
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
VecSize
,
true
>
(
ins
,
out
,
num
,
data_offset
,
func
);
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
true
>
(
ins
,
outs
,
num
,
data_offset
,
func
);
}
}
...
...
@@ -121,7 +139,12 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
return
vec_size
;
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
VecSize
>
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
>
void
ElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
...
...
@@ -131,11 +154,15 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
int
grid_size
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
block_size
-
1
)
/
block_size
;
auto
stream
=
ctx
.
stream
();
OutT
*
out_data
=
(
*
outs
)[
0
]
->
mutable_data
<
OutT
>
();
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
ins_data
;
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs_data
;
for
(
int
i
=
0
;
i
<
Arity
;
++
i
)
{
ins_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
}
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
outs_data
[
i
]
=
(
*
outs
)[
i
]
->
mutable_data
<
OutT
>
();
}
#ifdef PADDLE_WITH_XPU2
block_size
=
128
;
grid_size
=
8
;
...
...
@@ -144,20 +171,26 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
ins_data
,
out_data
,
numel
,
main_offset
,
func
);
ins_data
,
out
s
_data
,
numel
,
main_offset
,
func
);
#else
int
main_offset
=
(
numel
/
(
VecSize
*
block_size
))
*
VecSize
*
block_size
;
VectorizedElementwiseKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
ins_data
,
out_data
,
numel
,
main_offset
,
func
);
ins_data
,
out
s
_data
,
numel
,
main_offset
,
func
);
#endif
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchSameDimsElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
...
...
@@ -174,19 +207,39 @@ void LaunchSameDimsElementwiseCudaKernel(
"is %d, the arity of functor is %d."
,
ins
.
size
(),
kArity
));
PADDLE_ENFORCE_EQ
(
outs
->
size
(),
NumOuts
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, of functions is %d."
,
outs
->
size
(),
NumOuts
));
if
(
NumOuts
>
1
)
{
for
(
int
i
=
1
;
i
<
NumOuts
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
(
*
outs
)[
i
]
->
dims
(),
(
*
outs
)[
0
]
->
dims
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"The shape of each output tensor shall be identical yet, "
"but %dth output tensor`s shape is not."
,
i
));
}
}
// calculate the max vec_size for all ins and outs
int
vec_size
=
GetVectorizedSizeForTensors
<
InT
,
OutT
>
(
ins
,
*
outs
);
switch
(
vec_size
)
{
case
4
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
4
>
(
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
4
>
(
ctx
,
ins
,
outs
,
func
);
break
;
case
2
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
2
>
(
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
2
>
(
ctx
,
ins
,
outs
,
func
);
break
;
case
1
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
1
>
(
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
1
>
(
ctx
,
ins
,
outs
,
func
);
break
;
default:
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录