Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
14dd68e1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
14dd68e1
编写于
2月 02, 2023
作者:
L
liuruyan
提交者:
GitHub
2月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the FP16 precision problem of add_n. (#50129)
上级
ec6e0a2c
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
75 addition
and
9 deletion
+75
-9
paddle/phi/kernels/gpu/add_n_kernel.cu
paddle/phi/kernels/gpu/add_n_kernel.cu
+11
-9
python/paddle/fluid/tests/unittests/test_add_n_op.py
python/paddle/fluid/tests/unittests/test_add_n_op.py
+64
-0
未找到文件。
paddle/phi/kernels/gpu/add_n_kernel.cu
浏览文件 @
14dd68e1
...
@@ -14,11 +14,10 @@
...
@@ -14,11 +14,10 @@
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
namespace
phi
{
namespace
phi
{
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
...
@@ -38,16 +37,18 @@ __global__ void Sum2CUDAKernel(const T *in_0,
...
@@ -38,16 +37,18 @@ __global__ void Sum2CUDAKernel(const T *in_0,
template
<
class
T
>
template
<
class
T
>
__global__
void
SumArrayCUDAKernel
(
__global__
void
SumArrayCUDAKernel
(
T
**
in
,
T
*
out
,
int64_t
N
,
size_t
in_size
,
bool
read_dst
)
{
T
**
in
,
T
*
out
,
int64_t
N
,
size_t
in_size
,
bool
read_dst
)
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
id
<
N
)
{
while
(
id
<
N
)
{
T
total
(
read_dst
?
out
[
id
]
:
static_cast
<
T
>
(
0
));
MPType
total
(
read_dst
?
static_cast
<
MPType
>
(
out
[
id
])
:
static_cast
<
MPType
>
(
0
));
for
(
int
i
=
0
;
i
<
in_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
in_size
;
++
i
)
{
const
T
*
tmp
=
in
[
i
];
const
T
*
tmp
=
in
[
i
];
if
(
tmp
)
{
if
(
tmp
)
{
total
+=
tmp
[
id
]
;
total
+=
static_cast
<
MPType
>
(
tmp
[
id
])
;
}
}
}
}
out
[
id
]
=
total
;
out
[
id
]
=
static_cast
<
T
>
(
total
)
;
id
+=
blockDim
.
x
*
gridDim
.
x
;
id
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
}
}
...
@@ -116,11 +117,12 @@ void AddNKernel(const Context &dev_ctx,
...
@@ -116,11 +117,12 @@ void AddNKernel(const Context &dev_ctx,
int64_t
length_0
=
in_0
.
numel
();
int64_t
length_0
=
in_0
.
numel
();
int64_t
length_1
=
in_1
.
numel
();
int64_t
length_1
=
in_1
.
numel
();
if
(
length_0
&&
length_1
&&
in_0
.
IsInitialized
()
&&
in_1
.
IsInitialized
())
{
if
(
length_0
&&
length_1
&&
in_0
.
IsInitialized
()
&&
in_1
.
IsInitialized
())
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
in_0_e
=
EigenVector
<
T
>::
Flatten
(
in_0
);
auto
in_0_e
=
EigenVector
<
T
>::
Flatten
(
in_0
)
.
template
cast
<
MPType
>()
;
auto
in_1_e
=
EigenVector
<
T
>::
Flatten
(
in_1
);
auto
in_1_e
=
EigenVector
<
T
>::
Flatten
(
in_1
)
.
template
cast
<
MPType
>()
;
result
.
device
(
place
)
=
in_0_e
+
in_1_e
;
result
.
device
(
place
)
=
(
in_0_e
+
in_1_e
).
template
cast
<
T
>()
;
}
else
if
(
length_0
&&
in_0
.
IsInitialized
())
{
}
else
if
(
length_0
&&
in_0
.
IsInitialized
())
{
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
&
place
=
*
dev_ctx
.
eigen_device
();
...
...
python/paddle/fluid/tests/unittests/test_add_n_op.py
0 → 100644
浏览文件 @
14dd68e1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
class
TestAddnOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
20
)
l
=
32
self
.
x_np
=
np
.
random
.
random
([
l
,
16
,
256
])
def
check_main
(
self
,
x_np
,
dtype
,
axis
=
None
):
paddle
.
disable_static
()
x
=
[]
for
i
in
range
(
x_np
.
shape
[
0
]):
val
=
paddle
.
to_tensor
(
x_np
[
i
].
astype
(
dtype
))
val
.
stop_gradient
=
False
x
.
append
(
val
)
y
=
paddle
.
add_n
(
x
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
.
numpy
().
astype
(
'float32'
)
x_g_np
=
[]
for
val
in
x_g
:
x_g_np
.
append
(
val
.
numpy
().
astype
(
'float32'
))
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_add_n_fp16
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
y_np_16
,
x_g_np_16
=
self
.
check_main
(
self
.
x_np
,
'float16'
)
y_np_32
,
x_g_np_32
=
self
.
check_main
(
self
.
x_np
,
'float32'
)
np
.
testing
.
assert_allclose
(
y_np_16
,
y_np_32
,
rtol
=
1e-03
)
for
i
in
range
(
len
(
x_g_np_32
)):
np
.
testing
.
assert_allclose
(
x_g_np_16
[
i
],
x_g_np_32
[
i
],
rtol
=
1e-03
)
def
test_add_n_api
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
y_np_32
,
x_g_np_32
=
self
.
check_main
(
self
.
x_np
,
'float32'
)
y_np_gt
=
np
.
sum
(
self
.
x_np
,
axis
=
0
).
astype
(
'float32'
)
np
.
testing
.
assert_allclose
(
y_np_32
,
y_np_gt
,
rtol
=
1e-06
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录