Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cad2e68d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cad2e68d
编写于
11月 02, 2022
作者:
zhouweiwei2014
提交者:
GitHub
11月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Zero-Dim] support input 0D Tensor for some binary api (#46909)
上级
623dce83
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
503 addition
and
11 deletion
+503
-11
paddle/fluid/operators/common_infer_shape_functions.cc
paddle/fluid/operators/common_infer_shape_functions.cc
+1
-1
paddle/fluid/operators/elementwise/elementwise_npu.h
paddle/fluid/operators/elementwise/elementwise_npu.h
+1
-1
paddle/phi/kernels/funcs/common_shape.h
paddle/phi/kernels/funcs/common_shape.h
+1
-1
paddle/phi/kernels/funcs/elementwise_base.h
paddle/phi/kernels/funcs/elementwise_base.h
+2
-2
paddle/phi/kernels/funcs/elementwise_grad_base.h
paddle/phi/kernels/funcs/elementwise_grad_base.h
+2
-2
paddle/phi/kernels/xpu/elementwise.h
paddle/phi/kernels/xpu/elementwise.h
+2
-2
python/paddle/fluid/tests/unittests/test_bitwise_op.py
python/paddle/fluid/tests/unittests/test_bitwise_op.py
+59
-1
python/paddle/fluid/tests/unittests/test_compare_op.py
python/paddle/fluid/tests/unittests/test_compare_op.py
+48
-0
python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_add_op.py
+21
-0
python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_div_op.py
+36
-0
python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py
...dle/fluid/tests/unittests/test_elementwise_floordiv_op.py
+21
-0
python/paddle/fluid/tests/unittests/test_elementwise_max_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_max_op.py
+30
-0
python/paddle/fluid/tests/unittests/test_elementwise_min_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_min_op.py
+30
-0
python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_mod_op.py
+21
-0
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
+21
-0
python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
+33
-0
python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
+30
-0
python/paddle/fluid/tests/unittests/test_logical_op.py
python/paddle/fluid/tests/unittests/test_logical_op.py
+4
-1
python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
+140
-0
未找到文件。
paddle/fluid/operators/common_infer_shape_functions.cc
浏览文件 @
cad2e68d
...
...
@@ -40,7 +40,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform
::
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
platform
::
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
paddle/fluid/operators/elementwise/elementwise_npu.h
浏览文件 @
cad2e68d
...
...
@@ -123,7 +123,7 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform
::
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
platform
::
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
paddle/phi/kernels/funcs/common_shape.h
浏览文件 @
cad2e68d
...
...
@@ -45,7 +45,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi
::
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
phi
::
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
paddle/phi/kernels/funcs/elementwise_base.h
浏览文件 @
cad2e68d
...
...
@@ -326,7 +326,7 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi
::
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
phi
::
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
@@ -394,7 +394,7 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
paddle/phi/kernels/funcs/elementwise_grad_base.h
浏览文件 @
cad2e68d
...
...
@@ -287,7 +287,7 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
@@ -1725,7 +1725,7 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
paddle/phi/kernels/xpu/elementwise.h
浏览文件 @
cad2e68d
...
...
@@ -51,7 +51,7 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
@@ -121,7 +121,7 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors
::
InvalidArgument
(
"Axis should be great than or equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_L
T
(
axis
,
PADDLE_ENFORCE_L
E
(
axis
,
max_dim
,
errors
::
InvalidArgument
(
"Axis should be less than %d, but received axis is %d."
,
...
...
python/paddle/fluid/tests/unittests/test_bitwise_op.py
浏览文件 @
cad2e68d
...
...
@@ -57,6 +57,24 @@ class TestBitwiseAnd(OpTest):
self
.
high
=
100
class
TestBitwiseAnd_ZeroDim1
(
TestBitwiseAnd
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[]
class
TestBitwiseAnd_ZeroDim2
(
TestBitwiseAnd
):
def
init_shape
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
y_shape
=
[]
class
TestBitwiseAnd_ZeroDim3
(
TestBitwiseAnd
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[
2
,
3
,
4
,
5
]
class
TestBitwiseAndUInt8
(
TestBitwiseAnd
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint8
...
...
@@ -143,6 +161,24 @@ class TestBitwiseOr(OpTest):
self
.
high
=
100
class
TestBitwiseOr_ZeroDim1
(
TestBitwiseOr
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[]
class
TestBitwiseOr_ZeroDim2
(
TestBitwiseOr
):
def
init_shape
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
y_shape
=
[]
class
TestBitwiseOr_ZeroDim3
(
TestBitwiseOr
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[
2
,
3
,
4
,
5
]
class
TestBitwiseOrUInt8
(
TestBitwiseOr
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint8
...
...
@@ -229,6 +265,24 @@ class TestBitwiseXor(OpTest):
self
.
high
=
100
class
TestBitwiseXor_ZeroDim1
(
TestBitwiseXor
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[]
class
TestBitwiseXor_ZeroDim2
(
TestBitwiseXor
):
def
init_shape
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
y_shape
=
[]
class
TestBitwiseXor_ZeroDim3
(
TestBitwiseXor
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[
2
,
3
,
4
,
5
]
class
TestBitwiseXorUInt8
(
TestBitwiseXor
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint8
...
...
@@ -311,6 +365,11 @@ class TestBitwiseNot(OpTest):
self
.
high
=
100
class
TestBitwiseNot_ZeroDim
(
TestBitwiseNot
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
class
TestBitwiseNotUInt8
(
TestBitwiseNot
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint8
...
...
@@ -334,7 +393,6 @@ class TestBitwiseNotInt16(TestBitwiseNot):
def
init_shape
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
y_shape
=
[
4
,
1
]
class
TestBitwiseNotInt64
(
TestBitwiseNot
):
...
...
python/paddle/fluid/tests/unittests/test_compare_op.py
浏览文件 @
cad2e68d
...
...
@@ -283,6 +283,54 @@ def create_paddle_case(op_type, callback):
self
.
assertEqual
((
out
.
numpy
()
==
self
.
real_result
).
all
(),
True
)
paddle
.
enable_static
()
def
test_zero_dim_api_1
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
(),
Program
()):
x
=
paddle
.
randint
(
-
3
,
3
,
shape
=
[],
dtype
=
'int32'
)
y
=
paddle
.
randint
(
-
3
,
3
,
shape
=
[],
dtype
=
'int32'
)
op
=
eval
(
"paddle.%s"
%
(
self
.
op_type
))
out
=
op
(
x
,
y
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
(
x_np
,
y_np
,
res
,
)
=
exe
.
run
(
fetch_list
=
[
x
,
y
,
out
])
real_result
=
callback
(
x_np
,
y_np
)
self
.
assertEqual
((
res
==
real_result
).
all
(),
True
)
def
test_zero_dim_api_2
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
(),
Program
()):
x
=
paddle
.
randint
(
-
3
,
3
,
shape
=
[
2
,
3
,
4
],
dtype
=
'int32'
)
y
=
paddle
.
randint
(
-
3
,
3
,
shape
=
[],
dtype
=
'int32'
)
op
=
eval
(
"paddle.%s"
%
(
self
.
op_type
))
out
=
op
(
x
,
y
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
(
x_np
,
y_np
,
res
,
)
=
exe
.
run
(
fetch_list
=
[
x
,
y
,
out
])
real_result
=
callback
(
x_np
,
y_np
)
self
.
assertEqual
((
res
==
real_result
).
all
(),
True
)
def
test_zero_dim_api_3
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
(),
Program
()):
x
=
paddle
.
randint
(
-
3
,
3
,
shape
=
[],
dtype
=
'int32'
)
y
=
paddle
.
randint
(
-
3
,
3
,
shape
=
[
2
,
3
,
4
],
dtype
=
'int32'
)
op
=
eval
(
"paddle.%s"
%
(
self
.
op_type
))
out
=
op
(
x
,
y
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
(
x_np
,
y_np
,
res
,
)
=
exe
.
run
(
fetch_list
=
[
x
,
y
,
out
])
real_result
=
callback
(
x_np
,
y_np
)
self
.
assertEqual
((
res
==
real_result
).
all
(),
True
)
def
test_broadcast_api_1
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
(),
Program
()):
...
...
python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
浏览文件 @
cad2e68d
...
...
@@ -102,6 +102,27 @@ class TestElementwiseAddOp(OpTest):
self
.
axis
=
-
1
class
TestElementwiseAddOp_ZeroDim1
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
add
(
self
.
x
,
self
.
y
)
class
TestElementwiseAddOp_ZeroDim2
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
add
(
self
.
x
,
self
.
y
)
class
TestElementwiseAddOp_ZeroDim3
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
add
(
self
.
x
,
self
.
y
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
...
...
python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
浏览文件 @
cad2e68d
...
...
@@ -112,6 +112,42 @@ class ElementwiseDivOp(OpTest):
self
.
check_grad_with_place
(
*
check_args
,
**
check_kwargs
)
class
TestElementwiseDivOp_ZeroDim1
(
ElementwiseDivOp
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[]
class
TestElementwiseDivOp_ZeroDim2
(
ElementwiseDivOp
):
def
init_shape
(
self
):
self
.
x_shape
=
[
13
,
17
]
self
.
y_shape
=
[]
def
compute_output
(
self
,
x
,
y
):
return
x
/
y
.
reshape
([
1
,
1
])
def
compute_gradient_x
(
self
,
grad_out
,
y
):
return
grad_out
/
y
.
reshape
([
1
,
1
])
def
compute_gradient_y
(
self
,
grad_out
,
out
,
y
):
return
np
.
sum
(
-
1
*
grad_out
*
out
/
y
.
reshape
([
1
,
1
]))
class
TestElementwiseDivOp_ZeroDim3
(
ElementwiseDivOp
):
def
init_shape
(
self
):
self
.
x_shape
=
[]
self
.
y_shape
=
[
13
,
17
]
def
compute_output
(
self
,
x
,
y
):
return
x
.
reshape
([
1
,
1
])
/
y
def
compute_gradient_x
(
self
,
grad_out
,
y
):
return
np
.
sum
(
grad_out
/
y
)
def
compute_gradient_y
(
self
,
grad_out
,
out
,
y
):
return
-
1
*
grad_out
*
out
/
y
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
...
...
python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py
浏览文件 @
cad2e68d
...
...
@@ -57,6 +57,27 @@ class TestElementwiseModOp(OpTest):
pass
class
TestElementwiseFloorDivOp_ZeroDim1
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0
,
10000
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0
,
1000
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
floor_divide
(
self
.
x
,
self
.
y
)
class
TestElementwiseFloorDivOp_ZeroDim2
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0
,
10000
,
[
10
,
10
]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0
,
1000
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
floor_divide
(
self
.
x
,
self
.
y
)
class
TestElementwiseFloorDivOp_ZeroDim3
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0
,
10000
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0
,
1000
,
[
10
,
10
]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
floor_divide
(
self
.
x
,
self
.
y
)
class
TestElementwiseModOp_scalar
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
scale_x
=
random
.
randint
(
0
,
100000000
)
...
...
python/paddle/fluid/tests/unittests/test_elementwise_max_op.py
浏览文件 @
cad2e68d
...
...
@@ -55,6 +55,36 @@ class TestElementwiseOp(OpTest):
)
class
TestElementwiseMaxOp_ZeroDim1
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_max"
self
.
python_api
=
paddle
.
maximum
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
maximum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseMaxOp_ZeroDim2
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_max"
self
.
python_api
=
paddle
.
maximum
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
"float64"
)
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
maximum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseMaxOp_ZeroDim3
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_max"
self
.
python_api
=
paddle
.
maximum
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
"float64"
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
maximum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
@
unittest
.
skipIf
(
core
.
is_compiled_with_cuda
()
and
(
...
...
python/paddle/fluid/tests/unittests/test_elementwise_min_op.py
浏览文件 @
cad2e68d
...
...
@@ -58,6 +58,36 @@ class TestElementwiseOp(OpTest):
)
class
TestElementwiseMinOp_ZeroDim1
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
self
.
python_api
=
paddle
.
minimum
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
minimum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseMinOp_ZeroDim2
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
self
.
python_api
=
paddle
.
minimum
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
"float64"
)
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
minimum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseMinOp_ZeroDim3
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
self
.
python_api
=
paddle
.
minimum
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
)
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
"float64"
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
minimum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
@
skip_check_grad_ci
(
reason
=
"[skip shape check] Use y_shape(1) to test broadcast."
)
...
...
python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py
浏览文件 @
cad2e68d
...
...
@@ -59,6 +59,27 @@ class TestElementwiseModOp(OpTest):
pass
class
TestElementwiseModOp_ZeroDim1
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0
,
10000
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0
,
1000
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
mod
(
self
.
x
,
self
.
y
)
class
TestElementwiseModOp_ZeroDim2
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0
,
10000
,
[
10
,
10
]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0
,
1000
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
mod
(
self
.
x
,
self
.
y
)
class
TestElementwiseModOp_ZeroDim3
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0
,
10000
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0
,
1000
,
[
10
,
10
]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
mod
(
self
.
x
,
self
.
y
)
class
TestElementwiseModOp_scalar
(
TestElementwiseModOp
):
def
init_input_output
(
self
):
scale_x
=
random
.
randint
(
0
,
100000000
)
...
...
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
浏览文件 @
cad2e68d
...
...
@@ -85,6 +85,27 @@ class ElementwiseMulOp(OpTest):
pass
class
TestElementwiseMulOp_ZeroDim1
(
ElementwiseMulOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
multiply
(
self
.
x
,
self
.
y
)
class
TestElementwiseMulOp_ZeroDim2
(
ElementwiseMulOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
multiply
(
self
.
x
,
self
.
y
)
class
TestElementwiseMulOp_ZeroDim3
(
ElementwiseMulOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
17
]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
multiply
(
self
.
x
,
self
.
y
)
class
TestBF16ElementwiseMulOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_mul"
...
...
python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
浏览文件 @
cad2e68d
...
...
@@ -48,6 +48,39 @@ class TestElementwisePowOp(OpTest):
self
.
check_grad
([
'X'
,
'Y'
],
'Out'
,
check_eager
=
True
)
class
TestElementwisePowOp_ZeroDim1
(
TestElementwisePowOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_pow"
self
.
python_api
=
paddle
.
pow
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
1
,
2
,
[]).
astype
(
"float64"
),
'Y'
:
np
.
random
.
uniform
(
1
,
2
,
[]).
astype
(
"float64"
),
}
self
.
outputs
=
{
'Out'
:
np
.
power
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwisePowOp_ZeroDim2
(
TestElementwisePowOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_pow"
self
.
python_api
=
paddle
.
pow
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
1
,
2
,
[
20
,
5
]).
astype
(
"float64"
),
'Y'
:
np
.
random
.
uniform
(
1
,
2
,
[]).
astype
(
"float64"
),
}
self
.
outputs
=
{
'Out'
:
np
.
power
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwisePowOp_ZeroDim3
(
TestElementwisePowOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_pow"
self
.
python_api
=
paddle
.
pow
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
1
,
2
,
[]).
astype
(
"float64"
),
'Y'
:
np
.
random
.
uniform
(
1
,
2
,
[
20
,
5
]).
astype
(
"float64"
),
}
self
.
outputs
=
{
'Out'
:
np
.
power
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwisePowOp_big_shape_1
(
TestElementwisePowOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_pow"
...
...
python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
浏览文件 @
cad2e68d
...
...
@@ -46,6 +46,36 @@ class TestElementwiseOp(OpTest):
)
class
TestElementwiseSubOp_ZeroDim1
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
),
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
class
TestElementwiseSubOp_ZeroDim2
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
,
5
]).
astype
(
"float64"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
),
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
class
TestElementwiseSubOp_ZeroDim3
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[]).
astype
(
"float64"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
,
5
]).
astype
(
"float64"
),
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
class
TestBF16ElementwiseOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
...
...
python/paddle/fluid/tests/unittests/test_logical_op.py
浏览文件 @
cad2e68d
...
...
@@ -50,6 +50,9 @@ TEST_META_SHAPE_DATA = {
'Axis1InLargerDim'
:
{
'x_shape'
:
[
1
,
4
,
5
],
'y_shape'
:
[
2
,
3
,
1
,
5
]},
'EqualDim1'
:
{
'x_shape'
:
[
10
,
7
],
'y_shape'
:
[
10
,
7
]},
'EqualDim2'
:
{
'x_shape'
:
[
1
,
1
,
4
,
5
],
'y_shape'
:
[
2
,
3
,
1
,
5
]},
'ZeroDim1'
:
{
'x_shape'
:
[],
'y_shape'
:
[]},
'ZeroDim2'
:
{
'x_shape'
:
[
2
,
3
,
4
,
5
],
'y_shape'
:
[]},
'ZeroDim3'
:
{
'x_shape'
:
[],
'y_shape'
:
[
2
,
3
,
4
,
5
]},
}
TEST_META_WRONG_SHAPE_DATA
=
{
...
...
@@ -116,7 +119,7 @@ def np_data_generator(np_shape, dtype, *args, **kwargs):
if
dtype
==
bool
:
return
np
.
random
.
choice
(
a
=
[
True
,
False
],
size
=
np_shape
).
astype
(
bool
)
else
:
return
np
.
random
.
randn
(
*
np_shape
).
astype
(
dtype
)
return
np
.
random
.
normal
(
0
,
1
,
np_shape
).
astype
(
dtype
)
def
test
(
unit_test
,
use_gpu
=
False
,
test_error
=
False
):
...
...
python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
浏览文件 @
cad2e68d
...
...
@@ -210,5 +210,145 @@ class TestReduceAPI(unittest.TestCase):
paddle
.
disable_static
()
binary_api_list
=
[
{
'func'
:
paddle
.
add
,
'cls_method'
:
'__add__'
},
{
'func'
:
paddle
.
subtract
,
'cls_method'
:
'__sub__'
},
{
'func'
:
paddle
.
multiply
,
'cls_method'
:
'__mul__'
},
{
'func'
:
paddle
.
divide
,
'cls_method'
:
'__div__'
},
{
'func'
:
paddle
.
subtract
,
'cls_method'
:
'__sub__'
},
paddle
.
pow
,
]
binary_api_list_without_grad
=
[
{
'func'
:
paddle
.
add
,
'cls_method'
:
'__add__'
},
{
'func'
:
paddle
.
subtract
,
'cls_method'
:
'__sub__'
},
{
'func'
:
paddle
.
multiply
,
'cls_method'
:
'__mul__'
},
{
'func'
:
paddle
.
divide
,
'cls_method'
:
'__div__'
},
{
'func'
:
paddle
.
subtract
,
'cls_method'
:
'__sub__'
},
paddle
.
pow
,
{
'func'
:
paddle
.
mod
,
'cls_method'
:
'__mod__'
},
paddle
.
floor_mod
,
paddle
.
remainder
,
{
'func'
:
paddle
.
equal
,
'cls_method'
:
'__eq__'
},
{
'func'
:
paddle
.
not_equal
,
'cls_method'
:
'__ne__'
},
{
'func'
:
paddle
.
greater_equal
,
'cls_method'
:
'__ge__'
},
{
'func'
:
paddle
.
greater_than
,
'cls_method'
:
'__gt__'
},
{
'func'
:
paddle
.
less_equal
,
'cls_method'
:
'__le__'
},
{
'func'
:
paddle
.
less_than
,
'cls_method'
:
'__lt__'
},
paddle
.
logical_and
,
paddle
.
logical_or
,
paddle
.
logical_xor
,
]
class
TestBinaryAPI
(
unittest
.
TestCase
):
def
test_dygraph_binary
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
for
api
in
binary_api_list
+
binary_api_list_without_grad
:
# 1) x/y is 0D
x
=
paddle
.
rand
([])
y
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
if
isinstance
(
api
,
dict
):
out
=
api
[
'func'
](
x
,
y
)
out_cls
=
getattr
(
paddle
.
Tensor
,
api
[
'cls_method'
])(
x
,
y
)
np
.
testing
.
assert_array_equal
(
out_cls
.
numpy
(),
out
.
numpy
())
else
:
out
=
api
(
x
,
y
)
self
.
assertEqual
(
x
.
shape
,
[])
self
.
assertEqual
(
y
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[])
if
api
not
in
binary_api_list_without_grad
:
out
.
backward
()
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
y
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
# 2) x is not 0D , y is 0D
x
=
paddle
.
rand
([
2
,
3
,
4
])
y
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
if
isinstance
(
api
,
dict
):
out
=
api
[
'func'
](
x
,
y
)
out_cls
=
getattr
(
paddle
.
Tensor
,
api
[
'cls_method'
])(
x
,
y
)
np
.
testing
.
assert_array_equal
(
out_cls
.
numpy
(),
out
.
numpy
())
else
:
out
=
api
(
x
,
y
)
self
.
assertEqual
(
x
.
shape
,
[
2
,
3
,
4
])
self
.
assertEqual
(
y
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[
2
,
3
,
4
])
if
api
not
in
binary_api_list_without_grad
:
out
.
backward
()
self
.
assertEqual
(
x
.
grad
.
shape
,
[
2
,
3
,
4
])
self
.
assertEqual
(
y
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
2
,
3
,
4
])
# 3) x is 0D , y is not 0D
x
=
paddle
.
rand
([])
y
=
paddle
.
rand
([
2
,
3
,
4
])
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
if
isinstance
(
api
,
dict
):
out
=
api
[
'func'
](
x
,
y
)
out_cls
=
getattr
(
paddle
.
Tensor
,
api
[
'cls_method'
])(
x
,
y
)
np
.
testing
.
assert_array_equal
(
out_cls
.
numpy
(),
out
.
numpy
())
else
:
out
=
api
(
x
,
y
)
out
.
backward
()
self
.
assertEqual
(
x
.
shape
,
[])
self
.
assertEqual
(
y
.
shape
,
[
2
,
3
,
4
])
self
.
assertEqual
(
out
.
shape
,
[
2
,
3
,
4
])
if
api
not
in
binary_api_list_without_grad
:
out
.
backward
()
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
y
.
grad
.
shape
,
[
2
,
3
,
4
])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
2
,
3
,
4
])
paddle
.
enable_static
()
def
test_static_unary
(
self
):
paddle
.
enable_static
()
for
api
in
binary_api_list
:
main_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
fluid
.
Program
()):
x
=
paddle
.
rand
([])
y
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
if
isinstance
(
api
,
dict
):
out
=
api
[
'func'
](
x
,
y
)
else
:
out
=
api
(
x
,
y
)
fluid
.
backward
.
append_backward
(
out
)
# append_backward always set grad shape to [1]
prog
=
paddle
.
static
.
default_main_program
()
block
=
prog
.
global_block
()
# Test compile shape
self
.
assertEqual
(
x
.
shape
,
())
self
.
assertEqual
(
y
.
shape
,
())
self
.
assertEqual
(
out
.
shape
,
())
exe
=
fluid
.
Executor
()
result
=
exe
.
run
(
main_prog
,
fetch_list
=
[
x
,
y
,
out
])
# Test runtime shape
self
.
assertEqual
(
result
[
0
].
shape
,
())
self
.
assertEqual
(
result
[
1
].
shape
,
())
self
.
assertEqual
(
result
[
2
].
shape
,
())
paddle
.
disable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录