Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
88b43b51
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
88b43b51
编写于
5月 25, 2021
作者:
N
niuliling123
提交者:
GitHub
5月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add a new high performance framework for reduce ops (#32697)
上级
4920c474
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
704 addition
and
0 deletion
+704
-0
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+58
-0
paddle/fluid/operators/reduce_ops/reduce_op.cuh
paddle/fluid/operators/reduce_ops/reduce_op.cuh
+646
-0
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
0 → 100644
浏览文件 @
88b43b51
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
CustomMin
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
b
<
a
)
?
b
:
a
;
}
};
template
<
typename
T
>
struct
CustomMax
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
b
>
a
)
?
b
:
a
;
}
};
template
<
typename
T
>
struct
CustomSum
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
T
>
struct
CustomMul
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
*
a
;
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/reduce_op.cuh
0 → 100644
浏览文件 @
88b43b51
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
// Post processing function for sum, max, min, prod, any
template
<
typename
T
>
struct
IdentityFunctor
{
DEVICE
explicit
inline
IdentityFunctor
()
{}
DEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
;
}
};
// Post processing function for mean
template
<
typename
T
>
struct
DivideFunctor
{
DEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
DEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
static
inline
int
GetLastPow2
(
int
n
)
{
n
|=
(
n
>>
1
);
n
|=
(
n
>>
2
);
n
|=
(
n
>>
4
);
n
|=
(
n
>>
8
);
n
|=
(
n
>>
16
);
return
std
::
max
(
1
,
n
-
(
n
>>
1
));
}
static
inline
std
::
vector
<
int
>
GetStrides
(
const
std
::
vector
<
int
>&
dims
,
const
std
::
vector
<
int
>&
idx
)
{
int
n
=
static_cast
<
int
>
(
idx
.
size
());
if
(
n
==
0
)
return
std
::
vector
<
int
>
();
std
::
vector
<
int
>
strides
(
n
);
strides
.
back
()
=
1
;
for
(
int
i
=
n
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
dims
[
idx
[
i
+
1
]];
}
return
strides
;
}
#ifdef __HIPCC__
constexpr
int
kMaxBlockDim
=
256
;
#else
constexpr
int
kMaxBlockDim
=
512
;
#endif
static
inline
int
GetDesiredBlockDim
(
int
block_dim
)
{
return
block_dim
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
block_dim
)));
}
static
inline
void
CheckReduceRankIsValid
(
int
reduce_rank
,
int
rank
)
{
if
(
rank
%
2
==
0
)
{
PADDLE_ENFORCE_EQ
(
reduce_rank
,
rank
/
2
,
platform
::
errors
::
InvalidArgument
(
"ReduceOp: invalid reduce rank. When rank = %d, "
"reduce_rank must be %d, but got %d."
,
rank
,
rank
/
2
,
reduce_rank
));
}
else
{
auto
lower_rank
=
(
rank
-
1
)
/
2
;
auto
upper_rank
=
(
rank
+
1
)
/
2
;
PADDLE_ENFORCE_EQ
(
reduce_rank
==
lower_rank
||
reduce_rank
==
upper_rank
,
true
,
platform
::
errors
::
InvalidArgument
(
"ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
"must be %d or %d, but got %d."
,
rank
,
lower_rank
,
upper_rank
,
reduce_rank
));
}
}
template
<
typename
T
,
size_t
ElementCount
,
typename
VectorLikeType
>
static
inline
paddle
::
framework
::
Array
<
T
,
ElementCount
>
from
(
const
VectorLikeType
&
vec
)
{
PADDLE_ENFORCE_EQ
(
vec
.
size
(),
ElementCount
,
platform
::
errors
::
InvalidArgument
(
"Cub reduce Array: size not match. Received "
"vec.size() %d != ElementCount %d."
,
vec
.
size
(),
ElementCount
));
size_t
n
=
static_cast
<
size_t
>
(
vec
.
size
());
paddle
::
framework
::
Array
<
T
,
ElementCount
>
ret
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
ret
[
i
]
=
vec
[
i
];
return
ret
;
}
}
// namespace detail
enum
ReduceType
{
kReduceAll
=
0x00
,
kReduceLastDim
=
0x01
,
kReduceHigherDim
=
0x02
,
// ReduceFirstDim or reduceSecondDim
kReduceAny
=
0x03
,
};
// reduce config
template
<
typename
Ty
>
struct
ReduceConfig
{
ReduceConfig
(
std
::
vector
<
int
>
origin_reduce_dims
,
std
::
vector
<
int
>
x_dim
)
:
reduce_dims_origin
(
origin_reduce_dims
),
x_dim
(
x_dim
)
{}
// get the parameters of reduceKernel
void
Run
()
{
// step1: update the reduce_dim left_dim and x_dim
SetReduceDim
();
// step2: get the strides of dim for reduceAny and reduceLastDim
SetStrides
();
// step3: get the type of reduce
SetReduceType
();
// step4: set the block and grid for launch kernel
SetBlockDim
();
}
// when should_reduce_again is true, we need malloc temp space for temp data
void
SetOutputData
(
Ty
*
y_data
,
const
platform
::
Place
&
place
,
framework
::
Tensor
&
tmp
)
{
if
(
should_reduce_again
)
{
output_data
=
tmp
.
mutable_data
<
Ty
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
left_num
*
grid
.
y
*
sizeof
(
Ty
))}),
place
);
}
else
{
output_data
=
y_data
;
}
}
private:
// set reduce_dim, left_dim and update x_dim
// eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
// --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
void
SetReduceDim
()
{
std
::
set
<
int
>
reduce_set
;
for
(
auto
e
:
reduce_dims_origin
)
{
auto
pos
=
e
>=
0
?
e
:
e
+
x_dim
.
size
();
reduce_set
.
insert
(
pos
);
}
std
::
vector
<
int
>
reduce_dim_temp
(
reduce_set
.
begin
(),
reduce_set
.
end
());
std
::
sort
(
reduce_dim_temp
.
begin
(),
reduce_dim_temp
.
end
());
// get reduce_dim
if
(
reduce_dim_temp
.
size
()
>
1
)
{
int
num
=
0
;
// for update axis
reduce_dim
.
push_back
(
reduce_dim_temp
[
0
]);
for
(
int
idx
=
1
;
idx
<
reduce_dim_temp
.
size
();
idx
++
)
{
// update x_dim
if
(
reduce_dim_temp
[
idx
]
-
reduce_dim_temp
[
idx
-
1
]
==
1
)
{
x_dim
[
reduce_dim_temp
[
idx
-
1
]]
*=
x_dim
[
reduce_dim_temp
[
idx
]];
x_dim
.
erase
(
x_dim
.
begin
()
+
reduce_dim_temp
[
idx
]);
num
++
;
}
else
{
reduce_dim
.
push_back
(
reduce_dim_temp
[
idx
]
-
num
);
}
}
}
else
{
reduce_dim
=
reduce_dim_temp
;
}
// update new_x_dim and new_reduce_dim
std
::
vector
<
int
>
new_x_dim
,
new_reduce_dim_temp
;
int
is_reduced
=
0
;
for
(
auto
e
:
reduce_dim
)
{
is_reduced
|=
1
<<
e
;
}
for
(
int
i
=
0
;
i
<
x_dim
.
size
();
i
++
)
{
if
((
i
==
0
)
||
(((
is_reduced
>>
i
)
^
(
is_reduced
>>
(
i
-
1
)))
&
1
))
{
new_x_dim
.
push_back
(
x_dim
[
i
]);
if
((
is_reduced
>>
i
)
&
1
)
new_reduce_dim_temp
.
push_back
(
new_x_dim
.
size
()
-
1
);
}
else
{
new_x_dim
[
new_x_dim
.
size
()
-
1
]
*=
x_dim
[
i
];
}
}
x_dim
=
new_x_dim
;
reduce_dim
=
new_reduce_dim_temp
;
int
x_rank
=
static_cast
<
int
>
(
x_dim
.
size
());
std
::
set
<
int
>
left_set
;
for
(
int
i
=
0
;
i
<
x_rank
;
++
i
)
{
left_set
.
insert
(
i
);
}
for
(
auto
e
:
reduce_dim
)
{
left_set
.
erase
(
e
);
}
left_dim
.
assign
(
left_set
.
begin
(),
left_set
.
end
());
}
// set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
// eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
// --SetStrides--> x_strides= [6,1], reduce_strides = [1],
// left_strides = [1]
void
SetStrides
()
{
std
::
vector
<
int
>
idx_dim
;
for
(
int
i
=
0
;
i
<
x_dim
.
size
();
i
++
)
{
idx_dim
.
push_back
(
i
);
}
x_strides
=
detail
::
GetStrides
(
x_dim
,
idx_dim
);
reduce_strides
=
detail
::
GetStrides
(
x_dim
,
reduce_dim
);
left_strides
=
detail
::
GetStrides
(
x_dim
,
left_dim
);
reduce_num
=
reduce_strides
[
0
]
*
x_dim
[
reduce_dim
[
0
]];
left_num
=
1
;
if
(
left_dim
.
size
())
{
left_num
=
left_strides
[
0
]
*
x_dim
[
left_dim
[
0
]];
}
}
// get the reduceType
// eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
// x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
// x_dim = [8] reduce_dim = [0] --> reduceAll
// x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
void
SetReduceType
()
{
int
rank
=
x_dim
.
size
();
int
reduce_rank
=
reduce_dim
.
size
();
if
(
rank
==
reduce_rank
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceAll
);
}
else
if
(
rank
==
2
&&
reduce_rank
==
1
&&
reduce_dim
[
0
]
==
1
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceLastDim
);
}
else
if
(
reduce_rank
==
1
)
{
// ReduceFirstDim and reduceSecondDim
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceHigherDim
);
}
else
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceAny
);
}
}
// set block and grid for launch kernel
// for ReduceHigherDim: if block is enough -> splite reduce_num
// else init block(32, 1) grid(block_num, 1)
// for others: block(block_num, 1) , grid(left_num, 1)
void
SetBlockDim
()
{
// init
int
block_num
=
detail
::
GetDesiredBlockDim
(
reduce_num
);
should_reduce_again
=
false
;
dim3
block_dim
(
block_num
,
1
);
dim3
grid_dim
(
left_num
,
1
);
blocking_size
=
reduce_num
;
if
(
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
int
last_dim_num
=
x_dim
.
back
();
// update left_num
int
grid_z
=
left_num
/
last_dim_num
;
left_num
=
last_dim_num
;
block_dim
.
z
=
1
;
grid_dim
.
z
=
grid_z
;
int
device_id
=
platform
::
GetCurrentDeviceId
();
int
max_mp
=
platform
::
GetCUDAMultiProcessors
(
device_id
);
int
max_threads_per_mp
=
platform
::
GetCUDAMaxThreadsPerMultiProcessor
(
device_id
);
int
max_threads
=
max_threads_per_mp
*
max_mp
;
// init
int
num_block
=
(
max_threads
/
left_num
);
if
(
num_block
>
1
&&
reduce_num
>=
512
)
{
blocking_size
=
detail
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
blocking_size
<=
1
)
{
blocking_size
=
detail
::
GetLastPow2
(
sqrt
(
reduce_num
));
}
else
if
(
blocking_size
*
2
<
reduce_num
)
{
blocking_size
*=
2
;
}
should_reduce_again
=
true
;
block_dim
.
x
=
32
;
block_dim
.
y
=
1
;
grid_dim
.
x
=
(
left_num
+
block_dim
.
x
-
1
)
/
block_dim
.
x
;
grid_dim
.
y
=
(
reduce_num
+
blocking_size
-
1
)
/
blocking_size
;
}
else
{
block_dim
.
x
=
32
;
block_dim
.
y
=
1
;
blocking_size
=
reduce_num
;
grid_dim
.
x
=
(
left_num
+
block_dim
.
x
-
1
)
/
block_dim
.
x
;
grid_dim
.
y
=
1
;
}
}
block
=
block_dim
;
grid
=
grid_dim
;
}
public:
std
::
vector
<
int
>
reduce_dims_origin
;
std
::
vector
<
int
>
reduce_dim
;
std
::
vector
<
int
>
x_dim
;
std
::
vector
<
int
>
left_dim
;
std
::
vector
<
int
>
x_strides
;
std
::
vector
<
int
>
left_strides
;
std
::
vector
<
int
>
reduce_strides
;
int
reduce_type
;
int
reduce_num
;
int
left_num
;
int
blocking_size
;
bool
should_reduce_again
;
Ty
*
output_data
;
dim3
block
;
dim3
grid
;
};
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
__device__
__forceinline__
void
ReduceLastDim
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
)
{
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
int
idx_x
=
blockIdx
.
x
*
reduce_num
;
int
idx_y
=
threadIdx
.
x
;
Ty
reduce_var
=
init
;
for
(
int
idx_y
=
threadIdx
.
x
;
idx_y
<
reduce_num
;
idx_y
+=
BlockDim
)
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
x
[
idx_x
+
idx_y
]));
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
transformer
(
reduce_var
);
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
>
__device__
__forceinline__
void
ReduceHigherDim
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
,
int
left_num
,
int
block_size
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
idy
=
blockIdx
.
y
*
block_size
;
Ty
temp
=
init
;
Ty
reduce_var
=
init
;
if
(
idx
<
left_num
)
{
int
loop
=
reduce_num
-
idy
;
loop
=
loop
>
block_size
?
block_size
:
loop
;
for
(
int
iy
=
0
;
iy
<
loop
;
iy
++
)
{
int
id
=
(
idy
+
iy
)
*
left_num
+
idx
+
blockIdx
.
z
*
reduce_num
*
left_num
;
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
x
[
id
]));
}
y
[
idx
+
blockIdx
.
y
*
left_num
+
blockIdx
.
z
*
gridDim
.
y
*
left_num
]
=
static_cast
<
Ty
>
(
transformer
(
reduce_var
));
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
>
__device__
__forceinline__
void
ReduceAny
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
,
paddle
::
framework
::
Array
<
int
,
Rank
>
x_strides
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_dim
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
int
sub_index
[
Rank
];
int
left_idx
=
blockIdx
.
x
;
for
(
int
i
=
0
;
i
<
Rank
-
ReduceRank
;
++
i
)
{
sub_index
[
left_dim
[
i
]]
=
left_idx
/
left_strides
[
i
];
left_idx
%=
left_strides
[
i
];
}
int
reduce_idx
=
threadIdx
.
x
;
for
(
int
j
=
0
;
j
<
ReduceRank
;
++
j
)
{
sub_index
[
reduce_dim
[
j
]]
=
reduce_idx
/
reduce_strides
[
j
];
reduce_idx
%=
reduce_strides
[
j
];
}
int
idx_x
=
0
;
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
Ty
reduce_var
=
static_cast
<
Ty
>
(
x
[
idx_x
]);
for
(
int
i
=
threadIdx
.
x
+
BlockDim
;
i
<
reduce_num
;
i
+=
BlockDim
)
{
int
reduce_idx
=
i
;
for
(
int
j
=
0
;
j
<
ReduceRank
;
++
j
)
{
sub_index
[
reduce_dim
[
j
]]
=
reduce_idx
/
reduce_strides
[
j
];
reduce_idx
%=
reduce_strides
[
j
];
}
int
idx_x
=
0
;
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
reduce_var
=
static_cast
<
Ty
>
(
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
x
[
idx_x
])));
}
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
transformer
(
reduce_var
);
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
,
int
ReduceType
>
__device__
__forceinline__
void
ReduceModule
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
,
int
left_num
,
int
blocking_size
,
paddle
::
framework
::
Array
<
int
,
Rank
>
x_strides
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_dim
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
if
(
ReduceType
==
ReduceType
::
kReduceLastDim
)
{
ReduceLastDim
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
);
}
else
if
(
ReduceType
==
ReduceType
::
kReduceHigherDim
)
{
ReduceHigherDim
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
);
}
else
{
ReduceAny
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
,
Rank
,
ReduceRank
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
x_strides
,
reduce_dim
,
reduce_strides
,
left_dim
,
left_strides
);
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
,
int
ReduceType
>
__global__
void
ReduceKernelFunction
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
,
int
left_num
,
int
block_size
,
paddle
::
framework
::
Array
<
int
,
Rank
>
x_strides
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_dim
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
ReduceModule
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
,
Rank
,
ReduceRank
,
ReduceType
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
block_size
,
x_strides
,
reduce_dim
,
reduce_strides
,
left_dim
,
left_strides
);
}
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
,
typename
TransformOp
,
int
kRank
,
int
kReduceRank
>
static
void
launchKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
Ty
&
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
#define CUB_REDUCE_TYPE_CASE(type) \
case type: { \
constexpr auto kReduceType = type; \
ReduceKernelFunction< \
Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \
kReduceType><<<config.grid, config.block, 0, stream>>>( \
x_data, config.output_data, reducer, transformer, init, \
config.reduce_num, config.left_num, config.blocking_size, \
detail::from<int, kRank>(config.x_strides), \
detail::from<int, kReduceRank>(config.reduce_dim), \
detail::from<int, kReduceRank>(config.reduce_strides), \
detail::from<int, kRank - kReduceRank>(config.left_dim), \
detail::from<int, kRank - kReduceRank>(config.left_strides)); \
} break
switch
(
config
.
reduce_type
)
{
CUB_REDUCE_TYPE_CASE
(
1
);
// reduceLastDim
CUB_REDUCE_TYPE_CASE
(
2
);
// ReduceHigherDim
CUB_REDUCE_TYPE_CASE
(
3
);
// reduceAny
}
if
(
config
.
should_reduce_again
)
{
dim3
block
(
config
.
block
.
x
,
1
,
1
);
dim3
grid
(
config
.
grid
.
x
,
1
,
config
.
grid
.
z
);
ReduceKernelFunction
<
Ty
,
Ty
,
ReduceOp
,
detail
::
IdentityFunctor
<
Ty
>
,
128
,
kRank
,
kReduceRank
,
ReduceType
::
kReduceHigherDim
><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
detail
::
IdentityFunctor
<
Ty
>
(),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
detail
::
from
<
int
,
kRank
>
(
config
.
x_strides
),
detail
::
from
<
int
,
kReduceRank
>
(
config
.
reduce_dim
),
detail
::
from
<
int
,
kReduceRank
>
(
config
.
reduce_strides
),
detail
::
from
<
int
,
kRank
-
kReduceRank
>
(
config
.
left_dim
),
detail
::
from
<
int
,
kRank
-
kReduceRank
>
(
config
.
left_strides
));
}
}
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
,
typename
TransformOp
>
static
void
launchReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
Ty
&
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
int
reduce_rank
=
config
.
reduce_strides
.
size
();
int
rank
=
config
.
x_strides
.
size
();
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto kRank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
launchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
x_data, y_data, place, reducer, transformer, init, stream, config); \
} break
// launch CUB::Reduce
if
(
config
.
reduce_type
==
static_cast
<
int
>
(
ReduceType
::
kReduceAll
))
{
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transformer
);
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
init
,
stream
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
place
);
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
init
,
stream
);
return
;
}
detail
::
CheckReduceRankIsValid
(
reduce_rank
,
rank
);
switch
(
rank
)
{
CUB_RANK_CASE
(
2
,
CUB_REDUCE_RANK_CASE
(
1
););
CUB_RANK_CASE
(
3
,
CUB_REDUCE_RANK_CASE
(
1
);
CUB_REDUCE_RANK_CASE
(
2
););
CUB_RANK_CASE
(
4
,
CUB_REDUCE_RANK_CASE
(
2
););
CUB_RANK_CASE
(
5
,
CUB_REDUCE_RANK_CASE
(
2
);
CUB_REDUCE_RANK_CASE
(
3
););
CUB_RANK_CASE
(
6
,
CUB_REDUCE_RANK_CASE
(
3
););
CUB_RANK_CASE
(
7
,
CUB_REDUCE_RANK_CASE
(
3
);
CUB_REDUCE_RANK_CASE
(
4
););
CUB_RANK_CASE
(
8
,
CUB_REDUCE_RANK_CASE
(
4
););
CUB_RANK_CASE
(
9
,
CUB_REDUCE_RANK_CASE
(
4
);
CUB_REDUCE_RANK_CASE
(
5
););
}
#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
>
void
TensorReduceFunc
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
Ty
&
init
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
gpuStream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
config
=
ReduceConfig
<
Ty
>
(
origin_reduce_dims
,
x_dim
);
config
.
Run
();
auto
x_data
=
x
.
data
<
Tx
>
();
auto
y_data
=
y
->
mutable_data
<
Ty
>
(
x
.
place
());
framework
::
Tensor
tmp
;
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in
// y_data;
config
.
SetOutputData
(
y_data
,
x
.
place
(),
tmp
);
if
(
config
.
reduce_num
==
1
)
{
auto
out_dims
=
y
->
dims
();
framework
::
TensorCopy
(
x
,
y
->
place
(),
y
);
y
->
Resize
(
out_dims
);
return
;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
launchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
x_data, y_data, x.place(), reducer, transformer, init, stream, \
config); \
} break
switch
(
detail
::
GetDesiredBlockDim
(
config
.
reduce_num
))
{
CUB_BLOCK_DIM_CASE
(
512
);
CUB_BLOCK_DIM_CASE
(
256
);
CUB_BLOCK_DIM_CASE
(
128
);
CUB_BLOCK_DIM_CASE
(
64
);
CUB_BLOCK_DIM_CASE
(
32
);
CUB_BLOCK_DIM_CASE
(
16
);
CUB_BLOCK_DIM_CASE
(
8
);
CUB_BLOCK_DIM_CASE
(
4
);
CUB_BLOCK_DIM_CASE
(
2
);
}
#undef CUB_BLOCK_DIM_CASE
}
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录