Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8f9d573f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8f9d573f
编写于
8月 10, 2021
作者:
N
niuliling123
提交者:
GitHub
8月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Kernel primitives api (#34672)
添加Kernel primitives api: ReadData, WriteData ComputeFunctor
上级
8b9bd165
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
330 addition
and
2 deletion
+330
-2
paddle/fluid/operators/kernel_primitives/compute_primitives.h
...le/fluid/operators/kernel_primitives/compute_primitives.h
+133
-1
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
.../fluid/operators/kernel_primitives/datamover_primitives.h
+197
-1
未找到文件。
paddle/fluid/operators/kernel_primitives/compute_primitives.h
浏览文件 @
8f9d573f
...
@@ -14,8 +14,140 @@
...
@@ -14,8 +14,140 @@
#pragma once
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
#include <algorithm>
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
kernel_primitives
{}
namespace
kernel_primitives
{
namespace
details
{
template
<
typename
T
>
class
MPTypeTrait
{
public:
using
Type
=
T
;
};
template
<
>
class
MPTypeTrait
<
platform
::
float16
>
{
public:
using
Type
=
float
;
};
}
// namespace details
/*************************** Compute Functor****************************/
template
<
typename
T
,
typename
Enable
=
void
>
struct
DivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
/
args
[
1
];
}
};
template
<
typename
T
>
struct
DivFunctor
<
T
,
typename
std
::
enable_if_t
<
std
::
is_integral
<
T
>::
value
>>
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
PADDLE_ENFORCE
(
args
[
1
]
!=
0
,
platform
::
errors
::
InvalidArgument
(
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value."
));
return
args
[
0
]
/
args
[
1
];
}
};
/*************************** Compute Function****************************/
/**
* @brief compute functor for elementwise_two, in1 and in2 has the same shape
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in1 and in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
OpFunc
compute
)
{
T
args
[
2
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
args
[
0
]
=
in1
[
idx
];
args
[
1
]
=
in2
[
idx
];
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
args
));
}
}
/**
* @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape
* @param:
* T : the type of in1 and in2, in3
* NX: the row of in1, in2 and in3
* NY: the col of in1, in2 and in3
* BlockSize: the strid of col
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseFma
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
const
T
*
in3
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
],
in3
[
idx
]));
}
}
}
/**
* @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY]
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
CycleBinary
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
idx
++
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
idy
++
)
{
out
[
idx
+
idy
*
NX
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
+
idy
*
NX
]));
}
}
}
}
/**
* @brief compute functor for unary, in1 is [NX, NY]
* @param:
* T : the type of in
* NX: the row of in
* NY: the col of in
* BlockSize: the strid of col
* OpFunc: compute functor eg: relu, sigmoid, exp
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseUnary
(
OutT
*
out
,
const
T
*
in
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
idx
++
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in
+
idx
));
}
}
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
浏览文件 @
8f9d573f
...
@@ -13,9 +13,205 @@
...
@@ -13,9 +13,205 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <iostream>
#include <vector>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
kernel_primitives
{}
namespace
kernel_primitives
{
namespace
details
{
#define INT_BITS 32
template
<
typename
T
,
int
VecSize
>
struct
alignas
(
sizeof
(
T
)
*
VecSize
)
VectorType
{
T
val
[
VecSize
];
};
struct
FastDivMod
{
// 1st value represents the result of input number divides by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor
using
DivModT
=
VectorType
<
uint32_t
,
2
>
;
FastDivMod
()
{}
HOSTDEVICE
FastDivMod
(
uint32_t
d
)
:
divisor
(
d
)
{
static_assert
(
sizeof
(
unsigned
int
)
==
4
,
"Only Support 32-bit unsigned int."
);
for
(
shift_val
=
0
;
shift_val
<
INT_BITS
;
++
shift_val
)
{
auto
shift_limit
=
1
<<
shift_val
;
if
(
shift_limit
>=
divisor
)
break
;
}
uint64_t
long_one
=
1
;
uint64_t
temp_div
=
((
long_one
<<
INT_BITS
)
*
((
long_one
<<
shift_val
)
-
divisor
))
/
divisor
+
1
;
multiplier
=
temp_div
;
}
__device__
__forceinline__
uint32_t
Div
(
uint32_t
n
)
const
{
uint32_t
t
=
__umulhi
(
n
,
multiplier
);
return
(
t
+
n
)
>>
shift_val
;
}
__device__
__forceinline__
DivModT
Divmod
(
uint32_t
n
)
const
{
uint32_t
q
=
Div
(
n
);
DivModT
result
=
{
q
,
n
-
q
*
divisor
};
return
result
;
}
int32_t
divisor
;
int32_t
shift_val
;
uint32_t
multiplier
;
};
template
<
int
kDims
>
struct
BroadcastConfig
{
FastDivMod
divmoders
[
kDims
];
uint32_t
strides
[
framework
::
DDim
::
kMaxRank
];
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
const
std
::
vector
<
int64_t
>&
in_dims
,
int
dim_size
)
{
std
::
vector
<
uint32_t
>
strides_in
;
std
::
vector
<
FastDivMod
>
divmoders_in
;
// for divmoders
divmoders_in
.
resize
(
dim_size
);
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
divmoders_in
[
i
]
=
FastDivMod
(
out_dims
[
i
]);
}
// for strides
strides_in
.
resize
(
dim_size
,
1
);
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
strides_in
[
i
]
=
in_dims
[
i
]
==
1
?
0
:
strides_in
[
i
];
strides_in
[
i
]
=
(
i
!=
0
&&
strides_in
[
i
]
!=
0
)
?
std
::
accumulate
(
in_dims
.
begin
(),
in_dims
.
begin
()
+
i
,
1
,
std
::
multiplies
<
int64_t
>
())
:
strides_in
[
i
];
}
memcpy
(
strides
,
strides_in
.
data
(),
kDims
*
sizeof
(
uint32_t
));
memcpy
(
divmoders
,
divmoders_in
.
data
(),
kDims
*
sizeof
(
FastDivMod
));
}
};
#undef INT_BITS
}
// namespace details
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadDataBase
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
size
)
{
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
>=
size
)
{
break
;
}
dst
[
idx
]
=
src
[
idx
+
dx
];
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
size
)
{
const
int
VECTOR_SIZE
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
VECTORS_PER_THREAD
=
NX
/
VECTOR_SIZE
;
// Vector per thread
if
(
blockDim
.
x
*
NX
>
size
)
{
ReadDataBase
<
T
,
NX
,
NY
,
BlockSize
>
(
dst
,
src
,
size
);
}
else
{
// Vector type
using
VecType
=
details
::
VectorType
<
T
,
VECTOR_SIZE
>
;
VecType
vec_temp
[
VECTORS_PER_THREAD
];
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
ReadDataBase
<
VecType
,
VECTORS_PER_THREAD
,
NY
,
BlockSize
>
(
vec_temp
,
vec_input
,
VECTORS_PER_THREAD
*
blockDim
.
x
);
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
}
}
}
/** @brief: ReadDataBc
* read data from src ptr when the shape of src and dst are different
* @param:
* src: the source pointer
* dst: the dst pointer
* stride_nx: the stride of src
* stride_ny: the stride of src
* the shape of dst is [NY, NX]
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
fix
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
num
,
int
stride_nx
,
int
stride_ny
)
{
uint32_t
base_offset
=
fix
+
threadIdx
.
x
*
NX
;
uint32_t
offset
=
0
;
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
idx
=
base_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
if
(
idx
<
num
)
{
offset
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ShapeSize
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
idx
);
idx
=
fast_divmoder
.
val
[
0
];
offset
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
WriteDataBase
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
size
)
{
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
>=
size
)
{
break
;
}
dst
[
idx
+
dx
]
=
src
[
idx
];
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
size
)
{
const
int
VECTOR_SIZE
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
VECTORS_PER_THREAD
=
NX
/
VECTOR_SIZE
;
// Vector per thread
if
(
blockDim
.
x
*
NX
>
size
)
{
WriteDataBase
<
T
,
NX
,
NY
,
BlockSize
>
(
dst
,
src
,
size
);
}
else
{
// Vector type
using
VecType
=
details
::
VectorType
<
T
,
VECTOR_SIZE
>
;
VecType
vec_temp
[
VECTORS_PER_THREAD
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
VECTORS_PER_THREAD
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
}
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
WriteDataBase
<
VecType
,
VECTORS_PER_THREAD
,
NY
,
BlockSize
>
(
vec_dst
,
vec_temp
,
VECTORS_PER_THREAD
*
blockDim
.
x
);
}
}
}
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录