Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
82b33be3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82b33be3
编写于
9月 08, 2021
作者:
N
niuliling123
提交者:
GitHub
9月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify the reduce op according to the kernel primitive api (#35282)
上级
7aa4d879
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
293 addition
and
361 deletion
+293
-361
paddle/fluid/operators/fused/attn_bias_add.cu.h
paddle/fluid/operators/fused/attn_bias_add.cu.h
+2
-2
paddle/fluid/operators/kernel_primitives/compute_primitives.h
...le/fluid/operators/kernel_primitives/compute_primitives.h
+5
-3
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
.../fluid/operators/kernel_primitives/datamover_primitives.h
+61
-60
paddle/fluid/operators/kernel_primitives/helper_primitives.h
paddle/fluid/operators/kernel_primitives/helper_primitives.h
+1
-1
paddle/fluid/operators/margin_cross_entropy_op.cu
paddle/fluid/operators/margin_cross_entropy_op.cu
+4
-34
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+9
-7
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+211
-254
未找到文件。
paddle/fluid/operators/fused/attn_bias_add.cu.h
浏览文件 @
82b33be3
...
...
@@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
int
num_block
=
(
max_threads
/
left_num
);
if
(
num_block
>
1
&&
reduce_num
>=
REDUCE_SPLIT_BOUNDARY
)
{
*
blocking_size
=
detail
::
GetLastPow2
(
reduce_num
/
num_block
);
*
blocking_size
=
detail
s
::
GetLastPow2
(
reduce_num
/
num_block
);
if
(
*
blocking_size
<=
1
)
{
*
blocking_size
=
detail
::
GetLastPow2
(
sqrt
(
reduce_num
));
*
blocking_size
=
detail
s
::
GetLastPow2
(
sqrt
(
reduce_num
));
}
else
if
(
*
blocking_size
*
2
<
reduce_num
)
{
*
blocking_size
*=
2
;
}
...
...
paddle/fluid/operators/kernel_primitives/compute_primitives.h
浏览文件 @
82b33be3
...
...
@@ -31,13 +31,15 @@ namespace kernel_primitives {
namespace
details
{
#ifdef __HIPCC__
constexpr
int
kMaxThread
=
256
;
constexpr
int
k
Reduce
MaxThread
=
256
;
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kMaxThread
=
128
;
constexpr
int
k
Reduce
MaxThread
=
128
;
constexpr
int
kWarpSize
=
32
;
#endif
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum
ReduceMode
{
kGlobalMode
,
kLocalMode
};
template
<
typename
T
>
...
...
@@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockYReduce
(
T
val
,
ReduceOp
reducer
)
{
__shared__
T
shared_memory
[
details
::
kMaxThread
];
__shared__
T
shared_memory
[
details
::
k
Reduce
MaxThread
];
shared_memory
[
SharedMemoryIndex
(
0
)]
=
val
;
for
(
int
stride
=
blockDim
.
y
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
...
...
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
浏览文件 @
82b33be3
...
...
@@ -124,36 +124,36 @@ struct BroadcastConfig {
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
if
(
NY
==
1
&&
NX
==
1
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread
Idx
.
x
]);
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread
_offset
]);
}
else
if
(
NX
==
1
)
{
int
dx
=
threadIdx
.
x
;
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
dx
+
idy
*
stride_ny
]);
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
]);
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idx
*
stride_nx
]);
}
}
else
{
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
*
NX
+
idx
]
=
s
tatic_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
d
x
+
idy
*
stride_ny
]);
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
s
rc
[
thread_offset
+
idx
*
stride_n
x
+
idy
*
stride_ny
]);
}
}
}
}
/**
* @brief load data from src to dst
, src can be 1D data or 2D data. When
*
boundary judgment is required, you need to set a to true, and a is false by
* default.
* @brief load data from src to dst
with stride, src can be 1D data or 2D data.
*
When boundary judgment is required, you need to set a to true, and a is false
*
by
default.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
...
...
@@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
dx
=
threadIdx
.
x
*
NX
;
int
size
=
size_nx
-
dx
;
int
thread_offset
=
threadIdx
.
x
*
NX
;
int
left_size_nx
=
size_nx
-
thread_offset
;
// Each branch is added for better performance
if
(
NX
==
1
&&
NY
==
1
)
{
// for NX == 1 and NY == 1
if
(
IsBoundary
)
{
if
(
dx
<
size_nx
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
dx
]);
if
(
left_size_nx
>
0
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
]);
}
}
else
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
dx
]);
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
]);
}
}
else
if
(
NX
==
1
)
{
// for NX == 1 and NY != 1
#pragma unroll
...
...
@@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break
;
}
}
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
dx
+
idy
*
stride_ny
]);
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
// for NY == 1 and NX != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
size
)
{
if
(
idx
>=
left_size_nx
)
{
break
;
}
}
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
d
x
]);
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idx
*
stride_n
x
]);
}
}
else
{
// for NX != 1 and NY != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
size
)
{
if
(
idx
>=
left_size_nx
)
{
break
;
}
}
...
...
@@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break
;
}
}
dst
[
idy
*
NX
+
idx
]
=
s
tatic_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
d
x
+
idy
*
stride_ny
]);
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
s
rc
[
thread_offset
+
idx
*
stride_n
x
+
idy
*
stride_ny
]);
}
}
}
...
...
@@ -251,17 +251,17 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
num
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
dx
=
threadIdx
.
x
*
NX
;
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
dx
<
num
)
{
dst
[
idx
]
=
src
[
idx
+
dx
];
if
(
idx
+
thread_offset
<
num
)
{
dst
[
idx
]
=
src
[
thread_offset
+
i
dx
];
}
}
}
else
{
// blockDim,x * NX < num
const
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
t
id
=
threadIdx
.
x
*
kVectorsPerThread
;
int
t
hread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
...
...
@@ -269,7 +269,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
i
+
tid
];
vec_temp
[
i
]
=
vec_input
[
thread_offset
+
i
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
...
...
@@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
* is 2
* IsBoundary: whether to make boundary judgment
* @param:
*
fix
: data offset of this block, blockDim.x * blockIdx.x * NX;
*
block_offset
: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host;
*
num: the num of o
ut
*
total_num_output: total num of outp
ut
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
fix
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
num
,
int
stride_nx
,
int
stride_ny
)
{
uint32_t
base_offset
=
fix
+
threadIdx
.
x
*
NX
;
uint32_t
offset
=
0
;
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
total_num_output
,
int
stride_n
x
,
int
stride_n
y
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
idx
=
base_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
uint32_t
index_output
=
thread_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
i
dx
>=
num
)
{
if
(
i
ndex_output
>=
total_num_output
)
{
break
;
}
}
offset
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ShapeSize
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
i
dx
);
i
dx
=
fast_divmoder
.
val
[
0
];
offset
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
i
ndex_output
);
i
ndex_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
dst
[
nx
+
ny
*
NX
]
=
src
[
index_src
];
}
}
}
...
...
@@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc(
* IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment
* @param:
*
fix
: data offset of this block, blockDim.x * blockIdx.x * NX;
*
block_offset
: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in
* host;
* size_nx: number of columns to be processed by the current block
...
...
@@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc(
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
fix
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
int
base_offset
=
fix
;
T
*
dst
,
const
T
*
__restrict__
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
int
thread_offset
=
0
;
if
(
reduce_last_dim
)
{
base_offset
+=
threadIdx
.
x
;
thread_offset
=
block_offset
+
threadIdx
.
x
;
}
else
{
base_offset
+=
threadIdx
.
y
;
thread_offset
=
block_offset
+
threadIdx
.
y
;
}
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
(
base
_offset
>=
size_ny
)
{
if
(
thread
_offset
>=
size_ny
)
{
break
;
}
}
uint32_t
offset
=
index_cal
(
base
_offset
);
dst
[
ny
]
=
src
[
offset
];
base
_offset
+=
stride_ny
;
uint32_t
index_src
=
index_cal
(
thread
_offset
);
dst
[
ny
]
=
src
[
index_src
];
thread
_offset
+=
stride_ny
;
}
}
else
{
#pragma unroll
...
...
@@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce(
break
;
}
}
uint32_t
offset
=
index_cal
(
base
_offset
);
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
base
_offset
+=
stride_ny
;
uint32_t
index_src
=
index_cal
(
thread
_offset
);
dst
[
nx
+
ny
*
NX
]
=
src
[
index_src
];
thread
_offset
+=
stride_ny
;
}
thread_offset
+=
stride_nx
;
}
}
}
/**
@brief: WriteData
/**
* @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
...
...
@@ -412,11 +413,11 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
num
)
{
if
(
IsBoundary
)
{
int
dx
=
threadIdx
.
x
*
NX
;
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
<
num
)
{
dst
[
idx
+
dx
]
=
src
[
idx
];
if
((
thread_offset
+
i
dx
)
<
num
)
{
dst
[
thread_offset
+
i
dx
]
=
src
[
idx
];
}
}
}
else
{
...
...
@@ -424,14 +425,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
const
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
dx
=
threadIdx
.
x
*
kVectorsPerThread
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
dx
+
idx
]
=
vec_temp
[
idx
];
vec_dst
[
thread_offset
+
idx
]
=
vec_temp
[
idx
];
}
}
}
...
...
paddle/fluid/operators/kernel_primitives/helper_primitives.h
浏览文件 @
82b33be3
...
...
@@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor(
static
__device__
__forceinline__
float
LogFunctor
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
LogFunctor
(
double
x
)
{
return
log
(
x
);
}
}
// namespace details
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template
<
typename
Tx
,
typename
Ty
=
Tx
>
...
...
@@ -75,6 +74,7 @@ struct DivideFunctor {
T
n_inv
;
};
}
// namespace details
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/margin_cross_entropy_op.cu
浏览文件 @
82b33be3
...
...
@@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel(
}
}
static
__device__
__forceinline__
platform
::
float16
exp_on_device
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
exp
(
x
);
}
static
__device__
__forceinline__
float
exp_on_device
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
exp_on_device
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
platform
::
float16
log_on_device
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
log
(
x
);
}
static
__device__
__forceinline__
float
log_on_device
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
log_on_device
(
double
x
)
{
return
log
(
x
);
}
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpLogitTransformer
{
HOSTDEVICE
explicit
inline
ExpLogitTransformer
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
exp_on_device
(
x
));
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpAndSum
{
using
Transformer
=
ExpLogitTransformer
<
Tx
>
;
using
Transformer
=
kpds
::
ExpLogitTransformer
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row,
const
int64_t
N
,
const
int64_t
D
)
{
CUDA_KERNEL_LOOP
(
i
,
N
*
D
)
{
auto
row
=
i
/
D
;
logits
[
i
]
-=
log_on_device
(
logits_sum_per_row
[
row
]);
logits
[
i
]
-=
kpds
::
LogFunctor
(
logits_sum_per_row
[
row
]);
}
}
...
...
@@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel(
if
((
col
+
start_index
)
==
labels
[
row
])
{
auto
softmax
=
log_softmax
[
i
];
loss
[
row
]
=
-
softmax
;
log_softmax
[
i
]
=
exp_on_device
(
softmax
);
log_softmax
[
i
]
=
kpds
::
ExpFunctor
(
softmax
);
}
else
{
log_softmax
[
i
]
=
exp_on_device
(
log_softmax
[
i
]);
log_softmax
[
i
]
=
kpds
::
ExpFunctor
(
log_softmax
[
i
]);
}
}
}
...
...
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
82b33be3
...
...
@@ -24,9 +24,11 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
kpds
=
paddle
::
operators
::
kernel_primitives
::
details
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMin
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
max
());
...
...
@@ -39,7 +41,7 @@ struct CustomMin {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMax
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
lowest
());
...
...
@@ -53,7 +55,7 @@ struct CustomMax {
// for cub::Reduce
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomSum
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
,
Ty
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -64,7 +66,7 @@ struct CustomSum {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
detail
::
DivideFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
DivideFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
@@ -75,7 +77,7 @@ struct CustomMean {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMul
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
1.0
f
);
}
...
...
@@ -86,7 +88,7 @@ struct CustomMul {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalOr
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
false
);
}
...
...
@@ -97,7 +99,7 @@ struct CustomLogicalOr {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalAnd
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
using
Transformer
=
kpds
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
true
);
}
...
...
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
82b33be3
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录