Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9c5d5665
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c5d5665
编写于
11月 17, 2021
作者:
N
niuliling123
提交者:
GitHub
11月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify reduce_op.op.h for xpu2 with kernel primitive api (#36904)
* Modify reduce_op.op.h for xpu2 with kernel primitive api
上级
d08753df
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
407 addition
and
239 deletion
+407
-239
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
.../fluid/operators/kernel_primitives/datamover_primitives.h
+6
-6
paddle/fluid/operators/kernel_primitives/helper_primitives.h
paddle/fluid/operators/kernel_primitives/helper_primitives.h
+32
-47
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
+36
-2
paddle/fluid/operators/margin_cross_entropy_op.cu
paddle/fluid/operators/margin_cross_entropy_op.cu
+4
-4
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+8
-8
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+321
-172
未找到文件。
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
浏览文件 @
9c5d5665
...
...
@@ -360,12 +360,12 @@ __device__ __forceinline__ void ReadDataBc(
* reduce_last_dim: Used to indicate whether the dimension of reduce contains
* the lowest dimension.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
template
<
typename
T
x
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
typename
IndexCal
,
typename
Functor
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
block_offset
,
T
y
*
dst
,
const
Tx
*
__restrict__
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
int
stride_ny
,
Functor
func
,
bool
reduce_last_dim
)
{
int
thread_offset
=
0
;
int
left_idx
=
0
;
if
(
reduce_last_dim
)
{
...
...
@@ -385,7 +385,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t
index_src
=
index_cal
(
thread_offset
+
block_offset
);
dst
[
ny
]
=
s
rc
[
index_src
]
;
dst
[
ny
]
=
s
tatic_cast
<
Ty
>
(
func
(
src
[
index_src
]))
;
thread_offset
+=
stride_ny
;
}
}
else
{
...
...
@@ -400,7 +400,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t
index_src
=
index_cal
(
thread_offset
+
block_offset
);
dst
[
nx
+
ny
*
NX
]
=
s
rc
[
index_src
]
;
dst
[
nx
+
ny
*
NX
]
=
s
tatic_cast
<
Ty
>
(
func
(
src
[
index_src
]))
;
thread_offset
+=
stride_ny
;
}
}
...
...
paddle/fluid/operators/kernel_primitives/helper_primitives.h
浏览文件 @
9c5d5665
...
...
@@ -17,64 +17,49 @@
namespace
paddle
{
namespace
operators
{
namespace
kernel_primitives
{
namespace
details
{
static
__device__
__forceinline__
platform
::
float16
ExpFunctor
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
exp
(
x
);
}
static
__device__
__forceinline__
float
ExpFunctor
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
ExpFunctor
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
platform
::
float16
LogFunctor
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
log
(
x
);
}
static
__device__
__forceinline__
float
LogFunctor
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
LogFunctor
(
double
x
)
{
return
log
(
x
);
}
#ifdef PADDLE_WITH_XPU2
struct
dim3
{
int
x
;
int
y
;
int
z
;
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpLogitTransformer
{
HOSTDEVICE
explicit
inline
ExpLogitTransformer
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
*
x
)
const
{
return
static_cast
<
Ty
>
(
details
::
ExpFunctor
(
x
[
0
]));
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
details
::
ExpFunctor
(
x
));
explicit
inline
dim3
(
int
split_x
,
int
split_y
=
1
,
int
split_z
=
1
)
{
x
=
split_x
;
y
=
split_y
;
z
=
split_z
;
}
};
#endif
// Post processing function for sum, max, min, prod, any
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
struct
DimConfig
{
int
split_num_x
;
int
split_num_y
;
int
split_num_z
;
int
deal_size_x
;
int
deal_size_y
;
int
deal_size_z
;
int
rem_x
;
int
rem_y
;
int
rem_z
;
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
*
x
)
const
{
return
static_cast
<
Ty
>
(
x
[
0
]);
HOSTDEVICE
explicit
inline
DimConfig
(
int
split_x
,
int
split_y
,
int
split_z
,
int
size_x
,
int
size_y
,
int
size_z
)
{
split_num_x
=
split_x
;
split_num_y
=
split_y
;
split_num_z
=
split_z
;
deal_size_x
=
size_x
;
deal_size_y
=
size_y
;
deal_size_z
=
size_z
;
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
);
HOSTDEVICE
void
SetRem
(
int
rem_nx
,
int
rem_ny
,
int
rem_nz
)
{
rem_x
=
rem_nx
;
rem_y
=
rem_ny
;
rem_z
=
rem_nz
;
}
};
// Post processing function for mean
template
<
typename
T
>
struct
DivideFunctor
{
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
T
operator
()(
const
T
*
x
)
const
{
return
x
[
0
]
*
n_inv
;
}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
}
// namespace details
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
浏览文件 @
9c5d5665
...
...
@@ -13,11 +13,45 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#ifdef PADDLE_WITH_XPU2
#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h"
#define THREAD_ID_X core_id()
#define THREAD_ID_Y 0
#define THREAD_ID_Z 0
#define BLOCK_NUM_X core_num()
#define BLOCK_NUM_Y 0
#define BLOCK_NUM_Z 0
#define BLOCK_ID_X cluster_id()
#define BLOCK_ID_Y 0
#define BLOCK_ID_Z 0
#define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0
#define GRID_NUM_Z 0
#else
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Z threadIdx.z
#define BLOCK_NUM_X blockDim.x
#define BLOCK_NUM_Y blockDim.y
#define BLOCK_NUM_Z blockDim.z
#define BLOCK_ID_X blockIdx.x
#define BLOCK_ID_Y blockIdx.y
#define BLOCK_ID_Z blockIdx.z
#define GRID_NUM_X gridDim.x
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z
#endif
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/margin_cross_entropy_op.cu
浏览文件 @
9c5d5665
...
...
@@ -130,7 +130,7 @@ __global__ void AddMarginToPositiveLogitsKernel(
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpAndSum
{
using
Transformer
=
kp
ds
::
ExpLogitTransforme
r
<
Tx
>
;
using
Transformer
=
kp
s
::
ExpFuncto
r
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -159,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row,
const
int64_t
N
,
const
int64_t
D
)
{
CUDA_KERNEL_LOOP
(
i
,
N
*
D
)
{
auto
row
=
i
/
D
;
logits
[
i
]
-=
kp
ds
::
LogFunctor
(
logits_sum_per_row
[
row
]);
logits
[
i
]
-=
kp
s
::
details
::
Log
(
logits_sum_per_row
[
row
]);
}
}
...
...
@@ -174,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel(
if
((
col
+
start_index
)
==
labels
[
row
])
{
auto
softmax
=
log_softmax
[
i
];
loss
[
row
]
=
-
softmax
;
log_softmax
[
i
]
=
kp
ds
::
ExpFunctor
(
softmax
);
log_softmax
[
i
]
=
kp
s
::
details
::
Exp
(
softmax
);
}
else
{
log_softmax
[
i
]
=
kp
ds
::
ExpFunctor
(
log_softmax
[
i
]);
log_softmax
[
i
]
=
kp
s
::
details
::
Exp
(
log_softmax
[
i
]);
}
}
}
...
...
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
9c5d5665
...
...
@@ -24,11 +24,11 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
kp
ds
=
paddle
::
operators
::
kernel_primitives
::
detail
s
;
namespace
kp
s
=
paddle
::
operators
::
kernel_primitive
s
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMin
{
using
Transformer
=
kp
d
s
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kps
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
max
());
...
...
@@ -41,7 +41,7 @@ struct CustomMin {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMax
{
using
Transformer
=
kp
d
s
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kps
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
lowest
());
...
...
@@ -55,7 +55,7 @@ struct CustomMax {
// for cub::Reduce
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomSum
{
using
Transformer
=
kp
d
s
::
IdentityFunctor
<
Tx
,
Ty
>
;
using
Transformer
=
kps
::
IdentityFunctor
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -66,7 +66,7 @@ struct CustomSum {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
kp
d
s
::
DivideFunctor
<
Tx
>
;
using
Transformer
=
kps
::
DivideFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -77,7 +77,7 @@ struct CustomMean {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMul
{
using
Transformer
=
kp
d
s
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kps
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
1.0
f
);
}
...
...
@@ -88,7 +88,7 @@ struct CustomMul {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalOr
{
using
Transformer
=
kp
d
s
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kps
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
false
);
}
...
...
@@ -99,7 +99,7 @@ struct CustomLogicalOr {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalAnd
{
using
Transformer
=
kp
d
s
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kps
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
true
);
}
...
...
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
9c5d5665
...
...
@@ -165,10 +165,93 @@ struct IndexCalculator {
framework
::
Array
<
platform
::
FastDivMod
,
kMaxRank
>
divmoders
;
};
template
<
bool
ReduceLastDim
=
false
>
struct
ReduceIndexMapping
{
const
kps
::
DimConfig
dim
;
HOSTDEVICE
explicit
ReduceIndexMapping
(
const
kps
::
DimConfig
&
dims
)
:
dim
(
dims
)
{}
__device__
__forceinline__
int
BlockIdX
()
{
#ifdef PADDLE_WITH_XPU2
if
(
ReduceLastDim
)
{
return
(
cluster_id
()
/
dim
.
split_num_x
%
dim
.
split_num_y
);
}
else
{
return
cluster_id
()
%
dim
.
split_num_x
;
}
#else
return
blockIdx
.
x
;
#endif
}
__device__
__forceinline__
int
BlockIdY
()
{
#ifdef PADDLE_WITH_XPU2
if
(
ReduceLastDim
)
{
return
(
cluster_id
()
%
dim
.
split_num_x
);
}
else
{
return
(
cluster_id
()
/
dim
.
split_num_x
%
dim
.
split_num_y
);
}
#else
return
blockIdx
.
y
;
#endif
}
__device__
__forceinline__
int
BlockDimX
()
{
#ifdef PADDLE_WITH_XPU2
return
dim
.
deal_size_x
;
#else
return
blockDim
.
x
;
#endif
}
__device__
__forceinline__
int
BlockDimY
()
{
#ifdef PADDLE_WITH_XPU2
return
dim
.
deal_size_y
;
#else
return
blockDim
.
y
;
#endif
}
__device__
__forceinline__
int
GridDimX
()
{
#ifdef PADDLE_WITH_XPU2
if
(
ReduceLastDim
)
{
return
dim
.
split_num_y
;
}
else
{
return
dim
.
split_num_x
;
}
#else
return
gridDim
.
x
;
#endif
}
__device__
__forceinline__
int
GridDimY
()
{
#ifdef PADDLE_WITH_XPU2
if
(
ReduceLastDim
)
{
return
dim
.
split_num_x
;
}
else
{
return
dim
.
split_num_y
;
}
#else
return
gridDim
.
y
;
#endif
}
__device__
__forceinline__
int
GetLoopSize
()
{
#ifdef PADDLE_WITH_XPU2
if
(
ReduceLastDim
)
{
return
dim
.
deal_size_y
;
}
else
{
return
dim
.
deal_size_x
;
}
#else
return
1
;
#endif
}
};
// when reduce_type == kReduceLastDim this struct will be used
// for higher performance
struct
Last
DimIndexCal
{
explicit
Last
DimIndexCal
(
int
num
)
:
stride
(
num
)
{}
struct
One
DimIndexCal
{
explicit
One
DimIndexCal
(
int
num
)
:
stride
(
num
)
{}
__device__
inline
int
operator
()(
int
index
)
const
{
return
index
*
stride
;
}
int
stride
;
...
...
@@ -331,8 +414,16 @@ struct ReduceConfig {
if
(
rank
==
reduce_rank
||
is_last_dim
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceLastDim
);
}
else
if
(
reduce_rank
==
1
)
{
// ReduceFirstDim and reduceSecondDim
// ReduceFirstDim and reduceSecondDim
#ifdef PADDLE_WITH_XPU2
if
(
reduce_dim
[
0
]
==
0
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceHigherDim
);
}
else
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceAny
);
}
#else
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceHigherDim
);
#endif
}
else
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceAny
);
}
...
...
@@ -408,59 +499,61 @@ struct ReduceConfig {
// for ReduceHigherDim: if block is enough -> splite reduce_num
// else init block(32, 1) grid(block_num, 1)
// for others: block(block_num, 1) , grid(left_num, 1)
void
SetBlockDimForHigher
(
dim3
*
block_dim
,
dim3
*
grid_dim
)
{
int
last_dim_num
=
x_dim
.
back
();
// update left_num
int
grid_z
=
left_num
/
last_dim_num
;
left_num
=
last_dim_num
;
grid_dim
->
z
=
grid_z
;
int
device_id
=
platform
::
GetCurrentDeviceId
();
int
max_mp
=
platform
::
GetCUDAMultiProcessors
(
device_id
);
int
max_threads_per_mp
=
platform
::
GetCUDAMaxThreadsPerMultiProcessor
(
device_id
);
int
max_threads
=
max_threads_per_mp
*
max_mp
;
// init
int
num_block
=
(
max_threads
/
left_num
);
block_dim
->
x
=
details
::
GetBlockDim
(
left_num
);
grid_dim
->
x
=
details
::
AlignUp
(
left_num
,
block_dim
->
x
);
blocking_size
=
reduce_num
;
if
(
num_block
>
1
&&
reduce_num
>=
REDUCE_SPLIT_BOUNDARY
)
{
blocking_size
=
details
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
blocking_size
<=
1
)
{
blocking_size
=
details
::
GetLastPow2
(
sqrt
(
reduce_num
));
}
else
if
(
blocking_size
*
2
<
reduce_num
)
{
blocking_size
*=
2
;
}
should_reduce_again
=
true
;
grid_dim
->
y
=
details
::
AlignUp
(
reduce_num
,
blocking_size
);
}
}
void
SetBlockDim
()
{
// init
int
block_num
=
details
::
GetBlockDim
(
reduce_num
);
should_reduce_again
=
false
;
dim3
block_dim
(
block_num
,
1
);
dim3
grid_dim
(
left_num
,
1
);
dim3
block_dim
(
block_num
,
1
,
1
);
dim3
grid_dim
(
left_num
,
1
,
1
);
blocking_size
=
reduce_num
;
#ifdef PADDLE_WITH_XPU2
if
(
reduce_last_dim
)
{
block_dim
.
x
=
128
;
block_dim
.
y
=
reduce_num
;
grid_dim
.
x
=
8
;
grid_dim
.
y
=
1
;
}
else
{
block_dim
.
x
=
128
;
block_dim
.
y
=
left_num
;
grid_dim
.
x
=
8
;
grid_dim
.
y
=
1
;
}
#else
if
(
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
int
last_dim_num
=
x_dim
.
back
();
// update left_num
int
grid_z
=
left_num
/
last_dim_num
;
left_num
=
last_dim_num
;
block_dim
.
z
=
1
;
grid_dim
.
z
=
grid_z
;
int
device_id
=
platform
::
GetCurrentDeviceId
();
int
max_mp
=
platform
::
GetCUDAMultiProcessors
(
device_id
);
int
max_threads_per_mp
=
platform
::
GetCUDAMaxThreadsPerMultiProcessor
(
device_id
);
int
max_threads
=
max_threads_per_mp
*
max_mp
;
// init
int
num_block
=
(
max_threads
/
left_num
);
if
(
num_block
>
1
&&
reduce_num
>=
REDUCE_SPLIT_BOUNDARY
)
{
blocking_size
=
details
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
blocking_size
<=
1
)
{
blocking_size
=
details
::
GetLastPow2
(
sqrt
(
reduce_num
));
}
else
if
(
blocking_size
*
2
<
reduce_num
)
{
blocking_size
*=
2
;
}
should_reduce_again
=
true
;
block_dim
.
x
=
details
::
GetBlockDim
(
left_num
);
block_dim
.
y
=
1
;
grid_dim
.
x
=
(
left_num
+
block_dim
.
x
-
1
)
/
block_dim
.
x
;
grid_dim
.
y
=
(
reduce_num
+
blocking_size
-
1
)
/
blocking_size
;
}
else
{
block_dim
.
x
=
details
::
GetBlockDim
(
left_num
);
block_dim
.
y
=
1
;
blocking_size
=
reduce_num
;
grid_dim
.
x
=
(
left_num
+
block_dim
.
x
-
1
)
/
block_dim
.
x
;
grid_dim
.
y
=
1
;
}
SetBlockDimForHigher
(
&
block_dim
,
&
grid_dim
);
}
else
{
SetBlockDimForReduceAny
(
&
block_dim
,
&
grid_dim
);
}
#endif
block
=
block_dim
;
grid
=
grid_dim
;
...
...
@@ -487,72 +580,6 @@ struct ReduceConfig {
dim3
block
;
dim3
grid
;
};
/* size : how many colonms left have to be reduced
* loop : how many rows data have to be reduced
* block_size: max rows this block to reduce
*/
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
,
bool
IsBoundary
=
false
>
__device__
void
HigherDimDealSegment
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
block_size
)
{
const
int
NY
=
1
;
int
idx
=
blockIdx
.
x
*
blockDim
.
x
;
int
idy
=
blockIdx
.
y
*
block_size
;
// block_offset of rows
Tx
reduce_input
[
NY
];
MPType
reduce_compute
[
NY
];
MPType
result
=
init
;
// the offset of this block
int
block_offset
=
idy
*
left_num
+
idx
+
blockIdx
.
z
*
reduce_num
*
left_num
;
const
Tx
*
input
=
x
+
block_offset
;
int
store_offset
=
blockIdx
.
y
*
left_num
+
blockIdx
.
z
*
gridDim
.
y
*
left_num
+
idx
;
// how many columns left
int
size
=
left_num
-
idx
;
// how many rows have to be reduced
int
loop
=
reduce_num
-
idy
;
loop
=
loop
>
block_size
?
block_size
:
loop
;
for
(
int
loop_index
=
0
;
loop_index
<
loop
;
loop_index
+=
NY
)
{
kps
::
ReadData
<
Tx
,
Tx
,
1
,
NY
,
1
,
IsBoundary
>
(
&
reduce_input
[
0
],
input
+
loop_index
*
left_num
,
size
,
NY
,
1
,
left_num
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
reduce_compute
[
0
],
&
reduce_input
[
0
],
transformer
);
kps
::
Reduce
<
MPType
,
NY
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
result
,
&
reduce_compute
[
0
],
reducer
,
false
);
}
Ty
temp_data
=
static_cast
<
Ty
>
(
result
);
kps
::
WriteData
<
Ty
,
1
,
1
,
1
,
IsBoundary
>
(
y
+
store_offset
,
&
temp_data
,
size
);
}
template
<
typename
Tx
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
,
typename
Calculator
,
bool
IsBoundary
>
__device__
void
ReduceAnyKernelImpl
(
const
Tx
*
input
,
MPType
*
reduce_var
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
input_idx
,
bool
reduce_last_dim
,
const
Calculator
&
reduce_index_calculator
,
int
stride
,
int
num
)
{
Tx
input_reg
[
REDUCE_VEC_SIZE
];
MPType
input_compute
[
REDUCE_VEC_SIZE
];
MPType
input_transform
[
REDUCE_VEC_SIZE
];
kps
::
Init
<
MPType
,
REDUCE_VEC_SIZE
>
(
&
input_compute
[
0
],
init
);
kps
::
ReadDataReduce
<
Tx
,
1
,
REDUCE_VEC_SIZE
,
1
,
1
,
Calculator
,
IsBoundary
>
(
&
input_reg
[
0
],
input
,
input_idx
,
reduce_index_calculator
,
1
,
reduce_num
,
1
,
stride
,
reduce_last_dim
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
input_transform
[
0
],
&
input_reg
[
0
],
transformer
);
kps
::
Init
<
MPType
,
REDUCE_VEC_SIZE
,
IsBoundary
>
(
input_compute
,
input_transform
,
num
);
kps
::
Reduce
<
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
reduce_var
,
&
input_compute
[
0
],
reducer
,
reduce_last_dim
);
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
...
...
@@ -564,54 +591,76 @@ __global__ void ReduceAnyKernel(const Tx* x, Ty* y, ReduceOp reducer,
int
reduce_num
,
int
left_num
,
bool
reduce_last_dim
,
const
Calculator
reduce_index_calculator
,
const
Calculator
left_index_calculator
)
{
const
Calculator
left_index_calculator
,
const
kps
::
DimConfig
dim
)
{
int
input_idx
,
left_idx
,
stride
;
int
block_size
=
0
;
bool
need_store
=
true
;
int
loop_left
=
0
;
int
tid
=
0
;
// the last dim gets involved in reduction
int
store_offset
=
0
;
int
stride_left
=
0
;
if
(
reduce_last_dim
)
{
input_idx
=
blockIdx
.
y
*
blockDim
.
x
;
left_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
stride
=
gridDim
.
y
*
blockDim
.
x
;
block_size
=
blockDim
.
x
;
need_store
=
(
threadIdx
.
x
==
0
)
&&
(
left_idx
<
left_num
);
auto
block
=
ReduceIndexMapping
<
true
>
(
dim
);
input_idx
=
block
.
BlockIdY
()
*
block
.
BlockDimX
();
left_idx
=
block
.
BlockIdX
()
*
block
.
BlockDimY
()
+
THREAD_ID_Y
;
stride
=
block
.
GridDimY
()
*
block
.
BlockDimX
();
block_size
=
block
.
BlockDimX
();
need_store
=
(
THREAD_ID_X
==
0
)
&&
(
left_idx
<
left_num
);
store_offset
=
block
.
BlockIdY
()
*
left_num
+
left_idx
;
loop_left
=
min
(
block
.
GetLoopSize
(),
left_num
-
left_idx
);
stride_left
=
1
;
tid
=
threadIdx
.
x
;
}
else
{
input_idx
=
blockIdx
.
y
*
blockDim
.
y
;
left_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
stride
=
gridDim
.
y
*
blockDim
.
y
;
block_size
=
blockDim
.
y
;
need_store
=
(
threadIdx
.
y
==
0
)
&&
(
left_idx
<
left_num
);
auto
block
=
ReduceIndexMapping
<
false
>
(
dim
);
input_idx
=
block
.
BlockIdY
()
*
block
.
BlockDimY
();
left_idx
=
block
.
BlockIdX
()
*
block
.
BlockDimX
()
+
THREAD_ID_X
;
stride
=
block
.
GridDimY
()
*
block
.
BlockDimY
();
block_size
=
block
.
BlockDimY
();
need_store
=
(
THREAD_ID_Y
==
0
)
&&
(
left_idx
<
left_num
);
loop_left
=
min
(
block
.
GetLoopSize
(),
left_num
-
left_idx
);
stride_left
=
block
.
BlockDimX
()
*
block
.
GridDimX
();
store_offset
=
block
.
BlockIdY
()
*
left_num
+
left_idx
;
tid
=
threadIdx
.
y
;
}
int
store_offset
=
blockIdx
.
y
*
left_num
+
left_idx
;
// calculate the offset, means the addr where each thread really start.
int
input_offset
=
left_index_calculator
(
left_idx
);
const
Tx
*
input
=
x
+
input_offset
;
MPType
reduce_var
=
init
;
Ty
store_data
;
// 1. reduce for each thread
if
(
left_idx
<
left_num
)
{
MPType
input_compute
[
REDUCE_VEC_SIZE
];
Tx
input_reg
[
REDUCE_VEC_SIZE
];
for
(
int
i
=
0
;
i
<
loop_left
;
i
+=
stride_left
)
{
int
input_offset
=
left_index_calculator
(
left_idx
+
i
);
const
Tx
*
input
=
x
+
input_offset
;
MPType
reduce_var
=
init
;
// load REDUCE_VEC_SIZE data once, and then compute
int
bound
=
reduce_num
-
(
REDUCE_VEC_SIZE
-
1
)
*
stride
;
for
(;
input_idx
+
block_size
<
bound
;
input_idx
+=
REDUCE_VEC_SIZE
*
stride
)
{
ReduceAnyKernelImpl
<
Tx
,
MPType
,
ReduceOp
,
TransformOp
,
Calculator
,
false
>
(
input
,
&
reduce_var
,
reducer
,
transformer
,
init
,
reduce_num
,
input_idx
,
reduce_last_dim
,
reduce_index_calculator
,
stride
,
reduce_num
);
kps
::
ReadDataReduce
<
Tx
,
Tx
,
1
,
REDUCE_VEC_SIZE
,
1
,
1
,
Calculator
,
kps
::
IdentityFunctor
<
Tx
>
,
false
>
(
&
input_reg
[
0
],
input
,
input_idx
,
reduce_index_calculator
,
1
,
reduce_num
,
1
,
stride
,
kps
::
IdentityFunctor
<
Tx
>
(),
reduce_last_dim
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
input_compute
[
0
],
&
input_reg
[
0
],
transformer
);
kps
::
Reduce
<
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
reduce_var
,
&
input_compute
[
0
],
reducer
,
reduce_last_dim
);
}
int
num
=
(
reduce_num
-
input_idx
-
tid
+
stride
-
1
)
/
stride
;
ReduceAnyKernelImpl
<
Tx
,
MPType
,
ReduceOp
,
TransformOp
,
Calculator
,
true
>
(
input
,
&
reduce_var
,
reducer
,
transformer
,
init
,
reduce_num
-
input_idx
,
input_idx
,
reduce_last_dim
,
reduce_index_calculator
,
stride
,
num
);
}
kps
::
Reduce
<
MPType
,
1
,
1
,
1
,
ReduceOp
,
kps
::
details
::
kGlobalMode
>
(
&
reduce_var
,
&
reduce_var
,
reducer
,
reduce_last_dim
);
if
(
need_store
)
{
y
[
store_offset
]
=
static_cast
<
Ty
>
(
reduce_var
);
kps
::
Init
<
MPType
,
REDUCE_VEC_SIZE
>
(
&
input_compute
[
0
],
init
);
kps
::
ReadDataReduce
<
Tx
,
MPType
,
1
,
REDUCE_VEC_SIZE
,
1
,
1
,
Calculator
,
TransformOp
,
true
>
(
&
input_compute
[
0
],
input
,
input_idx
,
reduce_index_calculator
,
1
,
reduce_num
-
input_idx
,
1
,
stride
,
transformer
,
reduce_last_dim
);
kps
::
Reduce
<
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
reduce_var
,
&
input_compute
[
0
],
reducer
,
reduce_last_dim
);
kps
::
Reduce
<
MPType
,
1
,
1
,
1
,
ReduceOp
,
kps
::
details
::
kGlobalMode
>
(
&
reduce_var
,
&
reduce_var
,
reducer
,
reduce_last_dim
);
if
(
need_store
)
{
y
[
store_offset
+
i
]
=
static_cast
<
Ty
>
(
reduce_var
);
}
}
}
...
...
@@ -620,21 +669,55 @@ template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
__global__
void
ReduceHigherDimKernel
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
int
left_num
,
int
blocking_size
)
{
int
blocking_size
,
const
kps
::
DimConfig
dim
)
{
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
int
idx
=
blockIdx
.
x
*
blockDim
.
x
;
int
size
=
left_num
-
idx
;
if
(
size
>=
blockDim
.
x
)
{
// complete segment
HigherDimDealSegment
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
);
}
else
{
HigherDimDealSegment
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
true
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
);
auto
block
=
ReduceIndexMapping
<
false
>
(
dim
);
int
idy
=
block
.
BlockIdY
()
*
blocking_size
;
int
idx
=
block
.
BlockIdX
()
*
block
.
BlockDimX
();
int
idz
=
BLOCK_ID_Z
*
left_num
;
int
stride
=
dim
.
split_num_x
*
dim
.
deal_size_x
;
int
size
=
left_num
-
dim
.
rem_x
;
int
loop_size
=
min
(
reduce_num
-
idy
,
blocking_size
);
int
store_offset
=
block
.
BlockIdY
()
*
left_num
+
idz
*
block
.
GridDimY
();
int
block_offset
=
idy
*
left_num
+
idz
*
reduce_num
;
const
Tx
*
input
=
x
+
block_offset
;
Tx
reduce_input
;
for
(;
idx
<
size
;
idx
+=
stride
)
{
MPType
reduce_var
=
init
;
MPType
reduce_compute
=
init
;
for
(
int
loop_idx
=
0
;
loop_idx
<
loop_size
;
++
loop_idx
)
{
kps
::
ReadData
<
Tx
,
Tx
,
1
,
1
,
1
,
false
>
(
&
reduce_input
,
input
+
loop_idx
*
left_num
+
idx
,
block
.
BlockDimX
(),
1
,
1
,
left_num
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
reduce_compute
,
&
reduce_input
,
transformer
);
kps
::
Reduce
<
MPType
,
1
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
reduce_var
,
&
reduce_compute
,
reducer
,
false
);
}
Ty
result
=
static_cast
<
Ty
>
(
reduce_var
);
kps
::
WriteData
<
Ty
,
1
,
1
,
1
,
false
>
(
y
+
store_offset
+
idx
,
&
result
,
block
.
BlockDimX
());
}
if
(
idx
<
left_num
)
{
MPType
reduce_var
=
init
;
MPType
reduce_compute
=
init
;
for
(
int
loop_idx
=
0
;
loop_idx
<
loop_size
;
++
loop_idx
)
{
kps
::
ReadData
<
Tx
,
Tx
,
1
,
1
,
1
,
true
>
(
&
reduce_input
,
input
+
loop_idx
*
left_num
+
idx
,
dim
.
rem_x
,
1
,
1
,
left_num
);
kps
::
ElementwiseUnary
<
Tx
,
MPType
,
REDUCE_VEC_SIZE
,
1
,
1
,
TransformOp
>
(
&
reduce_compute
,
&
reduce_input
,
transformer
);
kps
::
Reduce
<
MPType
,
1
,
1
,
1
,
ReduceOp
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
reduce_var
,
&
reduce_compute
,
reducer
,
false
);
}
Ty
result
=
static_cast
<
Ty
>
(
reduce_var
);
kps
::
WriteData
<
Ty
,
1
,
1
,
1
,
true
>
(
y
+
store_offset
+
idx
,
&
result
,
dim
.
rem_x
);
}
}
...
...
@@ -648,14 +731,27 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
int
stride_reduce
=
1
;
int
stride_left
=
config
.
reduce_num
;
// for higher performance
auto
reduce_index_calculator
=
Last
DimIndexCal
(
stride_reduce
);
auto
left_index_calculator
=
Last
DimIndexCal
(
stride_left
);
auto
reduce_index_calculator
=
One
DimIndexCal
(
stride_reduce
);
auto
left_index_calculator
=
One
DimIndexCal
(
stride_left
);
kps
::
DimConfig
dim
=
kps
::
DimConfig
(
config
.
grid
.
x
,
config
.
grid
.
y
,
config
.
grid
.
z
,
config
.
block
.
x
,
config
.
block
.
y
,
0
);
dim
.
SetRem
(
config
.
reduce_num
%
config
.
block
.
x
,
0
,
0
);
#ifdef PADDLE_WITH_XPU2
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
OneDimIndexCal
><<<
8
,
128
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
#else
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
Last
DimIndexCal
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
One
DimIndexCal
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
);
reduce_index_calculator
,
left_index_calculator
,
dim
);
#endif
}
else
{
int
reduce_rank
=
config
.
reduce_strides
.
size
();
...
...
@@ -665,11 +761,25 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
config
.
x_strides
);
auto
left_index_calculator
=
IndexCalculator
(
left_rank
,
config
.
left_dim
,
config
.
left_strides
,
config
.
x_strides
);
kps
::
DimConfig
dim
=
kps
::
DimConfig
(
config
.
grid
.
x
,
config
.
grid
.
y
,
config
.
grid
.
z
,
config
.
block
.
x
,
config
.
block
.
y
,
0
);
dim
.
SetRem
(
config
.
reduce_num
%
config
.
block
.
x
,
0
,
0
);
#ifdef PADDLE_WITH_XPU2
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
IndexCalculator
><<<
8
,
128
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
#else
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
IndexCalculator
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
);
reduce_index_calculator
,
left_index_calculator
,
dim
);
#endif
}
if
(
config
.
should_reduce_again
)
{
...
...
@@ -683,12 +793,25 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
grid
=
dim3
(
config
.
grid
.
x
,
1
,
config
.
grid
.
z
);
}
auto
last_index
=
OneDimIndexCal
(
1
);
auto
first_index
=
OneDimIndexCal
(
config
.
left_num
);
kps
::
DimConfig
dim
=
kps
::
DimConfig
(
grid
.
x
,
grid
.
y
,
grid
.
z
,
block
.
x
,
config
.
grid
.
y
,
0
);
dim
.
SetRem
(
config
.
left_num
%
block
.
x
,
0
,
0
);
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
8
,
128
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim
);
#else
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
,
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
);
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim
);
#endif
}
}
...
...
@@ -699,7 +822,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
gpuStream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
config
=
ReduceConfig
<
Ty
>
(
origin_reduce_dims
,
x_dim
);
config
.
Run
();
// get the parameters of LaunchReduceKernel
config
.
Run
();
int
numel
=
x
.
numel
();
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
...
...
@@ -759,23 +882,49 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if
(
config
.
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
using
TransformOp
=
typename
ReduceOp
<
Tx
,
MPType
>::
Transformer
;
kps
::
DimConfig
dim
=
kps
::
DimConfig
(
config
.
grid
.
x
,
config
.
grid
.
y
,
config
.
grid
.
z
,
config
.
block
.
x
,
config
.
blocking_size
,
0
);
dim
.
SetRem
(
config
.
left_num
%
config
.
block
.
x
,
config
.
reduce_num
%
config
.
blocking_size
,
0
);
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
TransformOp
><<<
8
,
128
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
dim
);
#else
ReduceHigherDimKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
TransformOp
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
);
config
.
blocking_size
,
dim
);
#endif
if
(
config
.
should_reduce_again
)
{
dim3
block
=
dim3
(
config
.
block
.
x
,
1
,
1
);
dim3
grid
=
dim3
(
config
.
grid
.
x
,
1
,
config
.
grid
.
z
);
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
kps
::
DimConfig
dim2
=
kps
::
DimConfig
(
grid
.
x
,
grid
.
y
,
grid
.
z
,
block
.
x
,
config
.
grid
.
y
,
0
);
dim2
.
SetRem
(
config
.
left_num
%
config
.
block
.
x
,
0
,
0
);
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
8
,
128
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
reducer
.
initial
(),
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim2
);
#else
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
details
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
reducer
.
initial
(),
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
);
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
reducer
.
initial
(),
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim2
);
#endif
}
return
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录