Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1354652b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1354652b
编写于
2月 17, 2022
作者:
N
niuliling123
提交者:
GitHub
2月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modified distribution kernel with Kernel Primitive API (#39563)
上级
a909bdf1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
74 addition
and
14 deletion
+74
-14
paddle/fluid/operators/distribution_helper.h
paddle/fluid/operators/distribution_helper.h
+21
-14
paddle/pten/kernels/primitive/compute_primitives.h
paddle/pten/kernels/primitive/compute_primitives.h
+53
-0
未找到文件。
paddle/fluid/operators/distribution_helper.h
浏览文件 @
1354652b
...
@@ -28,6 +28,10 @@ limitations under the License. */
...
@@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/core/hostdevice.h"
#include "paddle/pten/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
#endif
#if !defined(_WIN32)
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
#else
...
@@ -91,6 +95,8 @@ struct normal_transform {
...
@@ -91,6 +95,8 @@ struct normal_transform {
#if defined(__NVCC__) || defined(__HIPCC__)
#if defined(__NVCC__) || defined(__HIPCC__)
namespace
kps
=
pten
::
kps
;
/*********************** Distribution Function *************************/
/*********************** Distribution Function *************************/
template
<
typename
T
>
template
<
typename
T
>
struct
uniform_distribution
;
struct
uniform_distribution
;
...
@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
...
@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
DistOp
dist
,
TransformOp
trans
,
DistOp
dist
,
TransformOp
trans
,
T
*
out_data
)
{
T
*
out_data
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
int32_t
returns_c
ount
=
DistOp
::
kReturnsCount
;
static
constexpr
int
kC
ount
=
DistOp
::
kReturnsCount
;
#if defined(__NVCC__)
#if defined(__NVCC__)
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
offset
,
&
state
);
curand_init
(
seed
,
idx
+
THREAD_ID_X
,
offset
,
&
state
);
using
SType
=
curandStatePhilox4_32_10_t
;
#else
#else
hiprandStatePhilox4_32_10_t
state
;
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
,
offset
,
&
state
);
hiprand_init
(
seed
,
idx
+
THREAD_ID_X
,
offset
,
&
state
);
using
SType
=
hiprandStatePhilox4_32_10_t
;
#endif
#endif
size_t
total_thread
=
gridDim
.
x
*
blockDim
.
x
;
size_t
total_thread
=
GRID_NUM_X
*
BLOCK_NUM_X
;
for
(
size_t
i
=
idx
;
i
<
size
;
i
+=
total_thread
*
returns_count
)
{
T
args
[
kCount
];
auto
random_tuple
=
dist
(
&
state
);
T
result
[
kCount
];
for
(
size_t
j
=
0
;
j
<
returns_count
;
j
++
)
{
for
(
size_t
i
=
idx
;
i
<
size
;
i
+=
total_thread
*
kCount
)
{
size_t
index
=
i
+
j
*
total_thread
;
kps
::
ElementwiseRandom
<
SType
,
T
,
kCount
,
1
,
DistOp
>
(
&
args
[
0
],
dist
,
&
state
);
if
(
index
<
size
)
{
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
1
,
TransformOp
>
(
&
result
[
0
],
&
args
[
0
],
auto
random
=
(
&
random_tuple
.
x
)[
j
];
trans
);
out_data
[
index
]
=
static_cast
<
T
>
(
trans
(
random
));
kps
::
WriteData
<
T
,
T
,
kCount
,
1
,
1
,
true
>
(
out_data
+
i
,
&
result
[
0
],
size
-
i
,
}
1
,
total_thread
,
1
);
}
}
}
}
}
...
...
paddle/pten/kernels/primitive/compute_primitives.h
浏览文件 @
1354652b
...
@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
...
@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
}
}
}
}
template
<
typename
StateType
,
typename
OutT
,
int
ReturnsCount
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseRandom
(
OutT
*
out
,
OpFunc
compute
,
StateType
*
state
)
{
auto
random_tuple
=
compute
(
state
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ReturnsCount
;
i
++
)
{
out
[
i
]
=
static_cast
<
OutT
>
((
&
random_tuple
.
x
)[
i
]);
}
}
// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
Cumsum
(
OutT
*
out
,
const
InT
*
in
,
OpFunc
compute
)
{
__shared__
InT
temp
[
shared_size
*
2
+
(
shared_size
*
2
)
/
32
];
int
tidx
=
threadIdx
.
x
;
temp
[
tidx
+
tidx
/
32
]
=
in
[
0
];
temp
[
shared_size
+
tidx
+
(
shared_size
+
tidx
)
/
32
]
=
in
[
1
];
for
(
int
stride
=
1
;
stride
<=
blockDim
.
x
;
stride
*=
2
)
{
__syncthreads
();
int
index
=
(
tidx
+
1
)
*
2
*
stride
-
1
;
if
(
index
<
(
blockDim
.
x
*
2
))
{
temp
[
index
+
index
/
32
]
+=
temp
[
index
-
stride
+
(
index
-
stride
)
/
32
];
}
}
for
(
int
stride
=
(
blockDim
.
x
*
2
)
/
4
;
stride
>
0
;
stride
/=
2
)
{
__syncthreads
();
int
index
=
(
tidx
+
1
)
*
2
*
stride
-
1
;
if
((
index
+
stride
)
<
(
blockDim
.
x
*
2
))
{
temp
[
index
+
stride
+
(
stride
+
index
)
/
32
]
+=
temp
[
index
+
(
index
)
/
32
];
}
}
__syncthreads
();
out
[
0
]
=
static_cast
<
OutT
>
(
temp
[
tidx
+
tidx
/
32
]);
out
[
1
]
=
static_cast
<
OutT
>
(
temp
[
tidx
+
shared_size
+
(
tidx
+
shared_size
)
/
32
]);
}
}
// namespace kps
}
// namespace kps
}
// namespace pten
}
// namespace pten
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录