Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ae1f3209
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae1f3209
编写于
1月 13, 2021
作者:
C
Chen Weihang
提交者:
GitHub
1月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix prune input bug (#30384)
上级
cf786d22
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
109 addition
and
32 deletion
+109
-32
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+9
-6
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+26
-22
python/paddle/fluid/tests/unittests/test_jit_save_load.py
python/paddle/fluid/tests/unittests/test_jit_save_load.py
+74
-4
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
ae1f3209
...
@@ -470,19 +470,22 @@ class StaticFunction(object):
...
@@ -470,19 +470,22 @@ class StaticFunction(object):
cached_program_len
=
len
(
self
.
_program_cache
)
cached_program_len
=
len
(
self
.
_program_cache
)
# If specific `input_spec`, apply convertion from dygraph layers into static Program.
# If specific `input_spec`, apply convertion from dygraph layers into static Program.
if
cached_program_len
==
0
:
if
cached_program_len
==
0
:
if
input_spec
is
None
:
desired_input_spec
=
input_spec
input_spec
=
self
.
_function_spec
.
input_spec
if
self
.
_function_spec
.
input_spec
is
not
None
:
elif
self
.
_function_spec
.
input_spec
is
not
None
:
if
input_spec
is
not
None
and
not
input_specs_compatible
(
if
not
input_specs_compatible
(
flatten
(
input_spec
),
flatten
(
input_spec
),
flatten
(
self
.
_function_spec
.
input_spec
)):
flatten
(
self
.
_function_spec
.
input_spec
)):
raise
ValueError
(
raise
ValueError
(
"The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`"
.
"The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`"
.
format
(
input_spec
,
self
.
_function_spec
.
input_spec
))
format
(
input_spec
,
self
.
_function_spec
.
input_spec
))
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec
=
self
.
_function_spec
.
input_spec
has_input_spec
=
(
input_spec
is
not
None
)
has_input_spec
=
(
desired_
input_spec
is
not
None
)
if
has_input_spec
:
if
has_input_spec
:
concrete_program
,
_
=
self
.
get_concrete_program
(
*
input_spec
)
concrete_program
,
_
=
self
.
get_concrete_program
(
*
desired_input_spec
)
return
concrete_program
return
concrete_program
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
ae1f3209
...
@@ -1222,37 +1222,41 @@ def unwrap(func):
...
@@ -1222,37 +1222,41 @@ def unwrap(func):
return
unwrapped_f
return
unwrapped_f
def
input_specs_compatible
(
src_input_specs
,
other
_input_specs
):
def
input_specs_compatible
(
src_input_specs
,
desired
_input_specs
):
"""
"""
Returns True if the two input specs are compatible, otherwise False.
Returns True if the two input specs are compatible, otherwise False.
args:
args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
paddle.static.InputSpec
other_input_spec
(list[InputSpec]|tuple(InputSpec)): list/tuple of
desired_input_specs
(list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
paddle.static.InputSpec
"""
"""
len_specs
=
len
(
src_input_specs
)
len_specs
=
len
(
src_input_specs
)
if
len_specs
!=
len
(
other_input_specs
):
if
len_specs
!=
len
(
desired_input_specs
):
return
False
# NOTE(chenweihang): if the input_spec of jit.save is a subset of
# input_spec of to_static, also compatible
for
i
in
range
(
len_specs
):
for
spec
in
src_input_specs
:
src_shape
=
src_input_specs
[
i
].
shape
if
spec
not
in
desired_input_specs
:
other_shape
=
other_input_specs
[
i
].
shape
return
False
len_shape
=
len
(
src_shape
)
else
:
if
len_shape
!=
len
(
other_shape
):
for
i
in
range
(
len_specs
):
return
False
src_shape
=
src_input_specs
[
i
].
shape
for
j
in
range
(
len_shape
):
other_shape
=
desired_input_specs
[
i
].
shape
if
src_shape
[
j
]
is
None
or
src_shape
[
j
]
<
0
:
len_shape
=
len
(
src_shape
)
continue
if
len_shape
!=
len
(
other_shape
):
if
other_shape
[
j
]
is
None
or
other_shape
[
j
]
<
0
:
return
False
continue
for
j
in
range
(
len_shape
):
if
src_shape
[
j
]
!=
other_shape
[
j
]:
if
src_shape
[
j
]
is
None
or
src_shape
[
j
]
<
0
:
continue
if
other_shape
[
j
]
is
None
or
other_shape
[
j
]
<
0
:
continue
if
src_shape
[
j
]
!=
other_shape
[
j
]:
return
False
src_dtype
=
convert_dtype
(
src_input_specs
[
i
].
dtype
)
other_dtype
=
convert_dtype
(
desired_input_specs
[
i
].
dtype
)
if
src_dtype
!=
other_dtype
:
return
False
return
False
src_dtype
=
convert_dtype
(
src_input_specs
[
i
].
dtype
)
other_dtype
=
convert_dtype
(
other_input_specs
[
i
].
dtype
)
if
src_dtype
!=
other_dtype
:
return
False
return
True
return
True
python/paddle/fluid/tests/unittests/test_jit_save_load.py
浏览文件 @
ae1f3209
...
@@ -95,6 +95,38 @@ class LinerNetWithLabel(paddle.nn.Layer):
...
@@ -95,6 +95,38 @@ class LinerNetWithLabel(paddle.nn.Layer):
return
out
,
avg_loss
return
out
,
avg_loss
class
LinerNetWithPruneInput
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_size
,
out_size
):
super
(
LinerNetWithPruneInput
,
self
).
__init__
()
self
.
_linear
=
Linear
(
in_size
,
out_size
)
@
declarative
(
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
],
dtype
=
'float32'
,
name
=
"image"
),
InputSpec
(
shape
=
[
None
,
1
],
dtype
=
'int64'
,
name
=
"label"
)
])
def
forward
(
self
,
x
,
label
):
out
=
self
.
_linear
(
x
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
return
out
class
LinerNetWithUselessInput
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_size
,
out_size
):
super
(
LinerNetWithUselessInput
,
self
).
__init__
()
self
.
_linear
=
Linear
(
in_size
,
out_size
)
@
declarative
(
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
],
dtype
=
'float32'
,
name
=
"image"
),
InputSpec
(
shape
=
[
None
,
1
],
dtype
=
'int64'
,
name
=
"label"
)
])
def
forward
(
self
,
x
,
label
):
out
=
self
.
_linear
(
x
)
return
out
class
LinearNetReturnLoss
(
fluid
.
dygraph
.
Layer
):
class
LinearNetReturnLoss
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
in_size
,
out_size
):
def
__init__
(
self
,
in_size
,
out_size
):
super
(
LinearNetReturnLoss
,
self
).
__init__
()
super
(
LinearNetReturnLoss
,
self
).
__init__
()
...
@@ -627,16 +659,24 @@ class TestJitSaveMultiCases(unittest.TestCase):
...
@@ -627,16 +659,24 @@ class TestJitSaveMultiCases(unittest.TestCase):
paddle
.
seed
(
SEED
)
paddle
.
seed
(
SEED
)
paddle
.
framework
.
random
.
_manual_program_seed
(
SEED
)
paddle
.
framework
.
random
.
_manual_program_seed
(
SEED
)
def
verify_inference_correctness
(
self
,
layer
,
model_path
,
with_label
=
False
):
def
verify_inference_correctness
(
self
,
layer
,
model_path
,
with_label_and_loss
=
False
,
with_label
=
False
):
layer
.
eval
()
layer
.
eval
()
loaded_layer
=
paddle
.
jit
.
load
(
model_path
)
loaded_layer
=
paddle
.
jit
.
load
(
model_path
)
loaded_layer
.
eval
()
loaded_layer
.
eval
()
# inference & compare
# inference & compare
x
=
paddle
.
to_tensor
(
np
.
random
.
random
((
1
,
784
)).
astype
(
'float32'
))
x
=
paddle
.
to_tensor
(
np
.
random
.
random
((
1
,
784
)).
astype
(
'float32'
))
if
with_label
:
if
with_label
_and_loss
:
y
=
paddle
.
to_tensor
(
np
.
random
.
random
((
1
,
1
)).
astype
(
'int64'
))
y
=
paddle
.
to_tensor
(
np
.
random
.
random
((
1
,
1
)).
astype
(
'int64'
))
pred
,
_
=
layer
(
x
,
y
)
pred
,
_
=
layer
(
x
,
y
)
pred
=
pred
.
numpy
()
pred
=
pred
.
numpy
()
elif
with_label
:
y
=
paddle
.
to_tensor
(
np
.
random
.
random
((
1
,
1
)).
astype
(
'int64'
))
pred
=
layer
(
x
,
y
)
pred
=
pred
.
numpy
()
else
:
else
:
pred
=
layer
(
x
).
numpy
()
pred
=
layer
(
x
).
numpy
()
loaded_pred
=
loaded_layer
(
x
).
numpy
()
loaded_pred
=
loaded_layer
(
x
).
numpy
()
...
@@ -714,7 +754,8 @@ class TestJitSaveMultiCases(unittest.TestCase):
...
@@ -714,7 +754,8 @@ class TestJitSaveMultiCases(unittest.TestCase):
],
],
output_spec
=
[
out
])
output_spec
=
[
out
])
self
.
verify_inference_correctness
(
layer
,
model_path
,
True
)
self
.
verify_inference_correctness
(
layer
,
model_path
,
with_label_and_loss
=
True
)
def
test_prune_to_static_no_train
(
self
):
def
test_prune_to_static_no_train
(
self
):
layer
=
LinerNetWithLabel
(
784
,
1
)
layer
=
LinerNetWithLabel
(
784
,
1
)
...
@@ -732,7 +773,36 @@ class TestJitSaveMultiCases(unittest.TestCase):
...
@@ -732,7 +773,36 @@ class TestJitSaveMultiCases(unittest.TestCase):
],
],
output_spec
=
output_spec
)
output_spec
=
output_spec
)
self
.
verify_inference_correctness
(
layer
,
model_path
,
True
)
self
.
verify_inference_correctness
(
layer
,
model_path
,
with_label_and_loss
=
True
)
def
test_prune_input_to_static_no_train
(
self
):
layer
=
LinerNetWithPruneInput
(
784
,
1
)
model_path
=
"test_prune_input_to_static_no_train/model"
paddle
.
jit
.
save
(
layer
,
model_path
,
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
],
dtype
=
'float32'
,
name
=
"image"
)
])
self
.
verify_inference_correctness
(
layer
,
model_path
,
with_label
=
True
)
def
test_prune_useless_input_to_static_no_train
(
self
):
layer
=
LinerNetWithUselessInput
(
784
,
1
)
model_path
=
"test_prune_useless_input_to_static_no_train/model"
paddle
.
jit
.
save
(
layer
,
model_path
,
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
],
dtype
=
'float32'
,
name
=
"image"
)
])
self
.
verify_inference_correctness
(
layer
,
model_path
,
with_label
=
True
)
def
test_no_prune_input_spec_name_warning
(
self
):
def
test_no_prune_input_spec_name_warning
(
self
):
layer
=
LinearNetWithInputSpec
(
784
,
1
)
layer
=
LinearNetWithInputSpec
(
784
,
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录