未验证 提交 2424297f 编写于 作者: C Chen Weihang 提交者: GitHub

add dygraph not support limit for io apis, test=develop (#24342)

上级 8a1a2af8
...@@ -32,7 +32,7 @@ from paddle.fluid import layers ...@@ -32,7 +32,7 @@ from paddle.fluid import layers
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \ from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \
program_guard program_guard, dygraph_not_support
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
...@@ -121,6 +121,7 @@ def is_belong_to_optimizer(var): ...@@ -121,6 +121,7 @@ def is_belong_to_optimizer(var):
return False return False
@dygraph_not_support
def get_program_parameter(program): def get_program_parameter(program):
""" """
Get all the parameters from Program. Get all the parameters from Program.
...@@ -143,6 +144,7 @@ def get_program_parameter(program): ...@@ -143,6 +144,7 @@ def get_program_parameter(program):
return list(filter(is_parameter, program.list_vars())) return list(filter(is_parameter, program.list_vars()))
@dygraph_not_support
def get_program_persistable_vars(program): def get_program_persistable_vars(program):
""" """
Get all the persistable vars from Program. Get all the persistable vars from Program.
...@@ -213,6 +215,7 @@ def _get_valid_program(main_program): ...@@ -213,6 +215,7 @@ def _get_valid_program(main_program):
return main_program return main_program
@dygraph_not_support
def save_vars(executor, def save_vars(executor,
dirname, dirname,
main_program=None, main_program=None,
...@@ -359,6 +362,7 @@ def save_vars(executor, ...@@ -359,6 +362,7 @@ def save_vars(executor,
return global_scope().find_var(params_var_name).get_bytes() return global_scope().find_var(params_var_name).get_bytes()
@dygraph_not_support
def save_params(executor, dirname, main_program=None, filename=None): def save_params(executor, dirname, main_program=None, filename=None):
""" """
This operator saves all parameters from the :code:`main_program` to This operator saves all parameters from the :code:`main_program` to
...@@ -581,6 +585,7 @@ def _save_distributed_persistables(executor, dirname, main_program): ...@@ -581,6 +585,7 @@ def _save_distributed_persistables(executor, dirname, main_program):
main_program._endpoints) main_program._endpoints)
@dygraph_not_support
def save_persistables(executor, dirname, main_program=None, filename=None): def save_persistables(executor, dirname, main_program=None, filename=None):
""" """
This operator saves all persistable variables from :code:`main_program` to This operator saves all persistable variables from :code:`main_program` to
...@@ -648,6 +653,7 @@ def save_persistables(executor, dirname, main_program=None, filename=None): ...@@ -648,6 +653,7 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
filename=filename) filename=filename)
@dygraph_not_support
def load_vars(executor, def load_vars(executor,
dirname, dirname,
main_program=None, main_program=None,
...@@ -820,6 +826,7 @@ def load_vars(executor, ...@@ -820,6 +826,7 @@ def load_vars(executor,
format(orig_shape, each_var.name, new_shape)) format(orig_shape, each_var.name, new_shape))
@dygraph_not_support
def load_params(executor, dirname, main_program=None, filename=None): def load_params(executor, dirname, main_program=None, filename=None):
""" """
This API filters out all parameters from the give ``main_program`` This API filters out all parameters from the give ``main_program``
...@@ -877,6 +884,7 @@ def load_params(executor, dirname, main_program=None, filename=None): ...@@ -877,6 +884,7 @@ def load_params(executor, dirname, main_program=None, filename=None):
filename=filename) filename=filename)
@dygraph_not_support
def load_persistables(executor, dirname, main_program=None, filename=None): def load_persistables(executor, dirname, main_program=None, filename=None):
""" """
This API filters out all variables with ``persistable==True`` from the This API filters out all variables with ``persistable==True`` from the
...@@ -1065,6 +1073,7 @@ def append_fetch_ops(inference_program, ...@@ -1065,6 +1073,7 @@ def append_fetch_ops(inference_program,
attrs={'col': i}) attrs={'col': i})
@dygraph_not_support
def save_inference_model(dirname, def save_inference_model(dirname,
feeded_var_names, feeded_var_names,
target_vars, target_vars,
...@@ -1272,6 +1281,7 @@ def save_inference_model(dirname, ...@@ -1272,6 +1281,7 @@ def save_inference_model(dirname,
return target_var_name_list return target_var_name_list
@dygraph_not_support
def load_inference_model(dirname, def load_inference_model(dirname,
executor, executor,
model_filename=None, model_filename=None,
...@@ -1564,6 +1574,7 @@ def _load_persistable_nodes(executor, dirname, graph): ...@@ -1564,6 +1574,7 @@ def _load_persistable_nodes(executor, dirname, graph):
load_vars(executor=executor, dirname=dirname, vars=var_list) load_vars(executor=executor, dirname=dirname, vars=var_list)
@dygraph_not_support
def save(program, model_path): def save(program, model_path):
""" """
This function save parameters, optimizer information and network description to model_path. This function save parameters, optimizer information and network description to model_path.
...@@ -1622,6 +1633,7 @@ def save(program, model_path): ...@@ -1622,6 +1633,7 @@ def save(program, model_path):
f.write(program.desc.serialize_to_string()) f.write(program.desc.serialize_to_string())
@dygraph_not_support
def load(program, model_path, executor=None, var_list=None): def load(program, model_path, executor=None, var_list=None):
""" """
This function get parameters and optimizer information from program, and then get corresponding value from file. This function get parameters and optimizer information from program, and then get corresponding value from file.
...@@ -1788,6 +1800,7 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1788,6 +1800,7 @@ def load(program, model_path, executor=None, var_list=None):
set_var(v, load_dict[v.name]) set_var(v, load_dict[v.name])
@dygraph_not_support
def load_program_state(model_path, var_list=None): def load_program_state(model_path, var_list=None):
""" """
Load program state from local file Load program state from local file
...@@ -1918,6 +1931,7 @@ def load_program_state(model_path, var_list=None): ...@@ -1918,6 +1931,7 @@ def load_program_state(model_path, var_list=None):
return para_dict return para_dict
@dygraph_not_support
def set_program_state(program, state_dict): def set_program_state(program, state_dict):
""" """
Set program parameter from state_dict Set program parameter from state_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册