From f8267db65714885ec240442877740b93a8074856 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 10 Oct 2017 20:26:36 -0700 Subject: [PATCH] Explose check_attr to Python --- paddle/framework/op_desc.cc | 9 +++++++++ paddle/framework/op_desc.h | 2 ++ paddle/pybind/protobuf.cc | 1 + python/paddle/v2/framework/tests/test_protobuf_descs.py | 6 ++++++ 4 files changed, 18 insertions(+) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index e7538b4af34..d3c11ad60a0 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -211,6 +211,15 @@ static InferShapeFuncMap &InferShapeFuncs() { return *g_map; } +void OpDescBind::CheckAttrs() { + PADDLE_ENFORCE(!Type().empty(), + "CheckAttr() can not be called before type is setted."); + const auto *checker = OpInfoMap::Instance().Get(Type()).Checker(); + PADDLE_ENFORCE_NOT_NULL(checker, "Operator \"%s\" has no registered checker.", + Type()); + checker->Check(attrs_); +} + void OpDescBind::InferShape(const BlockDescBind &block) const { auto &funcs = InferShapeFuncs(); auto it = funcs.find(this->Type()); diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 81c42250411..90155fadeac 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -100,6 +100,8 @@ class OpDescBind { return &this->attrs_; } + void CheckAttrs(); + void InferShape(const BlockDescBind &block) const; private: diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 116c99bd2c1..c73d064fcfd 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -199,6 +199,7 @@ void BindOpDesc(py::module &m) { .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape); } diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 2b7ba6688a6..3db1e79ce43 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -55,6 +55,12 @@ class TestOpDesc(unittest.TestCase): op.set_block_attr("block_attr", prog.block(0)) self.assertEqual(0, op.get_block_attr("block_attr")) + mul_op = block.append_op() + mul_op.set_type("mul") + mul_op.check_attrs() + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_op.attr("y_num_col_dims"), 1) + class TestProgramDesc(unittest.TestCase): def test_instance(self): -- GitLab