Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8501fb00
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8501fb00
编写于
5月 16, 2022
作者:
N
niuliling123
提交者:
GitHub
5月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete rank switch in broadcast_function.h for compile (#42645)
上级
8ffebb5a
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
192 addition
and
385 deletion
+192
-385
paddle/fluid/operators/fused/attn_bias_add.cu.h
paddle/fluid/operators/fused/attn_bias_add.cu.h
+5
-6
paddle/phi/kernels/funcs/broadcast_function.h
paddle/phi/kernels/funcs/broadcast_function.h
+135
-219
paddle/phi/kernels/primitive/datamover_primitives.h
paddle/phi/kernels/primitive/datamover_primitives.h
+13
-54
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
+33
-105
paddle/phi/kernels/primitive/kernel_primitives.h
paddle/phi/kernels/primitive/kernel_primitives.h
+6
-1
未找到文件。
paddle/fluid/operators/fused/attn_bias_add.cu.h
浏览文件 @
8501fb00
...
@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
...
@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__
void
BroadcastKernelBinary
(
__global__
void
BroadcastKernelBinary
(
const
InT
*
__restrict__
in0
,
const
InT
*
__restrict__
in1
,
OutT
*
out
,
const
InT
*
__restrict__
in0
,
const
InT
*
__restrict__
in1
,
OutT
*
out
,
phi
::
Array
<
bool
,
MAX_INPUT_NUM
>
use_broadcast
,
uint32_t
numel
,
phi
::
Array
<
bool
,
MAX_INPUT_NUM
>
use_broadcast
,
uint32_t
numel
,
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
ShapeSize
>
,
MAX_INPUT_NUM
>
phi
::
Array
<
kps
::
details
::
BroadcastConfig
,
MAX_INPUT_NUM
>
configlists
,
configlists
,
int
main_tid
,
int
tail_tid
,
Functor
func
)
{
int
main_tid
,
int
tail_tid
,
Functor
func
)
{
int
fix
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
;
int
fix
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
;
int
num
=
tail_tid
;
int
num
=
tail_tid
;
...
@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(
...
@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(
// load in0
// load in0
if
(
use_broadcast
[
0
])
{
if
(
use_broadcast
[
0
])
{
kernel_primitives
::
ReadDataBc
<
InT
,
VecSize
,
DATA_PER_THREAD
,
1
,
ShapeSize
>
(
kernel_primitives
::
ReadDataBc
<
InT
,
VecSize
,
DATA_PER_THREAD
,
1
>
(
arg0
,
in0
,
fix
,
configlists
[
0
],
numel
);
arg0
,
in0
,
fix
,
configlists
[
0
],
numel
);
}
else
{
}
else
{
kernel_primitives
::
ReadData
<
InT
,
VecSize
,
1
,
1
>
(
arg0
,
in0
+
fix
,
num
);
kernel_primitives
::
ReadData
<
InT
,
VecSize
,
1
,
1
>
(
arg0
,
in0
+
fix
,
num
);
}
}
// load in1
// load in1
if
(
use_broadcast
[
1
])
{
if
(
use_broadcast
[
1
])
{
kernel_primitives
::
ReadDataBc
<
InT
,
VecSize
,
DATA_PER_THREAD
,
1
,
ShapeSize
>
(
kernel_primitives
::
ReadDataBc
<
InT
,
VecSize
,
DATA_PER_THREAD
,
1
>
(
arg1
,
in1
,
fix
,
configlists
[
1
],
numel
);
arg1
,
in1
,
fix
,
configlists
[
1
],
numel
);
}
else
{
}
else
{
kernel_primitives
::
ReadData
<
InT
,
VecSize
,
1
,
1
>
(
arg1
,
in1
+
fix
,
num
);
kernel_primitives
::
ReadData
<
InT
,
VecSize
,
1
,
1
>
(
arg1
,
in1
+
fix
,
num
);
...
@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
...
@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int
main_tid
=
numel
/
(
data_per_thread
*
vec_size
*
threads
);
int
main_tid
=
numel
/
(
data_per_thread
*
vec_size
*
threads
);
int
tail_tid
=
numel
%
(
data_per_thread
*
vec_size
*
threads
);
int
tail_tid
=
numel
%
(
data_per_thread
*
vec_size
*
threads
);
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
2
>
,
MAX_INPUT_NUM
>
configlists
;
phi
::
Array
<
kps
::
details
::
BroadcastConfig
,
MAX_INPUT_NUM
>
configlists
;
phi
::
Array
<
bool
,
MAX_INPUT_NUM
>
use_broadcast
;
phi
::
Array
<
bool
,
MAX_INPUT_NUM
>
use_broadcast
;
use_broadcast
[
0
]
=
false
;
use_broadcast
[
0
]
=
false
;
...
@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
...
@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig.
// Here, dims are transposed due to the logic in BroadcastConfig.
std
::
vector
<
int64_t
>
input1_dims
=
{
n
,
1
};
std
::
vector
<
int64_t
>
input1_dims
=
{
n
,
1
};
std
::
vector
<
int64_t
>
out_dims
=
{
n
,
m
};
std
::
vector
<
int64_t
>
out_dims
=
{
n
,
m
};
configlists
[
1
]
=
kps
::
details
::
BroadcastConfig
<
2
>
(
out_dims
,
input1_dims
,
2
);
configlists
[
1
]
=
kps
::
details
::
BroadcastConfig
(
out_dims
,
input1_dims
,
2
);
auto
func
=
AddFunctor
<
T
>
();
auto
func
=
AddFunctor
<
T
>
();
auto
stream
=
ctx
.
stream
();
auto
stream
=
ctx
.
stream
();
...
...
paddle/phi/kernels/funcs/broadcast_function.h
浏览文件 @
8501fb00
此差异已折叠。
点击以展开。
paddle/phi/kernels/primitive/datamover_primitives.h
浏览文件 @
8501fb00
...
@@ -82,10 +82,10 @@ struct FastDivMod {
...
@@ -82,10 +82,10 @@ struct FastDivMod {
* index of the output data. if input or output shape is [dim0, dim1] then dims
* index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0].
* must be [dim1, dim0].
*/
*/
template
<
int
kDims
>
struct
BroadcastConfig
{
struct
BroadcastConfig
{
FastDivMod
divmoders
[
kDims
];
FastDivMod
divmoders
[
phi
::
DDim
::
kMaxRank
];
uint32_t
strides
[
phi
::
DDim
::
kMaxRank
];
uint32_t
strides
[
phi
::
DDim
::
kMaxRank
];
int
kDims
;
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
...
@@ -109,7 +109,7 @@ struct BroadcastConfig {
...
@@ -109,7 +109,7 @@ struct BroadcastConfig {
std
::
multiplies
<
int64_t
>
())
std
::
multiplies
<
int64_t
>
())
:
strides_in
[
i
];
:
strides_in
[
i
];
}
}
kDims
=
dim_size
;
memcpy
(
strides
,
strides_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
strides
,
strides_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
divmoders
,
divmoders_in
.
data
(),
kDims
*
sizeof
(
FastDivMod
));
memcpy
(
divmoders
,
divmoders_in
.
data
(),
kDims
*
sizeof
(
FastDivMod
));
}
}
...
@@ -436,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
...
@@ -436,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
*/
template
<
typename
T
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
T
*
dst
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
const
details
::
BroadcastConfig
&
config
,
int
total_num_output
,
int
total_num_output
,
int
stride_nx
,
int
stride_nx
,
int
stride_ny
)
{
int
stride_ny
)
{
...
@@ -465,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc(
...
@@ -465,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc(
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
for
(
int
i
=
0
;
i
<
phi
::
DDim
::
kMaxRank
;
++
i
)
{
if
(
i
>=
config
.
kDims
)
break
;
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
...
@@ -785,53 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
...
@@ -785,53 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data.
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
* total_num_output: Total number of original output.
*/
*/
template
<
typename
T
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
]
=
src
[
index_src
];
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
T
*
dst
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
const
details
::
BroadcastConfig
&
config
,
int
total_num_output
,
int
total_num_output
,
int
read_lens
)
{
int
read_lens
=
NX
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
uint32_t
index_src
=
0
;
...
@@ -845,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc(
...
@@ -845,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc(
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
for
(
int
i
=
0
;
i
<
phi
::
DDim
::
kMaxRank
;
++
i
)
{
if
(
i
>=
config
.
kDims
)
break
;
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
...
@@ -853,6 +811,7 @@ __device__ __forceinline__ void ReadDataBc(
...
@@ -853,6 +811,7 @@ __device__ __forceinline__ void ReadDataBc(
dst
[
nx
]
=
src
[
index_src
];
dst
[
nx
]
=
src
[
index_src
];
}
}
}
}
/**
/**
* @brief Initialize register with data index.
* @brief Initialize register with data index.
*
*
...
...
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
浏览文件 @
8501fb00
...
@@ -65,7 +65,6 @@ struct alignas(sizeof(T) * VecSize) VectorType {
...
@@ -65,7 +65,6 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* must be [dim1, dim0].
* must be [dim1, dim0].
*/
*/
#pragma pack(4)
#pragma pack(4)
template
<
int
kDims
>
struct
BroadcastConfig
{
struct
BroadcastConfig
{
int
strides_in
[
phi
::
DDim
::
kMaxRank
];
int
strides_in
[
phi
::
DDim
::
kMaxRank
];
int
strides_out
[
phi
::
DDim
::
kMaxRank
];
int
strides_out
[
phi
::
DDim
::
kMaxRank
];
...
@@ -78,7 +77,7 @@ struct BroadcastConfig {
...
@@ -78,7 +77,7 @@ struct BroadcastConfig {
int
n
=
1
;
int
n
=
1
;
int
k
=
1
;
int
k
=
1
;
int
buf_len
=
0
;
int
buf_len
=
0
;
int
kDims
;
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
...
@@ -99,7 +98,7 @@ struct BroadcastConfig {
...
@@ -99,7 +98,7 @@ struct BroadcastConfig {
for
(
int
i
=
0
;
i
<
dim_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
dim_size
;
i
++
)
{
dim_tmp
[
i
]
=
in_dims
[
i
];
dim_tmp
[
i
]
=
in_dims
[
i
];
}
}
kDims
=
dim_size
;
memcpy
(
strides_in
,
strides_in_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
strides_in
,
strides_in_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
strides_out
,
strides_out_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
strides_out
,
strides_out_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
in_dim
,
dim_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
in_dim
,
dim_tmp
.
data
(),
kDims
*
sizeof
(
int
));
...
@@ -551,7 +550,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
...
@@ -551,7 +550,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* NY: The number of data rows loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* NX x NY x core_num(), boundary judgment is required to avoid memory access
...
@@ -567,16 +565,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
...
@@ -567,16 +565,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
*/
template
<
typename
T
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
const
details
::
BroadcastConfig
&
config
,
int
total_num_output
,
int
total_num_output
,
int
stride_nx
,
int
stride_nx
,
int
stride_ny
)
{
int
stride_ny
)
{
...
@@ -882,60 +875,6 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) {
...
@@ -882,60 +875,6 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) {
}
}
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
total_num_output
)
{
int
thread_offset
=
block_offset
+
core_id
()
*
NX
;
int
index_src
=
0
;
__local__
T
in_temp
;
#pragma unroll
for
(
int
nx
=
0
;
nx
<
NX
;
++
nx
)
{
int
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
index_src
=
config
(
index_output
);
GM2LM
(
src
+
index_src
,
&
in_temp
,
sizeof
(
T
));
dst
[
nx
]
=
in_temp
;
}
}
/**
/**
* @brief Read data from global memory to local memory with broadcast
* @brief Read data from global memory to local memory with broadcast
* {m, 1, k}-> {m, n, k} form.
* {m, 1, k}-> {m, n, k} form.
...
@@ -952,12 +891,12 @@ __device__ __inline__ void ReadDataBc(
...
@@ -952,12 +891,12 @@ __device__ __inline__ void ReadDataBc(
* coordinate mapping relationship between output data and input data.
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
*/
*/
template
<
typename
T
,
int
Rank
>
template
<
typename
T
>
__device__
__inline__
void
ReadDataBcM1kMnk
(
__device__
__inline__
void
ReadDataBcM1kMnk
(
T
*
dst
,
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>
&
config
,
const
details
::
BroadcastConfig
&
config
,
int
read_lens
)
{
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
index_base
=
config
(
index_output
);
...
@@ -999,12 +938,12 @@ __device__ __inline__ void ReadDataBcM1kMnk(
...
@@ -999,12 +938,12 @@ __device__ __inline__ void ReadDataBcM1kMnk(
* coordinate mapping relationship between output data and input data.
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
*/
*/
template
<
typename
T
,
int
Rank
>
template
<
typename
T
>
__device__
__inline__
void
ReadDataBcM1Mn
(
__device__
__inline__
void
ReadDataBcM1Mn
(
T
*
dst
,
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>
&
config
,
const
details
::
BroadcastConfig
&
config
,
int
read_lens
)
{
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
index_base
=
config
(
index_output
);
...
@@ -1027,7 +966,6 @@ __device__ __inline__ void ReadDataBcM1Mn(
...
@@ -1027,7 +966,6 @@ __device__ __inline__ void ReadDataBcM1Mn(
*
*
* @template paraments
* @template paraments
* T: Data type of register.
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
*
* @param:
* @param:
* dst: The register pointer of the thread, the size is NX.
* dst: The register pointer of the thread, the size is NX.
...
@@ -1037,12 +975,12 @@ __device__ __inline__ void ReadDataBcM1Mn(
...
@@ -1037,12 +975,12 @@ __device__ __inline__ void ReadDataBcM1Mn(
* coordinate mapping relationship between output data and input data.
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
*/
*/
template
<
typename
T
,
int
Rank
>
template
<
typename
T
>
__device__
__inline__
void
ReadDataBc1NMn
(
__device__
__inline__
void
ReadDataBc1NMn
(
T
*
dst
,
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>
&
config
,
const
details
::
BroadcastConfig
&
config
,
int
read_lens
)
{
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
index_base
=
config
(
index_output
);
...
@@ -1075,7 +1013,6 @@ __device__ __inline__ void ReadDataBc1NMn(
...
@@ -1075,7 +1013,6 @@ __device__ __inline__ void ReadDataBc1NMn(
*
*
* @template paraments
* @template paraments
* T: Data type of register.
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
*
* @param:
* @param:
* dst: The register pointer of the thread, the size is NX.
* dst: The register pointer of the thread, the size is NX.
...
@@ -1085,12 +1022,12 @@ __device__ __inline__ void ReadDataBc1NMn(
...
@@ -1085,12 +1022,12 @@ __device__ __inline__ void ReadDataBc1NMn(
* coordinate mapping relationship between output data and input data.
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
*/
*/
template
<
typename
T
,
int
Rank
>
template
<
typename
T
>
__device__
__inline__
void
ReadDataBc1N1Mnk
(
__device__
__inline__
void
ReadDataBc1N1Mnk
(
T
*
dst
,
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>
&
config
,
const
details
::
BroadcastConfig
&
config
,
int
read_lens
)
{
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
index_base
=
config
(
index_output
);
...
@@ -1130,7 +1067,6 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
...
@@ -1130,7 +1067,6 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
*
*
* @template paraments
* @template paraments
* T: Data type of register.
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
*
* @param:
* @param:
* dst: The register pointer of the thread, the size is NX.
* dst: The register pointer of the thread, the size is NX.
...
@@ -1140,13 +1076,12 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
...
@@ -1140,13 +1076,12 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
* coordinate mapping relationship between output data and input data.
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
*/
*/
template
<
typename
T
,
int
Rank
>
template
<
typename
T
>
__device__
__inline__
void
ReadDataBc1N
(
__device__
__inline__
void
ReadDataBc1N
(
T
*
dst
,
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
int
thread_offset
,
const
details
::
BroadcastConfig
&
config
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
index_base
=
config
(
index_output
);
T
in_temp
;
T
in_temp
;
...
@@ -1174,12 +1109,12 @@ __device__ __inline__ void ReadDataBc1N(
...
@@ -1174,12 +1109,12 @@ __device__ __inline__ void ReadDataBc1N(
* total_num_output: Total number of original output.
* total_num_output: Total number of original output.
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
*/
*/
template
<
typename
T
,
int
Rank
,
bool
IsBoundary
=
false
>
template
<
typename
T
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBcCanNotCmp
(
__device__
__inline__
void
ReadDataBcCanNotCmp
(
T
*
dst
,
T
*
dst
,
const
T
_global_ptr_
*
src
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>
&
config
,
const
details
::
BroadcastConfig
&
config
,
int
total_num_output
,
int
total_num_output
,
int
read_lens
)
{
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_output
=
thread_offset
;
...
@@ -1215,7 +1150,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
...
@@ -1215,7 +1150,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* NX x NY x core_num(), boundary judgment is required to avoid memory access
...
@@ -1230,33 +1164,27 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
...
@@ -1230,33 +1164,27 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
* read_lens: The number of data continuously loaded by each thread.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
* total_num_output: Total number of original output.
*/
*/
template
<
typename
T
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
int
NX
,
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
int
NY
,
const
T
_global_ptr_
*
src
,
int
BlockSize
,
uint32_t
block_offset
,
int
Rank
,
const
details
::
BroadcastConfig
&
config
,
bool
IsBoundary
=
false
>
int
total_num_output
,
__device__
__inline__
void
ReadDataBc
(
int
read_lens
)
{
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
total_num_output
,
int
read_lens
)
{
int
thread_offset
=
block_offset
+
core_id
()
*
read_lens
;
int
thread_offset
=
block_offset
+
core_id
()
*
read_lens
;
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_M1K
)
{
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_M1K
)
{
ReadDataBcM1kMnk
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
ReadDataBcM1kMnk
<
T
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
N_1
)
{
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
N_1
)
{
ReadDataBc1N
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
ReadDataBc1N
<
T
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_M
)
{
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_M
)
{
ReadDataBcM1Mn
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
ReadDataBcM1Mn
<
T
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_N
)
{
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_N
)
{
ReadDataBc1NMn
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
ReadDataBc1NMn
<
T
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_1N1
)
{
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_1N1
)
{
ReadDataBc1N1Mnk
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
ReadDataBc1N1Mnk
<
T
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
{
}
else
{
ReadDataBcCanNotCmp
<
T
,
Rank
,
IsBoundary
>
(
ReadDataBcCanNotCmp
<
T
,
IsBoundary
>
(
dst
,
src
,
thread_offset
,
config
,
total_num_output
,
read_lens
);
dst
,
src
,
thread_offset
,
config
,
total_num_output
,
read_lens
);
}
}
}
}
...
...
paddle/phi/kernels/primitive/kernel_primitives.h
浏览文件 @
8501fb00
...
@@ -40,7 +40,9 @@
...
@@ -40,7 +40,9 @@
#define GRID_NUM_X cluster_num()
#define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0
#define GRID_NUM_Y 0
#define GRID_NUM_Z 0
#define GRID_NUM_Z 0
#define VecSizeL 512
#define VecSizeM 256
#define VecSizeS 128
#else
#else
#define KPStream gpuStream_t
#define KPStream gpuStream_t
...
@@ -64,6 +66,9 @@
...
@@ -64,6 +66,9 @@
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z
#define GRID_NUM_Z gridDim.z
#define VecSizeL 4
#define VecSizeM 2
#define VecSizeS 1
#endif
#endif
// include file
// include file
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录