未验证 提交 307801d5 编写于 作者: F Feiyu Chan 提交者: GitHub

add strongly typed functions to set attributes to avoid unexpected type conversions. (#45107)

上级 642f6df9
......@@ -96,6 +96,15 @@ class OpDesc {
void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name);
// NOTE(chenfeiyu): this template is added to avoid using a variant(Attribute)
// as a parameter of a function which is bound to python, which causes
// unexpected type conversion due to the overload resolution mechanism
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>
void SetPlainAttr(const std::string &name, const T &value) {
SetAttr(name, value);
}
void SetVarAttr(const std::string &name, VarDesc *var);
void SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars);
......
......@@ -286,6 +286,8 @@ void BindOpDesc(pybind11::module *m) {
.value("LONGS", pd::proto::AttrType::LONGS)
.value("FLOAT", pd::proto::AttrType::FLOAT)
.value("FLOATS", pd::proto::AttrType::FLOATS)
// .value("FLOAT64", pd::proto::AttrType::FLOAT64)
.value("FLOAT64S", pd::proto::AttrType::FLOAT64S)
.value("STRING", pd::proto::AttrType::STRING)
.value("STRINGS", pd::proto::AttrType::STRINGS)
.value("BOOL", pd::proto::AttrType::BOOLEAN)
......@@ -361,6 +363,21 @@ void BindOpDesc(pybind11::module *m) {
py::arg("with_attr_var") = false)
.def("_set_attr", &pd::OpDesc::SetAttr)
.def("remove_attr", &pd::OpDesc::RemoveAttr)
.def("_set_bool_attr", &pd::OpDesc::SetPlainAttr<bool>)
.def("_set_int32_attr", &pd::OpDesc::SetPlainAttr<int>)
.def("_set_int64_attr", &pd::OpDesc::SetPlainAttr<int64_t>)
.def("_set_float32_attr", &pd::OpDesc::SetPlainAttr<float>)
// .def("_set_float64_attr", &pd::OpDesc::SetPlainAttr<double>)
.def("_set_str_attr", &pd::OpDesc::SetPlainAttr<std::string>)
.def("_set_bools_attr", &pd::OpDesc::SetPlainAttr<std::vector<bool>>)
.def("_set_int32s_attr", &pd::OpDesc::SetPlainAttr<std::vector<int>>)
.def("_set_int64s_attr", &pd::OpDesc::SetPlainAttr<std::vector<int64_t>>)
.def("_set_float32s_attr", &pd::OpDesc::SetPlainAttr<std::vector<float>>)
.def("_set_float64s_attr", &pd::OpDesc::SetPlainAttr<std::vector<double>>)
.def("_set_strs_attr",
&pd::OpDesc::SetPlainAttr<std::vector<std::string>>)
.def(
"attr",
[](pd::OpDesc &self, const std::string &name, bool with_attr_var) {
......
......@@ -2675,6 +2675,16 @@ class Operator(object):
inputs=None,
outputs=None,
attrs=None):
# read attr type index from op proto to avoid unexpected type
# conversions, e.g. narrowing conversion like double to float
try:
proto = OpProtoHolder.instance().get_op_proto(type)
self._attr_types = {}
for attr in proto.attrs:
self._attr_types[attr.name] = attr.type
except ValueError:
pass
if _non_static_mode():
if type is None:
raise ValueError(
......@@ -3159,7 +3169,42 @@ class Operator(object):
isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string())
else:
self.desc._set_attr(name, val)
self._update_desc_plain_attr(name, val)
def _update_desc_plain_attr(self, name, val):
desc = self.desc
if not hasattr(self, "_attr_types") or (name not in self._attr_types):
desc._set_attr(name, val)
return
type_index = self._attr_types[name]
if type_index == core.AttrType.BOOL:
desc._set_bool_attr(name, val)
elif type_index == core.AttrType.INT:
desc._set_int32_attr(name, val)
elif type_index == core.AttrType.LONG:
desc._set_int64_attr(name, val)
elif type_index == core.AttrType.FLOAT:
desc._set_float32_attr(name, val)
# elif type_index == core.AttrType.FLOAT64:
# desc._set_float64_attr(name, val)
elif type_index == core.AttrType.STRING:
desc._set_str_attr(name, val)
elif type_index == core.AttrType.BOOLS:
desc._set_bools_attr(name, val)
elif type_index == core.AttrType.INTS:
desc._set_int32s_attr(name, val)
elif type_index == core.AttrType.LONGS:
desc._set_int64s_attr(name, val)
elif type_index == core.AttrType.FLOATS:
desc._set_float32s_attr(name, val)
elif type_index == core.AttrType.FLOAT64S:
desc._set_float64s_attr(name, val)
elif type_index == core.AttrType.STRINGS:
desc._set_strs_attr(name, val)
else:
# defaults to old methods
desc._set_attr(name, val)
@property
def attr_names(self):
......
......@@ -206,7 +206,7 @@ class TestFoldOpError(unittest.TestCase):
self.assertRaises(AssertionError, test_dilations_shape)
self.assertRaises(AssertionError, test_strides_shape)
self.assertRaises(ValueError, test_output_size)
self.assertRaises(ValueError, test_output_size_2)
self.assertRaises(TypeError, test_output_size_2)
self.assertRaises(ValueError, test_block_h_w)
self.assertRaises(ValueError, test_GT_0)
......
......@@ -111,7 +111,7 @@ class TestHistogramOpError(unittest.TestCase):
value=3.0)
paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5)
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
self.run_network(net_func)
def test_type_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册