Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
89a8989f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2292
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
89a8989f
编写于
10月 29, 2021
作者:
N
niuliling123
提交者:
GitHub
10月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add io api and compute api for XPU (#36423)
上级
92d6a048
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
891 addition
and
0 deletion
+891
-0
paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h
...uid/operators/kernel_primitives/compute_primitives_xpu2.h
+324
-0
paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h
...d/operators/kernel_primitives/datamover_primitives_xpu2.h
+567
-0
未找到文件。
paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h
0 → 100644
浏览文件 @
89a8989f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
namespace
paddle
{
namespace
operators
{
namespace
kernel_primitives
{
namespace
details
{
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum
ReduceMode
{
kGlobalMode
,
kLocalMode
};
template
<
typename
T
>
class
MPTypeTrait
{
public:
using
Type
=
T
;
};
template
<
>
class
MPTypeTrait
<
platform
::
float16
>
{
public:
using
Type
=
float
;
};
static
inline
__device__
void
sync_all
()
{
__asm__
__volatile__
(
"sync_local
\t\n
"
"csr_set csr3, %0
\t\n
"
"sync_group csr3"
::
"r"
(
-
1
));
}
#define ncores 64
template
<
typename
T
,
typename
OpFunc
,
int
VecSize
>
__device__
void
BlockXReduce
(
T
*
data
,
OpFunc
reducer
)
{
__shared__
T
sum_array
[
ncores
*
VecSize
];
int
core_idx
=
core_id
()
*
VecSize
;
mfence
();
sync_all
();
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
mfence
();
sum_array
[
core_idx
+
i
]
=
data
[
i
];
mfence
();
data
[
i
]
=
0
;
}
sync_all
();
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ncores
;
j
++
)
{
mfence
();
T
tmp
=
sum_array
[
j
*
VecSize
+
i
];
mfence
();
data
[
i
]
=
reducer
(
data
[
i
],
tmp
);
mfence
();
}
}
sync_all
();
}
#undef ncores
}
// namespace details
/**
* @brief Perform unary calculation according to OpFunc. Shape of input and
* output are the same.
*
* @template paraments
* InT: The data type of in.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseUnary
(
OutT
*
out
,
const
InT
*
in
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
idx
++
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in
[
idx
]));
}
}
/**
* @brief Binary calculation according to OpFunc. Shape of The input and output
* are the same.
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns computed by each thread.
* NY: The number of data rows computed by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
* const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* in3: The register pointer of third input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseTernary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
const
InT
*
in3
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
],
in3
[
idx
]));
}
}
/**
* @brief Multivariate calculation according to OpFunc. Shape of inputs and
* output are the same.
*
* @template paraments
* InT: The data type of in1, in2 and in3.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Arity: The size of ins
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(const InT* args) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* ins: A pointers of array consisting of multiple inputs.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Arity
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseAny
(
OutT
*
out
,
InT
(
*
ins
)[
NX
*
NY
],
OpFunc
compute
)
{
__local__
InT
args
[
Arity
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
Arity
;
++
j
)
{
args
[
j
]
=
ins
[
j
][
idx
];
}
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
args
));
}
}
/**
* @brief Binary calculation according to OpFunc. The shape of in1 and in2 are
* different. When in1's shape is [1, NX], in2's shape is [NY, NX], then
* output's shape is [NY, NX].
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * 1.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
CycleBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
idx
++
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
idy
++
)
{
out
[
idx
+
idy
*
NX
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
+
idy
*
NX
]));
}
}
}
/**
* @brief The Reduce provides collective methods for computing a parallel
* reduction of items partitioned across a CUDA block and intra thread. When
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
* kGlobalMode, use shared memory to reduce between threads.
*
* @template paraments
* T: The type of data.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* ReduceFunctor: Compute functor which has an operator() as following
* template <typename InT>
* struct ReduceFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
* ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* reducer: Compute function which was declared like ReduceFunctor<InT>().
* reduce_last_dim: if the last dim gets involved in reduction.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
class
ReduceFunctor
,
details
::
ReduceMode
Mode
>
__device__
__forceinline__
void
Reduce
(
T
*
out
,
const
T
*
in
,
ReduceFunctor
reducer
,
bool
reduce_last_dim
)
{
if
(
Mode
==
kGlobalMode
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NX
;
++
j
)
{
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
}
}
BlockXReduce
<
T
,
OpFunc
,
NY
>
(
out
,
reducer
);
}
else
{
// else kLocalMode
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NX
;
++
j
)
{
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
}
}
}
}
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h
0 → 100644
浏览文件 @
89a8989f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
namespace
paddle
{
namespace
operators
{
namespace
kernel_primitives
{
namespace
details
{
template
<
typename
T
,
int
VecSize
>
struct
alignas
(
sizeof
(
T
)
*
VecSize
)
VectorType
{
T
val
[
VecSize
];
};
/**
* Configuration of broadcast. Calculate the input data index according to the
* index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0].
*/
template
<
int
kDims
>
struct
BroadcastConfig
{
uint32_t
stride_in
[
framework
::
DDim
::
kMaxRank
];
uint32_t
stride_out
[
framework
::
DDim
::
kMaxRank
];
uint32_t
shape_in
[
framework
::
DDim
::
kMaxRank
];
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
const
std
::
vector
<
int64_t
>&
in_dims
,
int
dim_size
)
{
std
::
vector
<
uint32_t
>
strides_in
;
std
::
vector
<
uint32_t
>
strides_out
;
std
::
vector
<
uint32_t
>
shapes_in
;
strides_out
.
resize
(
dim_size
,
1
);
strides_in
.
resize
(
dim_size
,
1
);
shapes_in
.
resize
(
dim_size
,
1
);
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
shape_in
[
i
]
=
in_dims
[
dim_size
-
i
-
1
];
}
for
(
int
i
=
1
;
i
<
dim_size
-
1
;
++
i
)
{
strides_out
[
dim_size
-
i
-
1
]
=
std
::
accumulate
(
out_dims
.
begin
(),
out_dims
.
begin
()
+
i
,
1
,
std
::
multiplies
<
int64_t
>
())
strides_in
[
dim_size
-
i
-
1
]
=
std
::
accumulate
(
in_dims
.
begin
(),
in_dims
.
begin
()
+
i
,
1
,
std
::
multiplies
<
int64_t
>
())
}
memcpy
(
stride_in
,
strides_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
stride_out
,
strides_out
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
shape_in
,
shapes_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
}
};
}
// namespace details
/**
* @brief Read 2D data from global memory to register according to Tx type, and
* store it as Ty type into register.
*
* @template paraments
* Tx: The type of data stored in the global memory.
* Ty: The type of data that needs to be stored in registers.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size_nx: The maximum offset of the current block is size_nx elements in the
* lowest dimension. The parameters are only calculated when isboundary = true.
* size_ny: The maximum offset of the current block is size_ny elements in the
* first dimension. The parameters are only calculated when isboundary = true.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
_global_ptr_
*
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
core_id
();
int
left_size_nx
=
size_nx
-
thread_offset
;
__local__
T
in_temp
[
1
];
// Each branch is added for better performance
if
(
NX
==
1
&&
NY
==
1
)
{
// for NX == 1 and NY == 1
if
(
IsBoundary
)
{
if
(
left_size_nx
>
0
)
{
GM2LM
(
src
+
thread_offset
,
in_temp
,
sizeof
(
Tx
));
dst
[
0
]
=
static_cast
<
Ty
>
(
in_temp
[
0
]);
}
}
else
{
GM2LM
(
src
+
thread_offset
,
in_temp
,
sizeof
(
Tx
));
dst
[
0
]
=
static_cast
<
Ty
>
(
in_temp
[
0
]);
}
}
else
if
(
NX
==
1
)
{
// for NX == 1 and NY != 1
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
if
(
IsBoundary
)
{
if
(
idy
*
stride_ny
>=
size_ny
)
{
break
;
}
}
GM2LM
(
src
+
thread_offset
+
idy
*
stride_ny
,
in_temp
,
sizeof
(
Tx
));
dst
[
idy
]
=
static_cast
<
Ty
>
(
in_temp
[
0
]);
}
}
else
if
(
NY
==
1
)
{
// for NY == 1 and NX != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
*
stride_nx
>=
left_size_nx
)
{
break
;
}
}
GM2LM
(
src
+
thread_offset
+
idx
*
stride_nx
,
in_temp
,
sizeof
(
Tx
));
dst
[
idx
]
=
static_cast
<
Ty
>
(
in_temp
[
0
]);
}
}
else
{
// for NX != 1 and NY != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
if
(
IsBoundary
)
{
if
(
idy
*
stride_ny
>=
size_ny
||
idx
*
stride_nx
>=
left_size_nx
)
{
break
;
}
}
int
fix
=
thread_offset
+
idx
*
stride_nx
+
idy
*
stride_ny
;
GM2LM
(
src
+
fix
,
in_temp
,
sizeof
(
Tx
));
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
in_temp
[
0
]);
}
}
}
}
/**
* @brief Initialize register with init_data.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: Initial value.
*/
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
* @brief Read 1D data from global memory to register. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
* improve memory access efficiency.
*
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
num
)
{
int
thread_offset
=
core_id
()
*
NX
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * NX > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
GM2LM
(
src
+
thread_offset
+
idx
,
in_temp
,
sizeof
(
T
));
dst
[
idx
]
=
in_temp
[
0
];
}
}
}
else
{
// core_num() * NX < num
GM2LM
(
src
+
thread_offset
,
dst
,
NX
*
sizeof
(
T
));
}
}
/**
* @brief Read 2D data from global memory to registers with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: Raw input data pointer of kernel.
* block_offset: Data offset of this block, core_num() * cluster_id() * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
stride_nx
,
int
stride_ny
)
{
uint32_t
thread_offset
=
block_offset
+
core_id
();
uint32_t
index_src
=
0
;
__local__
T
in_temp
[
1
];
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
uint32_t
tmp
=
index_output
/
config
.
stride_out
[
i
];
index_output
=
index_output
-
tmp
*
config
.
stride_out
[
i
];
index_src
+=
(
tmp
%
config
.
shape_in
[
i
])
*
config
.
stride_in
[
i
];
}
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
nx
+
ny
*
NX
]
=
in_temp
[
0
];
}
}
}
/**
* @brief Read 2D data from global memory to register with reduce form.
*
* @template paraments
* T: The type of data.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The input data pointer of this block.
* block_offset: The data offset of this block, blockDim.x * cluster_id() * NX.
* index_cal: Calculation configuration of Reduce. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* size_nx: The current block needs to load size_nx columns of data, this
* parameter will participate in the calculation when isboundary = true.
* size_ny: The current block needs to load size_ny rows of data, this parameter
* will participate in the calculation when isboundary = true.
* will be used when IsBoundary = true.
* stride_nx: Each read one element stride stride_nx columns.
* stride_ny: Each read one element stride stride_ny raws.
* reduce_last_dim: Used to indicate whether the dimension of reduce contains
* the lowest dimension.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
__local__
T
in_temp
[
1
];
int
thread_offset
=
0
;
int
left_size_nx
=
size_nx
;
int
left_size_ny
=
size_ny
;
if
(
reduce_last_dim
)
{
thread_offset
=
block_offset
+
core_id
();
left_size_nx
-=
thread_offset
;
}
else
{
thread_offset
=
block_offset
+
core_id
();
left_size_ny
-=
thread_offset
;
}
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
(
ny
*
stride_ny
>=
left_size_ny
)
{
break
;
}
}
uint32_t
index_src
=
index_cal
(
thread_offset
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
ny
]
=
in_temp
[
0
];
thread_offset
+=
stride_ny
;
}
}
else
{
#pragma unroll
for
(
int
nx
=
0
;
nx
<
NX
;
++
nx
)
{
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
((
ny
*
stride_ny
>=
left_size_ny
)
||
(
nx
*
stride_nx
>=
left_size_nx
))
{
break
;
}
}
uint32_t
index_src
=
index_cal
(
thread_offset
);
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
nx
+
ny
*
NX
]
=
in_temp
[
0
];
thread_offset
+=
stride_ny
;
}
thread_offset
+=
stride_nx
;
}
}
}
/**
* @brief Write 1D data from registers to global memory. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
* data loading efficiency
*
* @template paraments
* T: The type of data.
* NX: The number of data continuously writed by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The data pointer of the current block.
* src: The register pointer, the size is NX * NY.
* size: The current block needs to load size elements continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
void
WriteData
(
T
_global_ptr_
*
dst
,
const
T
*
src
,
int
num
)
{
int
thread_offset
=
core_id
()
*
NX
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * NX > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
in_temp
[
0
]
=
src
[
idx
];
LM2GM
(
in_temp
,
dst
+
idx
+
thread_offset
,
sizeof
(
T
));
}
}
}
else
{
// core_num() * NX < num
LM2GM
(
src
,
dst
+
thread_offset
,
NX
*
sizeof
(
T
));
}
}
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
*
* @template paraments
* Tx: The type of data that needs to be stored in registers.
* Ty: The type of data stored in the global memory.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: Data pointer of the current block.
* src: The register pointer of the thread, the size is NX * NY.
* size_nx: The current block needs to load size_nx columns of data, this
* parameter will be used when IsBoundary = true.
* size_ny: The current block needs to load size_ny rows of data. This parameter
* will be used when IsBoundary = true.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
Ty
_global_ptr_
*
dst
,
const
Tx
*
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
core_id
();
int
left_size_nx
=
size_nx
-
thread_offset
;
__local__
Ty
in_temp
[
1
];
// Each branch is added for better performance
if
(
NX
==
1
&&
NY
==
1
)
{
if
(
IsBoundary
)
{
if
(
left_size_nx
>
0
)
{
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
0
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
,
sizeof
(
T
));
}
}
else
{
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
0
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
,
sizeof
(
T
));
}
}
else
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
if
(
IsBoundary
)
{
if
(
idy
*
stride_ny
>=
size_ny
)
{
break
;
}
}
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
idy
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idy
*
stride_ny
,
sizeof
(
T
));
}
}
else
if
(
NY
==
1
)
{
// for NY == 1 and NX != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
*
stride_nx
>=
left_size_nx
)
{
break
;
}
}
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
idx
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idx
*
stride_nx
,
sizeof
(
T
));
}
}
else
{
// for NX != 1 and NY != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
*
stride_nx
>=
left_size_nx
)
{
break
;
}
}
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
if
(
IsBoundary
)
{
if
(
idy
*
stride_ny
>=
size_ny
)
{
break
;
}
}
in_temp
[
0
]
=
static_cast
<
Ty
>
(
src
[
idx
+
idy
*
NX
]);
LM2GM
(
in_temp
,
dst
+
thread_offset
+
idx
*
stride_nx
+
idy
*
stride_ny
,
sizeof
(
T
));
}
}
}
}
/**
* @brief Initialize register with init_data.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: The register pointer of init data, the size is NX.
*/
template
<
typename
T
,
int
NX
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
*
init_data
,
int
num
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
if
(
IsBoundary
)
{
if
(
i
>=
num
)
{
break
;
}
}
dst
[
i
]
=
init_data
[
i
];
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
)
{
uint32_t
thread_offset
=
block_offset
+
core_id
()
*
NX
;
uint32_t
index_src
=
0
;
__local__
T
in_temp
[
1
];
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
uint32_t
tmp
=
index_output
/
config
.
stride_out
[
i
];
index_output
=
index_output
-
tmp
*
config
.
stride_out
[
i
];
index_src
+=
(
tmp
%
config
.
shape_in
[
i
])
*
config
.
stride_in
[
i
];
}
GM2LM
(
src
+
index_src
,
in_temp
,
sizeof
(
T
));
dst
[
nx
+
ny
*
NX
]
=
in_temp
[
0
];
}
}
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录