Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a1980d9c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1980d9c
编写于
1月 18, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
1月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
break the circular dependency between reduce and elementwise (#38951)
上级
30845734
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
393 addition
and
381 deletion
+393
-381
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+3
-2
paddle/pten/infermeta/binary.cc
paddle/pten/infermeta/binary.cc
+1
-1
paddle/pten/kernels/cpu/elementwise.h
paddle/pten/kernels/cpu/elementwise.h
+1
-0
paddle/pten/kernels/funcs/common_shape.h
paddle/pten/kernels/funcs/common_shape.h
+61
-2
paddle/pten/kernels/funcs/cuda_kernel_config.h
paddle/pten/kernels/funcs/cuda_kernel_config.h
+1
-0
paddle/pten/kernels/funcs/elementwise_base.h
paddle/pten/kernels/funcs/elementwise_base.h
+305
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+4
-2
paddle/pten/kernels/gpu/elementwise.h
paddle/pten/kernels/gpu/elementwise.h
+13
-310
paddle/pten/kernels/gpu/reduce.h
paddle/pten/kernels/gpu/reduce.h
+2
-2
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+2
-2
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
a1980d9c
...
...
@@ -59,8 +59,9 @@ void LaunchSameDimsElementwiseCudaKernel(
for
(
int
i
=
0
;
i
<
pt_outputs_tmp
.
size
();
i
++
)
{
pt_outputs
.
push_back
(
pt_outputs_tmp
[
i
].
get
());
}
pten
::
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
ctx
,
pt_inputs
,
&
pt_outputs
,
func
);
pten
::
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
ctx
,
pt_inputs
,
&
pt_outputs
,
func
);
}
}
// namespace operators
...
...
paddle/pten/infermeta/binary.cc
浏览文件 @
a1980d9c
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/kernels/funcs/
elementwise_bas
e.h"
#include "paddle/pten/kernels/funcs/
common_shap
e.h"
namespace
pten
{
...
...
paddle/pten/kernels/cpu/elementwise.h
浏览文件 @
a1980d9c
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
#include "paddle/fluid/operators/math/blas.h"
...
...
paddle/pten/kernels/funcs/common_shape.h
浏览文件 @
a1980d9c
...
...
@@ -19,8 +19,8 @@ limitations under the License. */
namespace
pten
{
namespace
funcs
{
inline
void
SetXShape
(
const
DenseTensor
&
x
,
DenseTensor
*
xshape
)
{
const
auto
&
in_dims
=
x
.
meta
().
dims
;
inline
void
SetXShape
(
const
DenseTensor
&
x
,
DenseTensor
*
xshape
)
{
const
auto
&
in_dims
=
x
.
meta
().
dims
;
std
::
vector
<
int64_t
>
xshape_dims
(
in_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
...
...
@@ -30,5 +30,64 @@ inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) {
xshape
->
ResetLoD
(
x
.
meta
().
lod
);
}
inline
void
GetBroadcastDimsArrays
(
const
DDim
&
x_dims
,
const
DDim
&
y_dims
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
const
int
max_dim
,
const
int
axis
)
{
PADDLE_ENFORCE_GE
(
axis
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_LT
(
axis
,
max_dim
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
max_dim
,
axis
));
if
(
x_dims
.
size
()
>
y_dims
.
size
())
{
std
::
fill
(
y_dims_array
,
y_dims_array
+
axis
,
1
);
if
(
axis
+
y_dims
.
size
()
<
max_dim
)
{
std
::
fill
(
y_dims_array
+
axis
+
y_dims
.
size
(),
y_dims_array
+
max_dim
,
1
);
}
std
::
copy
(
x_dims
.
Get
(),
x_dims
.
Get
()
+
x_dims
.
size
(),
x_dims_array
);
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
+
axis
);
}
else
{
std
::
fill
(
x_dims_array
,
x_dims_array
+
axis
,
1
);
if
(
axis
+
x_dims
.
size
()
<
max_dim
)
{
std
::
fill
(
x_dims_array
+
axis
+
x_dims
.
size
(),
x_dims_array
+
max_dim
,
1
);
}
std
::
copy
(
x_dims
.
Get
(),
x_dims
.
Get
()
+
x_dims
.
size
(),
x_dims_array
+
axis
);
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
);
}
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
x_dims_array
[
i
]
==
y_dims_array
[
i
]
||
x_dims_array
[
i
]
<=
1
||
y_dims_array
[
i
]
<=
1
,
true
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Broadcast dimension mismatch. Operands could "
"not be broadcast together with the shape of X = [%s] and "
"the shape of Y = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d."
,
x_dims
,
y_dims
,
x_dims_array
[
i
],
y_dims_array
[
i
],
i
));
if
((
x_dims_array
[
i
]
>
1
||
y_dims_array
[
i
]
>
1
)
||
(
x_dims_array
[
i
]
==
1
&&
y_dims_array
[
i
]
==
1
))
{
out_dims_array
[
i
]
=
(
std
::
max
)(
x_dims_array
[
i
],
y_dims_array
[
i
]);
}
else
{
out_dims_array
[
i
]
=
-
1
;
}
}
}
}
// namespace funcs
}
// namespace pten
paddle/pten/kernels/funcs/cuda_kernel_config.h
浏览文件 @
a1980d9c
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
...
...
paddle/pten/kernels/funcs/elementwise_base.h
浏览文件 @
a1980d9c
...
...
@@ -19,9 +19,26 @@ limitations under the License. */
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/dense_tensor.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/function_traits.h"
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
#endif
namespace
pten
{
namespace
funcs
{
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
,
kTernary
=
3
,
kAny
=
-
1
};
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template
<
class
T
,
int
Num
>
using
ConditionalT
=
typename
std
::
conditional_t
<
Num
==
1
,
T
,
paddle
::
framework
::
Array
<
T
,
Num
>>
;
namespace
funcs
{
using
DDim
=
paddle
::
framework
::
DDim
;
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
...
...
@@ -343,65 +360,6 @@ inline void get_mid_dims(const DDim &x_dims,
}
}
inline
void
GetBroadcastDimsArrays
(
const
DDim
&
x_dims
,
const
DDim
&
y_dims
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
const
int
max_dim
,
const
int
axis
)
{
PADDLE_ENFORCE_GE
(
axis
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_LT
(
axis
,
max_dim
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
max_dim
,
axis
));
if
(
x_dims
.
size
()
>
y_dims
.
size
())
{
std
::
fill
(
y_dims_array
,
y_dims_array
+
axis
,
1
);
if
(
axis
+
y_dims
.
size
()
<
max_dim
)
{
std
::
fill
(
y_dims_array
+
axis
+
y_dims
.
size
(),
y_dims_array
+
max_dim
,
1
);
}
std
::
copy
(
x_dims
.
Get
(),
x_dims
.
Get
()
+
x_dims
.
size
(),
x_dims_array
);
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
+
axis
);
}
else
{
std
::
fill
(
x_dims_array
,
x_dims_array
+
axis
,
1
);
if
(
axis
+
x_dims
.
size
()
<
max_dim
)
{
std
::
fill
(
x_dims_array
+
axis
+
x_dims
.
size
(),
x_dims_array
+
max_dim
,
1
);
}
std
::
copy
(
x_dims
.
Get
(),
x_dims
.
Get
()
+
x_dims
.
size
(),
x_dims_array
+
axis
);
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
);
}
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
x_dims_array
[
i
]
==
y_dims_array
[
i
]
||
x_dims_array
[
i
]
<=
1
||
y_dims_array
[
i
]
<=
1
,
true
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Broadcast dimension mismatch. Operands could "
"not be broadcast together with the shape of X = [%s] and "
"the shape of Y = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d."
,
x_dims
,
y_dims
,
x_dims_array
[
i
],
y_dims_array
[
i
],
i
));
if
((
x_dims_array
[
i
]
>
1
||
y_dims_array
[
i
]
>
1
)
||
(
x_dims_array
[
i
]
==
1
&&
y_dims_array
[
i
]
==
1
))
{
out_dims_array
[
i
]
=
(
std
::
max
)(
x_dims_array
[
i
],
y_dims_array
[
i
]);
}
else
{
out_dims_array
[
i
]
=
-
1
;
}
}
}
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
...
...
@@ -432,5 +390,292 @@ void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
())});
}
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
InT
,
typename
OutT
>
int
GetVectorizedSizeForTensors
(
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
const
std
::
vector
<
DenseTensor
*>
&
outs
)
{
int
vec_size
=
4
;
for
(
auto
iter
=
ins
.
begin
();
iter
!=
ins
.
end
();
++
iter
)
{
vec_size
=
std
::
min
<
int
>
(
vec_size
,
paddle
::
platform
::
GetVectorizedSize
((
*
iter
)
->
data
<
InT
>
()));
}
for
(
auto
iter
=
outs
.
begin
();
iter
!=
outs
.
end
();
++
iter
)
{
vec_size
=
std
::
min
<
int
>
(
vec_size
,
paddle
::
platform
::
GetVectorizedSize
((
*
iter
)
->
data
<
OutT
>
()));
}
return
vec_size
;
}
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
,
bool
CallElementwiseAny
=
false
>
struct
ElementwisePrimitiveCaller
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
);
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
result
,
args
,
func
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
func
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCaller
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
)
{
OutT
dst
[
NumOuts
][
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NumOuts
;
++
j
)
{
dst
[
j
][
i
]
=
(
src
[
i
])[
j
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
i
]
+
block_offset
,
dst
[
i
],
num
);
}
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
0
]
+
block_offset
,
src
,
num
);
}
};
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
,
bool
IsBoundary
>
__device__
void
VectorizedElementwiseKernelImpl
(
const
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
&
in
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
num
,
int
data_offset
,
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
kps
::
ReadData
<
InT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
args
[
i
],
in
[
i
]
+
data_offset
,
num
);
}
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
ElementwisePrimitiveCaller
<
InT
,
ConditionalT
<
OutT
,
NumOuts
>
,
VecSize
,
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
data_offset
,
num
);
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
>
__global__
void
VectorizedElementwiseKernel
(
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
size
,
int
main_offset
,
Functor
func
)
{
int
data_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
for
(;
data_offset
<
main_offset
;
data_offset
+=
stride
)
{
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
false
>
(
ins
,
outs
,
VecSize
*
BLOCK_NUM_X
,
data_offset
,
func
);
}
int
num
=
size
-
data_offset
;
if
(
num
>
0
)
{
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
true
>
(
ins
,
outs
,
num
,
data_offset
,
func
);
}
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
>
void
ElementwiseCudaKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
auto
numel
=
ins
[
0
]
->
numel
();
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins_data
;
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs_data
;
for
(
int
i
=
0
;
i
<
Arity
;
++
i
)
{
ins_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
}
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
outs_data
[
i
]
=
(
*
outs
)[
i
]
->
mutable_data
<
OutT
>
();
}
#ifdef PADDLE_WITH_XPU2
int
block_size
=
64
;
int
grid_size
=
8
;
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
int
main_offset
=
(
numel
/
(
VecSize
*
block_size
))
*
VecSize
*
block_size
;
VectorizedElementwiseKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
ins_data
,
outs_data
,
numel
,
main_offset
,
func
);
#else
auto
gpu_config
=
GetGpuLaunchConfig1D
(
ctx
,
numel
,
VecSize
);
int
main_offset
=
(
numel
/
(
VecSize
*
gpu_config
.
GetBlockSize
()))
*
VecSize
*
gpu_config
.
GetBlockSize
();
auto
stream
=
ctx
.
stream
();
VectorizedElementwiseKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
><<<
gpu_config
.
block_per_grid
,
gpu_config
.
thread_per_block
,
0
,
stream
>>>
(
ins_data
,
outs_data
,
numel
,
main_offset
,
func
);
#endif
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchSameDimsElementwiseCudaKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
using
Traits
=
paddle
::
platform
::
FunctionTraits
<
Functor
>
;
const
int
kArity
=
Traits
::
has_pointer_args
?
static_cast
<
int
>
(
ET
)
:
Traits
::
arity
;
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
kArity
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d."
,
ins
.
size
(),
kArity
));
PADDLE_ENFORCE_EQ
(
outs
->
size
(),
NumOuts
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, of functions is %d."
,
outs
->
size
(),
NumOuts
));
if
(
NumOuts
>
1
)
{
for
(
int
i
=
1
;
i
<
NumOuts
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
(
*
outs
)[
i
]
->
dims
(),
(
*
outs
)[
0
]
->
dims
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"The shape of each output tensor shall be identical yet, "
"but %dth output tensor`s shape is not."
,
i
));
}
}
// calculate the max vec_size for all ins and outs
int
vec_size
=
GetVectorizedSizeForTensors
<
InT
,
OutT
>
(
ins
,
*
outs
);
switch
(
vec_size
)
{
case
4
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
4
>
(
ctx
,
ins
,
outs
,
func
);
break
;
case
2
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
2
>
(
ctx
,
ins
,
outs
,
func
);
break
;
case
1
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
1
>
(
ctx
,
ins
,
outs
,
func
);
break
;
default:
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unimplemented
(
"Unsupported vectorized size: %d !"
,
vec_size
));
break
;
}
}
}
#endif
}
// namespace funcs
}
// namespace pten
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
a1980d9c
...
...
@@ -17,9 +17,9 @@
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
...
...
@@ -44,7 +44,9 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx,
inputs
.
emplace_back
(
&
x
);
outputs
.
emplace_back
(
out
);
out
->
mutable_data
<
OutT
>
();
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
InT
,
OutT
>
(
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
InT
,
OutT
>
(
dev_ctx
,
inputs
,
&
outputs
,
CastFuctor
<
InT
,
OutT
>
());
}
...
...
paddle/pten/kernels/gpu/elementwise.h
浏览文件 @
a1980d9c
...
...
@@ -14,12 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/function_traits.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
#include "paddle/pten/kernels/funcs/cuda_kernel_config.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
...
...
@@ -39,301 +34,7 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
} while (0)
namespace
pten
{
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
,
kTernary
=
3
,
kAny
=
-
1
};
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template
<
class
T
,
int
Num
>
using
ConditionalT
=
typename
std
::
conditional_t
<
Num
==
1
,
T
,
paddle
::
framework
::
Array
<
T
,
Num
>>
;
// FORWARD CODE
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
,
bool
CallElementwiseAny
=
false
>
struct
ElementwisePrimitiveCaller
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
);
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
result
,
args
,
func
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
func
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCaller
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
)
{
OutT
dst
[
NumOuts
][
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NumOuts
;
++
j
)
{
dst
[
j
][
i
]
=
(
src
[
i
])[
j
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
i
]
+
block_offset
,
dst
[
i
],
num
);
}
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
0
]
+
block_offset
,
src
,
num
);
}
};
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
,
bool
IsBoundary
>
__device__
void
VectorizedElementwiseKernelImpl
(
const
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
&
in
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
num
,
int
data_offset
,
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
kps
::
ReadData
<
InT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
args
[
i
],
in
[
i
]
+
data_offset
,
num
);
}
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
ElementwisePrimitiveCaller
<
InT
,
ConditionalT
<
OutT
,
NumOuts
>
,
VecSize
,
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
data_offset
,
num
);
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
>
__global__
void
VectorizedElementwiseKernel
(
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
size
,
int
main_offset
,
Functor
func
)
{
int
data_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
for
(;
data_offset
<
main_offset
;
data_offset
+=
stride
)
{
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
false
>
(
ins
,
outs
,
VecSize
*
BLOCK_NUM_X
,
data_offset
,
func
);
}
int
num
=
size
-
data_offset
;
if
(
num
>
0
)
{
VectorizedElementwiseKernelImpl
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
true
>
(
ins
,
outs
,
num
,
data_offset
,
func
);
}
}
template
<
typename
InT
,
typename
OutT
>
int
GetVectorizedSizeForTensors
(
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
const
std
::
vector
<
DenseTensor
*>
&
outs
)
{
int
vec_size
=
4
;
for
(
auto
iter
=
ins
.
begin
();
iter
!=
ins
.
end
();
++
iter
)
{
vec_size
=
std
::
min
<
int
>
(
vec_size
,
paddle
::
platform
::
GetVectorizedSize
((
*
iter
)
->
data
<
InT
>
()));
}
for
(
auto
iter
=
outs
.
begin
();
iter
!=
outs
.
end
();
++
iter
)
{
vec_size
=
std
::
min
<
int
>
(
vec_size
,
paddle
::
platform
::
GetVectorizedSize
((
*
iter
)
->
data
<
OutT
>
()));
}
return
vec_size
;
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
int
Arity
,
int
NumOuts
,
int
VecSize
>
void
ElementwiseCudaKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
auto
numel
=
ins
[
0
]
->
numel
();
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins_data
;
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs_data
;
for
(
int
i
=
0
;
i
<
Arity
;
++
i
)
{
ins_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
}
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
outs_data
[
i
]
=
(
*
outs
)[
i
]
->
mutable_data
<
OutT
>
();
}
#ifdef PADDLE_WITH_XPU2
int
block_size
=
64
;
int
grid_size
=
8
;
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
int
main_offset
=
(
numel
/
(
VecSize
*
block_size
))
*
VecSize
*
block_size
;
VectorizedElementwiseKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
ins_data
,
outs_data
,
numel
,
main_offset
,
func
);
#else
auto
gpu_config
=
GetGpuLaunchConfig1D
(
ctx
,
numel
,
VecSize
);
int
main_offset
=
(
numel
/
(
VecSize
*
gpu_config
.
GetBlockSize
()))
*
VecSize
*
gpu_config
.
GetBlockSize
();
auto
stream
=
ctx
.
stream
();
VectorizedElementwiseKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
><<<
gpu_config
.
block_per_grid
,
gpu_config
.
thread_per_block
,
0
,
stream
>>>
(
ins_data
,
outs_data
,
numel
,
main_offset
,
func
);
#endif
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchSameDimsElementwiseCudaKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
using
Traits
=
paddle
::
platform
::
FunctionTraits
<
Functor
>
;
const
int
kArity
=
Traits
::
has_pointer_args
?
static_cast
<
int
>
(
ET
)
:
Traits
::
arity
;
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
kArity
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d."
,
ins
.
size
(),
kArity
));
PADDLE_ENFORCE_EQ
(
outs
->
size
(),
NumOuts
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, of functions is %d."
,
outs
->
size
(),
NumOuts
));
if
(
NumOuts
>
1
)
{
for
(
int
i
=
1
;
i
<
NumOuts
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
(
*
outs
)[
i
]
->
dims
(),
(
*
outs
)[
0
]
->
dims
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"The shape of each output tensor shall be identical yet, "
"but %dth output tensor`s shape is not."
,
i
));
}
}
// calculate the max vec_size for all ins and outs
int
vec_size
=
GetVectorizedSizeForTensors
<
InT
,
OutT
>
(
ins
,
*
outs
);
switch
(
vec_size
)
{
case
4
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
4
>
(
ctx
,
ins
,
outs
,
func
);
break
;
case
2
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
2
>
(
ctx
,
ins
,
outs
,
func
);
break
;
case
1
:
ElementwiseCudaKernel
<
InT
,
OutT
,
Functor
,
kArity
,
NumOuts
,
1
>
(
ctx
,
ins
,
outs
,
func
);
break
;
default:
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unimplemented
(
"Unsupported vectorized size: %d !"
,
vec_size
));
break
;
}
}
}
struct
DimensionsTransform
{
using
DimVector
=
std
::
vector
<
int64_t
>
;
typedef
void
(
*
MergeFunctor
)(
...
...
@@ -538,14 +239,15 @@ __device__ void ElementwiseBroadcastKernelImpl(
}
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
ElementwisePrimitiveCaller
<
InT
,
ConditionalT
<
OutT
,
NumOuts
>
,
VecSize
,
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
pten
::
funcs
::
ElementwisePrimitiveCaller
<
InT
,
ConditionalT
<
OutT
,
NumOuts
>
,
VecSize
,
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
pten
::
funcs
::
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
);
}
...
...
@@ -864,8 +566,9 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
dims_size
.
emplace_back
(
in
->
dims
().
size
());
}
if
(
no_broadcast_flag
)
{
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
ctx
,
ins
,
outs
,
func
);
pten
::
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
ctx
,
ins
,
outs
,
func
);
}
else
{
axis
=
axis
==
-
1
?
*
std
::
max_element
(
dims_size
.
begin
(),
dims_size
.
end
())
-
...
...
paddle/pten/kernels/gpu/reduce.h
浏览文件 @
a1980d9c
...
...
@@ -45,7 +45,7 @@ namespace cub = hipcub;
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/
gpu/elementwi
se.h"
#include "paddle/pten/kernels/
funcs/elementwise_ba
se.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
...
...
@@ -1095,7 +1095,7 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x,
if
(
config
.
reduce_num
==
1
)
{
std
::
vector
<
const
DenseTensor
*>
inputs
=
{
&
x
};
std
::
vector
<
DenseTensor
*>
outputs
=
{
y
};
pten
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
Tx
,
Ty
>
(
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
Tx
,
Ty
>
(
*
dev_ctx
,
inputs
,
&
outputs
,
transform
);
return
;
}
...
...
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
a1980d9c
...
...
@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace
pten
{
...
...
@@ -55,7 +55,7 @@ void ScaleKernel(const Context& dev_ctx,
inputs
.
emplace_back
(
&
x
);
outputs
.
emplace_back
(
out
);
out
->
mutable_data
<
T
>
();
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
>
(
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
>
(
dev_ctx
,
inputs
,
&
outputs
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录