未验证 提交 45880f60 编写于 作者: C Chen Weihang 提交者: GitHub

API(Program) error message enhancement (#23519)

* polish api program error message, test=develop

* fix condition error, test=develop

* fix test prune error, test=develop

* fix coverage problem, test=develop
上级 078dd05b
......@@ -228,9 +228,11 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
}
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
PADDLE_ENFORCE(!desc_->has_forward_block_idx(),
"Parent block ID has been set to %d. Cannot set to %d",
desc_->forward_block_idx(), forward_block_id);
PADDLE_ENFORCE_EQ(
desc_->has_forward_block_idx(), false,
platform::errors::PreconditionNotMet(
"Block %d's parent block ID has been set to %d, cannot be set to %d.",
desc_->idx(), desc_->forward_block_idx(), forward_block_id));
desc_->set_forward_block_idx(forward_block_id);
}
......
......@@ -99,8 +99,9 @@ void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
}
ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
PADDLE_ENFORCE_EQ(desc_.ParseFromString(binary_str), true,
platform::errors::InvalidArgument(
"Failed to parse program_desc from binary string."));
InitFromProto();
}
......
......@@ -36,8 +36,9 @@ static pybind11::bytes SerializeMessage(
T &self) { // NOLINT due to pybind11 convention.
// Check IsInitialized in Python
std::string retv;
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
"Cannot serialize message");
PADDLE_ENFORCE_EQ(self.Proto()->SerializePartialToString(&retv), true,
platform::errors::InvalidArgument(
"Failed to serialize input Desc to string."));
return retv;
}
......@@ -66,9 +67,10 @@ void BindProgramDesc(pybind11::module *m) {
.def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) {
pd::proto::ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle.");
PADDLE_ENFORCE_EQ(
desc->ParseFromString(data), true,
platform::errors::InvalidArgument(
"Failed to parse ProgramDesc from binary string."));
})
.def("_set_version",
[](pd::ProgramDesc &self, int64_t version) {
......
......@@ -3838,8 +3838,15 @@ class Program(object):
prog_string_with_detail = prog.to_string(throw_on_error=True, with_details=True)
print("program string with detail: {}".format(prog_string_with_detail))
"""
assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool)
assert isinstance(
throw_on_error, bool
), "The type of throw_on_error parameter is wrong, expected bool, but received {}.".format(
type(throw_on_error))
assert isinstance(
with_details, bool
), "The type of with_details parameter is wrong, expected bool, but received {}.".format(
type(with_details))
if with_details:
res_str = ""
for block in self.blocks:
......@@ -4105,8 +4112,9 @@ class Program(object):
for var in feeded_var_names:
if not isinstance(var, six.string_types):
raise ValueError("All feeded_var_names of prune() can only be "
"str.")
raise ValueError(
"All feeded_var_names of Program._prune_with_input() can only be "
"str, but received %s." % type(var))
targets_idx = []
for t in targets:
......@@ -4116,8 +4124,9 @@ class Program(object):
elif isinstance(t, six.string_types):
name = str(t)
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
raise ValueError(
"All targets of Program._prune_with_input() can only be "
"Variable or Operator, but received %s." % type(t))
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
......@@ -4327,7 +4336,9 @@ class Program(object):
@random_seed.setter
def random_seed(self, seed):
if not isinstance(seed, int):
raise ValueError("Seed must be a integer.")
raise ValueError(
"Program.random_seed's input seed must be an integer, but received %s."
% type(seed))
self._seed = seed
def __repr__(self):
......@@ -4460,8 +4471,9 @@ class Program(object):
None
"""
if not isinstance(other, Program):
raise TypeError("_copy_param_info_from should be invoked with "
"Program")
raise TypeError(
"Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
% type(other))
self.global_block()._copy_param_info_from(other.global_block())
......@@ -4476,8 +4488,9 @@ class Program(object):
None
"""
if not isinstance(other, Program):
raise TypeError("_copy_dist_param_info_from should be invoked with "
"Program")
raise TypeError(
"Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
% type(other))
self._is_distributed = other._is_distributed
self._is_chief = other._is_chief
self._parameters_on_pservers = other._parameters_on_pservers
......@@ -4503,8 +4516,9 @@ class Program(object):
None
"""
if not isinstance(other, Program):
raise TypeError("_copy_data_info_from should be invoked with "
"Program")
raise TypeError(
"Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
% type(other))
if not pruned_origin_block_id_map:
pruned_origin_block_id_map = {
......
......@@ -145,6 +145,23 @@ class TestProgram(unittest.TestCase):
self.assertEqual(param_list[0].name, "fc_0.w_0")
self.assertEqual(param_list[1].name, "fc_0.b_0")
def test_prune_with_input_type_error(self):
program = fluid.default_main_program()
feed_var_names = [2, 3, 4]
self.assertRaises(ValueError, program._prune_with_input, feed_var_names,
[])
def test_random_seed_error(self):
program = fluid.default_main_program()
with self.assertRaises(ValueError):
program.random_seed = "seed"
def test_copy_info_from_error(self):
program = fluid.default_main_program()
self.assertRaises(TypeError, program._copy_param_info_from, "program")
self.assertRaises(TypeError, program._copy_dist_param_info_from,
"program")
if __name__ == '__main__':
unittest.main()
......@@ -94,8 +94,8 @@ class TestPrune(unittest.TestCase):
try:
pruned_program = program._prune(targets=None)
except ValueError as e:
self.assertEqual(
"All targets of prune() can only be Variable or Operator.",
self.assertIn(
"All targets of Program._prune_with_input() can only be Variable or Operator",
cpt.get_exception_message(e))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册