diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 67054eccb3397ea40f0fb3e2ff2530ee1ea64736..aa452ac220ea63bbf7a79c09b90aadfd2764856b 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -133,6 +133,32 @@ struct ExtractAttribute> { const std::string& attr_name_; }; +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + float* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } else if (attr.type() == typeid(int64_t)) { // NOLINT + int64_t val = boost::get(attr); + attr = static_cast(val); + } + float* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type float, its type is %s", + attr_name_, paddle::platform::demangle(attr.type().name())); + } + return attr_value; + } + + const std::string& attr_name_; +}; + template inline proto::AttrType AttrTypeID() { Attribute tmp = T(); diff --git a/paddle/fluid/pybind/pybind_boost_headers.h b/paddle/fluid/pybind/pybind_boost_headers.h index 70c3136d095fbdcf27d6fec0b0b17140a3ee82ee..3eb4db175a745c8ea7a3afaff919e4f21d430a8b 100644 --- a/paddle/fluid/pybind/pybind_boost_headers.h +++ b/paddle/fluid/pybind/pybind_boost_headers.h @@ -77,6 +77,15 @@ struct paddle_variant_caster> { } } + if (std::is_same::value) { + auto caster_int64 = make_caster(); + if (caster_int64.load(src, convert)) { + VLOG(4) << "this value are float and int64 satisfy simula."; + value = cast_op(caster_int64); + return true; + } + } + value = cast_op(caster); return true; } diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index fd59c5bb7cff5dd33fae284ba3efe04e667ed75a..e22bd09ed06a5dc2385006498a7794a70c776de8 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -50,6 +50,34 @@ class TestFillConstantOp2(OpTest): self.check_output() +class TestFillConstantOp3(OpTest): + def setUp(self): + '''Test fill_constant op with specified int64 value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 10000000000} + self.outputs = {'Out': np.full((123, 92), 10000000000)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp4(OpTest): + def setUp(self): + '''Test fill_constant op with specified int value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 3} + self.outputs = {'Out': np.full((123, 92), 3)} + + def test_check_output(self): + self.check_output() + + class TestFillConstantOpWithSelectedRows(OpTest): def check_with_place(self, place): scope = core.Scope()