Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
890d6bc0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
890d6bc0
编写于
4月 22, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify some contents for elementwise op impl (#32414)
上级
1064f2b8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
23 addition
and
14 deletion
+23
-14
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+3
-2
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+20
-12
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
890d6bc0
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
...
...
@@ -34,7 +33,9 @@ namespace operators {
*/
template
<
typename
T
>
struct
CudaAddFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
args
[])
const
{
return
args
[
0
]
+
args
[
1
];
}
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
+
args
[
1
];
}
};
template
<
typename
T
>
...
...
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
890d6bc0
...
...
@@ -13,6 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -90,8 +101,7 @@ struct ElementwiseDataWrapper {
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
T
,
typename
Functor
>
__device__
void
VectorizedKernelImpl
(
ElementwiseDataWrapper
<
ET
,
VecSize
,
T
>
data
,
int
size
,
Functor
func
,
int
tid
)
{
ElementwiseDataWrapper
<
ET
,
VecSize
,
T
>
data
,
Functor
func
,
int
tid
)
{
using
VecType
=
CudaAlignedVector
<
T
,
VecSize
>
;
VecType
ins_vec
[
ET
];
VecType
out_vec
;
...
...
@@ -121,10 +131,9 @@ __device__ void VectorizedKernelImpl(
data
.
store_vector
(
out_vec
,
tid
);
}
template
<
ElementwiseType
ET
,
typename
T
,
typename
Functor
>
__device__
void
ScalarKernelImpl
(
ElementwiseDataWrapper
<
ET
,
1
,
T
>
data
,
int
size
,
Functor
func
,
int
start
,
int
remain
)
{
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
T
,
typename
Functor
>
__device__
void
ScalarKernelImpl
(
ElementwiseDataWrapper
<
ET
,
VecSize
,
T
>
data
,
Functor
func
,
int
start
,
int
remain
)
{
T
ins
[
ET
];
T
out
;
...
...
@@ -146,12 +155,11 @@ __global__ void VectorizedKernel(const T *__restrict__ in0,
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
remain
=
size
-
VecSize
*
tid
;
remain
=
remain
>
0
?
remain
:
0
;
auto
data
=
ElementwiseDataWrapper
<
ET
,
VecSize
,
T
>
(
out
,
in0
,
in1
);
if
(
remain
>=
VecSize
)
{
auto
data
=
ElementwiseDataWrapper
<
ET
,
VecSize
,
T
>
(
out
,
in0
,
in1
);
VectorizedKernelImpl
(
data
,
size
,
func
,
tid
);
VectorizedKernelImpl
(
data
,
func
,
tid
);
}
else
{
auto
data
=
ElementwiseDataWrapper
<
ET
,
1
,
T
>
(
out
,
in0
,
in1
);
ScalarKernelImpl
(
data
,
size
,
func
,
tid
*
VecSize
,
remain
);
ScalarKernelImpl
(
data
,
func
,
tid
*
VecSize
,
remain
);
}
}
...
...
@@ -162,7 +170,7 @@ __global__ void ScalarKernel(const T *__restrict__ in0,
auto
data
=
ElementwiseDataWrapper
<
ET
,
1
,
T
>
(
out
,
in0
,
in1
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
remain
=
tid
<
size
?
1
:
0
;
ScalarKernelImpl
(
data
,
size
,
func
,
tid
,
remain
);
ScalarKernelImpl
(
data
,
func
,
tid
,
remain
);
}
template
<
ElementwiseType
ET
,
typename
T
,
typename
Functor
>
...
...
@@ -173,7 +181,7 @@ void LaunchElementwiseCudaKernel(
// calculate the max vec_size for all ins and outs
auto
size
=
ins
[
0
]
->
numel
();
int
vec_size
=
GetVectorizedSize
<
T
>
(
ins
,
*
outs
);
int
block_size
=
PADDLE_CUDA_THREAD
_SIZE
;
int
block_size
=
ELEMENTWISE_BLOCK
_SIZE
;
int
grid_size
=
((
size
+
vec_size
-
1
)
/
vec_size
+
block_size
-
1
)
/
block_size
;
const
T
*
in0
=
ins
[
0
]
->
data
<
T
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录