Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
480b284c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
480b284c
编写于
6月 22, 2021
作者:
N
niuliling123
提交者:
GitHub
6月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modified reduce_max, reduce_min, reduce_prod to higher_performance implementation. (#32974)
上级
20eafd79
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
349 addition
and
177 deletion
+349
-177
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+68
-16
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
+9
-11
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
+9
-11
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+251
-123
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
+12
-16
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
480b284c
...
...
@@ -13,46 +13,98 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include <cmath>
#include <limits>
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/macros.h"
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
CustomMin
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
max
());
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
b
<
a
)
?
b
:
a
;
}
};
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
CustomMax
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
Ty
>::
lowest
());
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
b
>
a
)
?
b
:
a
;
}
};
template
<
typename
T
>
// for cub::Reduce
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomSum
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
T
>
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
detail
::
DivideFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMul
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
1.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
*
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalOr
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
false
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
||
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomLogicalAnd
{
using
Transformer
=
detail
::
IdentityFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
true
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
&&
a
;
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/reduce_max_op.cu
浏览文件 @
480b284c
...
...
@@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_max
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
MaxFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
MaxFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
MaxFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
MaxFunctor
>
);
// reduce_max
REGISTER_OP_CUDA_KERNEL
(
reduce_max
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMax
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMax
>
);
paddle/fluid/operators/reduce_ops/reduce_min_op.cu
浏览文件 @
480b284c
...
...
@@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL
(
reduce_min
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
MinFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
MinFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
MinFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
MinFunctor
>
);
// reduce_min
REGISTER_OP_CUDA_KERNEL
(
reduce_min
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMin
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMin
>
);
paddle/fluid/operators/reduce_ops/reduce_op.cuh
→
paddle/fluid/operators/reduce_ops/reduce_op.cu
.
h
浏览文件 @
480b284c
此差异已折叠。
点击以展开。
paddle/fluid/operators/reduce_ops/reduce_prod_op.cu
浏览文件 @
480b284c
...
...
@@ -12,26 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
// reduce_prod
#ifdef __HIPCC__
// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922
// do not support double in HIPCC platform (Eigen3 to be fixed)
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
ProdFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMul
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
ops
::
ProdFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
ProdFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
reduce_prod
,
ops
::
ReduceCudaKernel
<
float
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
double
,
paddle
::
operators
::
CustomMul
>
,
ops
::
ReduceCudaKernel
<
int64_t
,
paddle
::
operators
::
CustomMul
>
);
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录