Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
24258c27
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
24258c27
编写于
3月 10, 2023
作者:
陈
陈沧夜
提交者:
GitHub
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
No.54:为 Paddle allclose、isclose 算子实现 float16 数据类型支持 (#51168)
上级
07d8770f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
110 addition
and
22 deletion
+110
-22
paddle/phi/kernels/gpu/allclose_kernel.cu
paddle/phi/kernels/gpu/allclose_kernel.cu
+15
-6
paddle/phi/kernels/gpu/isclose_kernel.cu
paddle/phi/kernels/gpu/isclose_kernel.cu
+8
-2
paddle/phi/kernels/impl/isclose_kernel_impl.h
paddle/phi/kernels/impl/isclose_kernel_impl.h
+7
-4
python/paddle/fluid/tests/unittests/test_allclose_op.py
python/paddle/fluid/tests/unittests/test_allclose_op.py
+32
-1
python/paddle/fluid/tests/unittests/test_isclose_op.py
python/paddle/fluid/tests/unittests/test_isclose_op.py
+32
-1
python/paddle/tensor/logic.py
python/paddle/tensor/logic.py
+16
-8
未找到文件。
paddle/phi/kernels/gpu/allclose_kernel.cu
浏览文件 @
24258c27
...
...
@@ -16,6 +16,8 @@
#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
...
...
@@ -31,14 +33,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data,
bool
*
out_data
)
{
unsigned
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
bool
val
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
for
(
int
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
const
T
a
=
in_data
[
i
],
b
=
other_data
[
i
];
const
MPType
a
=
static_cast
<
MPType
>
(
in_data
[
i
]);
const
MPType
b
=
static_cast
<
MPType
>
(
other_data
[
i
]);
if
(
isnan
(
a
)
||
isnan
(
b
))
{
val
=
equal_nan
&&
isnan
(
a
)
==
isnan
(
b
);
}
else
{
T
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
T
right
=
atol
+
(
b
>
0
?
rtol
*
b
:
(
-
rtol
)
*
b
);
T
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
MPType
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
MPType
right
=
atol
+
(
b
>
0
?
rtol
*
b
:
(
-
rtol
)
*
b
);
MPType
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
val
=
a
==
b
||
left
<=
right
||
diff
<=
1e-15
;
}
if
(
!
val
)
*
out_data
=
false
;
...
...
@@ -92,7 +96,12 @@ void AllCloseKernel(const Context& dev_ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
allclose
,
GPU
,
ALL_LAYOUT
,
phi
::
AllCloseKernel
,
float
,
double
)
{
PD_REGISTER_KERNEL
(
allclose
,
GPU
,
ALL_LAYOUT
,
phi
::
AllCloseKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
BOOL
);
}
paddle/phi/kernels/gpu/isclose_kernel.cu
浏览文件 @
24258c27
...
...
@@ -15,8 +15,14 @@
#include "paddle/phi/kernels/isclose_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"
PD_REGISTER_KERNEL
(
isclose
,
GPU
,
ALL_LAYOUT
,
phi
::
IscloseKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
isclose
,
GPU
,
ALL_LAYOUT
,
phi
::
IscloseKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/impl/isclose_kernel_impl.h
浏览文件 @
24258c27
...
...
@@ -18,6 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
...
...
@@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data,
bool
*
out_data
)
{
unsigned
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
bool
val
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
for
(
int
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
const
T
a
=
in_data
[
i
],
b
=
other_data
[
i
];
const
MPType
a
=
static_cast
<
MPType
>
(
in_data
[
i
]);
const
MPType
b
=
static_cast
<
MPType
>
(
other_data
[
i
]);
if
(
isnan
(
a
)
||
isnan
(
b
))
{
val
=
equal_nan
&&
isnan
(
a
)
==
isnan
(
b
);
}
else
{
T
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
T
right
=
atol
+
(
b
>
0
?
rtol
*
b
:
(
-
rtol
)
*
b
);
T
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
MPType
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
MPType
right
=
atol
+
(
b
>
0
?
rtol
*
b
:
(
-
rtol
)
*
b
);
MPType
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
val
=
a
==
b
||
left
<=
right
||
diff
<=
1e-15
;
}
out_data
[
i
]
=
val
;
...
...
python/paddle/fluid/tests/unittests/test_allclose_op.py
浏览文件 @
24258c27
...
...
@@ -18,6 +18,7 @@ import numpy as np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid.core
as
core
class
TestAllcloseOp
(
OpTest
):
...
...
@@ -134,7 +135,7 @@ class TestAllcloseError(unittest.TestCase):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
):
x
=
paddle
.
fluid
.
data
(
name
=
'x'
,
shape
=
[
10
,
10
],
dtype
=
'
float16
'
)
x
=
paddle
.
fluid
.
data
(
name
=
'x'
,
shape
=
[
10
,
10
],
dtype
=
'
int32
'
)
y
=
paddle
.
fluid
.
data
(
name
=
'y'
,
shape
=
[
10
,
10
],
dtype
=
'float64'
)
result
=
paddle
.
allclose
(
x
,
y
)
...
...
@@ -170,6 +171,36 @@ class TestAllcloseError(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_equal_nan
)
class
TestAllcloseOpFp16
(
unittest
.
TestCase
):
def
test_fp16
(
self
):
x_data
=
np
.
random
.
rand
(
10
,
10
).
astype
(
'float16'
)
y_data
=
np
.
random
.
rand
(
10
,
10
).
astype
(
'float16'
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
shape
=
[
10
,
10
],
name
=
'x'
,
dtype
=
'float16'
)
y
=
paddle
.
static
.
data
(
shape
=
[
10
,
10
],
name
=
'x'
,
dtype
=
'float16'
)
out
=
paddle
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
)
if
core
.
is_compiled_with_cuda
():
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
out
=
exe
.
run
(
feed
=
{
'x'
:
x_data
,
'y'
:
y_data
},
fetch_list
=
[
out
])
class
TestAllcloseOpFloat16
(
TestAllcloseOp
):
def
set_args
(
self
):
self
.
input
=
np
.
array
([
10.1
]).
astype
(
"float16"
)
self
.
other
=
np
.
array
([
10
]).
astype
(
"float16"
)
self
.
rtol
=
np
.
array
([
0.01
]).
astype
(
"float64"
)
self
.
atol
=
np
.
array
([
0
]).
astype
(
"float64"
)
self
.
equal_nan
=
False
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
check_eager
=
True
)
class
TestAllcloseOpFloat32
(
TestAllcloseOp
):
def
set_args
(
self
):
self
.
input
=
np
.
array
([
10.1
]).
astype
(
"float32"
)
...
...
python/paddle/fluid/tests/unittests/test_isclose_op.py
浏览文件 @
24258c27
...
...
@@ -18,6 +18,7 @@ import numpy as np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid.core
as
core
class
TestIscloseOp
(
OpTest
):
...
...
@@ -166,7 +167,7 @@ class TestIscloseError(unittest.TestCase):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
):
x
=
paddle
.
fluid
.
data
(
name
=
'x'
,
shape
=
[
10
,
10
],
dtype
=
'
float16
'
)
x
=
paddle
.
fluid
.
data
(
name
=
'x'
,
shape
=
[
10
,
10
],
dtype
=
'
int32
'
)
y
=
paddle
.
fluid
.
data
(
name
=
'y'
,
shape
=
[
10
,
10
],
dtype
=
'float64'
)
result
=
paddle
.
isclose
(
x
,
y
)
...
...
@@ -203,6 +204,36 @@ class TestIscloseError(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_equal_nan
)
class
TestIscloseOpFp16
(
unittest
.
TestCase
):
def
test_fp16
(
self
):
x_data
=
np
.
random
.
rand
(
10
,
10
).
astype
(
'float16'
)
y_data
=
np
.
random
.
rand
(
10
,
10
).
astype
(
'float16'
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
shape
=
[
10
,
10
],
name
=
'x'
,
dtype
=
'float16'
)
y
=
paddle
.
static
.
data
(
shape
=
[
10
,
10
],
name
=
'x'
,
dtype
=
'float16'
)
out
=
paddle
.
isclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
)
if
core
.
is_compiled_with_cuda
():
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
out
=
exe
.
run
(
feed
=
{
'x'
:
x_data
,
'y'
:
y_data
},
fetch_list
=
[
out
])
class
TestIscloseOpFloat16
(
TestIscloseOp
):
def
set_args
(
self
):
self
.
input
=
np
.
array
([
10.1
]).
astype
(
"float16"
)
self
.
other
=
np
.
array
([
10
]).
astype
(
"float16"
)
self
.
rtol
=
np
.
array
([
0.01
]).
astype
(
"float64"
)
self
.
atol
=
np
.
array
([
0
]).
astype
(
"float64"
)
self
.
equal_nan
=
False
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
check_eager
=
True
)
class
TestIscloseOpFloat32
(
TestIscloseOp
):
def
set_args
(
self
):
self
.
input
=
np
.
array
([
10.1
]).
astype
(
"float32"
)
...
...
python/paddle/tensor/logic.py
浏览文件 @
24258c27
...
...
@@ -361,8 +361,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance.
Args:
x(Tensor): The input tensor, it's data type should be float32, float64..
y(Tensor): The input tensor, it's data type should be float32, float64..
x(Tensor): The input tensor, it's data type should be float
16, float
32, float64..
y(Tensor): The input tensor, it's data type should be float
16, float
32, float64..
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): ${equal_nan_comment}.
...
...
@@ -401,8 +401,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
if
in_dygraph_mode
():
return
_C_ops
.
allclose
(
x
,
y
,
rtol
,
atol
,
equal_nan
)
else
:
check_variable_and_dtype
(
x
,
"input"
,
[
'float32'
,
'float64'
],
'allclose'
)
check_variable_and_dtype
(
y
,
"input"
,
[
'float32'
,
'float64'
],
'allclose'
)
check_variable_and_dtype
(
x
,
"input"
,
[
'float16'
,
'float32'
,
'float64'
],
'allclose'
)
check_variable_and_dtype
(
y
,
"input"
,
[
'float16'
,
'float32'
,
'float64'
],
'allclose'
)
check_type
(
rtol
,
'rtol'
,
float
,
'allclose'
)
check_type
(
atol
,
'atol'
,
float
,
'allclose'
)
check_type
(
equal_nan
,
'equal_nan'
,
bool
,
'allclose'
)
...
...
@@ -989,8 +993,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance.
Args:
x(Tensor): The input tensor, it's data type should be float32, float64.
y(Tensor): The input tensor, it's data type should be float32, float64.
x(Tensor): The input tensor, it's data type should be float
16, float
32, float64.
y(Tensor): The input tensor, it's data type should be float
16, float
32, float64.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
...
...
@@ -1027,8 +1031,12 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
if
in_dygraph_mode
():
return
_C_ops
.
isclose
(
x
,
y
,
rtol
,
atol
,
equal_nan
)
else
:
check_variable_and_dtype
(
x
,
"input"
,
[
'float32'
,
'float64'
],
'isclose'
)
check_variable_and_dtype
(
y
,
"input"
,
[
'float32'
,
'float64'
],
'isclose'
)
check_variable_and_dtype
(
x
,
"input"
,
[
'float16'
,
'float32'
,
'float64'
],
'isclose'
)
check_variable_and_dtype
(
y
,
"input"
,
[
'float16'
,
'float32'
,
'float64'
],
'isclose'
)
check_type
(
rtol
,
'rtol'
,
float
,
'isclose'
)
check_type
(
atol
,
'atol'
,
float
,
'isclose'
)
check_type
(
equal_nan
,
'equal_nan'
,
bool
,
'isclose'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录