未验证 提交 307801d5 编写于 作者: F Feiyu Chan 提交者: GitHub

add strongly typed functions to set attributes to avoid unexpected type conversions. (#45107)

上级 642f6df9
...@@ -96,6 +96,15 @@ class OpDesc { ...@@ -96,6 +96,15 @@ class OpDesc {
void SetAttr(const std::string &name, const Attribute &v); void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name); void RemoveAttr(const std::string &name);
// NOTE(chenfeiyu): this template is added to avoid using a variant(Attribute)
// as a parameter of a function which is bound to python, which causes
// unexpected type conversion due to the overload resolution mechanism
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>
void SetPlainAttr(const std::string &name, const T &value) {
SetAttr(name, value);
}
void SetVarAttr(const std::string &name, VarDesc *var); void SetVarAttr(const std::string &name, VarDesc *var);
void SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars); void SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars);
......
...@@ -286,6 +286,8 @@ void BindOpDesc(pybind11::module *m) { ...@@ -286,6 +286,8 @@ void BindOpDesc(pybind11::module *m) {
.value("LONGS", pd::proto::AttrType::LONGS) .value("LONGS", pd::proto::AttrType::LONGS)
.value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOAT", pd::proto::AttrType::FLOAT)
.value("FLOATS", pd::proto::AttrType::FLOATS) .value("FLOATS", pd::proto::AttrType::FLOATS)
// .value("FLOAT64", pd::proto::AttrType::FLOAT64)
.value("FLOAT64S", pd::proto::AttrType::FLOAT64S)
.value("STRING", pd::proto::AttrType::STRING) .value("STRING", pd::proto::AttrType::STRING)
.value("STRINGS", pd::proto::AttrType::STRINGS) .value("STRINGS", pd::proto::AttrType::STRINGS)
.value("BOOL", pd::proto::AttrType::BOOLEAN) .value("BOOL", pd::proto::AttrType::BOOLEAN)
...@@ -361,6 +363,21 @@ void BindOpDesc(pybind11::module *m) { ...@@ -361,6 +363,21 @@ void BindOpDesc(pybind11::module *m) {
py::arg("with_attr_var") = false) py::arg("with_attr_var") = false)
.def("_set_attr", &pd::OpDesc::SetAttr) .def("_set_attr", &pd::OpDesc::SetAttr)
.def("remove_attr", &pd::OpDesc::RemoveAttr) .def("remove_attr", &pd::OpDesc::RemoveAttr)
.def("_set_bool_attr", &pd::OpDesc::SetPlainAttr<bool>)
.def("_set_int32_attr", &pd::OpDesc::SetPlainAttr<int>)
.def("_set_int64_attr", &pd::OpDesc::SetPlainAttr<int64_t>)
.def("_set_float32_attr", &pd::OpDesc::SetPlainAttr<float>)
// .def("_set_float64_attr", &pd::OpDesc::SetPlainAttr<double>)
.def("_set_str_attr", &pd::OpDesc::SetPlainAttr<std::string>)
.def("_set_bools_attr", &pd::OpDesc::SetPlainAttr<std::vector<bool>>)
.def("_set_int32s_attr", &pd::OpDesc::SetPlainAttr<std::vector<int>>)
.def("_set_int64s_attr", &pd::OpDesc::SetPlainAttr<std::vector<int64_t>>)
.def("_set_float32s_attr", &pd::OpDesc::SetPlainAttr<std::vector<float>>)
.def("_set_float64s_attr", &pd::OpDesc::SetPlainAttr<std::vector<double>>)
.def("_set_strs_attr",
&pd::OpDesc::SetPlainAttr<std::vector<std::string>>)
.def( .def(
"attr", "attr",
[](pd::OpDesc &self, const std::string &name, bool with_attr_var) { [](pd::OpDesc &self, const std::string &name, bool with_attr_var) {
......
...@@ -2675,6 +2675,16 @@ class Operator(object): ...@@ -2675,6 +2675,16 @@ class Operator(object):
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=None): attrs=None):
# read attr type index from op proto to avoid unexpected type
# conversions, e.g. narrowing conversion like double to float
try:
proto = OpProtoHolder.instance().get_op_proto(type)
self._attr_types = {}
for attr in proto.attrs:
self._attr_types[attr.name] = attr.type
except ValueError:
pass
if _non_static_mode(): if _non_static_mode():
if type is None: if type is None:
raise ValueError( raise ValueError(
...@@ -3159,7 +3169,42 @@ class Operator(object): ...@@ -3159,7 +3169,42 @@ class Operator(object):
isinstance(val, core.ProgramDesc): isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string()) self.desc.set_serialized_attr(name, val.serialize_to_string())
else: else:
self.desc._set_attr(name, val) self._update_desc_plain_attr(name, val)
def _update_desc_plain_attr(self, name, val):
desc = self.desc
if not hasattr(self, "_attr_types") or (name not in self._attr_types):
desc._set_attr(name, val)
return
type_index = self._attr_types[name]
if type_index == core.AttrType.BOOL:
desc._set_bool_attr(name, val)
elif type_index == core.AttrType.INT:
desc._set_int32_attr(name, val)
elif type_index == core.AttrType.LONG:
desc._set_int64_attr(name, val)
elif type_index == core.AttrType.FLOAT:
desc._set_float32_attr(name, val)
# elif type_index == core.AttrType.FLOAT64:
# desc._set_float64_attr(name, val)
elif type_index == core.AttrType.STRING:
desc._set_str_attr(name, val)
elif type_index == core.AttrType.BOOLS:
desc._set_bools_attr(name, val)
elif type_index == core.AttrType.INTS:
desc._set_int32s_attr(name, val)
elif type_index == core.AttrType.LONGS:
desc._set_int64s_attr(name, val)
elif type_index == core.AttrType.FLOATS:
desc._set_float32s_attr(name, val)
elif type_index == core.AttrType.FLOAT64S:
desc._set_float64s_attr(name, val)
elif type_index == core.AttrType.STRINGS:
desc._set_strs_attr(name, val)
else:
# defaults to old methods
desc._set_attr(name, val)
@property @property
def attr_names(self): def attr_names(self):
......
...@@ -206,7 +206,7 @@ class TestFoldOpError(unittest.TestCase): ...@@ -206,7 +206,7 @@ class TestFoldOpError(unittest.TestCase):
self.assertRaises(AssertionError, test_dilations_shape) self.assertRaises(AssertionError, test_dilations_shape)
self.assertRaises(AssertionError, test_strides_shape) self.assertRaises(AssertionError, test_strides_shape)
self.assertRaises(ValueError, test_output_size) self.assertRaises(ValueError, test_output_size)
self.assertRaises(ValueError, test_output_size_2) self.assertRaises(TypeError, test_output_size_2)
self.assertRaises(ValueError, test_block_h_w) self.assertRaises(ValueError, test_block_h_w)
self.assertRaises(ValueError, test_GT_0) self.assertRaises(ValueError, test_GT_0)
......
...@@ -111,7 +111,7 @@ class TestHistogramOpError(unittest.TestCase): ...@@ -111,7 +111,7 @@ class TestHistogramOpError(unittest.TestCase):
value=3.0) value=3.0)
paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5) paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5)
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
self.run_network(net_func) self.run_network(net_func)
def test_type_errors(self): def test_type_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册