Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
89a8989f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
89a8989f
编写于
10月 29, 2021
作者:
N
niuliling123
提交者:
GitHub
10月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add io api and compute api for XPU (#36423)
上级
92d6a048
变更
2
展开全部
显示空白变更内容
内联
并排
Showing
2 changed file
with
891 addition
and
0 deletion
+891
-0
paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h
...uid/operators/kernel_primitives/compute_primitives_xpu2.h
+324
-0
paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h
...d/operators/kernel_primitives/datamover_primitives_xpu2.h
+567
-0
未找到文件。
paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h
0 → 100644
浏览文件 @
89a8989f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
namespace
paddle
{
namespace
operators
{
namespace
kernel_primitives
{
namespace
details
{
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum
ReduceMode
{
kGlobalMode
,
kLocalMode
};
template
<
typename
T
>
class
MPTypeTrait
{
public:
using
Type
=
T
;
};
template
<
>
class
MPTypeTrait
<
platform
::
float16
>
{
public:
using
Type
=
float
;
};
static
inline
__device__
void
sync_all
()
{
__asm__
__volatile__
(
"sync_local
\t\n
"
"csr_set csr3, %0
\t\n
"
"sync_group csr3"
::
"r"
(
-
1
));
}
#define ncores 64
template
<
typename
T
,
typename
OpFunc
,
int
VecSize
>
__device__
void
BlockXReduce
(
T
*
data
,
OpFunc
reducer
)
{
__shared__
T
sum_array
[
ncores
*
VecSize
];
int
core_idx
=
core_id
()
*
VecSize
;
mfence
();
sync_all
();
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
mfence
();
sum_array
[
core_idx
+
i
]
=
data
[
i
];
mfence
();
data
[
i
]
=
0
;
}
sync_all
();
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ncores
;
j
++
)
{
mfence
();
T
tmp
=
sum_array
[
j
*
VecSize
+
i
];
mfence
();
data
[
i
]
=
reducer
(
data
[
i
],
tmp
);
mfence
();
}
}
sync_all
();
}
#undef ncores
}
// namespace details
/**
* @brief Perform unary calculation according to OpFunc. Shape of input and
* output are the same.
*
* @template paraments
* InT: The data type of in.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseUnary
(
OutT
*
out
,
const
InT
*
in
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
idx
++
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in
[
idx
]));
}
}
/**
* @brief Binary calculation according to OpFunc. Shape of The input and output
* are the same.
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns computed by each thread.
* NY: The number of data rows computed by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
* const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* in3: The register pointer of third input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseTernary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
const
InT
*
in3
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
],
in3
[
idx
]));
}
}
/**
* @brief Multivariate calculation according to OpFunc. Shape of inputs and
* output are the same.
*
* @template paraments
* InT: The data type of in1, in2 and in3.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Arity: The size of ins
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(const InT* args) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* ins: A pointers of array consisting of multiple inputs.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Arity
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseAny
(
OutT
*
out
,
InT
(
*
ins
)[
NX
*
NY
],
OpFunc
compute
)
{
__local__
InT
args
[
Arity
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
Arity
;
++
j
)
{
args
[
j
]
=
ins
[
j
][
idx
];
}
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
args
));
}
}
/**
* @brief Binary calculation according to OpFunc. The shape of in1 and in2 are
* different. When in1's shape is [1, NX], in2's shape is [NY, NX], then
* output's shape is [NY, NX].
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * 1.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
CycleBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
idx
++
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
idy
++
)
{
out
[
idx
+
idy
*
NX
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
+
idy
*
NX
]));
}
}
}
/**
* @brief The Reduce provides collective methods for computing a parallel
* reduction of items partitioned across a CUDA block and intra thread. When
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
* kGlobalMode, use shared memory to reduce between threads.
*
* @template paraments
* T: The type of data.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* ReduceFunctor: Compute functor which has an operator() as following
* template <typename InT>
* struct ReduceFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
* ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* reducer: Compute function which was declared like ReduceFunctor<InT>().
* reduce_last_dim: if the last dim gets involved in reduction.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
class
ReduceFunctor
,
details
::
ReduceMode
Mode
>
__device__
__forceinline__
void
Reduce
(
T
*
out
,
const
T
*
in
,
ReduceFunctor
reducer
,
bool
reduce_last_dim
)
{
if
(
Mode
==
kGlobalMode
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NX
;
++
j
)
{
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
}
}
BlockXReduce
<
T
,
OpFunc
,
NY
>
(
out
,
reducer
);
}
else
{
// else kLocalMode
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NX
;
++
j
)
{
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
}
}
}
}
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h
0 → 100644
浏览文件 @
89a8989f
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录