diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index e7538b4af3429e566a439d5a0db8496efcd94969..d3c11ad60a0f9319329a59c16bfc4668cd75b7ae 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -211,6 +211,15 @@ static InferShapeFuncMap &InferShapeFuncs() { return *g_map; } +void OpDescBind::CheckAttrs() { + PADDLE_ENFORCE(!Type().empty(), + "CheckAttr() can not be called before type is setted."); + const auto *checker = OpInfoMap::Instance().Get(Type()).Checker(); + PADDLE_ENFORCE_NOT_NULL(checker, "Operator \"%s\" has no registered checker.", + Type()); + checker->Check(attrs_); +} + void OpDescBind::InferShape(const BlockDescBind &block) const { auto &funcs = InferShapeFuncs(); auto it = funcs.find(this->Type()); diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 81c4225041157ac600d1db73ef2363ebcd4abfc0..90155fadeac148bd9cae4ce9066ac4ce8d9df52d 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -100,6 +100,8 @@ class OpDescBind { return &this->attrs_; } + void CheckAttrs(); + void InferShape(const BlockDescBind &block) const; private: diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 116c99bd2c1ca59b093392f9e6cc481c089309bc..c73d064fcfd3dffd36a86c069349fdb22ec4e27e 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -199,6 +199,7 @@ void BindOpDesc(py::module &m) { .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape); } diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 2b7ba6688a65c466d5bc656178f2991da8dfe016..3db1e79ce43b7f559c7caab8397817b76d56161e 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -55,6 +55,12 @@ class TestOpDesc(unittest.TestCase): op.set_block_attr("block_attr", prog.block(0)) self.assertEqual(0, op.get_block_attr("block_attr")) + mul_op = block.append_op() + mul_op.set_type("mul") + mul_op.check_attrs() + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_op.attr("y_num_col_dims"), 1) + class TestProgramDesc(unittest.TestCase): def test_instance(self):