Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
82b33be3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82b33be3
编写于
9月 08, 2021
作者:
N
niuliling123
提交者:
GitHub
9月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify the reduce op according to the kernel primitive api (#35282)
上级
7aa4d879
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
293 addition
and
361 deletion
+293
-361
paddle/fluid/operators/fused/attn_bias_add.cu.h
paddle/fluid/operators/fused/attn_bias_add.cu.h
+2
-2
paddle/fluid/operators/kernel_primitives/compute_primitives.h
...le/fluid/operators/kernel_primitives/compute_primitives.h
+5
-3
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
.../fluid/operators/kernel_primitives/datamover_primitives.h
+61
-60
paddle/fluid/operators/kernel_primitives/helper_primitives.h
paddle/fluid/operators/kernel_primitives/helper_primitives.h
+1
-1
paddle/fluid/operators/margin_cross_entropy_op.cu
paddle/fluid/operators/margin_cross_entropy_op.cu
+4
-34
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+9
-7
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+211
-254
未找到文件。
paddle/fluid/operators/fused/attn_bias_add.cu.h
浏览文件 @
82b33be3
...
...
@@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
int
num_block
=
(
max_threads
/
left_num
);
if
(
num_block
>
1
&&
reduce_num
>=
REDUCE_SPLIT_BOUNDARY
)
{
*
blocking_size
=
detail
::
GetLastPow2
(
reduce_num
/
num_block
);
*
blocking_size
=
detail
s
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
*
blocking_size
<=
1
)
{
*
blocking_size
=
detail
::
GetLastPow2
(
sqrt
(
reduce_num
));
*
blocking_size
=
detail
s
::
GetLastPow2
(
sqrt
(
reduce_num
));
}
else
if
(
*
blocking_size
*
2
<
reduce_num
)
{
*
blocking_size
*=
2
;
}
...
...
paddle/fluid/operators/kernel_primitives/compute_primitives.h
浏览文件 @
82b33be3
...
...
@@ -31,13 +31,15 @@ namespace kernel_primitives {
namespace
details
{
#ifdef __HIPCC__
constexpr
int
kMaxThread
=
256
;
constexpr
int
k
Reduce
MaxThread
=
256
;
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kMaxThread
=
128
;
constexpr
int
k
Reduce
MaxThread
=
128
;
constexpr
int
kWarpSize
=
32
;
#endif
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum
ReduceMode
{
kGlobalMode
,
kLocalMode
};
template
<
typename
T
>
...
...
@@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockYReduce
(
T
val
,
ReduceOp
reducer
)
{
__shared__
T
shared_memory
[
details
::
kMaxThread
];
__shared__
T
shared_memory
[
details
::
k
Reduce
MaxThread
];
shared_memory
[
SharedMemoryIndex
(
0
)]
=
val
;
for
(
int
stride
=
blockDim
.
y
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
...
...
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
浏览文件 @
82b33be3
...
...
@@ -124,36 +124,36 @@ struct BroadcastConfig {
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
if
(
NY
==
1
&&
NX
==
1
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread
Idx
.
x
]);
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread
_offset
]);
}
else
if
(
NX
==
1
)
{
int
dx
=
threadIdx
.
x
;
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
dx
+
idy
*
stride_ny
]);
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
]);
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idx
*
stride_nx
]);
}
}
else
{
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
*
NX
+
idx
]
=
s
tatic_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
d
x
+
idy
*
stride_ny
]);
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
s
rc
[
thread_offset
+
idx
*
stride_n
x
+
idy
*
stride_ny
]);
}
}
}
}
/**
* @brief load data from src to dst
, src can be 1D data or 2D data. When
*
boundary judgment is required, you need to set a to true, and a is false by
* default.
* @brief load data from src to dst
with stride, src can be 1D data or 2D data.
*
When boundary judgment is required, you need to set a to true, and a is false
*
by
default.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
...
...
@@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
dx
=
threadIdx
.
x
*
NX
;
int
size
=
size_nx
-
dx
;
int
thread_offset
=
threadIdx
.
x
*
NX
;
int
left_size_nx
=
size_nx
-
thread_offset
;
// Each branch is added for better performance
if
(
NX
==
1
&&
NY
==
1
)
{
// for NX == 1 and NY == 1
if
(
IsBoundary
)
{
if
(
dx
<
size_nx
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
dx
]);
if
(
left_size_nx
>
0
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
]);
}
}
else
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
dx
]);
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
]);
}
}
else
if
(
NX
==
1
)
{
// for NX == 1 and NY != 1
#pragma unroll
...
...
@@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break
;
}
}
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
dx
+
idy
*
stride_ny
]);
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
// for NY == 1 and NX != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
size
)
{
if
(
idx
>=
left_size_nx
)
{
break
;
}
}
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
d
x
]);
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idx
*
stride_n
x
]);
}
}
else
{
// for NX != 1 and NY != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
size
)
{
if
(
idx
>=
left_size_nx
)
{
break
;
}
}
...
...
@@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break
;
}
}
dst
[
idy
*
NX
+
idx
]
=
s
tatic_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
d
x
+
idy
*
stride_ny
]);
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
s
rc
[
thread_offset
+
idx
*
stride_n
x
+
idy
*
stride_ny
]);
}
}
}
...
...
@@ -251,17 +251,17 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
num
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
dx
=
threadIdx
.
x
*
NX
;
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
dx
<
num
)
{
dst
[
idx
]
=
src
[
idx
+
dx
];
if
(
idx
+
thread_offset
<
num
)
{
dst
[
idx
]
=
src
[
thread_offset
+
i
dx
];
}
}
}
else
{
// blockDim,x * NX < num
const
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
t
id
=
threadIdx
.
x
*
kVectorsPerThread
;
int
t
hread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
...
...
@@ -269,7 +269,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
i
+
tid
];
vec_temp
[
i
]
=
vec_input
[
thread_offset
+
i
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
...
...
@@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
* is 2
* IsBoundary: whether to make boundary judgment
* @param:
*
fix
: data offset of this block, blockDim.x * blockIdx.x * NX;
*
block_offset
: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host;
*
num: the num of o
ut
*
total_num_output: total num of outp
ut
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
fix
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
num
,
int
stride_nx
,
int
stride_ny
)
{
uint32_t
base_offset
=
fix
+
threadIdx
.
x
*
NX
;
uint32_t
offset
=
0
;
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
total_num_output
,
int
stride_n
x
,
int
stride_n
y
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
idx
=
base_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
uint32_t
index_output
=
thread_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
i
dx
>=
num
)
{
if
(
i
ndex_output
>=
total_num_output
)
{
break
;
}
}
offset
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ShapeSize
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
i
dx
);
i
dx
=
fast_divmoder
.
val
[
0
];
offset
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
i
ndex_output
);
i
ndex_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
dst
[
nx
+
ny
*
NX
]
=
src
[
index_src
];
}
}
}
...
...
@@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc(
* IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment
* @param:
*
fix
: data offset of this block, blockDim.x * blockIdx.x * NX;
*
block_offset
: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in
* host;
* size_nx: number of columns to be processed by the current block
...
...
@@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc(
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
fix
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
int
base_offset
=
fix
;
T
*
dst
,
const
T
*
__restrict__
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
int
thread_offset
=
0
;
if
(
reduce_last_dim
)
{
base_offset
+=
threadIdx
.
x
;
thread_offset
=
block_offset
+
threadIdx
.
x
;
}
else
{
base_offset
+=
threadIdx
.
y
;
thread_offset
=
block_offset
+
threadIdx
.
y
;
}
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
(
base
_offset
>=
size_ny
)
{
if
(
thread
_offset
>=
size_ny
)
{
break
;
}
}
uint32_t
offset
=
index_cal
(
base
_offset
);
dst
[
ny
]
=
src
[
offset
];
base
_offset
+=
stride_ny
;
uint32_t
index_src
=
index_cal
(
thread
_offset
);
dst
[
ny
]
=
src
[
index_src
];
thread
_offset
+=
stride_ny
;
}
}
else
{
#pragma unroll
...
...
@@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce(
break
;
}
}
uint32_t
offset
=
index_cal
(
base
_offset
);
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
base
_offset
+=
stride_ny
;
uint32_t
index_src
=
index_cal
(
thread
_offset
);
dst
[
nx
+
ny
*
NX
]
=
src
[
index_src
];
thread
_offset
+=
stride_ny
;
}
thread_offset
+=
stride_nx
;
}
}
}
/**
@brief: WriteData
/**
* @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
...
...
@@ -412,11 +413,11 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
num
)
{
if
(
IsBoundary
)
{
int
dx
=
threadIdx
.
x
*
NX
;
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
<
num
)
{
dst
[
idx
+
dx
]
=
src
[
idx
];
if
((
thread_offset
+
i
dx
)
<
num
)
{
dst
[
thread_offset
+
i
dx
]
=
src
[
idx
];
}
}
}
else
{
...
...
@@ -424,14 +425,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
const
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
dx
=
threadIdx
.
x
*
kVectorsPerThread
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
dx
+
idx
]
=
vec_temp
[
idx
];
vec_dst
[
thread_offset
+
idx
]
=
vec_temp
[
idx
];
}
}
}
...
...
paddle/fluid/operators/kernel_primitives/helper_primitives.h
浏览文件 @
82b33be3
...
...
@@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor(
static
__device__
__forceinline__
float
LogFunctor
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
LogFunctor
(
double
x
)
{
return
log
(
x
);
}
}
// namespace details
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template
<
typename
Tx
,
typename
Ty
=
Tx
>
...
...
@@ -75,6 +74,7 @@ struct DivideFunctor {
T
n_inv
;
};
}
// namespace details
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/margin_cross_entropy_op.cu
浏览文件 @
82b33be3
...
...
@@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel(
}
}
static
__device__
__forceinline__
platform
::
float16
exp_on_device
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
exp
(
x
);
}
static
__device__
__forceinline__
float
exp_on_device
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
exp_on_device
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
platform
::
float16
log_on_device
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
log
(
x
);
}
static
__device__
__forceinline__
float
log_on_device
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
log_on_device
(
double
x
)
{
return
log
(
x
);
}
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpLogitTransformer
{
HOSTDEVICE
explicit
inline
ExpLogitTransformer
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
exp_on_device
(
x
));
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpAndSum
{
using
Transformer
=
ExpLogitTransformer
<
Tx
>
;
using
Transformer
=
kpds
::
ExpLogitTransformer
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row,
const
int64_t
N
,
const
int64_t
D
)
{
CUDA_KERNEL_LOOP
(
i
,
N
*
D
)
{
auto
row
=
i
/
D
;
logits
[
i
]
-=
log_on_device
(
logits_sum_per_row
[
row
]);
logits
[
i
]
-=
kpds
::
LogFunctor
(
logits_sum_per_row
[
row
]);
}
}
...
...
@@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel(
if
((
col
+
start_index
)
==
labels
[
row
])
{
auto
softmax
=
log_softmax
[
i
];
loss
[
row
]
=
-
softmax
;
log_softmax
[
i
]
=
exp_on_device
(
softmax
);
log_softmax
[
i
]
=
kpds
::
ExpFunctor
(
softmax
);
}
else
{
log_softmax
[
i
]
=
exp_on_device
(
log_softmax
[
i
]);
log_softmax
[
i
]
=
kpds
::
ExpFunctor
(
log_softmax
[
i
]);
}
}
}
...
...
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
82b33be3
...
...
@@ -24,9 +24,11 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
kpds
=
paddle
::
operators
::
kernel_primitives
::
details
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMin
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
max
());
...
...
@@ -39,7 +41,7 @@ struct CustomMin {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMax
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
lowest
());
...
...
@@ -53,7 +55,7 @@ struct CustomMax {
// for cub::Reduce
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomSum
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
,
Ty
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -64,7 +66,7 @@ struct CustomSum {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
detail
::
DivideFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
DivideFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -75,7 +77,7 @@ struct CustomMean {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMul
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
1.0
f
);
}
...
...
@@ -86,7 +88,7 @@ struct CustomMul {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalOr
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
false
);
}
...
...
@@ -97,7 +99,7 @@ struct CustomLogicalOr {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalAnd
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
true
);
}
...
...
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
82b33be3
...
...
@@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
...
...
@@ -43,28 +44,10 @@ namespace cub = hipcub;
namespace
paddle
{
namespace
operators
{
namespace
detail
{
// Post processing function for sum, max, min, prod, any
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
);
}
};
// Post processing function for mean
template
<
typename
T
>
struct
DivideFunctor
{
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
namespace
details
{
static
inline
int
GetLastPow2
(
int
n
)
{
n
|=
(
n
>>
1
);
...
...
@@ -90,17 +73,11 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
return
strides
;
}
#ifdef __HIPCC__
constexpr
int
kMaxThread
=
256
;
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kMaxThread
=
128
;
constexpr
int
kWarpSize
=
32
;
#endif
// get blockDim for reduceLastDim and reduceAny
static
inline
int
GetBlockDim
(
int
block_dim
)
{
return
block_dim
>=
kMaxThread
?
kMaxThread
:
GetLastPow2
(
block_dim
);
return
block_dim
>=
kps
::
details
::
kReduceMaxThread
?
kps
::
details
::
kReduceMaxThread
:
GetLastPow2
(
block_dim
);
}
// check reduce rand is valid
...
...
@@ -140,7 +117,7 @@ static inline paddle::framework::Array<T, ElementCount> VectorToArray(
return
ret
;
}
}
// namespace detail
}
// namespace detail
s
using
Tensor
=
framework
::
Tensor
;
constexpr
int
kMaxRank
=
framework
::
DDim
::
kMaxRank
;
...
...
@@ -156,18 +133,18 @@ struct IndexCalculator {
const
std
::
vector
<
int
>&
cal_strides
,
const
std
::
vector
<
int
>&
full_strides
)
:
dim
(
dim
)
{
dims
=
detail
::
VectorToArray
<
int
,
kMaxRank
>
(
cal_dims
);
strides
=
detail
::
VectorToArray
<
int
,
kMaxRank
>
(
full_strides
);
dims
=
detail
s
::
VectorToArray
<
int
,
kMaxRank
>
(
cal_dims
);
strides
=
detail
s
::
VectorToArray
<
int
,
kMaxRank
>
(
full_strides
);
std
::
vector
<
platform
::
FastDivMod
>
cal_divmoders
;
// fast divmod
for
(
auto
i
:
cal_strides
)
{
cal_divmoders
.
push_back
(
platform
::
FastDivMod
(
i
));
}
divmoders
=
detail
::
VectorToArray
<
platform
::
FastDivMod
,
kMaxRank
>
(
cal_divmoders
);
detail
s
::
VectorToArray
<
platform
::
FastDivMod
,
kMaxRank
>
(
cal_divmoders
);
}
__device__
inline
int
Get
(
int
offset
)
const
{
__device__
inline
int
operator
()
(
int
offset
)
const
{
int
index
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kMaxRank
;
++
i
)
{
...
...
@@ -187,6 +164,15 @@ struct IndexCalculator {
framework
::
Array
<
platform
::
FastDivMod
,
kMaxRank
>
divmoders
;
};
// when reduce_type == kReduceLastDim this struct will be used
// for higher performance
struct
LastDimIndexCal
{
explicit
LastDimIndexCal
(
int
num
)
:
stride
(
num
)
{}
__device__
inline
int
operator
()(
int
index
)
const
{
return
index
*
stride
;
}
int
stride
;
};
// reduce config
template
<
typename
Ty
>
struct
ReduceConfig
{
...
...
@@ -307,7 +293,7 @@ struct ReduceConfig {
left_dim
.
assign
(
left_set
.
begin
(),
left_set
.
end
());
// if the last dim gets involved in reduction
reduce_lastdim
=
(
reduce_dim
.
back
()
==
x_dim
.
size
()
-
1
);
reduce_last
_
dim
=
(
reduce_dim
.
back
()
==
x_dim
.
size
()
-
1
);
}
// set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
...
...
@@ -320,9 +306,9 @@ struct ReduceConfig {
idx_dim
.
push_back
(
i
);
}
x_strides
=
detail
::
GetDimStrides
(
x_dim
,
idx_dim
);
reduce_strides
=
detail
::
GetDimStrides
(
x_dim
,
reduce_dim
);
left_strides
=
detail
::
GetDimStrides
(
x_dim
,
left_dim
);
x_strides
=
detail
s
::
GetDimStrides
(
x_dim
,
idx_dim
);
reduce_strides
=
detail
s
::
GetDimStrides
(
x_dim
,
reduce_dim
);
left_strides
=
detail
s
::
GetDimStrides
(
x_dim
,
left_dim
);
reduce_num
=
reduce_strides
[
0
]
*
x_dim
[
reduce_dim
[
0
]];
left_num
=
1
;
...
...
@@ -354,36 +340,36 @@ struct ReduceConfig {
void
SetBlockDimForReduceAny
(
dim3
*
block_dim
,
dim3
*
grid_dim
)
{
constexpr
int
min_reduce_num_per_thread
=
16
;
constexpr
int
max_reduce_num_per_thread
=
256
;
constexpr
int
max_num_threads
=
detail
::
k
MaxThread
;
constexpr
int
max_num_threads
=
kps
::
details
::
kReduce
MaxThread
;
// set block size.
// 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
// 1. If reduce_last
_
dim == true, all the threads whose threadIdx.y are same
// will process the reduction for one output.
// The number of output for one block is blockDim.y;
// 2. If reduce_lastdim == false, different threadIdx.x will process
// 2. If reduce_last
_
dim == false, different threadIdx.x will process
// different reduction and gets the output separately. If it is
// necessary, it should reduce in block y.
// The number of output for one block is blockDim.x;
int
block_x
,
block_y
;
int
grid_num
,
reduce_num_per_thread
;
if
(
reduce_lastdim
)
{
block_x
=
detail
::
GetBlockDim
(
reduce_num
);
block_y
=
detail
::
GetBlockDim
(
left_num
);
if
(
reduce_last
_
dim
)
{
block_x
=
detail
s
::
GetBlockDim
(
reduce_num
);
block_y
=
detail
s
::
GetBlockDim
(
left_num
);
block_dim
->
x
=
block_x
;
block_dim
->
y
=
std
::
min
(
block_y
,
static_cast
<
int
>
(
max_num_threads
/
block_dim
->
x
));
grid_num
=
detail
::
AlignUp
(
left_num
,
block_dim
->
y
);
reduce_num_per_thread
=
detail
::
AlignUp
(
reduce_num
,
block_dim
->
x
);
grid_num
=
detail
s
::
AlignUp
(
left_num
,
block_dim
->
y
);
reduce_num_per_thread
=
detail
s
::
AlignUp
(
reduce_num
,
block_dim
->
x
);
}
else
{
block_x
=
detail
::
GetBlockDim
(
left_num
);
block_y
=
detail
::
GetBlockDim
(
reduce_num
);
block_x
=
detail
s
::
GetBlockDim
(
left_num
);
block_y
=
detail
s
::
GetBlockDim
(
reduce_num
);
block_dim
->
x
=
std
::
min
(
block_x
,
32
);
block_dim
->
y
=
std
::
min
(
block_y
,
static_cast
<
int
>
(
max_num_threads
/
block_dim
->
x
));
block_dim
->
x
=
std
::
min
(
block_x
,
static_cast
<
int
>
(
max_num_threads
/
block_dim
->
y
));
grid_num
=
detail
::
AlignUp
(
left_num
,
block_dim
->
x
);
reduce_num_per_thread
=
detail
::
AlignUp
(
reduce_num
,
block_dim
->
y
);
grid_num
=
detail
s
::
AlignUp
(
left_num
,
block_dim
->
x
);
reduce_num_per_thread
=
detail
s
::
AlignUp
(
reduce_num
,
block_dim
->
y
);
}
int
device_id
=
platform
::
GetCurrentDeviceId
();
int
max_mp
=
platform
::
GetCUDAMultiProcessors
(
device_id
);
...
...
@@ -403,10 +389,10 @@ struct ReduceConfig {
// the number cannot be larger than max_reduce_num_per_thread, so we
// choose the maximum between the result above and input_split_num_2.
int
input_split_num_1
=
detail
::
AlignUp
(
reduce_num_per_thread
,
min_reduce_num_per_thread
);
detail
s
::
AlignUp
(
reduce_num_per_thread
,
min_reduce_num_per_thread
);
int
input_split_num_2
=
detail
::
AlignUp
(
reduce_num_per_thread
,
max_reduce_num_per_thread
);
int
input_split_num_3
=
detail
::
AlignUp
(
max_num_blocks
,
grid_num
);
detail
s
::
AlignUp
(
reduce_num_per_thread
,
max_reduce_num_per_thread
);
int
input_split_num_3
=
detail
s
::
AlignUp
(
max_num_blocks
,
grid_num
);
grid_dim
->
x
=
grid_num
;
grid_dim
->
y
=
std
::
max
(
std
::
min
(
input_split_num_1
,
input_split_num_3
),
...
...
@@ -423,7 +409,7 @@ struct ReduceConfig {
// for others: block(block_num, 1) , grid(left_num, 1)
void
SetBlockDim
()
{
// init
int
block_num
=
detail
::
GetBlockDim
(
reduce_num
);
int
block_num
=
detail
s
::
GetBlockDim
(
reduce_num
);
should_reduce_again
=
false
;
dim3
block_dim
(
block_num
,
1
);
...
...
@@ -449,23 +435,23 @@ struct ReduceConfig {
int
num_block
=
(
max_threads
/
left_num
);
if
(
num_block
>
1
&&
reduce_num
>=
REDUCE_SPLIT_BOUNDARY
)
{
blocking_size
=
detail
::
GetLastPow2
(
reduce_num
/
num_block
);
blocking_size
=
detail
s
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
blocking_size
<=
1
)
{
blocking_size
=
detail
::
GetLastPow2
(
sqrt
(
reduce_num
));
blocking_size
=
detail
s
::
GetLastPow2
(
sqrt
(
reduce_num
));
}
else
if
(
blocking_size
*
2
<
reduce_num
)
{
blocking_size
*=
2
;
}
should_reduce_again
=
true
;
block_dim
.
x
=
32
;
block_dim
.
x
=
details
::
GetBlockDim
(
left_num
)
;
block_dim
.
y
=
1
;
grid_dim
.
x
=
(
left_num
+
block_dim
.
x
-
1
)
/
block_dim
.
x
;
grid_dim
.
y
=
(
reduce_num
+
blocking_size
-
1
)
/
blocking_size
;
}
else
{
block_dim
.
x
=
32
;
block_dim
.
x
=
details
::
GetBlockDim
(
left_num
)
;
block_dim
.
y
=
1
;
blocking_size
=
reduce_num
;
grid_dim
.
x
=
(
left_num
+
block_dim
.
x
-
1
)
/
block_dim
.
x
;
...
...
@@ -493,272 +479,210 @@ struct ReduceConfig {
int
left_num
;
int
blocking_size
;
bool
should_reduce_again
;
bool
reduce_lastdim
;
bool
reduce_last
_
dim
;
Ty
*
output_data
;
dim3
block
;
dim3
grid
;
};
static
__device__
int
SharedMemoryIndex
(
int
index
)
{
return
(
threadIdx
.
y
+
index
)
*
blockDim
.
x
+
threadIdx
.
x
;
}
template
<
typename
T
,
typename
ReduceOp
>
static
__device__
T
WarpReduce
(
T
val
,
ReduceOp
reducer
)
{
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
stride
=
detail
::
kWarpSize
/
2
;
stride
>
0
;
stride
>>=
1
)
{
T
temp
=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
stride
);
val
=
reducer
(
val
,
temp
);
}
return
val
;
}
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
/* size : how many colonms left have to be reduced
* loop : how many rows data have to be reduced
* block_size: max rows this block to reduce
*/
template
<
typename
T
,
typename
ReduceOp
>
static
__device__
T
BlockXReduce
(
T
val
,
ReduceOp
reducer
)
{
using
detail
::
kWarpSize
;
__shared__
T
shared
[
2
*
kWarpSize
];
int
block_dim_x
=
blockDim
.
x
;
if
(
blockDim
.
x
>
kWarpSize
)
{
block_dim_x
=
blockDim
.
x
/
kWarpSize
;
int
lane
=
threadIdx
.
x
%
kWarpSize
;
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
wid
=
tid
/
kWarpSize
;
int
bid
=
threadIdx
.
y
;
val
=
WarpReduce
(
val
,
reducer
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
val
=
shared
[
bid
*
block_dim_x
+
lane
];
}
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
stride
=
1
;
stride
<
block_dim_x
;
stride
<<=
1
)
{
T
temp
=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
stride
);
val
=
reducer
(
val
,
temp
);
}
return
val
;
}
template
<
typename
T
,
typename
ReduceOp
>
static
__device__
T
BlockYReduce
(
T
val
,
ReduceOp
reducer
)
{
__shared__
T
shared_memory
[
detail
::
kMaxThread
];
shared_memory
[
SharedMemoryIndex
(
0
)]
=
val
;
for
(
int
stride
=
blockDim
.
y
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
if
(
threadIdx
.
y
<
stride
&&
threadIdx
.
y
+
stride
<
blockDim
.
y
)
{
T
temp
=
shared_memory
[
SharedMemoryIndex
(
stride
)];
val
=
reducer
(
val
,
temp
);
}
shared_memory
[
SharedMemoryIndex
(
0
)]
=
val
;
}
return
val
;
}
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
>
__device__
void
ReduceHigherDim
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
block_size
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
typename
TransformOp
,
bool
IsBoundary
=
false
>
__device__
void
HigherDimDealSegment
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
block_size
)
{
const
int
NY
=
1
;
int
idx
=
blockIdx
.
x
*
blockDim
.
x
;
int
idy
=
blockIdx
.
y
*
block_size
;
MPType
reduce_var
=
init
;
if
(
idx
<
left_num
)
{
int
loop
=
reduce_num
-
idy
;
loop
=
loop
>
block_size
?
block_size
:
loop
;
for
(
int
iy
=
0
;
iy
<
loop
;
iy
++
)
{
int
id
=
(
idy
+
iy
)
*
left_num
+
idx
+
blockIdx
.
z
*
reduce_num
*
left_num
;
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
MPType
>
(
transformer
(
x
[
id
])));
}
y
[
idx
+
blockIdx
.
y
*
left_num
+
blockIdx
.
z
*
gridDim
.
y
*
left_num
]
=
static_cast
<
Ty
>
(
reduce_var
);
// block_offset of rows
Tx
reduce_input
[
NY
];
MPType
reduce_compute
[
NY
];
MPType
result
=
init
;
// the offset of this block
int
block_offset
=
idy
*
left_num
+
idx
+
blockIdx
.
z
*
reduce_num
*
left_num
;
const
Tx
*
input
=
x
+
block_offset
;
int
store_offset
=
blockIdx
.
y
*
left_num
+
blockIdx
.
z
*
gridDim
.
y
*
left_num
+
idx
;
// how many columns left
int
size
=
left_num
-
idx
;
// how many rows have to be reduced
int
loop
=
reduce_num
-
idy
;
loop
=
loop
>
block_size
?
block_size
:
loop
;
for
(
int
loop_index
=
0
;
loop_index
<
loop
;
loop_index
+=
NY
)
{
kps
::
ReadData
<
Tx
,
Tx
,
1
,
NY
,
1
,
IsBoundary
>
(
&
reduce_input
[
0
],
input
+
loop_index
*
left_num
,
size
,
NY
,
1
,
left_num
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
reduce_compute
[
0
],
&
reduce_input
[
0
],
transformer
);
kps
::
Reduce
<
MPType
,
NY
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
result
,
&
reduce_compute
[
0
],
reducer
,
false
);
}
Ty
temp_data
=
static_cast
<
Ty
>
(
result
);
kps
::
WriteData
<
Ty
,
1
,
1
,
1
,
IsBoundary
>
(
y
+
store_offset
,
&
temp_data
,
size
);
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
>
__device__
void
ReduceAny
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
bool
reduce_lastdim
,
const
IndexCalculator
&
reduce_index_calculator
,
const
IndexCalculator
&
left_index_calculator
)
{
typename
TransformOp
,
typename
Calculator
>
__global__
void
ReduceAnyKernel
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
bool
reduce_last_dim
,
const
Calculator
reduce_index_calculator
,
const
Calculator
left_index_calculator
)
{
int
input_idx
,
left_idx
,
stride
;
int
block_size
=
0
;
bool
need_store
=
true
;
int
tid
=
0
;
// the last dim gets involved in reduction
if
(
reduce_lastdim
)
{
input_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
reduce_last
_
dim
)
{
input_idx
=
blockIdx
.
y
*
blockDim
.
x
;
left_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
stride
=
gridDim
.
y
*
blockDim
.
x
;
block_size
=
blockDim
.
x
;
need_store
=
(
threadIdx
.
x
==
0
)
&&
(
left_idx
<
left_num
);
tid
=
threadIdx
.
x
;
}
else
{
input_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
input_idx
=
blockIdx
.
y
*
blockDim
.
y
;
left_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
stride
=
gridDim
.
y
*
blockDim
.
y
;
block_size
=
blockDim
.
y
;
need_store
=
(
threadIdx
.
y
==
0
)
&&
(
left_idx
<
left_num
);
tid
=
threadIdx
.
y
;
}
int
store_offset
=
blockIdx
.
y
*
left_num
+
left_idx
;
// calculate the offset, means the addr where each thread really start.
int
input_offset
=
left_index_calculator
.
Get
(
left_idx
);
int
input_offset
=
left_index_calculator
(
left_idx
);
const
Tx
*
input
=
x
+
input_offset
;
MPType
reduce_var
=
init
;
Ty
store_data
;
// 1. reduce for each thread
if
(
left_idx
<
left_num
)
{
// load REDUCE_VEC_SIZE data once, and then compute
Tx
input_reg
[
REDUCE_VEC_SIZE
];
MPType
input_compute
[
REDUCE_VEC_SIZE
];
int
bound
=
reduce_num
-
(
REDUCE_VEC_SIZE
-
1
)
*
stride
;
while
(
input_idx
<
bound
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
REDUCE_VEC_SIZE
;
++
i
)
{
int
reduce_idx
=
input_idx
+
i
*
stride
;
int
idx_x
=
reduce_index_calculator
.
Get
(
reduce_idx
);
input_reg
[
i
]
=
input
[
idx_x
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
REDUCE_VEC_SIZE
;
++
i
)
{
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
MPType
>
(
transformer
(
input_reg
[
i
])));
}
input_idx
+=
REDUCE_VEC_SIZE
*
stride
;
for
(;
input_idx
+
block_size
<
bound
;
input_idx
+=
REDUCE_VEC_SIZE
*
stride
)
{
kps
::
ReadDataReduce
<
Tx
,
1
,
REDUCE_VEC_SIZE
,
1
,
1
,
Calculator
>
(
&
input_reg
[
0
],
input
,
input_idx
,
reduce_index_calculator
,
1
,
reduce_num
,
1
,
stride
,
reduce_last_dim
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
input_compute
[
0
],
&
input_reg
[
0
],
transformer
);
kps
::
Reduce
<
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
reduce_var
,
&
input_compute
[
0
],
reducer
,
reduce_last_dim
);
}
// deal with the remain part
int
input_idx_tmp
=
input_idx
;
#pragma unroll
for
(
int
i
=
0
;
i
<
REDUCE_VEC_SIZE
;
++
i
)
{
if
(
input_idx
>=
reduce_num
)
{
break
;
}
int
reduce_idx
=
input_idx
;
int
idx_x
=
reduce_index_calculator
.
Get
(
reduce_idx
);
input_reg
[
i
]
=
input
[
idx_x
];
input_idx
+=
stride
;
}
input_idx
=
input_idx_tmp
;
kps
::
Init
<
MPType
,
REDUCE_VEC_SIZE
>
(
&
input_compute
[
0
],
init
);
kps
::
ReadDataReduce
<
Tx
,
1
,
REDUCE_VEC_SIZE
,
1
,
1
,
Calculator
,
true
>
(
&
input_reg
[
0
],
input
,
input_idx
,
reduce_index_calculator
,
1
,
reduce_num
,
1
,
stride
,
reduce_last_dim
);
input_idx
+=
tid
;
#pragma unroll
for
(
int
i
=
0
;
i
<
REDUCE_VEC_SIZE
;
++
i
)
{
if
(
input_idx
>=
reduce_num
)
{
break
;
}
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
MPType
>
(
transformer
(
input_reg
[
i
])));
input_compute
[
i
]
=
static_cast
<
MPType
>
(
transformer
(
input_reg
[
i
]));
input_idx
+=
stride
;
}
kps
::
Reduce
<
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
reduce_var
,
&
input_compute
[
0
],
reducer
,
reduce_last_dim
);
}
// 2. reduce in block y
if
(
!
reduce_lastdim
&&
blockDim
.
y
>
1
)
{
reduce_var
=
BlockYReduce
(
reduce_var
,
reducer
);
}
__syncthreads
();
if
(
reduce_lastdim
)
{
// 3. reduce in block x
reduce_var
=
BlockXReduce
(
reduce_var
,
reducer
);
if
(
left_idx
<
left_num
&&
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
y
*
left_num
+
left_idx
]
=
static_cast
<
Ty
>
(
reduce_var
);
}
}
else
{
if
(
left_idx
<
left_num
&&
threadIdx
.
y
==
0
)
{
y
[
blockIdx
.
y
*
left_num
+
left_idx
]
=
static_cast
<
Ty
>
(
reduce_var
);
}
kps
::
Reduce
<
MPType
,
1
,
1
,
1
,
ReduceOp
,
kps
::
details
::
kGlobalMode
>
(
&
reduce_var
,
&
reduce_var
,
reducer
,
reduce_last_dim
);
if
(
need_store
)
{
y
[
store_offset
]
=
static_cast
<
Ty
>
(
reduce_var
);
}
}
// module function designed for global function
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
>
__device__
void
ReduceModule
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
blocking_size
,
int
reduce_type
,
bool
reduce_lastdim
,
const
IndexCalculator
&
reduce_index_calculator
,
const
IndexCalculator
&
left_index_calculator
)
{
if
(
reduce_type
==
ReduceType
::
kReduceLastDim
||
reduce_type
==
ReduceType
::
kReduceAny
)
{
ReduceAny
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
reduce_lastdim
,
reduce_index_calculator
,
left_index_calculator
);
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
}
else
if
(
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
ReduceHigherDim
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
>
(
__global__
void
ReduceHigherDimKernel
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
blocking_size
)
{
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
int
idx
=
blockIdx
.
x
*
blockDim
.
x
;
int
size
=
left_num
-
idx
;
if
(
size
>=
blockDim
.
x
)
{
// complete segment
HigherDimDealSegment
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
);
}
else
{
HigherDimDealSegment
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
true
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
);
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
>
__global__
void
ReduceKernelFunction
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
blocking_size
,
int
reduce_type
,
bool
reduce_lastdim
,
IndexCalculator
reduce_index_calculator
,
IndexCalculator
left_index_calculator
)
{
ReduceModule
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
,
reduce_type
,
reduce_lastdim
,
reduce_index_calculator
,
left_index_calculator
);
}
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
>
static
void
LaunchReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
ReduceOp
&
reducer
,
MPType
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
using
TransformOp
=
typename
ReduceOp
::
Transformer
;
int
reduce_rank
=
config
.
reduce_strides
.
size
();
int
left_rank
=
config
.
left_strides
.
size
();
auto
reduce_index_calculator
=
IndexCalculator
(
reduce_rank
,
config
.
reduce_dim
,
config
.
reduce_strides
,
config
.
x_strides
);
auto
left_index_calculator
=
IndexCalculator
(
left_rank
,
config
.
left_dim
,
config
.
left_strides
,
config
.
x_strides
);
ReduceKernelFunction
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
config
.
reduce_type
,
config
.
reduce_lastdim
,
reduce_index_calculator
,
left_index_calculator
);
if
(
config
.
reduce_type
==
kReduceLastDim
)
{
int
stride_reduce
=
1
;
int
stride_left
=
config
.
reduce_num
;
// for higher performance
auto
reduce_index_calculator
=
LastDimIndexCal
(
stride_reduce
);
auto
left_index_calculator
=
LastDimIndexCal
(
stride_left
);
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
LastDimIndexCal
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
);
}
else
{
int
reduce_rank
=
config
.
reduce_strides
.
size
();
int
left_rank
=
config
.
left_strides
.
size
();
auto
reduce_index_calculator
=
IndexCalculator
(
reduce_rank
,
config
.
reduce_dim
,
config
.
reduce_strides
,
config
.
x_strides
);
auto
left_index_calculator
=
IndexCalculator
(
left_rank
,
config
.
left_dim
,
config
.
left_strides
,
config
.
x_strides
);
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
IndexCalculator
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
);
}
if
(
config
.
should_reduce_again
)
{
dim3
block
;
dim3
grid
;
if
(
config
.
reduce_lastdim
)
{
if
(
config
.
reduce_last
_
dim
)
{
block
=
dim3
(
32
,
1
,
1
);
grid
=
dim3
(
detail
::
AlignUp
(
config
.
left_num
,
32
),
1
,
1
);
grid
=
dim3
(
detail
s
::
AlignUp
(
config
.
left_num
,
32
),
1
,
1
);
}
else
{
block
=
dim3
(
config
.
block
.
x
,
1
,
1
);
grid
=
dim3
(
config
.
grid
.
x
,
1
,
config
.
grid
.
z
);
}
Reduce
KernelFunction
<
Reduce
HigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
,
detail
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
detail
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
ReduceType
::
kReduceHigherDim
,
config
.
reduce_lastdim
,
reduce_index_calculator
,
left_index_calculator
);
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
);
}
}
...
...
@@ -812,6 +736,39 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
using
MPType
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
auto
reducer
=
ReduceOp
<
Tx
,
MPType
>
();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if
(
config
.
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
using
TransformOp
=
typename
ReduceOp
<
Tx
,
MPType
>::
Transformer
;
ReduceHigherDimKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
TransformOp
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
);
if
(
config
.
should_reduce_again
)
{
dim3
block
=
dim3
(
config
.
block
.
x
,
1
,
1
);
dim3
grid
=
dim3
(
config
.
grid
.
x
,
1
,
config
.
grid
.
z
);
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
reducer
.
initial
(),
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
);
}
return
;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
LaunchReduceKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>>
(
x_data
,
y_data
,
reducer
,
reducer
.
initial
(),
stream
,
config
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录