Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
524389ee
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
524389ee
编写于
12月 16, 2021
作者:
N
niuliling123
提交者:
GitHub
12月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the transformop parameter in TensorReduceFunctorImpl (#38135)
* Add the transformop parameter in TensorReduceFunctorImpl
上级
be874c08
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
138 addition
and
232 deletion
+138
-232
paddle/fluid/operators/clip_by_norm_op.cu
paddle/fluid/operators/clip_by_norm_op.cu
+4
-25
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+4
-3
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+4
-3
paddle/fluid/operators/fused/attn_gemm.h
paddle/fluid/operators/fused/attn_gemm.h
+3
-4
paddle/fluid/operators/margin_cross_entropy_op.cu
paddle/fluid/operators/margin_cross_entropy_op.cu
+6
-16
paddle/fluid/operators/p_norm_op.cu
paddle/fluid/operators/p_norm_op.cu
+10
-63
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+4
-4
paddle/fluid/operators/prelu_op.cu
paddle/fluid/operators/prelu_op.cu
+3
-11
paddle/fluid/operators/reduce_ops/reduce_all_op.cu
paddle/fluid/operators/reduce_ops/reduce_all_op.cu
+1
-2
paddle/fluid/operators/reduce_ops/reduce_any_op.cu
paddle/fluid/operators/reduce_ops/reduce_any_op.cu
+1
-2
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
+5
-5
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
+4
-4
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
+5
-5
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+52
-48
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+10
-5
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
+5
-5
paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
+12
-12
paddle/fluid/operators/trace_op.cu
paddle/fluid/operators/trace_op.cu
+3
-13
paddle/fluid/operators/triangular_solve_op.cu
paddle/fluid/operators/triangular_solve_op.cu
+2
-2
未找到文件。
paddle/fluid/operators/clip_by_norm_op.cu
浏览文件 @
524389ee
...
...
@@ -18,29 +18,6 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
SquareTransformer
{
HOSTDEVICE
explicit
inline
SquareTransformer
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
)
*
static_cast
<
Ty
>
(
x
);
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
*
x
)
const
{
return
static_cast
<
Ty
>
(
x
[
0
])
*
static_cast
<
Ty
>
(
x
[
0
]);
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
SquareSum
{
using
Transformer
=
SquareTransformer
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
>
class
ClipByNormKernel
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
...
...
@@ -97,8 +74,10 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
}
Tensor
tmp
=
context
.
AllocateTmpTensor
<
float
,
platform
::
CUDADeviceContext
>
(
{
1
},
dev_ctx
);
TensorReduceFunctorImpl
<
platform
::
float16
,
float
,
SquareSum
>
(
*
input
,
&
tmp
,
reduce_dims
,
dev_ctx
.
stream
());
TensorReduceFunctorImpl
<
platform
::
float16
,
float
,
kps
::
AddFunctor
,
kps
::
SquareFunctor
<
platform
::
float16
,
float
>>
(
*
input
,
&
tmp
,
kps
::
SquareFunctor
<
platform
::
float16
,
float
>
(),
reduce_dims
,
dev_ctx
.
stream
());
auto
tmp_eigen
=
EigenVector
<
float
>::
Flatten
(
tmp
);
auto
x_norm
=
tmp_eigen
.
sqrt
();
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
524389ee
...
...
@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -91,7 +90,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
}
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
x
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
dout
,
dx
,
reduce_dims
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
dout
,
dx
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
// dy
...
...
@@ -106,7 +106,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
}
else
{
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
y
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
dout
,
dy
,
reduce_dims
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
dout
,
dy
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
}
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
524389ee
...
...
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -69,7 +68,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
}
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
x
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
dout
,
dx
,
reduce_dims
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
dout
,
dx
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
// dy
...
...
@@ -90,7 +90,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
}
else
{
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
y
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSub
>
(
*
dout
,
dy
,
reduce_dims
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
InverseFunctor
<
T
>>
(
*
dout
,
dy
,
kps
::
InverseFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
}
...
...
paddle/fluid/operators/fused/attn_gemm.h
浏览文件 @
524389ee
...
...
@@ -16,11 +16,10 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_
functor_op
.h"
#include "paddle/fluid/operators/reduce_ops/reduce_
op.cu
.h"
namespace
paddle
{
namespace
operators
{
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template
<
typename
T
>
class
AttnMatMul
{
...
...
@@ -165,8 +164,8 @@ class AttnMatMul {
(
input_dims
[
2
]
==
output_dims
[
0
]));
if
(
support_case_1
||
support_case_2
)
{
gpuStream_t
stream
=
dev_ctx_
.
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
d_output
,
d_bias
,
{
0
,
1
},
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
d_output
,
d_bias
,
kps
::
IdentityFunctor
<
T
>
(),
{
0
,
1
},
stream
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support reduce when the input dims are [0,1,2,3,4] and "
...
...
paddle/fluid/operators/margin_cross_entropy_op.cu
浏览文件 @
524389ee
...
...
@@ -24,7 +24,6 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/margin_cross_entropy_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h"
...
...
@@ -128,17 +127,6 @@ __global__ void AddMarginToPositiveLogitsKernel(
}
}
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpAndSum
{
using
Transformer
=
kps
::
ExpFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
T
>
__global__
void
ScaleLogitKernel
(
T
*
logits
,
const
float
scale
,
const
int64_t
N
,
const
int64_t
D
)
{
...
...
@@ -309,8 +297,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
logits_max
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
({
N
,
1
},
dev_ctx
);
T
*
logits_max_buff
=
logits_max
.
mutable_data
<
T
>
(
place
);
TensorReduceFunctorImpl
<
T
,
T
,
CustomMax
>
(
softmax_2d
,
&
logits_max
,
{
1
},
dev_ctx
.
stream
());
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
MaxFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
softmax_2d
,
&
logits_max
,
kps
::
IdentityFunctor
<
T
>
(),
{
1
},
dev_ctx
.
stream
());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if
(
nranks
>
1
)
{
...
...
@@ -330,8 +319,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
sum_exp_logits
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
({
N
,
1
},
dev_ctx
);
T
*
sum_exp_logits_buff
=
sum_exp_logits
.
mutable_data
<
T
>
(
place
);
TensorReduceFunctorImpl
<
T
,
T
,
ExpAndSum
>
(
softmax_2d
,
&
sum_exp_logits
,
{
1
},
dev_ctx
.
stream
());
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
ExpFunctor
<
T
>>
(
softmax_2d
,
&
sum_exp_logits
,
kps
::
ExpFunctor
<
T
>
(),
{
1
},
dev_ctx
.
stream
());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if
(
nranks
>
1
)
{
...
...
paddle/fluid/operators/p_norm_op.cu
浏览文件 @
524389ee
...
...
@@ -59,28 +59,17 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return
pow
(
base
,
exponent
);
}
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
x
);
}
};
template
<
typename
T
>
struct
NonzeroFunctor
{
HOSTDEVICE
explicit
inline
NonzeroFunctor
()
{}
HOSTDEVICE
explicit
inline
NonzeroFunctor
(
int
n
)
{}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
static_cast
<
double
>
(
x
)
!=
0
);
}
};
template
<
typename
T
>
struct
AbsFunctor
{
HOSTDEVICE
explicit
inline
AbsFunctor
()
{}
HOSTDEVICE
explicit
inline
AbsFunctor
(
int
n
)
{}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
inline_abs
(
x
));
}
...
...
@@ -106,48 +95,6 @@ struct PowFunctor {
float
porder
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
AbsAndMin
{
using
Transformer
=
AbsFunctor
;
using
MT
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
MT
>::
infinity
());
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
a
<
b
)
?
a
:
b
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
AbsAndMax
{
using
Transformer
=
AbsFunctor
;
using
MT
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
-
std
::
numeric_limits
<
MT
>::
infinity
());
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
a
>
b
)
?
a
:
b
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
NonzeroAndSum
{
using
Transformer
=
NonzeroFunctor
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityAndSum
{
using
Transformer
=
IdentityFunctor
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
PnormCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -167,14 +114,14 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
if
(
porder
==
0
)
{
TensorReduceFunctorImpl
<
T
,
T
,
NonzeroAndSum
>
(
*
in_x
,
out_norm
,
reduce_axis
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
NonzeroFunctor
<
T
>>
(
*
in_x
,
out_norm
,
NonzeroFunctor
<
T
>
(),
reduce_axis
,
stream
);
}
else
if
(
porder
==
INFINITY
)
{
TensorReduceFunctorImpl
<
T
,
T
,
AbsAndMax
>
(
*
in_x
,
out_norm
,
reduce_axis
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
MaxFunctor
,
AbsFunctor
<
T
>>
(
*
in_x
,
out_norm
,
AbsFunctor
<
T
>
(),
reduce_axis
,
stream
);
}
else
if
(
porder
==
-
INFINITY
)
{
TensorReduceFunctorImpl
<
T
,
T
,
AbsAndMin
>
(
*
in_x
,
out_norm
,
reduce_axis
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
MinFunctor
,
AbsFunctor
<
T
>>
(
*
in_x
,
out_norm
,
AbsFunctor
<
T
>
(),
reduce_axis
,
stream
);
}
else
{
framework
::
Tensor
tmp_x
;
tmp_x
.
mutable_data
<
T
>
(
xdim
,
ctx
.
GetPlace
());
...
...
@@ -189,8 +136,8 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
cuda_ctx
,
ins
,
&
outs
,
func
);
framework
::
Tensor
tmp_y
;
tmp_y
.
mutable_data
<
T
>
(
ndim
,
ctx
.
GetPlace
());
TensorReduceFunctorImpl
<
T
,
T
,
IdentityAndSum
>
(
tmp_x
,
&
tmp_y
,
reduce_axis
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
tmp_x
,
&
tmp_y
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_axis
,
stream
);
const
framework
::
Tensor
*
tmp_norm
=
&
tmp_y
;
ins
=
{
tmp_norm
};
outs
=
{
out_norm
};
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
524389ee
...
...
@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
...
...
@@ -203,13 +202,14 @@ class PoolKernel : public framework::OpKernel<T> {
}
else
if
(
pooling_type
==
"avg"
)
{
std
::
vector
<
int
>
reduce_dim
;
int
reduce_num
=
getReduceNum
(
*
in_x
,
out
,
data_format
,
&
reduce_dim
);
if
(
reduce_num
>
0
&&
adaptive
)
{
// for adaptive_avg_pool2d && output_size == 1
#if defined(__HIPCC__) || defined(__NVCC__)
auto
stream
=
dev_ctx
.
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomMean
>
(
*
in_x
,
out
,
reduce_dim
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
<
T
>>
(
*
in_x
,
out
,
kps
::
DivideFunctor
<
T
>
(
reduce_num
),
reduce_dim
,
stream
);
#else // for cpu
paddle
::
operators
::
math
::
Pool2dFunctor
<
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
...
...
paddle/fluid/operators/prelu_op.cu
浏览文件 @
524389ee
...
...
@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
paddle
{
...
...
@@ -123,13 +123,6 @@ class PreluOpGradFunctor {
}
};
struct
IdentityFunctor
{
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
CUDAPReluGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -192,9 +185,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims
.
push_back
(
i
);
}
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
dalpha_tmp
,
dalpha
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
(),
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dalpha_tmp
,
dalpha
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_all_op.cu
浏览文件 @
524389ee
...
...
@@ -13,8 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_all
,
ops
::
ReduceCudaKernel
<
bool
,
paddle
::
operators
::
CustomLogicalAnd
>
);
ops
::
ReduceCudaKernel
<
bool
,
kps
::
LogicalAndFunctor
,
kps
::
IdentityFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_any_op.cu
浏览文件 @
524389ee
...
...
@@ -13,9 +13,8 @@
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_any
,
ops
::
ReduceCudaKernel
<
bool
,
paddle
::
operators
::
CustomLogicalO
r
>
);
ops
::
ReduceCudaKernel
<
bool
,
kps
::
LogicalOrFunctor
,
kps
::
IdentityFuncto
r
>
);
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
浏览文件 @
524389ee
...
...
@@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_max
REGISTER_OP_CUDA_KERNEL
(
reduce_max
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMax
>
);
reduce_max
,
ops
::
ReduceCudaKernel
<
float
,
kps
::
MaxFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
double
,
kps
::
MaxFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int
,
kps
::
MaxFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
kps
::
MaxFunctor
,
kps
::
IdentityFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
浏览文件 @
524389ee
...
...
@@ -13,11 +13,11 @@
// limitations under the License.
#include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_mean
,
ops
::
ReduceCudaKernel
<
bool
,
paddle
::
operators
::
CustomMean
>
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMean
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMean
>
);
reduce_mean
,
ops
::
ReduceCudaKernel
<
bool
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
,
ops
::
ReduceCudaKernel
<
float
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
,
ops
::
ReduceCudaKernel
<
double
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
浏览文件 @
524389ee
...
...
@@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_min
REGISTER_OP_CUDA_KERNEL
(
reduce_min
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMin
>
);
reduce_min
,
ops
::
ReduceCudaKernel
<
float
,
kps
::
MinFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
double
,
kps
::
MinFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int
,
kps
::
MinFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
kps
::
MinFunctor
,
kps
::
IdentityFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
524389ee
...
...
@@ -44,11 +44,11 @@ namespace cub = hipcub;
#define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
namespace
paddle
{
namespace
operators
{
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
namespace
details
{
static
inline
int
GetLastPow2
(
int
n
)
{
...
...
@@ -722,12 +722,12 @@ __global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer,
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
>
template
<
typename
Tx
,
typename
Ty
,
typename
MPType
,
typename
ReduceOp
,
typename
TransformOp
>
static
void
LaunchReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
ReduceOp
&
reducer
,
MPType
init
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transform
,
MPType
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
using
TransformOp
=
typename
ReduceOp
::
Transformer
;
if
(
config
.
reduce_type
==
kReduceLastDim
)
{
int
stride_reduce
=
1
;
int
stride_left
=
config
.
reduce_num
;
...
...
@@ -743,15 +743,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
OneDimIndexCal
><<<
8
,
128
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
)
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
x_data
,
config
.
output_data
,
reducer
,
transform
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
#else
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
OneDimIndexCal
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
)
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
x_data
,
config
.
output_data
,
reducer
,
transform
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
#endif
}
else
{
...
...
@@ -771,15 +771,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
IndexCalculator
><<<
8
,
128
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
)
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
x_data
,
config
.
output_data
,
reducer
,
transform
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
#else
ReduceAnyKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
,
TransformOp
,
IndexCalculator
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
)
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
x_data
,
config
.
output_data
,
reducer
,
transform
,
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
reduce_last_dim
,
reduce_index_calculator
,
left_index_calculator
,
dim
);
#endif
}
...
...
@@ -802,23 +802,22 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
8
,
128
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim
);
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim
);
#else
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim
);
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim
);
#endif
}
}
template
<
typename
Tx
,
typename
Ty
,
t
emplate
<
typename
,
typename
>
class
Reduce
Op
>
template
<
typename
Tx
,
typename
Ty
,
template
<
typename
>
class
ReduceOp
,
t
ypename
Transform
Op
>
void
TensorReduceFunctorImpl
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
const
TransformOp
&
transform
,
std
::
vector
<
int
>
origin_reduce_dims
,
gpuStream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
...
...
@@ -853,10 +852,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
(
!
std
::
is_same
<
Tx
,
paddle
::
platform
::
float16
>::
value
);
if
(
use_cub_reduce
)
{
// launch CUB::Reduce
using
TransformOp
=
typename
ReduceOp
<
Tx
,
Ty
>::
Transformer
;
auto
reducer
=
ReduceOp
<
Tx
,
Ty
>
();
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
TransformOp
(
config
.
reduce_num
));
auto
reducer
=
ReduceOp
<
Ty
>
();
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transform
);
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
reducer
.
initial
(),
...
...
@@ -873,7 +871,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}
using
MPType
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
auto
reducer
=
ReduceOp
<
Tx
,
MPType
>
();
auto
reducer
=
ReduceOp
<
MPType
>
();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
...
...
@@ -882,7 +880,6 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if
(
config
.
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
using
TransformOp
=
typename
ReduceOp
<
Tx
,
MPType
>::
Transformer
;
kps
::
DimConfig
dim
=
kps
::
DimConfig
(
config
.
grid
.
x
,
config
.
grid
.
y
,
config
.
grid
.
z
,
config
.
block
.
x
,
config
.
blocking_size
,
0
);
...
...
@@ -890,18 +887,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
config
.
reduce_num
%
config
.
blocking_size
,
0
);
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
ReduceHigherDimKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
MPType
>
,
TransformOp
><<<
8
,
128
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
dim
);
x_data
,
config
.
output_data
,
reducer
,
transform
,
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
dim
);
#else
ReduceHigherDimKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
Tx
,
Ty
,
MPType
,
ReduceOp
<
MPType
>
,
TransformOp
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
dim
);
x_data
,
config
.
output_data
,
reducer
,
transform
,
reducer
.
initial
(),
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
dim
);
#endif
if
(
config
.
should_reduce_again
)
{
...
...
@@ -913,14 +908,14 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
Ty
,
Ty
,
MPType
,
ReduceOp
<
MPType
>
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
8
,
128
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
reducer
.
initial
(),
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
dim2
);
#else
ReduceHigherDimKernel
<
Ty
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
,
Ty
,
Ty
,
MPType
,
ReduceOp
<
MPType
>
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
kps
::
IdentityFunctor
<
Ty
,
MPType
>
(
config
.
grid
.
y
),
reducer
.
initial
(),
...
...
@@ -933,23 +928,32 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
LaunchReduceKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
Tx
,
MPType
>
>
(
x_data
,
y_data
,
reducer
,
reducer
.
initial
(),
stream
,
config
);
LaunchReduceKernel
<
Tx
,
Ty
,
MPType
,
ReduceOp
<
MPType
>
,
TransformOp
>
(
x_data
,
y_data
,
reducer
,
transform
,
reducer
.
initial
(),
stream
,
config
);
}
template
<
typename
Tx
,
template
<
typename
,
typename
>
class
ReduceOp
>
template
<
typename
Tx
,
template
<
typename
>
class
ReduceOp
,
template
<
typename
,
typename
>
class
TransformOp
>
struct
TensorReduceFunc
{
const
framework
::
Tensor
&
x
;
framework
::
Tensor
*
y
;
std
::
vector
<
int
>
origin_reduce_dims
;
gpuStream_t
stream
;
int
reduce_num
;
TensorReduceFunc
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
gpuStream_t
stream
)
:
x
(
x
),
y
(
y
),
origin_reduce_dims
(
origin_reduce_dims
),
stream
(
stream
)
{}
std
::
vector
<
int
>
origin_reduce_dims
,
int
num_reduce
,
gpuStream_t
stream
)
:
x
(
x
),
y
(
y
),
origin_reduce_dims
(
origin_reduce_dims
),
reduce_num
(
num_reduce
),
stream
(
stream
)
{}
template
<
typename
Ty
>
void
apply
()
const
{
TensorReduceFunctorImpl
<
Tx
,
Ty
,
ReduceOp
>
(
x
,
y
,
origin_reduce_dims
,
stream
);
using
MPType
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
TensorReduceFunctorImpl
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
<
Tx
,
MPType
>>
(
x
,
y
,
TransformOp
<
Tx
,
MPType
>
(
reduce_num
),
origin_reduce_dims
,
stream
);
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
524389ee
...
...
@@ -670,7 +670,8 @@ If reduce_all is true, just reduce along all dimensions and output a scalar.
};
#if defined(__HIPCC__) || defined(__NVCC__)
template
<
typename
T
,
template
<
typename
,
typename
>
class
ReduceOp
>
template
<
typename
T
,
template
<
typename
>
class
ReduceOp
,
template
<
typename
,
typename
>
class
TransformOp
>
class
ReduceCudaKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
...
@@ -682,15 +683,19 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
dims
,
input
->
dims
().
size
(),
reduce_all
);
int
reduce_num
=
1
;
for
(
int
i
=
0
;
i
<
input
->
dims
().
size
();
i
++
)
{
reduce_num
*=
(
input
->
dims
())[
i
];
}
gpuStream_t
stream
=
context
.
cuda_device_context
().
stream
();
if
(
out_dtype
>=
0
)
{
framework
::
VisitDataTypeSmall
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
out_dtype
),
TensorReduceFunc
<
T
,
ReduceOp
>
(
*
input
,
output
,
reduce_dims
,
stream
));
TensorReduceFunc
<
T
,
ReduceOp
,
TransformOp
>
(
*
input
,
output
,
reduce_dims
,
reduce_num
,
stream
));
}
else
{
TensorReduceFunctorImpl
<
T
,
T
,
ReduceOp
>
(
*
input
,
output
,
reduce_dims
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
ReduceOp
,
TransformOp
<
T
,
T
>>
(
*
input
,
output
,
TransformOp
<
T
,
T
>
(
reduce_num
),
reduce_dims
,
stream
);
}
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
浏览文件 @
524389ee
...
...
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMul
>
);
reduce_prod
,
ops
::
ReduceCudaKernel
<
float
,
kps
::
MulFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int
,
kps
::
MulFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
double
,
kps
::
MulFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
kps
::
MulFunctor
,
kps
::
IdentityFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
浏览文件 @
524389ee
...
...
@@ -11,18 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_sum
,
ops
::
ReduceCudaKernel
<
bool
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
float16
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
complex
<
float
>
,
paddle
::
operators
::
CustomSum
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
complex
<
double
>
,
paddle
::
operators
::
CustomSum
>
);
reduce_sum
,
ops
::
ReduceCudaKernel
<
bool
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
float
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
double
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
float16
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
complex
<
float
>
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
complex
<
double
>
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
>
);
paddle/fluid/operators/trace_op.cu
浏览文件 @
524389ee
...
...
@@ -15,21 +15,12 @@
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
#include "paddle/fluid/operators/trace_op.h"
namespace
paddle
{
namespace
operators
{
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
template
<
typename
U
>
HOSTDEVICE
inline
U
operator
()(
const
U
&
x
)
const
{
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
TraceCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -48,9 +39,8 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
auto
stream
=
context
.
cuda_device_context
().
stream
();
std
::
vector
<
int
>
reduce_dims
;
reduce_dims
.
push_back
(
out
->
dims
().
size
());
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
diag
,
out
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
(),
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
diag
,
out
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
else
{
math
::
SetConstant
<
DeviceContext
,
T
>
functor
;
functor
(
context
.
device_context
<
DeviceContext
>
(),
out
,
static_cast
<
T
>
(
0
));
...
...
paddle/fluid/operators/triangular_solve_op.cu
浏览文件 @
524389ee
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
...
...
@@ -44,7 +43,8 @@ struct MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
}
}
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
in
,
out
,
out_reduce_dims
,
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
in
,
out
,
kps
::
IdentityFunctor
<
T
>
(),
out_reduce_dims
,
stream
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录