Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
64b3f2f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64b3f2f6
编写于
1月 19, 2023
作者:
F
Feiyu Chan
提交者:
GitHub
1月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add test for zero dimensional tensor for real, imag, angle, conj, as_real and sequence_pad (#49921)
上级
11e34ae0
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
188 addition
and
0 deletion
+188
-0
paddle/fluid/operators/sequence_ops/sequence_pad_op.cc
paddle/fluid/operators/sequence_ops/sequence_pad_op.cc
+1
-0
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
+187
-0
未找到文件。
paddle/fluid/operators/sequence_ops/sequence_pad_op.cc
浏览文件 @
64b3f2f6
...
...
@@ -56,6 +56,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
auto
pad_value_dims
=
ctx
->
GetInputDim
(
"PadValue"
);
PADDLE_ENFORCE_EQ
(
pad_value_dims
==
phi
::
make_ddim
({
1
})
||
pad_value_dims
==
phi
::
make_ddim
({})
||
pad_value_dims
==
time_step_dims
,
true
,
platform
::
errors
::
InvalidArgument
(
...
...
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
浏览文件 @
64b3f2f6
...
...
@@ -2282,6 +2282,21 @@ class TestSundryAPIStatic(unittest.TestCase):
self
.
assertEqual
(
res
[
1
].
shape
,
())
self
.
assertEqual
(
res
[
2
].
shape
,
())
@
prog_scope
()
def
test_sequence_pad
(
self
):
x
=
paddle
.
static
.
data
(
"x"
,
[
-
1
,
2
],
dtype
=
paddle
.
int64
,
lod_level
=
1
)
value
=
paddle
.
to_tensor
(
1000
,
dtype
=
paddle
.
int64
).
squeeze
()
out
=
paddle
.
static
.
nn
.
sequence_pad
(
x
,
value
)
x_tensor
=
paddle
.
fluid
.
create_lod_tensor
(
np
.
arange
(
20
).
astype
(
np
.
int64
).
reshape
(
-
1
,
2
),
[[
3
,
3
,
4
]],
place
=
self
.
exe
.
place
,
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
feed
=
{
"x"
:
x_tensor
},
fetch_list
=
[
out
])
self
.
assertEqual
(
res
[
0
].
shape
,
(
3
,
4
,
2
))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
...
...
@@ -2633,5 +2648,177 @@ class TestNoBackwardAPIStatic(unittest.TestCase):
self
.
assertEqual
(
res
[
0
][
2
],
1
)
unary_apis_with_complex_input
=
[
paddle
.
real
,
paddle
.
imag
,
paddle
.
angle
,
paddle
.
conj
,
]
class
TestUnaryElementwiseAPIWithComplexInput
(
unittest
.
TestCase
):
def
test_dygraph_unary
(
self
):
paddle
.
disable_static
()
for
api
in
unary_apis_with_complex_input
:
x
=
paddle
.
to_tensor
(
2.0
+
3.0j
).
squeeze
()
x
.
stop_gradient
=
False
x
.
retain_grads
()
out
=
api
(
x
)
out
.
retain_grads
()
out
.
backward
()
self
.
assertEqual
(
x
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[])
if
x
.
grad
is
not
None
:
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
paddle
.
enable_static
()
def
test_static_unary
(
self
):
paddle
.
enable_static
()
for
api
in
unary_apis_with_complex_input
:
main_prog
=
paddle
.
static
.
Program
()
block
=
main_prog
.
global_block
()
exe
=
paddle
.
static
.
Executor
()
with
paddle
.
static
.
program_guard
(
main_prog
,
paddle
.
static
.
Program
()
):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph
x
=
paddle
.
complex
(
paddle
.
to_tensor
(
2.0
),
paddle
.
to_tensor
(
2.0
)
).
squeeze
()
x
.
stop_gradient
=
False
out
=
api
(
x
)
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss
=
out
.
sum
()
paddle
.
static
.
append_backward
(
loss
)
fetch_list
=
[
x
,
out
]
if
block
.
has_var
(
x
.
grad_name
):
fetch_list
.
extend
([
x
.
grad_name
,
out
.
grad_name
])
# 1) Test Program
res
=
exe
.
run
(
main_prog
,
fetch_list
=
fetch_list
)
for
item
in
res
:
self
.
assertEqual
(
item
.
shape
,
())
# 2) Test CompiledProgram Program
if
paddle
.
device
.
is_compiled_with_cuda
():
places
=
[
paddle
.
CUDAPlace
(
0
)]
expect_shape
=
()
else
:
places
=
[
paddle
.
CPUPlace
()]
*
4
expect_shape
=
(
4
,)
compile_prog
=
paddle
.
static
.
CompiledProgram
(
main_prog
).
with_data_parallel
(
loss
.
name
,
places
=
places
)
# return_merged=False #
res
=
exe
.
run
(
compile_prog
,
fetch_list
=
fetch_list
,
return_merged
=
False
)
for
item1
in
res
:
for
item2
in
item1
:
self
.
assertEqual
(
item2
.
shape
,
())
# return_merged=True #
res
=
exe
.
run
(
compile_prog
,
fetch_list
=
fetch_list
,
return_merged
=
True
)
for
item
in
res
:
self
.
assertEqual
(
item
.
shape
,
expect_shape
)
paddle
.
disable_static
()
class
TestAsReal
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
for
api
in
unary_apis_with_complex_input
:
x
=
paddle
.
to_tensor
(
2.0
+
3.0j
).
squeeze
()
x
.
stop_gradient
=
False
x
.
retain_grads
()
out
=
paddle
.
as_real
(
x
)
out
.
retain_grads
()
out
.
backward
()
self
.
assertEqual
(
x
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[
2
])
if
x
.
grad
is
not
None
:
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
2
])
paddle
.
enable_static
()
def
test_static
(
self
):
paddle
.
enable_static
()
for
api
in
unary_apis_with_complex_input
:
main_prog
=
paddle
.
static
.
Program
()
block
=
main_prog
.
global_block
()
exe
=
paddle
.
static
.
Executor
()
with
paddle
.
static
.
program_guard
(
main_prog
,
paddle
.
static
.
Program
()
):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph
x
=
paddle
.
complex
(
paddle
.
to_tensor
(
2.0
),
paddle
.
to_tensor
(
2.0
)
).
squeeze
()
x
.
stop_gradient
=
False
out
=
paddle
.
as_real
(
x
)
self
.
assertEqual
(
x
.
shape
,
())
self
.
assertEqual
(
out
.
shape
,
(
2
,))
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss
=
out
.
abs
().
sum
()
paddle
.
static
.
append_backward
(
loss
)
fetch_list
=
[
x
,
out
]
if
block
.
has_var
(
x
.
grad_name
):
fetch_list
.
extend
([
x
.
grad_name
,
out
.
grad_name
])
# 1) Test Program
res
=
exe
.
run
(
main_prog
,
fetch_list
=
fetch_list
)
self
.
assertEqual
(
res
[
0
].
shape
,
())
self
.
assertEqual
(
res
[
1
].
shape
,
(
2
,))
self
.
assertEqual
(
res
[
2
].
shape
,
())
self
.
assertEqual
(
res
[
3
].
shape
,
(
2
,))
# 2) Test CompiledProgram Program
if
paddle
.
device
.
is_compiled_with_cuda
():
places
=
[
paddle
.
CUDAPlace
(
0
)]
expect_shapes
=
(),
(
2
,),
(),
(
2
,)
else
:
places
=
[
paddle
.
CPUPlace
()]
*
4
expect_shapes
=
(
4
,),
(
8
,),
(
4
,),
(
8
,)
compile_prog
=
paddle
.
static
.
CompiledProgram
(
main_prog
).
with_data_parallel
(
loss
.
name
,
places
=
places
)
# return_merged=False #
res
=
exe
.
run
(
compile_prog
,
fetch_list
=
fetch_list
,
return_merged
=
False
)
for
out_i
,
expect
in
zip
(
res
,
[(),
(
2
,),
(),
(
2
,)]):
for
replica
in
out_i
:
self
.
assertEqual
(
replica
.
shape
,
expect
)
# return_merged=True #
res
=
exe
.
run
(
compile_prog
,
fetch_list
=
fetch_list
,
return_merged
=
True
)
for
actual
,
expect
in
zip
(
res
,
expect_shapes
):
self
.
assertEqual
(
actual
.
shape
,
expect
)
paddle
.
disable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录