Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
193d1430
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2301
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
193d1430
编写于
4月 24, 2020
作者:
石
石晓伟
提交者:
GitHub
4月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
supports loading model from memory, test=release/2.0 (#24099)
上级
8aa095ca
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
177 addition
and
76 deletion
+177
-76
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+18
-1
paddle/fluid/operators/save_combine_op.h
paddle/fluid/operators/save_combine_op.h
+19
-9
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+98
-49
python/paddle/fluid/tests/unittests/test_inference_model_io.py
...n/paddle/fluid/tests/unittests/test_inference_model_io.py
+38
-17
未找到文件。
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
193d1430
...
@@ -71,6 +71,23 @@ to a file on disk.
...
@@ -71,6 +71,23 @@ to a file on disk.
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
"The
\"
file_path
\"
where the LoDTensor variables will be saved."
)
.
AddCustomChecker
(
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddAttr
<
bool
>
(
"save_to_memory"
,
"(boolean, default false)"
"If true, the variables will be saved to binary strings."
)
.
SetDefault
(
false
);
AddOutput
(
"Y"
,
"(RAW, default empty)."
"This output is used when saving variables to binary strings."
)
.
AsDispensable
();
}
};
class
SaveCombineOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
for
(
auto
&
o
:
ctx
->
Output
(
"Y"
))
{
ctx
->
SetType
(
o
,
framework
::
proto
::
VarType
::
RAW
);
}
}
}
};
};
...
@@ -80,7 +97,7 @@ to a file on disk.
...
@@ -80,7 +97,7 @@ to a file on disk.
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
save_combine
,
ops
::
SaveCombineOp
,
REGISTER_OPERATOR
(
save_combine
,
ops
::
SaveCombineOp
,
ops
::
SaveCombineOpProtoMaker
);
ops
::
SaveCombineOpProtoMaker
,
ops
::
SaveCombineOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
save_combine
,
save_combine
,
...
...
paddle/fluid/operators/save_combine_op.h
浏览文件 @
193d1430
...
@@ -38,6 +38,8 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
...
@@ -38,6 +38,8 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
ctx
.
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
ctx
.
Attr
<
bool
>
(
"overwrite"
);
auto
save_as_fp16
=
ctx
.
Attr
<
bool
>
(
"save_as_fp16"
);
auto
save_as_fp16
=
ctx
.
Attr
<
bool
>
(
"save_as_fp16"
);
auto
save_to_memory
=
ctx
.
Attr
<
bool
>
(
"save_to_memory"
);
auto
output
=
ctx
.
Output
<
std
::
string
>
(
"Y"
);
bool
is_present
=
FileExists
(
filename
);
bool
is_present
=
FileExists
(
filename
);
if
(
is_present
&&
!
overwrite
)
{
if
(
is_present
&&
!
overwrite
)
{
...
@@ -47,12 +49,7 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
...
@@ -47,12 +49,7 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
filename
,
overwrite
));
filename
,
overwrite
));
}
}
MkDirRecursively
(
DirName
(
filename
).
c_str
());
std
::
ostringstream
ss
;
std
::
ofstream
fout
(
filename
,
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fout
),
true
,
platform
::
errors
::
Unavailable
(
"Cannot open %s to save variables."
,
filename
));
auto
inp_var_names
=
ctx
.
InputNames
(
"X"
);
auto
inp_var_names
=
ctx
.
InputNames
(
"X"
);
auto
&
inp_vars
=
ctx
.
MultiInputVar
(
"X"
);
auto
&
inp_vars
=
ctx
.
MultiInputVar
(
"X"
);
PADDLE_ENFORCE_GT
(
inp_var_names
.
size
(),
0UL
,
PADDLE_ENFORCE_GT
(
inp_var_names
.
size
(),
0UL
,
...
@@ -91,12 +88,25 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
...
@@ -91,12 +88,25 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
// copy LoD info to the new tensor
// copy LoD info to the new tensor
out
.
set_lod
(
tensor
.
lod
());
out
.
set_lod
(
tensor
.
lod
());
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
framework
::
SerializeToStream
(
fout
,
out
,
dev_ctx
);
framework
::
SerializeToStream
(
ss
,
out
,
dev_ctx
);
}
else
{
}
else
{
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
framework
::
SerializeToStream
(
ss
,
tensor
,
dev_ctx
);
}
}
}
}
fout
.
close
();
if
(
save_to_memory
)
{
PADDLE_ENFORCE_NE
(
output
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Cannot find variable Y for save_combine_op"
));
*
output
=
ss
.
str
();
}
else
{
MkDirRecursively
(
DirName
(
filename
).
c_str
());
std
::
ofstream
fout
(
filename
,
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fout
),
true
,
platform
::
errors
::
Unavailable
(
"Cannot open %s to save variables."
,
filename
));
fout
<<
ss
.
str
();
fout
.
close
();
}
}
}
};
};
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
193d1430
...
@@ -957,6 +957,10 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -957,6 +957,10 @@ All parameter, weight, gradient are variables in Paddle.
return
self
.
GetMutable
<
LoDTensor
>
();
return
self
.
GetMutable
<
LoDTensor
>
();
},
},
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"get_bytes"
,
[](
Variable
&
self
)
{
return
py
::
bytes
(
*
self
.
GetMutable
<
std
::
string
>
());
})
.
def
(
"get_lod_rank_table"
,
.
def
(
"get_lod_rank_table"
,
[](
Variable
&
self
)
{
return
self
.
GetMutable
<
LoDRankTable
>
();
},
[](
Variable
&
self
)
{
return
self
.
GetMutable
<
LoDRankTable
>
();
},
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
...
...
python/paddle/fluid/io.py
浏览文件 @
193d1430
...
@@ -36,6 +36,7 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def
...
@@ -36,6 +36,7 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.log_helper
import
get_logger
from
.
import
reader
from
.
import
reader
from
.
import
unique_name
from
.reader
import
*
from
.reader
import
*
from
.
import
dataloader
from
.
import
dataloader
from
.dataloader
import
*
from
.dataloader
import
*
...
@@ -231,7 +232,8 @@ def save_vars(executor,
...
@@ -231,7 +232,8 @@ def save_vars(executor,
Args:
Args:
executor(Executor): The executor to run for saving variables.
executor(Executor): The executor to run for saving variables.
dirname(str): The folder where to save variables.
dirname(str, optional): The folder where to save variables.
When you need to save the parameter to the memory, set it to None.
main_program(Program, optional): The program whose variables will be saved.
main_program(Program, optional): The program whose variables will be saved.
If it is None, the default main program will
If it is None, the default main program will
be used automatically.
be used automatically.
...
@@ -246,7 +248,8 @@ def save_vars(executor,
...
@@ -246,7 +248,8 @@ def save_vars(executor,
Default: None
Default: None
Returns:
Returns:
None
str: When saving parameters to a file, returns None.
When saving parameters to memory, returns a binary string containing parameters.
Raises:
Raises:
TypeError: If `main_program` is not an instance of Program nor None.
TypeError: If `main_program` is not an instance of Program nor None.
...
@@ -283,17 +286,21 @@ def save_vars(executor,
...
@@ -283,17 +286,21 @@ def save_vars(executor,
fluid.io.save_vars(executor=exe, dirname=param_path, main_program=main_prog, vars=None, predicate = name_has_fc)
fluid.io.save_vars(executor=exe, dirname=param_path, main_program=main_prog, vars=None, predicate = name_has_fc)
# all variables whose names contain "fc " are saved.
# all variables whose names contain "fc " are saved.
"""
"""
save_dirname
=
os
.
path
.
normpath
(
dirname
)
save_to_memory
=
False
if
dirname
is
None
and
filename
is
None
:
save_to_memory
=
True
main_program
=
_get_valid_program
(
main_program
)
main_program
=
_get_valid_program
(
main_program
)
if
vars
is
None
:
if
vars
is
None
:
save_vars
(
return
save_vars
(
executor
,
executor
,
main_program
=
main_program
,
main_program
=
main_program
,
dirname
=
save_
dirname
,
dirname
=
dirname
,
vars
=
list
(
filter
(
predicate
,
main_program
.
list_vars
())),
vars
=
list
(
filter
(
predicate
,
main_program
.
list_vars
())),
filename
=
filename
)
filename
=
filename
)
else
:
else
:
params_var_name
=
unique_name
.
generate
(
"saved_params"
)
# give warning when there is no var in model
# give warning when there is no var in model
if
len
(
list
(
vars
))
==
0
:
if
len
(
list
(
vars
))
==
0
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -310,33 +317,45 @@ def save_vars(executor,
...
@@ -310,33 +317,45 @@ def save_vars(executor,
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
continue
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
if
filename
is
None
:
if
filename
is
None
and
save_to_memory
is
False
:
save_file_path
=
os
.
path
.
join
(
save_dirname
,
new_var
.
name
)
save_file_path
=
os
.
path
.
join
(
save_file_path
=
os
.
path
.
normpath
(
save_file_path
)
os
.
path
.
normpath
(
dirname
),
new_var
.
name
)
save_block
.
append_op
(
save_block
.
append_op
(
type
=
'save'
,
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
outputs
=
{},
attrs
=
{
'file_path'
:
save_file_path
})
attrs
=
{
'file_path'
:
os
.
path
.
normpath
(
save_file_path
)
})
else
:
else
:
save_var_map
[
new_var
.
name
]
=
new_var
save_var_map
[
new_var
.
name
]
=
new_var
if
filename
is
not
None
:
if
filename
is
not
None
or
save_to_memory
:
save_var_list
=
[]
save_var_list
=
[]
for
name
in
sorted
(
save_var_map
.
keys
()):
for
name
in
sorted
(
save_var_map
.
keys
()):
save_var_list
.
append
(
save_var_map
[
name
])
save_var_list
.
append
(
save_var_map
[
name
])
save_path
=
str
()
if
save_to_memory
is
False
:
save_path
=
os
.
path
.
join
(
os
.
path
.
normpath
(
dirname
),
filename
)
saved_params
=
save_block
.
create_var
(
type
=
core
.
VarDesc
.
VarType
.
RAW
,
name
=
params_var_name
)
saved_params
.
desc
.
set_persistable
(
True
)
save_block
.
append_op
(
save_block
.
append_op
(
type
=
'save_combine'
,
type
=
'save_combine'
,
inputs
=
{
'X'
:
save_var_list
},
inputs
=
{
'X'
:
save_var_list
},
outputs
=
{},
outputs
=
{
'Y'
:
saved_params
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
save_dirname
,
filename
)})
attrs
=
{
'file_path'
:
save_path
,
'save_to_memory'
:
save_to_memory
})
#NOTE(zhiqiu): save op will add variable kLookupTablePath in save_program.desc,
#NOTE(zhiqiu): save op will add variable kLookupTablePath in save_program.desc,
# which leads to diff on save_program and its desc. Call _sync_with_cpp
# which leads to diff on save_program and its desc. Call _sync_with_cpp
# to keep consistency.
# to keep consistency.
save_program
.
_sync_with_cpp
()
save_program
.
_sync_with_cpp
()
executor
.
run
(
save_program
)
executor
.
run
(
save_program
)
if
save_to_memory
:
return
global_scope
().
find_var
(
params_var_name
).
get_bytes
()
def
save_params
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
def
save_params
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
...
@@ -364,7 +383,8 @@ def save_params(executor, dirname, main_program=None, filename=None):
...
@@ -364,7 +383,8 @@ def save_params(executor, dirname, main_program=None, filename=None):
Args:
Args:
executor(Executor): The executor to run for saving parameters, You can
executor(Executor): The executor to run for saving parameters, You can
refer to :ref:`api_guide_executor_en`.
refer to :ref:`api_guide_executor_en`.
dirname(str): The saving directory path.
dirname(str, optional): The saving directory path.
When you need to save the parameter to the memory, set it to None.
main_program(Program, optional): The program whose parameters will be
main_program(Program, optional): The program whose parameters will be
saved. You can refer to
saved. You can refer to
:ref:`api_guide_Program_en` for more
:ref:`api_guide_Program_en` for more
...
@@ -377,7 +397,8 @@ def save_params(executor, dirname, main_program=None, filename=None):
...
@@ -377,7 +397,8 @@ def save_params(executor, dirname, main_program=None, filename=None):
Default: None
Default: None
Returns:
Returns:
None
str: When saving parameters to a file, returns None.
When saving parameters to memory, returns a binary string containing parameters.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -399,7 +420,7 @@ def save_params(executor, dirname, main_program=None, filename=None):
...
@@ -399,7 +420,7 @@ def save_params(executor, dirname, main_program=None, filename=None):
# The parameters weights and bias of the fc layer in the network are going to
# The parameters weights and bias of the fc layer in the network are going to
# be saved in different files in the path "./my_paddle_model"
# be saved in different files in the path "./my_paddle_model"
"""
"""
save_vars
(
return
save_vars
(
executor
,
executor
,
dirname
=
dirname
,
dirname
=
dirname
,
main_program
=
main_program
,
main_program
=
main_program
,
...
@@ -576,8 +597,9 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
...
@@ -576,8 +597,9 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
executor(Executor): The executor to run for saving persistable variables.
executor(Executor): The executor to run for saving persistable variables.
You can refer to :ref:`api_guide_executor_en` for
You can refer to :ref:`api_guide_executor_en` for
more details.
more details.
dirname(str): The saving directory path.
dirname(str, optional): The saving directory path.
main_program(Program, optional): The program whose persistable variables will
When you need to save the parameter to the memory, set it to None.
main_program(Program, optional): The program whose persistbale variables will
be saved. You can refer to
be saved. You can refer to
:ref:`api_guide_Program_en` for more details.
:ref:`api_guide_Program_en` for more details.
If it is None, the default main program will
If it is None, the default main program will
...
@@ -588,7 +610,8 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
...
@@ -588,7 +610,8 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
Default: None.
Default: None.
Returns:
Returns:
None
str: When saving parameters to a file, returns None.
When saving parameters to memory, returns a binary string containing parameters.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -612,10 +635,10 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
...
@@ -612,10 +635,10 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
# "./my_paddle_model"
# "./my_paddle_model"
"""
"""
if
main_program
and
main_program
.
_is_distributed
:
if
main_program
and
main_program
.
_is_distributed
:
_save_distributed_persistables
(
return
_save_distributed_persistables
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
else
:
else
:
save_vars
(
return
save_vars
(
executor
,
executor
,
dirname
=
dirname
,
dirname
=
dirname
,
main_program
=
main_program
,
main_program
=
main_program
,
...
@@ -705,7 +728,11 @@ def load_vars(executor,
...
@@ -705,7 +728,11 @@ def load_vars(executor,
# And all the variables are supposed to be saved in separate files.
# And all the variables are supposed to be saved in separate files.
"""
"""
load_dirname
=
os
.
path
.
normpath
(
dirname
)
vars_from_memory
=
False
if
dirname
is
not
None
:
dirname
=
os
.
path
.
normpath
(
dirname
)
else
:
vars_from_memory
=
True
if
vars
is
None
:
if
vars
is
None
:
if
main_program
is
None
:
if
main_program
is
None
:
...
@@ -717,7 +744,7 @@ def load_vars(executor,
...
@@ -717,7 +744,7 @@ def load_vars(executor,
load_vars
(
load_vars
(
executor
,
executor
,
dirname
=
load_
dirname
,
dirname
=
dirname
,
main_program
=
main_program
,
main_program
=
main_program
,
vars
=
list
(
filter
(
predicate
,
main_program
.
list_vars
())),
vars
=
list
(
filter
(
predicate
,
main_program
.
list_vars
())),
filename
=
filename
)
filename
=
filename
)
...
@@ -746,13 +773,15 @@ def load_vars(executor,
...
@@ -746,13 +773,15 @@ def load_vars(executor,
))
))
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
if
filename
is
None
:
if
filename
is
None
:
if
dirname
is
None
:
raise
ValueError
(
"The directory path and params cannot be None at the same time."
)
load_block
.
append_op
(
load_block
.
append_op
(
type
=
'load'
,
type
=
'load'
,
inputs
=
{},
inputs
=
{},
outputs
=
{
'Out'
:
[
new_var
]},
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
new_var
.
name
)})
'file_path'
:
os
.
path
.
join
(
load_dirname
,
new_var
.
name
)
})
else
:
else
:
load_var_map
[
new_var
.
name
]
=
new_var
load_var_map
[
new_var
.
name
]
=
new_var
...
@@ -761,11 +790,17 @@ def load_vars(executor,
...
@@ -761,11 +790,17 @@ def load_vars(executor,
for
name
in
sorted
(
load_var_map
.
keys
()):
for
name
in
sorted
(
load_var_map
.
keys
()):
load_var_list
.
append
(
load_var_map
[
name
])
load_var_list
.
append
(
load_var_map
[
name
])
if
vars_from_memory
is
False
:
filename
=
os
.
path
.
join
(
dirname
,
filename
)
load_block
.
append_op
(
load_block
.
append_op
(
type
=
'load_combine'
,
type
=
'load_combine'
,
inputs
=
{},
inputs
=
{},
outputs
=
{
"Out"
:
load_var_list
},
outputs
=
{
"Out"
:
load_var_list
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
load_dirname
,
filename
)})
attrs
=
{
'file_path'
:
filename
,
'model_from_memory'
:
vars_from_memory
})
executor
.
run
(
load_prog
)
executor
.
run
(
load_prog
)
# check var shape
# check var shape
...
@@ -1248,19 +1283,22 @@ def load_inference_model(dirname,
...
@@ -1248,19 +1283,22 @@ def load_inference_model(dirname,
You can refer to :ref:`api_guide_model_save_reader_en` for more details.
You can refer to :ref:`api_guide_model_save_reader_en` for more details.
Args:
Args:
dirname(str): The given directory path.
dirname(str): One of the following:
- The given directory path.
- Set to None when reading the model from memory.
executor(Executor): The executor to run for loading inference model.
executor(Executor): The executor to run for loading inference model.
See :ref:`api_guide_executor_en` for more details about it.
See :ref:`api_guide_executor_en` for more details about it.
model_filename(str, optional): The name of file to load the inference program.
model_filename(str, optional): One of the following:
If it is None, the default filename
- The name of file to load the inference program.
``__model__`` will be used.
- If it is None, the default filename ``__model__`` will be used.
Default: ``None``.
- When ``dirname`` is ``None``, it must be set to a string containing model.
params_filename(str, optional): The name of file to load all parameters.
Default: ``None``.
It is only used for the case that all
params_filename(str, optional): It is only used for the case that all
parameters were saved in a single binary
parameters were saved in a single binary file. One of the following:
file. If parameters were saved in separate
- The name of file to load all parameters.
files, set it as ``None``.
- When ``dirname`` is ``None``, it must be set to a string containing all the parameters.
Default: ``None``.
- If parameters were saved in separate files, set it as ``None``.
Default: ``None``.
pserver_endpoints(list, optional): It is only needed by the distributed inference.
pserver_endpoints(list, optional): It is only needed by the distributed inference.
If using a distributed look up table during the training,
If using a distributed look up table during the training,
...
@@ -1328,21 +1366,32 @@ def load_inference_model(dirname,
...
@@ -1328,21 +1366,32 @@ def load_inference_model(dirname,
# fetch_targets, we can use an executor to run the inference
# fetch_targets, we can use an executor to run the inference
# program for getting the inference result.
# program for getting the inference result.
"""
"""
load_dirname
=
os
.
path
.
normpath
(
dirname
)
load_from_memory
=
False
if
not
os
.
path
.
isdir
(
load_dirname
):
if
dirname
is
not
None
:
raise
ValueError
(
"There is no directory named '%s'"
,
dirname
)
load_dirname
=
os
.
path
.
normpath
(
dirname
)
if
not
os
.
path
.
isdir
(
load_dirname
):
raise
ValueError
(
"There is no directory named '%s'"
,
dirname
)
if
model_filename
is
not
None
:
if
model_filename
is
None
:
model_filename
=
os
.
path
.
basename
(
model_filename
)
model_filename
=
'__model__'
else
:
model_filename
=
"__model__"
model_filename
=
os
.
path
.
join
(
load_dirname
,
model_filename
)
if
params_filename
is
not
None
:
model_filename
=
os
.
path
.
join
(
load_dirname
,
params_filename
=
os
.
path
.
basename
(
params_filename
)
os
.
path
.
basename
(
model_filename
))
if
params_filename
is
not
None
:
params_filename
=
os
.
path
.
basename
(
params_filename
)
with
open
(
model_filename
,
"rb"
)
as
f
:
with
open
(
model_filename
,
"rb"
)
as
f
:
program_desc_str
=
f
.
read
()
program_desc_str
=
f
.
read
()
else
:
load_from_memory
=
True
if
params_filename
is
None
:
raise
ValueError
(
"The path of params cannot be None when the directory path is None."
)
load_dirname
=
dirname
program_desc_str
=
model_filename
params_filename
=
params_filename
program
=
Program
.
parse_from_string
(
program_desc_str
)
program
=
Program
.
parse_from_string
(
program_desc_str
)
if
not
core
.
_is_program_version_supported
(
program
.
_version
()):
if
not
core
.
_is_program_version_supported
(
program
.
_version
()):
...
...
python/paddle/fluid/tests/unittests/test_inference_model_io.py
浏览文件 @
193d1430
...
@@ -16,6 +16,7 @@ from __future__ import print_function
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
unittest
import
os
import
six
import
six
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
...
@@ -27,13 +28,20 @@ import paddle.fluid.layers as layers
...
@@ -27,13 +28,20 @@ import paddle.fluid.layers as layers
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.optimizer
as
optimizer
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.io
import
save_inference_model
,
load_inference_model
from
paddle.fluid.io
import
save_inference_model
,
load_inference_model
,
save_persistables
from
paddle.fluid.transpiler
import
memory_optimize
from
paddle.fluid.transpiler
import
memory_optimize
class
TestBook
(
unittest
.
TestCase
):
class
TestBook
(
unittest
.
TestCase
):
class
InferModel
(
object
):
def
__init__
(
self
,
list
):
self
.
program
=
list
[
0
]
self
.
feed_var_names
=
list
[
1
]
self
.
fetch_vars
=
list
[
2
]
def
test_fit_line_inference_model
(
self
):
def
test_fit_line_inference_model
(
self
):
MODEL_DIR
=
"./tmp/inference_model"
MODEL_DIR
=
"./tmp/inference_model"
UNI_MODEL_DIR
=
"./tmp/inference_model1"
init_program
=
Program
()
init_program
=
Program
()
program
=
Program
()
program
=
Program
()
...
@@ -65,30 +73,43 @@ class TestBook(unittest.TestCase):
...
@@ -65,30 +73,43 @@ class TestBook(unittest.TestCase):
'y'
:
tensor_y
},
'y'
:
tensor_y
},
fetch_list
=
[
avg_cost
])
fetch_list
=
[
avg_cost
])
# Separated model and unified model
save_inference_model
(
MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
exe
,
program
)
save_inference_model
(
MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
exe
,
program
)
save_inference_model
(
UNI_MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
exe
,
program
,
'model'
,
'params'
)
main_program
=
program
.
clone
().
_prune_with_input
(
feeded_var_names
=
[
"x"
,
"y"
],
targets
=
[
avg_cost
])
params_str
=
save_persistables
(
exe
,
None
,
main_program
,
None
)
expected
=
exe
.
run
(
program
,
expected
=
exe
.
run
(
program
,
feed
=
{
'x'
:
tensor_x
,
feed
=
{
'x'
:
tensor_x
,
'y'
:
tensor_y
},
'y'
:
tensor_y
},
fetch_list
=
[
avg_cost
])[
0
]
fetch_list
=
[
avg_cost
])[
0
]
six
.
moves
.
reload_module
(
executor
)
# reload to build a new scope
six
.
moves
.
reload_module
(
executor
)
# reload to build a new scope
exe
=
executor
.
Executor
(
place
)
[
infer_prog
,
feed_var_names
,
fetch_vars
]
=
load_inference_model
(
model_0
=
self
.
InferModel
(
load_inference_model
(
MODEL_DIR
,
exe
))
MODEL_DIR
,
exe
)
with
open
(
os
.
path
.
join
(
UNI_MODEL_DIR
,
'model'
),
"rb"
)
as
f
:
model_str
=
f
.
read
()
outs
=
exe
.
run
(
model_1
=
self
.
InferModel
(
infer_prog
,
load_inference_model
(
None
,
exe
,
model_str
,
params_str
))
feed
=
{
feed_var_names
[
0
]:
tensor_x
,
feed_var_names
[
1
]:
tensor_y
},
for
model
in
[
model_0
,
model_1
]:
fetch_list
=
fetch_vars
)
outs
=
exe
.
run
(
model
.
program
,
actual
=
outs
[
0
]
feed
=
{
model
.
feed_var_names
[
0
]:
tensor_x
,
self
.
assertEqual
(
feed_var_names
,
[
"x"
,
"y"
])
model
.
feed_var_names
[
1
]:
tensor_y
self
.
assertEqual
(
len
(
fetch_vars
),
1
)
},
print
(
"fetch %s"
%
str
(
fetch_vars
[
0
]))
fetch_list
=
model
.
fetch_vars
)
self
.
assertTrue
(
"scale"
in
str
(
fetch_vars
[
0
]))
actual
=
outs
[
0
]
self
.
assertEqual
(
expected
,
actual
)
self
.
assertEqual
(
model
.
feed_var_names
,
[
"x"
,
"y"
])
self
.
assertEqual
(
len
(
model
.
fetch_vars
),
1
)
print
(
"fetch %s"
%
str
(
model
.
fetch_vars
[
0
]))
self
.
assertEqual
(
expected
,
actual
)
self
.
assertRaises
(
ValueError
,
fluid
.
io
.
load_inference_model
,
None
,
exe
,
model_str
,
None
)
class
TestSaveInferenceModel
(
unittest
.
TestCase
):
class
TestSaveInferenceModel
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录