From f2522e91c43c1cf36d661d38b07ecd31f7050f40 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 14 Jan 2020 09:56:59 +0800 Subject: [PATCH] fix the type error caused by setting bool attr in OpDesc. test=develop (#22257) --- paddle/fluid/framework/op_desc.cc | 8 ++++ .../tests/unittests/test_set_bool_attr.py | 39 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_set_bool_attr.py diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 69823126ad..87a99afc9a 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -469,6 +469,14 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { return; } + // In order to set bool attr properly + if (attr_type == proto::AttrType::INT && HasProtoAttr(name) && + GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) { + this->attrs_[name] = static_cast(boost::get(v)); + need_update_ = true; + return; + } + this->attrs_[name] = v; need_update_ = true; } diff --git a/python/paddle/fluid/tests/unittests/test_set_bool_attr.py b/python/paddle/fluid/tests/unittests/test_set_bool_attr.py new file mode 100644 index 0000000000..827f63fe82 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_set_bool_attr.py @@ -0,0 +1,39 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest + + +class TestAttrSet(unittest.TestCase): + def test_set_bool_attr(self): + x = fluid.layers.data(name='x', shape=[3, 7, 3, 7], dtype='float32') + param_attr = fluid.ParamAttr( + name='batch_norm_w', + initializer=fluid.initializer.Constant(value=1.0)) + bias_attr = fluid.ParamAttr( + name='batch_norm_b', + initializer=fluid.initializer.Constant(value=0.0)) + bn = fluid.layers.batch_norm( + input=x, param_attr=param_attr, bias_attr=bias_attr) + block = fluid.default_main_program().desc.block(0) + op = block.op(0) + before_type = op.attr_type('is_test') + op._set_attr('is_test', True) + after_type = op.attr_type('is_test') + self.assertEqual(before_type, after_type) + + +if __name__ == '__main__': + unittest.main() -- GitLab