Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3eaf8d2c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3eaf8d2c
编写于
1月 11, 2022
作者:
N
niuliling123
提交者:
GitHub
1月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modified Kernel Primitive API and elementwise for xpu2 #38688
上级
2bed9b9c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
164 addition
and
147 deletion
+164
-147
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+3
-5
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+1
-2
paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h
...d/operators/kernel_primitives/datamover_primitives_xpu2.h
+86
-86
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
+14
-1
paddle/fluid/platform/hostdevice.h
paddle/fluid/platform/hostdevice.h
+8
-1
paddle/pten/kernels/gpu/elementwise.h
paddle/pten/kernels/gpu/elementwise.h
+52
-52
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
3eaf8d2c
...
...
@@ -25,8 +25,7 @@ namespace kps = paddle::operators::kernel_primitives;
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchBroadcastElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
std
::
vector
<
const
pten
::
DenseTensor
*>
pt_inputs
;
std
::
vector
<
pten
::
DenseTensor
*>
pt_outputs
;
...
...
@@ -58,8 +57,7 @@ void LaunchBroadcastElementwiseCudaKernel(
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
cuda_ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
std
::
vector
<
const
pten
::
DenseTensor
*>
pt_inputs
;
std
::
vector
<
pten
::
DenseTensor
*>
pt_outputs
;
...
...
@@ -85,7 +83,7 @@ void LaunchElementwiseCudaKernel(
pt_outputs
.
push_back
(
pt_outputs_tmp
[
i
].
get
());
}
pten
::
LaunchElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
c
uda_c
tx
,
pt_inputs
,
&
pt_outputs
,
axis
,
func
);
ctx
,
pt_inputs
,
&
pt_outputs
,
axis
,
func
);
}
}
// namespace operators
...
...
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
3eaf8d2c
...
...
@@ -35,8 +35,7 @@ using ElementwiseType = pten::ElementwiseType;
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchSameDimsElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
std
::
vector
<
const
pten
::
DenseTensor
*>
pt_inputs
;
std
::
vector
<
pten
::
DenseTensor
*>
pt_outputs
;
...
...
paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h
浏览文件 @
3eaf8d2c
...
...
@@ -32,42 +32,50 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0].
*/
#pragma pack(4)
template
<
int
kDims
>
struct
BroadcastConfig
{
uint32_t
stride
_in
[
framework
::
DDim
::
kMaxRank
];
uint32_t
stride
_out
[
framework
::
DDim
::
kMaxRank
];
uint32_t
shape_in
[
framework
::
DDim
::
kMaxRank
];
int
strides
_in
[
framework
::
DDim
::
kMaxRank
];
int
strides
_out
[
framework
::
DDim
::
kMaxRank
];
int
in_dim
[
framework
::
DDim
::
kMaxRank
];
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
const
std
::
vector
<
int64_t
>&
in_dims
,
int
dim_size
)
{
std
::
vector
<
uint32_t
>
strides_in
;
std
::
vector
<
uint32_t
>
strides_out
;
std
::
vector
<
uint32_t
>
shapes_in
;
strides_out
.
resize
(
dim_size
,
1
);
strides_in
.
resize
(
dim_size
,
1
);
shapes_in
.
resize
(
dim_size
,
1
);
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
shape_in
[
i
]
=
in_dims
[
dim_size
-
i
-
1
];
std
::
vector
<
int
>
strides_in_tmp
;
std
::
vector
<
int
>
strides_out_tmp
;
std
::
vector
<
int
>
dim_tmp
;
strides_in_tmp
.
resize
(
dim_size
,
1
);
strides_out_tmp
.
resize
(
dim_size
,
1
);
dim_tmp
.
resize
(
dim_size
,
1
);
for
(
int
i
=
1
;
i
<
dim_size
;
i
++
)
{
strides_in_tmp
[
i
]
=
strides_in_tmp
[
i
-
1
]
*
in_dims
[
i
-
1
];
strides_out_tmp
[
i
]
=
strides_out_tmp
[
i
-
1
]
*
out_dims
[
i
-
1
];
}
for
(
int
i
=
1
;
i
<
dim_size
-
1
;
++
i
)
{
strides_out
[
dim_size
-
i
-
1
]
=
std
::
accumulate
(
out_dims
.
begin
(),
out_dims
.
begin
()
+
i
,
1
,
std
::
multiplies
<
int64_t
>
())
strides_in
[
dim_size
-
i
-
1
]
=
std
::
accumulate
(
in_dims
.
begin
(),
in_dims
.
begin
()
+
i
,
1
,
std
::
multiplies
<
int64_t
>
())
for
(
int
i
=
0
;
i
<
dim_size
;
i
++
)
{
dim_tmp
[
i
]
=
in_dims
[
i
];
}
memcpy
(
stride_in
,
strides_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
stride_out
,
strides_out
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
shape_in
,
shapes_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
strides_in
,
strides_in_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
strides_out
,
strides_out_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
in_dim
,
dim_tmp
.
data
(),
kDims
*
sizeof
(
int
));
}
__device__
inline
int
operator
()(
int
index_output
)
const
{
int
index_src
=
0
;
#pragma unroll
for
(
int
i
=
kDims
-
1
;
i
>=
0
;
--
i
)
{
int
tmp_index
=
(
index_output
/
strides_out
[
i
]);
index_output
=
index_output
-
tmp_index
*
strides_out
[
i
];
index_src
+=
(
tmp_index
%
in_dim
[
i
])
*
strides_in
[
i
];
}
return
index_src
;
}
};
#pragma pack()
}
// namespace details
...
...
@@ -99,12 +107,12 @@ struct BroadcastConfig {
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__
force
inline__
void
ReadData
(
Ty
*
dst
,
const
Tx
_global_ptr_
*
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
__device__
__inline__
void
ReadData
(
Ty
*
dst
,
const
Tx
_global_ptr_
*
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
core_id
();
int
left_size_nx
=
size_nx
-
thread_offset
;
__local__
T
in_temp
[
1
];
__local__
T
x
in_temp
[
1
];
// Each branch is added for better performance
if
(
NX
==
1
&&
NY
==
1
)
{
// for NX == 1 and NY == 1
if
(
IsBoundary
)
{
...
...
@@ -168,7 +176,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src,
* init_data: Initial value.
*/
template
<
typename
T
,
int
NX
>
__device__
__
force
inline__
void
Init
(
T
*
dst
,
T
init_data
)
{
__device__
__inline__
void
Init
(
T
*
dst
,
T
init_data
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
...
...
@@ -197,8 +205,8 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
* size: The current block needs to load size data continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__
force
inline__
void
ReadData
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
num
)
{
__device__
__inline__
void
ReadData
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
num
)
{
int
thread_offset
=
core_id
()
*
NX
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * NX > num
...
...
@@ -241,10 +249,11 @@ __device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src,
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
stride_nx
,
int
stride_ny
)
{
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
stride_nx
,
int
stride_ny
)
{
uint32_t
thread_offset
=
block_offset
+
core_id
();
uint32_t
index_src
=
0
;
__local__
T
in_temp
[
1
];
...
...
@@ -256,16 +265,11 @@ __device__ __forceinline__ void ReadDataBc(
uint32_t
index_output
=
thread_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
if
(
index_output
>=
(
uint32_t
)
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
uint32_t
tmp
=
index_output
/
config
.
stride_out
[
i
];
index_output
=
index_output
-
tmp
*
config
.
stride_out
[
i
];
index_src
+=
(
tmp
%
config
.
shape_in
[
i
])
*
config
.
stride_in
[
i
];
}
index_src
=
config
(
index_output
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
nx
+
ny
*
NX
]
=
in_temp
[
0
];
}
...
...
@@ -305,33 +309,34 @@ __device__ __forceinline__ void ReadDataBc(
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
__local__
T
in_temp
[
1
];
__device__
__inline__
void
ReadDataReduce
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
__local__
Tx
in_temp
[
1
];
int
thread_offset
=
0
;
int
left_size_nx
=
size_nx
;
int
left_size_ny
=
size_ny
;
int
left_idx
=
0
;
if
(
reduce_last_dim
)
{
thread_offset
=
block_offset
+
core_id
();
left_
size_nx
-=
thread_offset
;
thread_offset
=
core_id
();
left_
idx
=
0
;
}
else
{
thread_offset
=
block_offset
+
core_id
()
;
left_
size_ny
-=
thread_offset
;
thread_offset
=
0
;
left_
idx
=
0
;
}
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
(
ny
*
stride_ny
>=
left_
size_ny
)
{
if
(
thread_offset
>=
size_ny
)
{
break
;
}
}
uint32_t
index_src
=
index_cal
(
thread_offset
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
ny
]
=
in_temp
[
0
]
;
uint32_t
index_src
=
index_cal
(
thread_offset
+
block_offset
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
x
));
dst
[
ny
]
=
static_cast
<
Ty
>
(
func
(
in_temp
[
0
]))
;
thread_offset
+=
stride_ny
;
}
}
else
{
...
...
@@ -340,17 +345,16 @@ __device__ __forceinline__ void ReadDataReduce(
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
((
ny
*
stride_ny
>=
left_
size_ny
)
||
(
nx
*
stride_nx
>=
left_
size_nx
))
{
if
((
thread_offset
>=
size_ny
)
||
(
left_idx
+
nx
*
stride_nx
>=
size_nx
))
{
break
;
}
}
uint32_t
index_src
=
index_cal
(
thread_offset
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
nx
+
ny
*
NX
]
=
in_temp
[
0
]
;
uint32_t
index_src
=
index_cal
(
thread_offset
+
block_offset
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
x
));
dst
[
nx
+
ny
*
NX
]
=
static_cast
<
Ty
>
(
func
(
in_temp
[
0
]))
;
thread_offset
+=
stride_ny
;
}
thread_offset
+=
stride_nx
;
}
}
}
...
...
@@ -421,9 +425,9 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__
force
inline__
void
WriteData
(
Ty
_global_ptr_
*
dst
,
const
Tx
*
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
__device__
__inline__
void
WriteData
(
Ty
_global_ptr_
*
dst
,
const
Tx
*
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
core_id
();
int
left_size_nx
=
size_nx
-
thread_offset
;
__local__
Ty
in_temp
[
1
];
...
...
@@ -433,11 +437,11 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
if
(
IsBoundary
)
{
if
(
left_size_nx
>
0
)
{
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
0
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
,
sizeof
(
T
));
LM2GM
(
in_temp
,
dst
+
thread_offset
,
sizeof
(
T
y
));
}
}
else
{
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
0
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
,
sizeof
(
T
));
LM2GM
(
in_temp
,
dst
+
thread_offset
,
sizeof
(
T
y
));
}
}
else
if
(
NX
==
1
)
{
#pragma unroll
...
...
@@ -449,7 +453,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
}
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
idy
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idy
*
stride_ny
,
sizeof
(
T
));
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idy
*
stride_ny
,
sizeof
(
T
y
));
}
}
else
if
(
NY
==
1
)
{
// for NY == 1 and NX != 1
#pragma unroll
...
...
@@ -461,7 +465,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
}
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
idx
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idx
*
stride_nx
,
sizeof
(
T
));
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idx
*
stride_nx
,
sizeof
(
T
y
));
}
}
else
{
// for NX != 1 and NY != 1
#pragma unroll
...
...
@@ -480,7 +484,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
}
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
idx
+
idy
*
NX
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idx
*
stride_nx
+
idy
*
stride_ny
,
sizeof
(
T
));
sizeof
(
T
y
));
}
}
}
...
...
@@ -498,7 +502,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
* init_data: The register pointer of init data, the size is NX.
*/
template
<
typename
T
,
int
NX
,
bool
IsBoundary
=
false
>
__device__
__
force
inline__
void
Init
(
T
*
dst
,
T
*
init_data
,
int
num
)
{
__device__
__inline__
void
Init
(
T
*
dst
,
T
*
init_data
,
int
num
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
if
(
IsBoundary
)
{
...
...
@@ -535,30 +539,26 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__
forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
)
{
uint32_t
thread_offset
=
block_offset
+
core_id
()
*
NX
;
uint32_t
index_src
=
0
;
__local__
T
in_temp
[
1
]
;
__device__
__
inline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
)
{
int
thread_offset
=
block_offset
+
core_id
()
*
NX
;
int
index_src
=
0
;
__local__
T
in_temp
;
#pragma unroll
for
(
uint32_
t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_
t
index_output
=
thread_offset
+
nx
;
for
(
in
t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
in
t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
uint32_t
tmp
=
index_output
/
config
.
stride_out
[
i
];
index_output
=
index_output
-
tmp
*
config
.
stride_out
[
i
];
index_src
+=
(
tmp
%
config
.
shape_in
[
i
])
*
config
.
stride_in
[
i
];
}
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
nx
+
ny
*
NX
]
=
in_temp
[
0
];
index_src
=
config
(
index_output
);
GM2LM
(
src
+
index_src
,
&
in_temp
,
sizeof
(
T
));
dst
[
nx
]
=
in_temp
;
}
}
...
...
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
浏览文件 @
3eaf8d2c
...
...
@@ -13,11 +13,18 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#ifdef PADDLE_WITH_XPU2
#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives_xpu2.h"
#define KPStream XPUStream
#define KPDevice paddle::platform::XPUDeviceContext
#define _ptr_ _global_ptr_
#define __forceinline__ __inline__
#define __restrict__
#define THREAD_ID_X core_id()
#define THREAD_ID_Y 0
#define THREAD_ID_Z 0
...
...
@@ -36,6 +43,12 @@
#else
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#define KPStream gpuStream_t
#define KPDevice paddle::platform::CUDADeviceContext
#define _ptr_
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Z threadIdx.z
...
...
paddle/fluid/platform/hostdevice.h
浏览文件 @
3eaf8d2c
...
...
@@ -17,7 +17,14 @@
#include <hip/hip_runtime.h>
#endif
#if (defined(__CUDACC__) || defined(__HIPCC__))
#ifdef __xpu_kp__
#include <xpu/runtime.h>
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#endif
#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__))
#define HOSTDEVICE __host__ __device__
#define DEVICE __device__
#define HOST __host__
...
...
paddle/pten/kernels/gpu/elementwise.h
浏览文件 @
3eaf8d2c
...
...
@@ -86,7 +86,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCaller
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
)
{
...
...
@@ -109,7 +109,7 @@ struct ElementwiseWriteDataCaller {
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
paddle
::
framework
::
Array
<
OutT
*
,
1
>
outs
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
)
{
...
...
@@ -126,8 +126,8 @@ template <typename InT,
int
VecSize
,
bool
IsBoundary
>
__device__
void
VectorizedElementwiseKernelImpl
(
const
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
&
in
,
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
const
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
&
in
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
num
,
int
data_offset
,
Functor
func
)
{
...
...
@@ -161,8 +161,8 @@ template <typename InT,
int
NumOuts
,
int
VecSize
>
__global__
void
VectorizedElementwiseKernel
(
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
ins
,
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
size
,
int
main_offset
,
Functor
func
)
{
...
...
@@ -212,17 +212,13 @@ template <typename InT,
int
Arity
,
int
NumOuts
,
int
VecSize
>
void
ElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
void
ElementwiseCudaKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
auto
numel
=
ins
[
0
]
->
numel
();
int
block_size
=
funcs
::
GetThreadsConfig
(
ctx
,
numel
,
VecSize
);
int
grid_size
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
block_size
-
1
)
/
block_size
;
auto
stream
=
ctx
.
stream
();
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
ins_data
;
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs_data
;
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins_data
;
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs_data
;
for
(
int
i
=
0
;
i
<
Arity
;
++
i
)
{
ins_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
();
...
...
@@ -231,8 +227,9 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
outs_data
[
i
]
=
(
*
outs
)[
i
]
->
mutable_data
<
OutT
>
();
}
#ifdef PADDLE_WITH_XPU2
block_size
=
128
;
grid_size
=
8
;
int
block_size
=
64
;
int
grid_size
=
8
;
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
int
main_offset
=
(
numel
/
(
VecSize
*
block_size
))
*
VecSize
*
block_size
;
VectorizedElementwiseKernel
<
InT
,
OutT
,
...
...
@@ -242,7 +239,11 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
VecSize
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
ins_data
,
outs_data
,
numel
,
main_offset
,
func
);
#else
int
block_size
=
funcs
::
GetThreadsConfig
(
ctx
,
numel
,
VecSize
);
int
grid_size
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
block_size
-
1
)
/
block_size
;
int
main_offset
=
(
numel
/
(
VecSize
*
block_size
))
*
VecSize
*
block_size
;
auto
stream
=
ctx
.
stream
();
VectorizedElementwiseKernel
<
InT
,
OutT
,
Functor
,
...
...
@@ -259,7 +260,7 @@ template <ElementwiseType ET,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchSameDimsElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
...
...
@@ -471,12 +472,12 @@ struct DimensionsTransform {
template
<
typename
T
,
int
VecSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
LoadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
const
_ptr_
T
*
src
,
uint32_t
block_offset
,
const
kps
::
details
::
BroadcastConfig
<
Rank
>
&
config
,
int
numel
,
int
num
,
bool
need_broadcast
)
{
int
need_broadcast
)
{
// numel : whole num of output
// num: how many data will be deal with in this time
if
(
need_broadcast
)
{
...
...
@@ -496,9 +497,9 @@ template <typename InT,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
void
ElementwiseBroadcastKernelImpl
(
const
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
&
ins
,
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
const
paddle
::
framework
::
Array
<
bool
,
Arity
>
&
use_broadcast
,
const
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
&
ins
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
const
paddle
::
framework
::
Array
<
int
,
Arity
>
&
use_broadcast
,
uint32_t
numel
,
const
paddle
::
framework
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
...
...
@@ -540,9 +541,9 @@ template <typename InT,
int
VecSize
,
int
Rank
>
__global__
void
ElementwiseBroadcastKernel
(
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
ins
,
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs
,
paddle
::
framework
::
Array
<
bool
,
Arity
>
use_broadcast
,
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins
,
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
paddle
::
framework
::
Array
<
int
,
Arity
>
use_broadcast
,
uint32_t
numel
,
paddle
::
framework
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
...
...
@@ -570,7 +571,8 @@ __global__ void ElementwiseBroadcastKernel(
block_offset
,
func
);
}
if
(
block_offset
<
numel
)
{
int
num
=
numel
-
block_offset
;
if
(
num
>
0
)
{
ElementwiseBroadcastKernelImpl
<
InT
,
OutT
,
Functor
,
...
...
@@ -579,7 +581,7 @@ __global__ void ElementwiseBroadcastKernel(
VecSize
,
Rank
,
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
func
);
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
func
);
}
#else
if
(
block_offset
<
main_offset
)
{
...
...
@@ -619,23 +621,16 @@ template <typename InT,
int
NumOuts
,
int
VecSize
,
int
Rank
>
void
LaunchKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
void
LaunchKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
,
DimensionsTransform
merge_dims
)
{
int
numel
=
(
*
outs
)[
0
]
->
numel
();
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
auto
stream
=
ctx
.
stream
();
paddle
::
framework
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
;
paddle
::
framework
::
Array
<
bool
,
Arity
>
use_broadcast
;
paddle
::
framework
::
Array
<
const
InT
*
__restrict__
,
Arity
>
ins_data
;
paddle
::
framework
::
Array
<
OutT
*
,
NumOuts
>
outs_data
;
paddle
::
framework
::
Array
<
int
,
Arity
>
use_broadcast
;
paddle
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins_data
;
paddle
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs_data
;
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
outs_data
[
i
]
=
(
*
outs
)[
i
]
->
mutable_data
<
OutT
>
();
...
...
@@ -643,7 +638,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
ins_data
[
i
]
=
ins
[
i
]
->
data
<
InT
>
(
);
ins_data
[
i
]
=
(
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
()
);
if
(
use_broadcast
[
i
])
{
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
...
...
@@ -654,10 +649,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
}
#ifdef PADDLE_WITH_XPU2
threads
=
128
;
blocks
=
8
;
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
tail_tid
=
numel
%
(
VecSize
*
threads
);
const
int
threads
=
64
;
const
int
blocks
=
8
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
ElementwiseBroadcastKernel
<
InT
,
OutT
,
Functor
,
...
...
@@ -673,6 +669,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
tail_tid
,
func
);
#else
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
auto
stream
=
ctx
.
stream
();
ElementwiseBroadcastKernel
<
InT
,
OutT
,
Functor
,
...
...
@@ -698,7 +699,7 @@ template <typename InT,
int
NumOuts
,
int
VecSize
>
void
LaunchBroadcastKernelForDifferentVecSize
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
int
axis
,
...
...
@@ -737,7 +738,7 @@ template <ElementwiseType ET,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchBroadcastElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
,
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
int
axis
,
...
...
@@ -835,12 +836,11 @@ template <ElementwiseType ET,
typename
OutT
,
typename
Functor
,
int
NumOuts
=
1
>
void
LaunchElementwiseCudaKernel
(
const
paddle
::
platform
::
CUDADeviceContext
&
cuda_ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
void
LaunchElementwiseCudaKernel
(
const
KPDevice
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
std
::
vector
<
int
>
dims_size
;
bool
no_broadcast_flag
=
true
;
for
(
auto
*
in
:
ins
)
{
...
...
@@ -849,14 +849,14 @@ void LaunchElementwiseCudaKernel(
}
if
(
no_broadcast_flag
)
{
LaunchSameDimsElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
c
uda_c
tx
,
ins
,
outs
,
func
);
ctx
,
ins
,
outs
,
func
);
}
else
{
axis
=
axis
==
-
1
?
*
std
::
max_element
(
dims_size
.
begin
(),
dims_size
.
end
())
-
*
std
::
min_element
(
dims_size
.
begin
(),
dims_size
.
end
())
:
axis
;
LaunchBroadcastElementwiseCudaKernel
<
ET
,
InT
,
OutT
,
Functor
,
NumOuts
>
(
c
uda_c
tx
,
ins
,
outs
,
axis
,
func
);
ctx
,
ins
,
outs
,
axis
,
func
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录