Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
585f1136
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
585f1136
编写于
6月 02, 2023
作者:
C
Charles-hit
提交者:
GitHub
6月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support some prim ops for bf16 dtype (#54285)
上级
f5342918
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
170 addition
and
17 deletion
+170
-17
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+1
-0
test/legacy_test/test_elementwise_sub_op.py
test/legacy_test/test_elementwise_sub_op.py
+28
-2
test/legacy_test/test_gather_op.py
test/legacy_test/test_gather_op.py
+97
-6
test/legacy_test/test_reduce_op.py
test/legacy_test/test_reduce_op.py
+43
-7
test/legacy_test/test_slice_op.py
test/legacy_test/test_slice_op.py
+1
-2
未找到文件。
python/paddle/tensor/manipulation.py
浏览文件 @
585f1136
...
@@ -2735,6 +2735,7 @@ def gather(x, index, axis=None, name=None):
...
@@ -2735,6 +2735,7 @@ def gather(x, index, axis=None, name=None):
'int32'
,
'int32'
,
'int64'
,
'int64'
,
'uint8'
,
'uint8'
,
'uint16'
,
],
],
'gather'
,
'gather'
,
)
)
...
...
test/legacy_test/test_elementwise_sub_op.py
浏览文件 @
585f1136
...
@@ -71,7 +71,7 @@ class TestElementwiseOp(OpTest):
...
@@ -71,7 +71,7 @@ class TestElementwiseOp(OpTest):
self
.
check_prim
=
True
self
.
check_prim
=
True
def
if_enable_cinn
(
self
):
def
if_enable_cinn
(
self
):
self
.
enable_cinn
=
False
pass
class
TestElementwiseFP16OP
(
TestElementwiseOp
):
class
TestElementwiseFP16OP
(
TestElementwiseOp
):
...
@@ -87,6 +87,7 @@ class TestElementwiseFP16OP(TestElementwiseOp):
...
@@ -87,6 +87,7 @@ class TestElementwiseFP16OP(TestElementwiseOp):
class
TestElementwiseBF16OP
(
TestElementwiseOp
):
class
TestElementwiseBF16OP
(
TestElementwiseOp
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
op_type
=
"elementwise_sub"
self
.
prim_op_type
=
"prim"
self
.
dtype
=
np
.
uint16
self
.
dtype
=
np
.
uint16
self
.
python_api
=
paddle
.
subtract
self
.
python_api
=
paddle
.
subtract
self
.
public_python_api
=
paddle
.
subtract
self
.
public_python_api
=
paddle
.
subtract
...
@@ -103,6 +104,9 @@ class TestElementwiseBF16OP(TestElementwiseOp):
...
@@ -103,6 +104,9 @@ class TestElementwiseBF16OP(TestElementwiseOp):
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
self
.
enable_cinn
=
False
def
test_check_grad_normal
(
self
):
def
test_check_grad_normal
(
self
):
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
self
.
check_grad_with_place
(
...
@@ -118,7 +122,12 @@ class TestElementwiseBF16OP(TestElementwiseOp):
...
@@ -118,7 +122,12 @@ class TestElementwiseBF16OP(TestElementwiseOp):
def
test_check_grad_ingore_y
(
self
):
def
test_check_grad_ingore_y
(
self
):
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
no_grad_set
=
set
(
'Y'
),
max_relative_error
=
0.1
place
,
[
'X'
],
'Out'
,
no_grad_set
=
set
(
'Y'
),
max_relative_error
=
0.1
,
check_prim
=
True
,
)
)
...
@@ -135,6 +144,10 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
...
@@ -135,6 +144,10 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
}
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
self
.
enable_cinn
=
False
class
TestElementwiseSubFP16OP_ZeroDim1
(
TestElementwiseSubOp_ZeroDim1
):
class
TestElementwiseSubFP16OP_ZeroDim1
(
TestElementwiseSubOp_ZeroDim1
):
...
@@ -181,6 +194,10 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
...
@@ -181,6 +194,10 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
}
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
self
.
enable_cinn
=
False
class
TestElementwiseSubFP16OP_ZeroDim2
(
TestElementwiseSubOp_ZeroDim2
):
class
TestElementwiseSubFP16OP_ZeroDim2
(
TestElementwiseSubOp_ZeroDim2
):
...
@@ -227,6 +244,10 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
...
@@ -227,6 +244,10 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
}
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
self
.
enable_cinn
=
False
class
TestElementwiseSubFP16OP_ZeroDim3
(
TestElementwiseSubOp_ZeroDim3
):
class
TestElementwiseSubFP16OP_ZeroDim3
(
TestElementwiseSubOp_ZeroDim3
):
...
@@ -580,6 +601,7 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
...
@@ -580,6 +601,7 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
}
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -653,6 +675,7 @@ class TestElementwiseBF16OP_commonuse_1(TestElementwiseBF16OP):
...
@@ -653,6 +675,7 @@ class TestElementwiseBF16OP_commonuse_1(TestElementwiseBF16OP):
}
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
class
TestElementwiseSubOp_commonuse_2
(
TestElementwiseOp
):
class
TestElementwiseSubOp_commonuse_2
(
TestElementwiseOp
):
...
@@ -698,6 +721,7 @@ class TestElementwiseBF16OP_commonuse_2(TestElementwiseBF16OP):
...
@@ -698,6 +721,7 @@ class TestElementwiseBF16OP_commonuse_2(TestElementwiseBF16OP):
}
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
class
TestElementwiseSubOp_xsize_lessthan_ysize
(
TestElementwiseOp
):
class
TestElementwiseSubOp_xsize_lessthan_ysize
(
TestElementwiseOp
):
...
@@ -717,6 +741,7 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
...
@@ -717,6 +741,7 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
'Out'
:
self
.
inputs
[
'X'
].
reshape
(
1
,
1
,
10
,
12
)
-
self
.
inputs
[
'Y'
]
'Out'
:
self
.
inputs
[
'X'
].
reshape
(
1
,
1
,
10
,
12
)
-
self
.
inputs
[
'Y'
]
}
}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
class
TestElementwiseSubFP16OP_xsize_lessthan_ysize
(
class
TestElementwiseSubFP16OP_xsize_lessthan_ysize
(
...
@@ -750,6 +775,7 @@ class TestElementwiseBF16OP_xsize_lessthan_ysize(TestElementwiseBF16OP):
...
@@ -750,6 +775,7 @@ class TestElementwiseBF16OP_xsize_lessthan_ysize(TestElementwiseBF16OP):
}
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])}
self
.
if_check_prim
()
self
.
if_check_prim
()
self
.
if_enable_cinn
()
class
TestComplexElementwiseSubOp
(
OpTest
):
class
TestComplexElementwiseSubOp
(
OpTest
):
...
...
test/legacy_test/test_gather_op.py
浏览文件 @
585f1136
...
@@ -37,12 +37,8 @@ class TestGatherOp(OpTest):
...
@@ -37,12 +37,8 @@ class TestGatherOp(OpTest):
self
.
public_python_api
=
paddle
.
gather
self
.
public_python_api
=
paddle
.
gather
self
.
config
()
self
.
config
()
self
.
prim_op_type
=
"prim"
self
.
prim_op_type
=
"prim"
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
self
.
init_inputs_and_outputs
()
self
.
inputs
=
{
self
.
if_enable_cinn
()
'X'
:
xnp
,
'Index'
:
np
.
array
(
self
.
index
).
astype
(
self
.
index_type
),
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
"X"
][
self
.
inputs
[
"Index"
]]}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -62,12 +58,56 @@ class TestGatherOp(OpTest):
...
@@ -62,12 +58,56 @@ class TestGatherOp(OpTest):
def
config_dtype
(
self
):
def
config_dtype
(
self
):
self
.
x_type
=
"float64"
self
.
x_type
=
"float64"
def
init_inputs_and_outputs
(
self
):
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
self
.
inputs
=
{
'X'
:
xnp
,
'Index'
:
np
.
array
(
self
.
index
).
astype
(
self
.
index_type
),
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
"X"
][
self
.
inputs
[
"Index"
]]}
def
if_enable_cinn
(
self
):
pass
class
TestGatherOpFP16
(
TestGatherOp
):
class
TestGatherOpFP16
(
TestGatherOp
):
def
config_dtype
(
self
):
def
config_dtype
(
self
):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
core
.
cudnn_version
()
<
8100
or
paddle
.
device
.
cuda
.
get_device_capability
()[
0
]
<
8
,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0"
,
)
class
TestGatherOpBFP16
(
TestGatherOp
):
def
config_dtype
(
self
):
self
.
x_type
=
"float32"
self
.
dtype
=
np
.
uint16
def
init_inputs_and_outputs
(
self
):
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
self
.
inputs
=
{
'X'
:
convert_float_to_uint16
(
xnp
),
'Index'
:
np
.
array
(
self
.
index
).
astype
(
self
.
index_type
),
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
xnp
[
self
.
inputs
[
"Index"
]])
}
def
if_enable_cinn
(
self
):
self
.
enable_cinn
=
False
def
test_check_output
(
self
):
self
.
check_output_with_place
(
place
=
paddle
.
CUDAPlace
(
0
))
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
paddle
.
CUDAPlace
(
0
),
[
'X'
],
'Out'
,
check_prim
=
True
)
class
TestCase1
(
TestGatherOp
):
class
TestCase1
(
TestGatherOp
):
def
config
(
self
):
def
config
(
self
):
"""
"""
...
@@ -87,6 +127,14 @@ class TestCase1FP16(TestCase1):
...
@@ -87,6 +127,14 @@ class TestCase1FP16(TestCase1):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
class
TestCase1BFP16
(
TestGatherOpBFP16
):
def
config
(
self
):
self
.
x_shape
=
100
self
.
config_dtype
()
self
.
index
=
[
1
,
3
,
5
]
self
.
index_type
=
"int32"
class
TestCase2
(
TestGatherOp
):
class
TestCase2
(
TestGatherOp
):
def
config
(
self
):
def
config
(
self
):
"""
"""
...
@@ -106,6 +154,14 @@ class TestCase2FP16(TestCase2):
...
@@ -106,6 +154,14 @@ class TestCase2FP16(TestCase2):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
class
TestCase2BFP16
(
TestGatherOpBFP16
):
def
config
(
self
):
self
.
x_shape
=
100
self
.
config_dtype
()
self
.
index
=
[
1
,
3
,
5
]
self
.
index_type
=
"int64"
class
TestCase3
(
TestGatherOp
):
class
TestCase3
(
TestGatherOp
):
def
config
(
self
):
def
config
(
self
):
"""
"""
...
@@ -125,6 +181,14 @@ class TestCase3Fp16(TestCase3):
...
@@ -125,6 +181,14 @@ class TestCase3Fp16(TestCase3):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
class
TestCase3BFP16
(
TestGatherOpBFP16
):
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
config_dtype
()
self
.
index
=
[
1
,
3
,
5
]
self
.
index_type
=
"int64"
class
TestCase4
(
TestGatherOp
):
class
TestCase4
(
TestGatherOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
x_shape
=
(
10
,
20
)
...
@@ -142,6 +206,15 @@ class TestCase4FP16(TestCase4):
...
@@ -142,6 +206,15 @@ class TestCase4FP16(TestCase4):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
class
TestCase4BFP16
(
TestGatherOpBFP16
):
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
attrs
=
{
'overwrite'
:
False
}
self
.
config_dtype
()
self
.
index
=
[
1
,
1
]
self
.
index_type
=
"int32"
class
TestCase5
(
TestGatherOp
):
class
TestCase5
(
TestGatherOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
x_shape
=
(
10
,
20
)
...
@@ -154,6 +227,15 @@ class TestCase5(TestGatherOp):
...
@@ -154,6 +227,15 @@ class TestCase5(TestGatherOp):
self
.
x_type
=
"float64"
self
.
x_type
=
"float64"
class
TestCase5BFP16
(
TestGatherOpBFP16
):
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
attrs
=
{
'overwrite'
:
False
}
self
.
config_dtype
()
self
.
index
=
[
1
,
1
]
self
.
index_type
=
"int32"
class
TestCase5FP16
(
TestCase5
):
class
TestCase5FP16
(
TestCase5
):
def
config_dtype
(
self
):
def
config_dtype
(
self
):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
...
@@ -176,6 +258,15 @@ class TestCase6FP16(TestCase6):
...
@@ -176,6 +258,15 @@ class TestCase6FP16(TestCase6):
self
.
x_type
=
"float16"
self
.
x_type
=
"float16"
class
TestCase6BFP16
(
TestGatherOpBFP16
):
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
attrs
=
{
'overwrite'
:
True
}
self
.
config_dtype
()
self
.
index
=
[
1
,
3
]
self
.
index_type
=
"int32"
class
TestGatherBF16Op
(
OpTest
):
class
TestGatherBF16Op
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"gather"
self
.
op_type
=
"gather"
...
...
test/legacy_test/test_reduce_op.py
浏览文件 @
585f1136
...
@@ -36,7 +36,7 @@ class TestSumOp(OpTest):
...
@@ -36,7 +36,7 @@ class TestSumOp(OpTest):
self
.
prim_op_type
=
"prim"
self
.
prim_op_type
=
"prim"
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
init_dtype
(
self
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float64
...
@@ -47,6 +47,9 @@ class TestSumOp(OpTest):
...
@@ -47,6 +47,9 @@ class TestSumOp(OpTest):
def
init_attrs
(
self
):
def
init_attrs
(
self
):
self
.
attrs
=
{
'dim'
:
[
0
]}
self
.
attrs
=
{
'dim'
:
[
0
]}
def
if_enable_cinn
(
self
):
pass
def
calc_output
(
self
):
def
calc_output
(
self
):
self
.
out
=
self
.
x
.
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
self
.
out
=
self
.
x
.
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
...
@@ -984,7 +987,10 @@ class Test1DReduce(OpTest):
...
@@ -984,7 +987,10 @@ class Test1DReduce(OpTest):
self
.
prim_op_type
=
"prim"
self
.
prim_op_type
=
"prim"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
120
).
astype
(
"float64"
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
120
).
astype
(
"float64"
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
0
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
0
)}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1002,6 +1008,7 @@ class Test2DReduce0(Test1DReduce):
...
@@ -1002,6 +1008,7 @@ class Test2DReduce0(Test1DReduce):
self
.
attrs
=
{
'dim'
:
[
0
]}
self
.
attrs
=
{
'dim'
:
[
0
]}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
20
,
10
)).
astype
(
"float64"
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
20
,
10
)).
astype
(
"float64"
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
0
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
0
)}
self
.
if_enable_cinn
()
class
Test2DReduce1
(
Test1DReduce
):
class
Test2DReduce1
(
Test1DReduce
):
...
@@ -1015,6 +1022,7 @@ class Test2DReduce1(Test1DReduce):
...
@@ -1015,6 +1022,7 @@ class Test2DReduce1(Test1DReduce):
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
}
}
self
.
if_enable_cinn
()
class
Test3DReduce0
(
Test1DReduce
):
class
Test3DReduce0
(
Test1DReduce
):
...
@@ -1028,6 +1036,7 @@ class Test3DReduce0(Test1DReduce):
...
@@ -1028,6 +1036,7 @@ class Test3DReduce0(Test1DReduce):
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
}
}
self
.
if_enable_cinn
()
class
Test3DReduce1
(
Test1DReduce
):
class
Test3DReduce1
(
Test1DReduce
):
...
@@ -1041,6 +1050,7 @@ class Test3DReduce1(Test1DReduce):
...
@@ -1041,6 +1050,7 @@ class Test3DReduce1(Test1DReduce):
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
}
}
self
.
if_enable_cinn
()
class
Test3DReduce2
(
Test1DReduce
):
class
Test3DReduce2
(
Test1DReduce
):
...
@@ -1054,6 +1064,7 @@ class Test3DReduce2(Test1DReduce):
...
@@ -1054,6 +1064,7 @@ class Test3DReduce2(Test1DReduce):
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
}
}
self
.
if_enable_cinn
()
class
Test3DReduce3
(
Test1DReduce
):
class
Test3DReduce3
(
Test1DReduce
):
...
@@ -1067,6 +1078,7 @@ class Test3DReduce3(Test1DReduce):
...
@@ -1067,6 +1078,7 @@ class Test3DReduce3(Test1DReduce):
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
tuple
(
self
.
attrs
[
'dim'
]))
}
}
self
.
if_enable_cinn
()
def
reduce_sum_wrapper2
(
x
,
axis
=
[
0
],
dtype
=
None
,
keepdim
=
False
):
def
reduce_sum_wrapper2
(
x
,
axis
=
[
0
],
dtype
=
None
,
keepdim
=
False
):
...
@@ -1105,6 +1117,7 @@ class TestKeepDimReduce(Test1DReduce):
...
@@ -1105,6 +1117,7 @@ class TestKeepDimReduce(Test1DReduce):
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
self
.
attrs
[
'keep_dim'
]
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
self
.
attrs
[
'keep_dim'
]
)
)
}
}
self
.
if_enable_cinn
()
class
TestKeepDimReduceForEager
(
Test1DReduce
):
class
TestKeepDimReduceForEager
(
Test1DReduce
):
...
@@ -1208,6 +1221,10 @@ class TestKeepDimReduceSumMultiAxises(OpTest):
...
@@ -1208,6 +1221,10 @@ class TestKeepDimReduceSumMultiAxises(OpTest):
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
True
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
True
)
)
}
}
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1248,7 +1265,10 @@ class TestReduceSumWithDimOne(OpTest):
...
@@ -1248,7 +1265,10 @@ class TestReduceSumWithDimOne(OpTest):
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
True
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
True
)
)
}
}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1290,7 +1310,10 @@ class TestReduceSumWithNumelOne(OpTest):
...
@@ -1290,7 +1310,10 @@ class TestReduceSumWithNumelOne(OpTest):
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
False
axis
=
tuple
(
self
.
attrs
[
'dim'
]),
keepdims
=
False
)
)
}
}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1314,7 +1337,10 @@ class TestReduceAll(OpTest):
...
@@ -1314,7 +1337,10 @@ class TestReduceAll(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
100
,
1
,
1
)).
astype
(
"float64"
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
100
,
1
,
1
)).
astype
(
"float64"
)}
self
.
attrs
=
{
'reduce_all'
:
True
,
'keep_dim'
:
False
}
self
.
attrs
=
{
'reduce_all'
:
True
,
'keep_dim'
:
False
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
()}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
()}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1332,7 +1358,10 @@ class TestReduceAllFp32(OpTest):
...
@@ -1332,7 +1358,10 @@ class TestReduceAllFp32(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
100
,
1
,
1
)).
astype
(
"float32"
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
100
,
1
,
1
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'reduce_all'
:
True
,
'keep_dim'
:
False
}
self
.
attrs
=
{
'reduce_all'
:
True
,
'keep_dim'
:
False
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
()}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
()}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1350,7 +1379,10 @@ class Test1DReduceWithAxes1(OpTest):
...
@@ -1350,7 +1379,10 @@ class Test1DReduceWithAxes1(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
100
).
astype
(
"float64"
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
100
).
astype
(
"float64"
)}
self
.
attrs
=
{
'dim'
:
[
0
],
'keep_dim'
:
False
}
self
.
attrs
=
{
'dim'
:
[
0
],
'keep_dim'
:
False
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
0
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
sum
(
axis
=
0
)}
self
.
enable_cinn
=
True
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -1380,6 +1412,10 @@ class TestReduceWithDtype(OpTest):
...
@@ -1380,6 +1412,10 @@ class TestReduceWithDtype(OpTest):
'out_dtype'
:
int
(
convert_np_dtype_to_dtype_
(
np
.
float64
)),
'out_dtype'
:
int
(
convert_np_dtype_to_dtype_
(
np
.
float64
)),
}
}
)
)
self
.
if_enable_cinn
()
def
if_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
...
test/legacy_test/test_slice_op.py
浏览文件 @
585f1136
...
@@ -531,9 +531,8 @@ class TestBF16(OpTest):
...
@@ -531,9 +531,8 @@ class TestBF16(OpTest):
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
# pad not support bfloat16, so we can't test prim.
def
test_check_grad_normal
(
self
):
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'Input'
],
'Out'
)
self
.
check_grad
([
'Input'
],
'Out'
,
check_prim
=
True
)
# Test python API
# Test python API
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录