未验证 提交 0b9acb49 编写于 作者: T tangwei12 提交者: GitHub

add check of executor (#17986)

* add check of executor, test=develop
上级 ec1000cc
...@@ -25,6 +25,7 @@ from paddle.fluid import layers ...@@ -25,6 +25,7 @@ from paddle.fluid import layers
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from . import reader from . import reader
from .reader import * from .reader import *
...@@ -187,6 +188,7 @@ def save_vars(executor, ...@@ -187,6 +188,7 @@ def save_vars(executor,
# saved in the same file named 'var_file' in the path "./my_paddle_vars". # saved in the same file named 'var_file' in the path "./my_paddle_vars".
""" """
save_dirname = os.path.normpath(dirname) save_dirname = os.path.normpath(dirname)
if vars is None: if vars is None:
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
...@@ -438,7 +440,7 @@ def _save_distributed_persistables(executor, dirname, main_program): ...@@ -438,7 +440,7 @@ def _save_distributed_persistables(executor, dirname, main_program):
return is_valid return is_valid
if not isinstance(main_program, Program): if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.") raise TypeError("'main_program' should be an instance of Program.")
if not main_program._is_distributed: if not main_program._is_distributed:
raise ValueError( raise ValueError(
...@@ -609,6 +611,7 @@ def load_vars(executor, ...@@ -609,6 +611,7 @@ def load_vars(executor,
# been saved in the same file named 'var_file' in the path "./my_paddle_vars". # been saved in the same file named 'var_file' in the path "./my_paddle_vars".
""" """
load_dirname = os.path.normpath(dirname) load_dirname = os.path.normpath(dirname)
if vars is None: if vars is None:
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
...@@ -627,6 +630,7 @@ def load_vars(executor, ...@@ -627,6 +630,7 @@ def load_vars(executor,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
if not isinstance(main_program, Program): if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None") raise TypeError("program should be as Program type or None")
...@@ -863,7 +867,7 @@ def _load_distributed_persistables(executor, dirname, main_program=None): ...@@ -863,7 +867,7 @@ def _load_distributed_persistables(executor, dirname, main_program=None):
executor.run(load_prog) executor.run(load_prog)
if not isinstance(main_program, Program): if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.") raise TypeError("'main_program' should be an instance of Program.")
if not main_program._is_distributed: if not main_program._is_distributed:
raise ValueError( raise ValueError(
...@@ -1027,6 +1031,9 @@ def save_inference_model(dirname, ...@@ -1027,6 +1031,9 @@ def save_inference_model(dirname,
we save the original program as inference model.", we save the original program as inference model.",
RuntimeWarning) RuntimeWarning)
elif not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
# fix the bug that the activation op's output as target will be pruned. # fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance. # will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op. # TODO(Superjomn) add an IR pass to remove 1-scale op.
......
...@@ -23,6 +23,7 @@ import paddle.fluid.core as core ...@@ -23,6 +23,7 @@ import paddle.fluid.core as core
import paddle.fluid.executor as executor import paddle.fluid.executor as executor
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer import paddle.fluid.optimizer as optimizer
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from paddle.fluid.io import save_inference_model, load_inference_model from paddle.fluid.io import save_inference_model, load_inference_model
from paddle.fluid.transpiler import memory_optimize from paddle.fluid.transpiler import memory_optimize
...@@ -114,5 +115,36 @@ class TestSaveInferenceModel(unittest.TestCase): ...@@ -114,5 +115,36 @@ class TestSaveInferenceModel(unittest.TestCase):
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program) save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
class TestInstance(unittest.TestCase):
def test_save_inference_model(self):
MODEL_DIR = "./tmp/inference_model3"
init_program = Program()
program = Program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
# will print warning message
cp_prog = CompiledProgram(program).with_data_parallel(
loss_name=avg_cost.name)
self.assertRaises(TypeError, save_inference_model,
[MODEL_DIR, ["x", "y"], [avg_cost], exe, cp_prog])
self.assertRaises(TypeError, save_inference_model,
[MODEL_DIR, ["x", "y"], [avg_cost], [], cp_prog])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册