Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
480b284c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
480b284c
编写于
6月 22, 2021
作者:
N
niuliling123
提交者:
GitHub
6月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modified reduce_max, reduce_min, reduce_prod to higher_performance implementation. (#32974)
上级
20eafd79
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
349 addition
and
177 deletion
+349
-177
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+68
-16
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
+9
-11
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
+9
-11
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+251
-123
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
+12
-16
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
480b284c
...
...
@@ -13,46 +13,98 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include <cmath>
#include <limits>
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/macros.h"
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
CustomMin
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
max
());
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
b
<
a
)
?
b
:
a
;
}
};
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
CustomMax
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
lowest
());
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
b
>
a
)
?
b
:
a
;
}
};
template
<
typename
T
>
// for cub::Reduce
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomSum
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
T
>
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
detail
::
DivideFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMul
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
1.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
*
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalOr
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
false
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
||
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalAnd
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
true
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
&&
a
;
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
浏览文件 @
480b284c
...
...
@@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_max
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
MaxFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
MaxFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
MaxFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
MaxFunctor
>
);
// reduce_max
REGISTER_OP_CUDA_KERNEL
(
reduce_max
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMax
>
);
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
浏览文件 @
480b284c
...
...
@@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_min
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
MinFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
MinFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
MinFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
MinFunctor
>
);
// reduce_min
REGISTER_OP_CUDA_KERNEL
(
reduce_min
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMin
>
);
paddle/fluid/operators/reduce_ops/reduce_op.cuh
→
paddle/fluid/operators/reduce_ops/reduce_op.cu
.
h
浏览文件 @
480b284c
...
...
@@ -30,32 +30,59 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
namespace
paddle
{
namespace
operators
{
namespace
detail
{
// Post processing function for sum, max, min, prod, any
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
IdentityFunctor
{
DEVICE
explicit
inline
IdentityFunctor
(
)
{}
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
DEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
;
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
);
}
};
// Post processing function for mean
template
<
typename
T
>
struct
DivideFunctor
{
DEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
HOST
DEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
DEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
HOST
DEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
static
inline
std
::
vector
<
int
>
GetReduceDim
(
const
std
::
vector
<
int
>&
dims
,
int
dim_size
,
bool
reduce_all
)
{
std
::
vector
<
int
>
reduce_dims
;
if
(
reduce_all
)
{
reduce_dims
.
resize
(
dim_size
);
for
(
int
i
=
0
;
i
<
reduce_dims
.
size
();
++
i
)
{
reduce_dims
[
i
]
=
i
;
}
}
else
{
for
(
auto
e
:
dims
)
{
PADDLE_ENFORCE_LT
(
e
,
dim_size
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"ReduceOp: invalid axis, when x_dims is %d, "
"axis[i] should less than x_dims, but got %d."
,
dim_size
,
e
));
reduce_dims
.
push_back
(
e
>=
0
?
e
:
e
+
dim_size
);
}
}
return
reduce_dims
;
}
static
inline
int
GetLastPow2
(
int
n
)
{
n
|=
(
n
>>
1
);
n
|=
(
n
>>
2
);
...
...
@@ -65,8 +92,9 @@ static inline int GetLastPow2(int n) {
return
std
::
max
(
1
,
n
-
(
n
>>
1
));
}
static
inline
std
::
vector
<
int
>
GetStrides
(
const
std
::
vector
<
int
>&
dims
,
const
std
::
vector
<
int
>&
idx
)
{
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static
inline
std
::
vector
<
int
>
GetDimStrides
(
const
std
::
vector
<
int
>&
dims
,
const
std
::
vector
<
int
>&
idx
)
{
int
n
=
static_cast
<
int
>
(
idx
.
size
());
if
(
n
==
0
)
return
std
::
vector
<
int
>
();
std
::
vector
<
int
>
strides
(
n
);
...
...
@@ -78,18 +106,18 @@ static inline std::vector<int> GetStrides(const std::vector<int>& dims,
}
#ifdef __HIPCC__
constexpr
int
kMax
BlockDim
=
256
;
constexpr
int
kMax
Thread
=
256
;
#else
constexpr
int
kMax
BlockDim
=
512
;
constexpr
int
kMax
Thread
=
128
;
#endif
static
inline
int
GetDesiredBlockDim
(
int
block_dim
)
{
return
block_dim
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
block_dim
)));
// get blockDim for reduceLastDim and reduceAny
static
inline
int
GetBlockDim
(
int
block_dim
)
{
return
block_dim
>=
kMaxThread
?
kMaxThread
:
GetLastPow2
(
block_dim
);
}
static
inline
void
CheckReduceRankIsValid
(
int
reduce_rank
,
int
rank
)
{
// check reduce rand is valid
static
inline
void
CheckReduceRank
(
int
reduce_rank
,
int
rank
)
{
if
(
rank
%
2
==
0
)
{
PADDLE_ENFORCE_EQ
(
reduce_rank
,
rank
/
2
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -108,8 +136,9 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
}
}
// convert dims from vector to array
template
<
typename
T
,
size_t
ElementCount
,
typename
VectorLikeType
>
static
inline
paddle
::
framework
::
Array
<
T
,
ElementCount
>
from
(
static
inline
paddle
::
framework
::
Array
<
T
,
ElementCount
>
VectorToArray
(
const
VectorLikeType
&
vec
)
{
PADDLE_ENFORCE_EQ
(
vec
.
size
(),
ElementCount
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -118,17 +147,21 @@ static inline paddle::framework::Array<T, ElementCount> from(
vec
.
size
(),
ElementCount
));
size_t
n
=
static_cast
<
size_t
>
(
vec
.
size
());
paddle
::
framework
::
Array
<
T
,
ElementCount
>
ret
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
ret
[
i
]
=
vec
[
i
];
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
ret
[
i
]
=
vec
[
i
];
}
return
ret
;
}
}
// namespace detail
using
Tensor
=
framework
::
Tensor
;
enum
ReduceType
{
kReduceAll
=
0x00
,
kReduceLastDim
=
0x01
,
kReduceAll
=
0x00
,
// when reduce_rank == x_rank
kReduceLastDim
=
0x01
,
// when reduce_dim[0] == x_dim.size() - 1;
kReduceHigherDim
=
0x02
,
// ReduceFirstDim or reduceSecondDim
kReduceAny
=
0x03
,
kReduceAny
=
0x03
,
// when reduce_dim.size() > 1
};
// reduce config
...
...
@@ -141,21 +174,24 @@ struct ReduceConfig {
void
Run
()
{
// step1: update the reduce_dim left_dim and x_dim
SetReduceDim
();
// step2: get the strides of dim for reduceAny and reduceLastDim
SetStrides
();
// step3: get the type of reduce
SetReduceType
();
// step4: set the block and grid for launch kernel
SetBlockDim
();
}
// when should_reduce_again is true, we need malloc temp space for temp data
void
SetOutputData
(
Ty
*
y_data
,
const
platform
::
Place
&
place
,
framework
::
Tensor
&
tmp
)
{
framework
::
Tensor
*
tmp
)
{
if
(
should_reduce_again
)
{
output_data
=
tmp
.
mutable_data
<
Ty
>
(
output_data
=
tmp
->
mutable_data
<
Ty
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
left_num
*
grid
.
y
*
sizeof
(
Ty
))}),
{
static_cast
<
int64_t
>
(
left_num
*
grid
.
z
*
grid
.
y
*
sizeof
(
Ty
))}),
place
);
}
else
{
output_data
=
y_data
;
...
...
@@ -168,50 +204,70 @@ struct ReduceConfig {
// --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
void
SetReduceDim
()
{
std
::
set
<
int
>
reduce_set
;
for
(
auto
e
:
reduce_dims_origin
)
{
auto
pos
=
e
>=
0
?
e
:
e
+
x_dim
.
size
();
reduce_set
.
insert
(
pos
);
}
std
::
vector
<
int
>
reduce_dim_temp
(
reduce_set
.
begin
(),
reduce_set
.
end
());
std
::
sort
(
reduce_dim_temp
.
begin
(),
reduce_dim_temp
.
end
());
// get reduce_dim
// update reduce_dim and x_dim
std
::
vector
<
int
>
x_new_dim
;
reduce_dim
.
push_back
(
reduce_dim_temp
[
0
]);
x_new_dim
.
push_back
(
x_dim
[
0
]);
int
idx_reduce
=
1
;
int
num
=
0
;
if
(
reduce_dim_temp
.
size
()
>
1
)
{
int
num
=
0
;
// for update axis
reduce_dim
.
push_back
(
reduce_dim_temp
[
0
]);
for
(
int
idx
=
1
;
idx
<
reduce_dim_temp
.
size
();
idx
++
)
{
// update x_dim
if
(
reduce_dim_temp
[
idx
]
-
reduce_dim_temp
[
idx
-
1
]
==
1
)
{
x_dim
[
reduce_dim_temp
[
idx
-
1
]]
*=
x_dim
[
reduce_dim_temp
[
idx
]];
x_dim
.
erase
(
x_dim
.
begin
()
+
reduce_dim_temp
[
idx
]);
num
++
;
for
(
int
i
=
1
;
i
<
x_dim
.
size
();
i
++
)
{
if
((
idx_reduce
<
reduce_dim_temp
.
size
())
&&
(
i
==
reduce_dim_temp
[
idx_reduce
]))
{
int
result
=
reduce_dim_temp
[
idx_reduce
]
-
reduce_dim
[
reduce_dim
.
size
()
-
1
];
bool
is_equal
=
((
result
-
num
)
==
1
);
if
(
is_equal
)
{
x_new_dim
[
x_new_dim
.
size
()
-
1
]
*=
x_dim
[
i
];
num
++
;
}
else
{
reduce_dim
.
push_back
(
reduce_dim_temp
[
idx_reduce
]
-
num
);
x_new_dim
.
push_back
(
x_dim
[
i
]);
}
idx_reduce
++
;
}
else
{
reduce_dim
.
push_back
(
reduce_dim_temp
[
idx
]
-
num
);
x_new_dim
.
push_back
(
x_dim
[
i
]
);
}
}
}
else
{
reduce_dim
=
reduce_dim_temp
;
x_new_dim
=
x_dim
;
}
// update new_x_dim and new_reduce_dim
std
::
vector
<
int
>
new_x_dim
,
new_reduce_dim_temp
;
// update x_dim
x_dim
=
x_new_dim
;
std
::
vector
<
int
>
().
swap
(
x_new_dim
);
std
::
vector
<
int
>
reduce_dim_new
;
int
is_reduced
=
0
;
for
(
auto
e
:
reduce_dim
)
{
is_reduced
|=
1
<<
e
;
}
std
::
vector
<
int
>
().
swap
(
reduce_dim
);
for
(
int
i
=
0
;
i
<
x_dim
.
size
();
i
++
)
{
if
((
i
==
0
)
||
(((
is_reduced
>>
i
)
^
(
is_reduced
>>
(
i
-
1
)))
&
1
))
{
new_x
_dim
.
push_back
(
x_dim
[
i
]);
x_new
_dim
.
push_back
(
x_dim
[
i
]);
if
((
is_reduced
>>
i
)
&
1
)
new_reduce_dim_temp
.
push_back
(
new_x
_dim
.
size
()
-
1
);
reduce_dim_new
.
push_back
(
x_new
_dim
.
size
()
-
1
);
}
else
{
new_x_dim
[
new_x
_dim
.
size
()
-
1
]
*=
x_dim
[
i
];
x_new_dim
[
x_new
_dim
.
size
()
-
1
]
*=
x_dim
[
i
];
}
}
x_dim
=
new_x
_dim
;
reduce_dim
=
new_reduce_dim_temp
;
x_dim
=
x_new
_dim
;
reduce_dim
=
reduce_dim_new
;
int
x_rank
=
static_cast
<
int
>
(
x_dim
.
size
());
std
::
set
<
int
>
left_set
;
...
...
@@ -237,9 +293,9 @@ struct ReduceConfig {
idx_dim
.
push_back
(
i
);
}
x_strides
=
detail
::
GetStrides
(
x_dim
,
idx_dim
);
reduce_strides
=
detail
::
GetStrides
(
x_dim
,
reduce_dim
);
left_strides
=
detail
::
GetStrides
(
x_dim
,
left_dim
);
x_strides
=
detail
::
Get
Dim
Strides
(
x_dim
,
idx_dim
);
reduce_strides
=
detail
::
Get
Dim
Strides
(
x_dim
,
reduce_dim
);
left_strides
=
detail
::
Get
Dim
Strides
(
x_dim
,
left_dim
);
reduce_num
=
reduce_strides
[
0
]
*
x_dim
[
reduce_dim
[
0
]];
left_num
=
1
;
...
...
@@ -256,13 +312,17 @@ struct ReduceConfig {
void
SetReduceType
()
{
int
rank
=
x_dim
.
size
();
int
reduce_rank
=
reduce_dim
.
size
();
bool
is_large_enough
=
(
reduce_num
>
REDUCE_SPLIT_BOUNDARY
/
2
)
||
(
left_num
>
REDUCE_SPLIT_BOUNDARY
);
if
(
rank
==
reduce_rank
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceAll
);
}
else
if
(
rank
==
2
&&
reduce_rank
==
1
&&
reduce_dim
[
0
]
==
1
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceLastDim
);
}
else
if
(
reduce_rank
==
1
)
{
}
else
if
(
reduce_rank
==
1
&&
((
rank
==
2
&&
is_large_enough
)
||
rank
!=
2
))
{
// ReduceFirstDim and reduceSecondDim
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceHigherDim
);
...
...
@@ -277,7 +337,7 @@ struct ReduceConfig {
// for others: block(block_num, 1) , grid(left_num, 1)
void
SetBlockDim
()
{
// init
int
block_num
=
detail
::
Get
Desired
BlockDim
(
reduce_num
);
int
block_num
=
detail
::
GetBlockDim
(
reduce_num
);
should_reduce_again
=
false
;
dim3
block_dim
(
block_num
,
1
);
...
...
@@ -302,7 +362,7 @@ struct ReduceConfig {
// init
int
num_block
=
(
max_threads
/
left_num
);
if
(
num_block
>
1
&&
reduce_num
>=
512
)
{
if
(
num_block
>
1
&&
reduce_num
>=
REDUCE_SPLIT_BOUNDARY
)
{
blocking_size
=
detail
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
blocking_size
<=
1
)
{
...
...
@@ -352,6 +412,9 @@ struct ReduceConfig {
dim3
grid
;
};
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
__device__
__forceinline__
void
ReduceLastDim
(
const
Tx
*
x
,
Ty
*
y
,
...
...
@@ -362,18 +425,25 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
int
idx_x
=
blockIdx
.
x
*
reduce_num
;
int
idx_y
=
threadIdx
.
x
;
Ty
reduce_var
=
init
;
for
(
int
idx_y
=
threadIdx
.
x
;
idx_y
<
reduce_num
;
idx_y
+=
BlockDim
)
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
x
[
idx_x
+
idx_y
]));
for
(
int
idx_y
=
threadIdx
.
x
;
idx_y
<
reduce_num
;
idx_y
+=
BlockDim
)
{
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
+
idx_y
])));
}
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
transformer
(
reduce_var
)
;
y
[
blockIdx
.
x
]
=
reduce_var
;
}
}
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
>
__device__
__forceinline__
void
ReduceHigherDim
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
...
...
@@ -383,25 +453,29 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
idy
=
blockIdx
.
y
*
block_size
;
Ty
temp
=
init
;
Ty
reduce_var
=
init
;
if
(
idx
<
left_num
)
{
int
loop
=
reduce_num
-
idy
;
loop
=
loop
>
block_size
?
block_size
:
loop
;
for
(
int
iy
=
0
;
iy
<
loop
;
iy
++
)
{
int
id
=
(
idy
+
iy
)
*
left_num
+
idx
+
blockIdx
.
z
*
reduce_num
*
left_num
;
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
x
[
id
]
));
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
transformer
(
x
[
id
])
));
}
y
[
idx
+
blockIdx
.
y
*
left_num
+
blockIdx
.
z
*
gridDim
.
y
*
left_num
]
=
static_cast
<
Ty
>
(
transformer
(
reduce_var
))
;
reduce_var
;
}
}
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
>
__device__
__forceinline__
void
ReduceAny
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
int
reduce_num
,
paddle
::
framework
::
Array
<
int
,
Rank
>
x_strides
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_dim
,
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
...
...
@@ -423,20 +497,26 @@ __device__ __forceinline__ void ReduceAny(
}
int
idx_x
=
0
;
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
Ty
reduce_var
=
static_cast
<
Ty
>
(
x
[
idx_x
]);
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
{
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
}
Ty
reduce_var
=
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
]));
for
(
int
i
=
threadIdx
.
x
+
BlockDim
;
i
<
reduce_num
;
i
+=
BlockDim
)
{
int
reduce_idx
=
i
;
for
(
int
j
=
0
;
j
<
ReduceRank
;
++
j
)
{
sub_index
[
reduce_dim
[
j
]]
=
reduce_idx
/
reduce_strides
[
j
];
reduce_idx
%=
reduce_strides
[
j
];
}
int
idx_x
=
0
;
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
reduce_var
=
static_cast
<
Ty
>
(
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
x
[
idx_x
])));
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
{
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
}
reduce_var
=
static_cast
<
Ty
>
(
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
]))));
}
__syncthreads
();
...
...
@@ -444,10 +524,11 @@ __device__ __forceinline__ void ReduceAny(
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
transformer
(
reduce_var
)
;
y
[
blockIdx
.
x
]
=
reduce_var
;
}
}
// module function designed for global function
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
,
int
ReduceType
>
__device__
__forceinline__
void
ReduceModule
(
...
...
@@ -458,17 +539,20 @@ __device__ __forceinline__ void ReduceModule(
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
// reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
if
(
ReduceType
==
ReduceType
::
kReduceLastDim
)
{
ReduceLastDim
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
);
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
}
else
if
(
ReduceType
==
ReduceType
::
kReduceHigherDim
)
{
ReduceHigherDim
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
blocking_size
);
// reduce_rank >= 2
}
else
{
ReduceAny
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
,
Rank
,
ReduceRank
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
x_strides
,
reduce_dim
,
x
,
y
,
reducer
,
transformer
,
reduce_num
,
x_strides
,
reduce_dim
,
reduce_strides
,
left_dim
,
left_strides
);
}
}
...
...
@@ -491,23 +575,22 @@ __global__ void ReduceKernelFunction(
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
,
typename
TransformOp
,
int
kRank
,
int
kReduceRank
>
static
void
launchKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
Ty
&
init
,
static
void
LaunchKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
Ty
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
#define CUB_REDUCE_TYPE_CASE(type) \
case type: { \
constexpr auto kReduceType = type; \
ReduceKernelFunction< \
Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \
kReduceType><<<config.grid, config.block, 0, stream>>>( \
x_data, config.output_data, reducer, transformer, init, \
config.reduce_num, config.left_num, config.blocking_size, \
detail::
from
<int, kRank>(config.x_strides), \
detail::
from
<int, kReduceRank>(config.reduce_dim), \
detail::
from
<int, kReduceRank>(config.reduce_strides), \
detail::
from
<int, kRank - kReduceRank>(config.left_dim), \
detail::
from
<int, kRank - kReduceRank>(config.left_strides)); \
#define CUB_REDUCE_TYPE_CASE(type)
\
case type: {
\
constexpr auto kReduceType = type;
\
ReduceKernelFunction<
\
Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank,
\
kReduceType><<<config.grid, config.block, 0, stream>>>(
\
x_data, config.output_data, reducer, transformer, init,
\
config.reduce_num, config.left_num, config.blocking_size,
\
detail::
VectorToArray
<int, kRank>(config.x_strides), \
detail::
VectorToArray
<int, kReduceRank>(config.reduce_dim), \
detail::
VectorToArray
<int, kReduceRank>(config.reduce_strides), \
detail::
VectorToArray
<int, kRank - kReduceRank>(config.left_dim), \
detail::
VectorToArray
<int, kRank - kReduceRank>(config.left_strides)); \
} break
switch
(
config
.
reduce_type
)
{
...
...
@@ -523,22 +606,22 @@ static void launchKernel(const Tx* x_data, Ty* y_data,
ReduceKernelFunction
<
Ty
,
Ty
,
ReduceOp
,
detail
::
IdentityFunctor
<
Ty
>
,
128
,
kRank
,
kReduceRank
,
ReduceType
::
kReduceHigherDim
><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
detail
::
IdentityFunctor
<
Ty
>
(),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
detail
::
from
<
int
,
kRank
>
(
config
.
x_strides
),
detail
::
from
<
int
,
kReduceRank
>
(
config
.
reduce_dim
),
detail
::
from
<
int
,
kReduceRank
>
(
config
.
reduce_strides
),
detail
::
from
<
int
,
kRank
-
kReduceRank
>
(
config
.
left_dim
),
detail
::
from
<
int
,
kRank
-
kReduceRank
>
(
config
.
left_strides
));
config
.
output_data
,
y_data
,
reducer
,
detail
::
IdentityFunctor
<
Ty
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
detail
::
VectorToArray
<
int
,
kRank
>
(
config
.
x_strides
),
detail
::
VectorToArray
<
int
,
kReduceRank
>
(
config
.
reduce_dim
),
detail
::
VectorToArray
<
int
,
kReduceRank
>
(
config
.
reduce_strides
),
detail
::
VectorToArray
<
int
,
kRank
-
kReduceRank
>
(
config
.
left_dim
),
detail
::
VectorToArray
<
int
,
kRank
-
kReduceRank
>
(
config
.
left_strides
));
}
}
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
,
typename
TransformOp
>
static
void
launchReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
static
void
LaunchReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
Ty
&
init
,
const
TransformOp
&
transformer
,
Ty
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
int
reduce_rank
=
config
.
reduce_strides
.
size
();
int
rank
=
config
.
x_strides
.
size
();
...
...
@@ -552,28 +635,11 @@ static void launchReduceKernel(const Tx* x_data, Ty* y_data,
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
l
aunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
x_data, y_data,
place, reducer, transformer, init, stream, config);
\
L
aunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
x_data, y_data,
reducer, transformer, init, stream, config);
\
} break
// launch CUB::Reduce
if
(
config
.
reduce_type
==
static_cast
<
int
>
(
ReduceType
::
kReduceAll
))
{
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transformer
);
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
init
,
stream
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
place
);
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
init
,
stream
);
return
;
}
detail
::
CheckReduceRankIsValid
(
reduce_rank
,
rank
);
detail
::
CheckReduceRank
(
reduce_rank
,
rank
);
switch
(
rank
)
{
CUB_RANK_CASE
(
2
,
CUB_REDUCE_RANK_CASE
(
1
););
...
...
@@ -595,23 +661,25 @@ static void launchReduceKernel(const Tx* x_data, Ty* y_data,
#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
>
void
TensorReduceFunc
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
Ty
&
init
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
gpuStream_t
stream
)
{
template
<
typename
Tx
,
typename
Ty
,
template
<
typename
,
typename
>
class
ReduceOp
>
void
TensorReduceFunctorImpl
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
gpuStream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
config
=
ReduceConfig
<
Ty
>
(
origin_reduce_dims
,
x_dim
);
config
.
Run
();
config
.
Run
();
// get the parameters of LaunchReduceKernel
auto
x_data
=
x
.
data
<
Tx
>
();
auto
y_data
=
y
->
mutable_data
<
Ty
>
(
x
.
place
());
framework
::
Tensor
tmp
;
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in
// y_data;
config
.
SetOutputData
(
y_data
,
x
.
place
(),
tmp
);
framework
::
Tensor
tmp
;
config
.
SetOutputData
(
y_data
,
x
.
place
(),
&
tmp
);
if
(
config
.
reduce_num
==
1
)
{
auto
out_dims
=
y
->
dims
();
...
...
@@ -619,17 +687,36 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
y
->
Resize
(
out_dims
);
return
;
}
using
TransformOp
=
typename
ReduceOp
<
Tx
,
Ty
>::
Transformer
;
auto
reducer
=
ReduceOp
<
Tx
,
Ty
>
();
// launch CUB::Reduce
if
(
config
.
reduce_type
==
static_cast
<
int
>
(
ReduceType
::
kReduceAll
))
{
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
TransformOp
(
config
.
reduce_num
));
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
reducer
.
initial
(),
stream
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
x
.
place
());
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
reducer
.
initial
(),
stream
);
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
launchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
x_data, y_data, x.place(), reducer, transformer, init, stream, \
config); \
return
;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp<Tx, Ty>, TransformOp>( \
x_data, y_data, reducer, TransformOp(config.reduce_num), \
reducer.initial(), stream, config); \
} break
switch
(
detail
::
GetDesiredBlockDim
(
config
.
reduce_num
))
{
CUB_BLOCK_DIM_CASE
(
512
);
switch
(
detail
::
GetBlockDim
(
config
.
reduce_num
))
{
CUB_BLOCK_DIM_CASE
(
256
);
CUB_BLOCK_DIM_CASE
(
128
);
CUB_BLOCK_DIM_CASE
(
64
);
...
...
@@ -642,5 +729,46 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
#undef CUB_BLOCK_DIM_CASE
}
template
<
typename
Tx
,
template
<
typename
,
typename
>
class
ReduceOp
>
struct
TensorReduceFunc
{
const
framework
::
Tensor
&
x
;
framework
::
Tensor
*
y
;
std
::
vector
<
int
>
origin_reduce_dims
;
gpuStream_t
stream
;
TensorReduceFunc
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
gpuStream_t
stream
)
:
x
(
x
),
y
(
y
),
origin_reduce_dims
(
origin_reduce_dims
),
stream
(
stream
)
{}
template
<
typename
Ty
>
void
apply
()
const
{
TensorReduceFunctorImpl
<
Tx
,
Ty
,
ReduceOp
>
(
x
,
y
,
origin_reduce_dims
,
stream
);
}
};
template
<
typename
T
,
template
<
typename
,
typename
>
class
ReduceOp
>
class
ReduceCudaKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
bool
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
const
Tensor
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
out_dtype
=
context
.
Attr
<
int
>
(
"out_dtype"
);
std
::
vector
<
int
>
dims
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
std
::
vector
<
int
>
reduce_dims
=
detail
::
GetReduceDim
(
dims
,
input
->
dims
().
size
(),
reduce_all
);
gpuStream_t
stream
=
context
.
cuda_device_context
().
stream
();
if
(
out_dtype
>=
0
)
{
framework
::
VisitDataTypeSmall
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
out_dtype
),
TensorReduceFunc
<
T
,
ReduceOp
>
(
*
input
,
output
,
reduce_dims
,
stream
));
}
else
{
TensorReduceFunctorImpl
<
T
,
T
,
ReduceOp
>
(
*
input
,
output
,
reduce_dims
,
stream
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
浏览文件 @
480b284c
...
...
@@ -12,26 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
// reduce_prod
#ifdef __HIPCC__
// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922
// do not support double in HIPCC platform (Eigen3 to be fixed)
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
ProdFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMul
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
ProdFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMul
>
);
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录