未验证 提交 f10100eb 编写于 作者: C Chen Weihang 提交者: GitHub

API (save/load series) error message enhancement (#23644)

上级 f39899a4
......@@ -197,11 +197,15 @@ def _get_valid_program(main_program):
elif isinstance(main_program, CompiledProgram):
main_program = main_program._program
if main_program is None:
raise TypeError("program should be as Program type or None")
raise TypeError(
"The type of input main_program is invalid, expected tyep is Program, but received None"
)
warnings.warn(
"The input is a CompiledProgram, this is not recommended.")
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
raise TypeError(
"The type of input main_program is invalid, expected type is fluid.Program, but received %s"
% type(main_program))
return main_program
......@@ -705,7 +709,9 @@ def load_vars(executor,
if main_program is None:
main_program = default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program's type should be Program")
raise TypeError(
"The type of input main_program is invalid, expected type is fluid.Program, but received %s"
% type(main_program))
load_vars(
executor,
......@@ -721,7 +727,9 @@ def load_vars(executor,
main_program = default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
raise TypeError(
"The type of input main_program is invalid, expected type is fluid.Program, but received %s"
% type(main_program))
# save origin param shape
orig_para_shape = {}
......@@ -769,7 +777,7 @@ def load_vars(executor,
orig_shape = orig_para_shape.get(each_var.name)
if new_shape != orig_shape:
raise RuntimeError(
"Shape not matching: the Program requires a parameter with a shape of ({}), "
"Variable's shape does not match, the Program requires a parameter with the shape of ({}), "
"while the loaded parameter (namely [ {} ]) has a shape of ({}).".
format(orig_shape, each_var.name, new_shape))
......@@ -1385,7 +1393,7 @@ def get_parameter_value(para, executor):
p = fluid.io.get_parameter_value(param, exe)
"""
assert is_parameter(para)
assert is_parameter(para), "The input variable is not parameter."
get_program = Program()
block = get_program.global_block()
......@@ -1531,7 +1539,7 @@ def save(program, model_path):
base_name = os.path.basename(model_path)
assert base_name != "", \
"model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
......@@ -1642,7 +1650,7 @@ def load(program, model_path, executor=None, var_list=None):
raise e
except:
raise RuntimeError(
"Failed to load model file , please make sure model file is saved with the "
"Failed to load model file, please make sure model file is saved with the "
"following APIs: save_params, save_persistables, save_vars")
return
......@@ -1658,7 +1666,7 @@ def load(program, model_path, executor=None, var_list=None):
for var in var_list:
if var.name not in program_var_name_set:
raise LookupError(
"loaded var [{}] not included in program variable list")
"loaded var [{}] is not in program variable list")
dir_name, file_name = os.path.split(model_path)
try:
......@@ -1903,11 +1911,11 @@ def set_program_state(program, state_dict):
orig_para_np = np.array(var_temp.get_tensor())
new_para_np = state_dict[para.name]
assert orig_para_np.shape == new_para_np.shape, \
"Shape not matching: the Program requires a parameter with a shape of ({}), " \
"Parameter's shape does not match, the Program requires a parameter with the shape of ({}), " \
"while the loaded parameter (namely [ {} ]) has a shape of ({})." \
.format(orig_para_np.shape, para.name, new_para_np.shape)
assert orig_para_np.dtype == new_para_np.dtype, \
"Dtype not matching: the Program requires a parameter with a dtype of ({}), " \
"Parameter's data type does not match, the Program requires a parameter with a dtype of ({}), " \
"while the loaded parameter (namely [ {} ]) has a dtype of ({})." \
.format(orig_para_np.dtype, para.name, new_para_np.dtype)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from paddle.fluid import core
class TestSaveLoadAPIError(unittest.TestCase):
def test_get_valid_program_error(self):
# case 1: CompiledProgram no program
graph = core.Graph(core.ProgramDesc())
compiled_program = fluid.CompiledProgram(graph)
with self.assertRaises(TypeError):
fluid.io._get_valid_program(compiled_program)
# case 2: main_program type error
with self.assertRaises(TypeError):
fluid.io._get_valid_program("program")
def test_load_vars_error(self):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# case 1: main_program type error when vars None
with self.assertRaises(TypeError):
fluid.io.load_vars(
executor=exe, dirname="./fake_dir", main_program="program")
# case 2: main_program type error when vars not None
with self.assertRaises(TypeError):
fluid.io.load_vars(
executor=exe,
dirname="./fake_dir",
main_program="program",
vars="vars")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册