提交 247fb2a0 编写于 作者: F fengjiayi

Add unittests

上级 cffca923
...@@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const { ...@@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx())); return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
this->attrs_[name] = desc;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -97,6 +97,11 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { ...@@ -97,6 +97,11 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
need_update_ = true; need_update_ = true;
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
this->attrs_[name] = desc;
}
void OpDescBind::SetAttrMap( void OpDescBind::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) { const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map; attrs_ = attr_map;
......
...@@ -196,7 +196,7 @@ void BindOpDesc(py::module &m) { ...@@ -196,7 +196,7 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr) .def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr) .def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr); .def("block_attr", &OpDescBind::GetBlockAttr);
} }
} // namespace pybind } // namespace pybind
......
...@@ -32,7 +32,7 @@ class OpProtoHolder(object): ...@@ -32,7 +32,7 @@ class OpProtoHolder(object):
op_protos = get_all_op_protos() op_protos = get_all_op_protos()
self.op_proto_map = {} self.op_proto_map = {}
for proto in op_protos: for proto in op_protos:
sefl.op_proto_map[proto.type] = proto self.op_proto_map[proto.type] = proto
def get_op_proto(self, type): def get_op_proto(self, type):
assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
...@@ -116,7 +116,39 @@ class Operator(object): ...@@ -116,7 +116,39 @@ class Operator(object):
else: else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
# TODO: Getters @property
def type(self):
return self.desc.type()
def input(self, name):
return self.desc.input(name)
@property
def input_names(self):
return self.desc.input_names()
def output(self, name):
return self.desc.output(name)
@property
def output_names(self):
return self.desc.output_names()
def has_attr(self, name):
return self.desc.has_attr(name)
def attr_type(self, name):
return self.desc.attr_type(name)
@property
def attr_names(self):
return self.desc.attr_names()
def attr(self, name):
return self.desc.attr(name)
def block_attr(self, name):
return self.desc.block_attr(name)
class Block(object): class Block(object):
......
import unittest
from paddle.v2.framework.graph import Variable, g_program
import paddle.v2.framework.core as core
class TestOperator(unittest.TestCase):
def test_error_type(self):
block = g_program.create_block()
try:
block.append_op(type="no_such_op")
self.assertFail()
except AssertionError as err:
self.assertEqual(
err.message,
"Operator with type \"no_such_op\" has not been registered.")
def test_input_output(self):
block = g_program.current_block()
mul_x = block.create_var(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mul_op = block.append_op(
type="mul",
inputs={"X": [mul_x],
"Y": mul_y},
outputs={"Out": [mul_out]})
self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["x"])
self.assertEqual(mul_op.output_names, ["Out"])
self.assertEqual(mul_op.output("Out"), ["out"])
def test_mult_input(self):
block = g_program.current_block()
sum_x1 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x1")
sum_x2 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x2")
sum_x3 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x3")
sum_out = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.out")
sum_op = block.append_op(
type="sum",
inputs={"X": [sum_x1, sum_x2, sum_x3]},
outputs={"Out": sum_out})
self.assertEqual(sum_op.type, "sum")
self.assertEqual(sum_op.input_names, ["X"])
self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"])
self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["out"])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册