From b414645a6581b29912288de4942b6973f5431d67 Mon Sep 17 00:00:00 2001 From: 123malin Date: Fri, 12 Jul 2019 18:07:09 +0800 Subject: [PATCH] =?UTF-8?q?fix=20#17430:=20int64=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E7=9A=84attr=E8=AE=AD=E7=BB=83=E9=9D=9E=E9=A2=84=E6=9C=9F=20(#?= =?UTF-8?q?18264)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix int64_t * update fill constant op unittest * add empty line --- paddle/fluid/framework/attribute.h | 26 +++++++++++++++++ paddle/fluid/pybind/pybind_boost_headers.h | 9 ++++++ .../tests/unittests/test_fill_constant_op.py | 28 +++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 67054eccb33..aa452ac220e 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -133,6 +133,32 @@ struct ExtractAttribute> { const std::string& attr_name_; }; +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + float* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } else if (attr.type() == typeid(int64_t)) { // NOLINT + int64_t val = boost::get(attr); + attr = static_cast(val); + } + float* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type float, its type is %s", + attr_name_, paddle::platform::demangle(attr.type().name())); + } + return attr_value; + } + + const std::string& attr_name_; +}; + template inline proto::AttrType AttrTypeID() { Attribute tmp = T(); diff --git a/paddle/fluid/pybind/pybind_boost_headers.h b/paddle/fluid/pybind/pybind_boost_headers.h index 70c3136d095..3eb4db175a7 100644 --- a/paddle/fluid/pybind/pybind_boost_headers.h +++ b/paddle/fluid/pybind/pybind_boost_headers.h @@ -77,6 +77,15 @@ struct paddle_variant_caster> { } } + if (std::is_same::value) { + auto caster_int64 = make_caster(); + if (caster_int64.load(src, convert)) { + VLOG(4) << "this value are float and int64 satisfy simula."; + value = cast_op(caster_int64); + return true; + } + } + value = cast_op(caster); return true; } diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index fd59c5bb7cf..e22bd09ed06 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -50,6 +50,34 @@ class TestFillConstantOp2(OpTest): self.check_output() +class TestFillConstantOp3(OpTest): + def setUp(self): + '''Test fill_constant op with specified int64 value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 10000000000} + self.outputs = {'Out': np.full((123, 92), 10000000000)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp4(OpTest): + def setUp(self): + '''Test fill_constant op with specified int value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 3} + self.outputs = {'Out': np.full((123, 92), 3)} + + def test_check_output(self): + self.check_output() + + class TestFillConstantOpWithSelectedRows(OpTest): def check_with_place(self, place): scope = core.Scope() -- GitLab