提交 f8267db6 编写于 作者: F fengjiayi

Explose check_attr to Python

上级 6604d7cd
...@@ -211,6 +211,15 @@ static InferShapeFuncMap &InferShapeFuncs() { ...@@ -211,6 +211,15 @@ static InferShapeFuncMap &InferShapeFuncs() {
return *g_map; return *g_map;
} }
void OpDescBind::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted.");
const auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
PADDLE_ENFORCE_NOT_NULL(checker, "Operator \"%s\" has no registered checker.",
Type());
checker->Check(attrs_);
}
void OpDescBind::InferShape(const BlockDescBind &block) const { void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs(); auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type()); auto it = funcs.find(this->Type());
......
...@@ -100,6 +100,8 @@ class OpDescBind { ...@@ -100,6 +100,8 @@ class OpDescBind {
return &this->attrs_; return &this->attrs_;
} }
void CheckAttrs();
void InferShape(const BlockDescBind &block) const; void InferShape(const BlockDescBind &block) const;
private: private:
......
...@@ -199,6 +199,7 @@ void BindOpDesc(py::module &m) { ...@@ -199,6 +199,7 @@ void BindOpDesc(py::module &m) {
.def("attr", &OpDescBind::GetAttr) .def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr) .def("get_block_attr", &OpDescBind::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape); .def("infer_shape", &OpDescBind::InferShape);
} }
......
...@@ -55,6 +55,12 @@ class TestOpDesc(unittest.TestCase): ...@@ -55,6 +55,12 @@ class TestOpDesc(unittest.TestCase):
op.set_block_attr("block_attr", prog.block(0)) op.set_block_attr("block_attr", prog.block(0))
self.assertEqual(0, op.get_block_attr("block_attr")) self.assertEqual(0, op.get_block_attr("block_attr"))
mul_op = block.append_op()
mul_op.set_type("mul")
mul_op.check_attrs()
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
class TestProgramDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase):
def test_instance(self): def test_instance(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册