diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 509aa235d3ee226adef15f08f5785866700499f1..b77d5525d4508056c9d6d487e63e500265e1d700 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const { return prog_->Block(static_cast(this->desc_->parent_idx())); } -void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { - BlockDesc *desc = block.RawPtr(); - this->attrs_[name] = desc; -} } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index c2e796b7c1b6e359765bafd6cd66fa16d69897a1..5d341584af7e214af56d92ff6ef6dc16ec13fe90 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -97,6 +97,11 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { need_update_ = true; } +void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { + BlockDesc *desc = block.RawPtr(); + this->attrs_[name] = desc; +} + void OpDescBind::SetAttrMap( const std::unordered_map &attr_map) { attrs_ = attr_map; diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 47bd7bc3bb3f0737ba3c9efe5b49defed87f36a1..534c615a5f16fab9f38eba6932a51b4bbc03acd4 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -196,7 +196,7 @@ void BindOpDesc(py::module &m) { .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr); + .def("block_attr", &OpDescBind::GetBlockAttr); } } // namespace pybind diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 8d808556851c42a6ecd096d247e9874f997808c1..0f446c46eb26c19b0e6e9c7fd7e10d66a9e08238 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -32,7 +32,7 @@ class OpProtoHolder(object): op_protos = get_all_op_protos() self.op_proto_map = {} for proto in op_protos: - sefl.op_proto_map[proto.type] = proto + self.op_proto_map[proto.type] = proto def get_op_proto(self, type): assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type @@ -116,7 +116,39 @@ class Operator(object): else: self.desc.set_block_attr(attr_name, attrs[attr_name].desc) - # TODO: Getters + @property + def type(self): + return self.desc.type() + + def input(self, name): + return self.desc.input(name) + + @property + def input_names(self): + return self.desc.input_names() + + def output(self, name): + return self.desc.output(name) + + @property + def output_names(self): + return self.desc.output_names() + + def has_attr(self, name): + return self.desc.has_attr(name) + + def attr_type(self, name): + return self.desc.attr_type(name) + + @property + def attr_names(self): + return self.desc.attr_names() + + def attr(self, name): + return self.desc.attr(name) + + def block_attr(self, name): + return self.desc.block_attr(name) class Block(object): diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..5ee1409c8f1abdf79873f5c1bf9860cac16fee8a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -0,0 +1,58 @@ +import unittest +from paddle.v2.framework.graph import Variable, g_program +import paddle.v2.framework.core as core + + +class TestOperator(unittest.TestCase): + def test_error_type(self): + block = g_program.create_block() + try: + block.append_op(type="no_such_op") + self.assertFail() + except AssertionError as err: + self.assertEqual( + err.message, + "Operator with type \"no_such_op\" has not been registered.") + + def test_input_output(self): + block = g_program.current_block() + mul_x = block.create_var( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mul_op = block.append_op( + type="mul", + inputs={"X": [mul_x], + "Y": mul_y}, + outputs={"Out": [mul_out]}) + self.assertEqual(mul_op.type, "mul") + self.assertEqual(mul_op.input_names, ["X", "Y"]) + self.assertEqual(mul_op.input("X"), ["x"]) + self.assertEqual(mul_op.output_names, ["Out"]) + self.assertEqual(mul_op.output("Out"), ["out"]) + + def test_mult_input(self): + block = g_program.current_block() + sum_x1 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x1") + sum_x2 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x2") + sum_x3 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x3") + sum_out = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.out") + sum_op = block.append_op( + type="sum", + inputs={"X": [sum_x1, sum_x2, sum_x3]}, + outputs={"Out": sum_out}) + self.assertEqual(sum_op.type, "sum") + self.assertEqual(sum_op.input_names, ["X"]) + self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"]) + self.assertEqual(sum_op.output_names, ["Out"]) + self.assertEqual(sum_op.output("Out"), ["out"]) + + +if __name__ == '__main__': + unittest.main()