未验证 提交 45880f60 编写于 作者: C Chen Weihang 提交者: GitHub

API(Program) error message enhancement (#23519)

* polish api program error message, test=develop

* fix condition error, test=develop

* fix test prune error, test=develop

* fix coverage problem, test=develop
上级 078dd05b
...@@ -228,9 +228,11 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ...@@ -228,9 +228,11 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
} }
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) { void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
PADDLE_ENFORCE(!desc_->has_forward_block_idx(), PADDLE_ENFORCE_EQ(
"Parent block ID has been set to %d. Cannot set to %d", desc_->has_forward_block_idx(), false,
desc_->forward_block_idx(), forward_block_id); platform::errors::PreconditionNotMet(
"Block %d's parent block ID has been set to %d, cannot be set to %d.",
desc_->idx(), desc_->forward_block_idx(), forward_block_id));
desc_->set_forward_block_idx(forward_block_id); desc_->set_forward_block_idx(forward_block_id);
} }
......
...@@ -99,8 +99,9 @@ void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) { ...@@ -99,8 +99,9 @@ void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
} }
ProgramDesc::ProgramDesc(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE_EQ(desc_.ParseFromString(binary_str), true,
"Fail to parse program_desc from binary string."); platform::errors::InvalidArgument(
"Failed to parse program_desc from binary string."));
InitFromProto(); InitFromProto();
} }
......
...@@ -36,8 +36,9 @@ static pybind11::bytes SerializeMessage( ...@@ -36,8 +36,9 @@ static pybind11::bytes SerializeMessage(
T &self) { // NOLINT due to pybind11 convention. T &self) { // NOLINT due to pybind11 convention.
// Check IsInitialized in Python // Check IsInitialized in Python
std::string retv; std::string retv;
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv), PADDLE_ENFORCE_EQ(self.Proto()->SerializePartialToString(&retv), true,
"Cannot serialize message"); platform::errors::InvalidArgument(
"Failed to serialize input Desc to string."));
return retv; return retv;
} }
...@@ -66,9 +67,10 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -66,9 +67,10 @@ void BindProgramDesc(pybind11::module *m) {
.def("parse_from_string", .def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) { [](pd::ProgramDesc &program_desc, const std::string &data) {
pd::proto::ProgramDesc *desc = program_desc.Proto(); pd::proto::ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data), PADDLE_ENFORCE_EQ(
"Fail to parse ProgramDesc from string. This could " desc->ParseFromString(data), true,
"be a bug of Paddle."); platform::errors::InvalidArgument(
"Failed to parse ProgramDesc from binary string."));
}) })
.def("_set_version", .def("_set_version",
[](pd::ProgramDesc &self, int64_t version) { [](pd::ProgramDesc &self, int64_t version) {
......
...@@ -3838,8 +3838,15 @@ class Program(object): ...@@ -3838,8 +3838,15 @@ class Program(object):
prog_string_with_detail = prog.to_string(throw_on_error=True, with_details=True) prog_string_with_detail = prog.to_string(throw_on_error=True, with_details=True)
print("program string with detail: {}".format(prog_string_with_detail)) print("program string with detail: {}".format(prog_string_with_detail))
""" """
assert isinstance(throw_on_error, bool) and isinstance(with_details, assert isinstance(
bool) throw_on_error, bool
), "The type of throw_on_error parameter is wrong, expected bool, but received {}.".format(
type(throw_on_error))
assert isinstance(
with_details, bool
), "The type of with_details parameter is wrong, expected bool, but received {}.".format(
type(with_details))
if with_details: if with_details:
res_str = "" res_str = ""
for block in self.blocks: for block in self.blocks:
...@@ -4105,8 +4112,9 @@ class Program(object): ...@@ -4105,8 +4112,9 @@ class Program(object):
for var in feeded_var_names: for var in feeded_var_names:
if not isinstance(var, six.string_types): if not isinstance(var, six.string_types):
raise ValueError("All feeded_var_names of prune() can only be " raise ValueError(
"str.") "All feeded_var_names of Program._prune_with_input() can only be "
"str, but received %s." % type(var))
targets_idx = [] targets_idx = []
for t in targets: for t in targets:
...@@ -4116,8 +4124,9 @@ class Program(object): ...@@ -4116,8 +4124,9 @@ class Program(object):
elif isinstance(t, six.string_types): elif isinstance(t, six.string_types):
name = str(t) name = str(t)
else: else:
raise ValueError("All targets of prune() can only be " raise ValueError(
"Variable or Operator.") "All targets of Program._prune_with_input() can only be "
"Variable or Operator, but received %s." % type(t))
# After transpiler processing, the op that output this # After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable # variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this # and we need to find the current op that generate this
...@@ -4327,7 +4336,9 @@ class Program(object): ...@@ -4327,7 +4336,9 @@ class Program(object):
@random_seed.setter @random_seed.setter
def random_seed(self, seed): def random_seed(self, seed):
if not isinstance(seed, int): if not isinstance(seed, int):
raise ValueError("Seed must be a integer.") raise ValueError(
"Program.random_seed's input seed must be an integer, but received %s."
% type(seed))
self._seed = seed self._seed = seed
def __repr__(self): def __repr__(self):
...@@ -4460,8 +4471,9 @@ class Program(object): ...@@ -4460,8 +4471,9 @@ class Program(object):
None None
""" """
if not isinstance(other, Program): if not isinstance(other, Program):
raise TypeError("_copy_param_info_from should be invoked with " raise TypeError(
"Program") "Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
% type(other))
self.global_block()._copy_param_info_from(other.global_block()) self.global_block()._copy_param_info_from(other.global_block())
...@@ -4476,8 +4488,9 @@ class Program(object): ...@@ -4476,8 +4488,9 @@ class Program(object):
None None
""" """
if not isinstance(other, Program): if not isinstance(other, Program):
raise TypeError("_copy_dist_param_info_from should be invoked with " raise TypeError(
"Program") "Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
% type(other))
self._is_distributed = other._is_distributed self._is_distributed = other._is_distributed
self._is_chief = other._is_chief self._is_chief = other._is_chief
self._parameters_on_pservers = other._parameters_on_pservers self._parameters_on_pservers = other._parameters_on_pservers
...@@ -4503,8 +4516,9 @@ class Program(object): ...@@ -4503,8 +4516,9 @@ class Program(object):
None None
""" """
if not isinstance(other, Program): if not isinstance(other, Program):
raise TypeError("_copy_data_info_from should be invoked with " raise TypeError(
"Program") "Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
% type(other))
if not pruned_origin_block_id_map: if not pruned_origin_block_id_map:
pruned_origin_block_id_map = { pruned_origin_block_id_map = {
......
...@@ -145,6 +145,23 @@ class TestProgram(unittest.TestCase): ...@@ -145,6 +145,23 @@ class TestProgram(unittest.TestCase):
self.assertEqual(param_list[0].name, "fc_0.w_0") self.assertEqual(param_list[0].name, "fc_0.w_0")
self.assertEqual(param_list[1].name, "fc_0.b_0") self.assertEqual(param_list[1].name, "fc_0.b_0")
def test_prune_with_input_type_error(self):
program = fluid.default_main_program()
feed_var_names = [2, 3, 4]
self.assertRaises(ValueError, program._prune_with_input, feed_var_names,
[])
def test_random_seed_error(self):
program = fluid.default_main_program()
with self.assertRaises(ValueError):
program.random_seed = "seed"
def test_copy_info_from_error(self):
program = fluid.default_main_program()
self.assertRaises(TypeError, program._copy_param_info_from, "program")
self.assertRaises(TypeError, program._copy_dist_param_info_from,
"program")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -94,8 +94,8 @@ class TestPrune(unittest.TestCase): ...@@ -94,8 +94,8 @@ class TestPrune(unittest.TestCase):
try: try:
pruned_program = program._prune(targets=None) pruned_program = program._prune(targets=None)
except ValueError as e: except ValueError as e:
self.assertEqual( self.assertIn(
"All targets of prune() can only be Variable or Operator.", "All targets of Program._prune_with_input() can only be Variable or Operator",
cpt.get_exception_message(e)) cpt.get_exception_message(e))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册