From 307801d5ecfe1aa64ca8eba74cf70170e382325a Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 16 Aug 2022 10:24:13 +0800 Subject: [PATCH] add strongly typed functions to set attributes to avoid unexpected type conversions. (#45107) --- paddle/fluid/framework/op_desc.h | 9 ++++ paddle/fluid/pybind/protobuf.cc | 17 +++++++ python/paddle/fluid/framework.py | 47 ++++++++++++++++++- .../fluid/tests/unittests/test_fold_op.py | 2 +- .../tests/unittests/test_histogram_op.py | 2 +- 5 files changed, 74 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 7b0d7c587e7..a2f503f4b96 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -96,6 +96,15 @@ class OpDesc { void SetAttr(const std::string &name, const Attribute &v); void RemoveAttr(const std::string &name); + // NOTE(chenfeiyu): this template is added to avoid using a variant(Attribute) + // as a parameter of a function which is bound to python, which causes + // unexpected type conversion due to the overload resolution mechanism + // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers + template + void SetPlainAttr(const std::string &name, const T &value) { + SetAttr(name, value); + } + void SetVarAttr(const std::string &name, VarDesc *var); void SetVarsAttr(const std::string &name, std::vector vars); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 6736c79d5af..ab725575351 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -286,6 +286,8 @@ void BindOpDesc(pybind11::module *m) { .value("LONGS", pd::proto::AttrType::LONGS) .value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOATS", pd::proto::AttrType::FLOATS) + // .value("FLOAT64", pd::proto::AttrType::FLOAT64) + .value("FLOAT64S", pd::proto::AttrType::FLOAT64S) .value("STRING", pd::proto::AttrType::STRING) .value("STRINGS", pd::proto::AttrType::STRINGS) .value("BOOL", pd::proto::AttrType::BOOLEAN) @@ -361,6 +363,21 @@ void BindOpDesc(pybind11::module *m) { py::arg("with_attr_var") = false) .def("_set_attr", &pd::OpDesc::SetAttr) .def("remove_attr", &pd::OpDesc::RemoveAttr) + .def("_set_bool_attr", &pd::OpDesc::SetPlainAttr) + .def("_set_int32_attr", &pd::OpDesc::SetPlainAttr) + .def("_set_int64_attr", &pd::OpDesc::SetPlainAttr) + .def("_set_float32_attr", &pd::OpDesc::SetPlainAttr) + // .def("_set_float64_attr", &pd::OpDesc::SetPlainAttr) + .def("_set_str_attr", &pd::OpDesc::SetPlainAttr) + + .def("_set_bools_attr", &pd::OpDesc::SetPlainAttr>) + .def("_set_int32s_attr", &pd::OpDesc::SetPlainAttr>) + .def("_set_int64s_attr", &pd::OpDesc::SetPlainAttr>) + .def("_set_float32s_attr", &pd::OpDesc::SetPlainAttr>) + .def("_set_float64s_attr", &pd::OpDesc::SetPlainAttr>) + .def("_set_strs_attr", + &pd::OpDesc::SetPlainAttr>) + .def( "attr", [](pd::OpDesc &self, const std::string &name, bool with_attr_var) { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index edf68762328..3d7a743376c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2675,6 +2675,16 @@ class Operator(object): inputs=None, outputs=None, attrs=None): + # read attr type index from op proto to avoid unexpected type + # conversions, e.g. narrowing conversion like double to float + try: + proto = OpProtoHolder.instance().get_op_proto(type) + self._attr_types = {} + for attr in proto.attrs: + self._attr_types[attr.name] = attr.type + except ValueError: + pass + if _non_static_mode(): if type is None: raise ValueError( @@ -3159,7 +3169,42 @@ class Operator(object): isinstance(val, core.ProgramDesc): self.desc.set_serialized_attr(name, val.serialize_to_string()) else: - self.desc._set_attr(name, val) + self._update_desc_plain_attr(name, val) + + def _update_desc_plain_attr(self, name, val): + desc = self.desc + if not hasattr(self, "_attr_types") or (name not in self._attr_types): + desc._set_attr(name, val) + return + + type_index = self._attr_types[name] + if type_index == core.AttrType.BOOL: + desc._set_bool_attr(name, val) + elif type_index == core.AttrType.INT: + desc._set_int32_attr(name, val) + elif type_index == core.AttrType.LONG: + desc._set_int64_attr(name, val) + elif type_index == core.AttrType.FLOAT: + desc._set_float32_attr(name, val) + # elif type_index == core.AttrType.FLOAT64: + # desc._set_float64_attr(name, val) + elif type_index == core.AttrType.STRING: + desc._set_str_attr(name, val) + elif type_index == core.AttrType.BOOLS: + desc._set_bools_attr(name, val) + elif type_index == core.AttrType.INTS: + desc._set_int32s_attr(name, val) + elif type_index == core.AttrType.LONGS: + desc._set_int64s_attr(name, val) + elif type_index == core.AttrType.FLOATS: + desc._set_float32s_attr(name, val) + elif type_index == core.AttrType.FLOAT64S: + desc._set_float64s_attr(name, val) + elif type_index == core.AttrType.STRINGS: + desc._set_strs_attr(name, val) + else: + # defaults to old methods + desc._set_attr(name, val) @property def attr_names(self): diff --git a/python/paddle/fluid/tests/unittests/test_fold_op.py b/python/paddle/fluid/tests/unittests/test_fold_op.py index fc873cda95b..8ae0442a1e7 100644 --- a/python/paddle/fluid/tests/unittests/test_fold_op.py +++ b/python/paddle/fluid/tests/unittests/test_fold_op.py @@ -206,7 +206,7 @@ class TestFoldOpError(unittest.TestCase): self.assertRaises(AssertionError, test_dilations_shape) self.assertRaises(AssertionError, test_strides_shape) self.assertRaises(ValueError, test_output_size) - self.assertRaises(ValueError, test_output_size_2) + self.assertRaises(TypeError, test_output_size_2) self.assertRaises(ValueError, test_block_h_w) self.assertRaises(ValueError, test_GT_0) diff --git a/python/paddle/fluid/tests/unittests/test_histogram_op.py b/python/paddle/fluid/tests/unittests/test_histogram_op.py index 17b7b95942f..ccace6ebc11 100644 --- a/python/paddle/fluid/tests/unittests/test_histogram_op.py +++ b/python/paddle/fluid/tests/unittests/test_histogram_op.py @@ -111,7 +111,7 @@ class TestHistogramOpError(unittest.TestCase): value=3.0) paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): self.run_network(net_func) def test_type_errors(self): -- GitLab