提交 5419f16b 编写于 作者: F fengjiayi

Add unittests

上级 f9f910a3
...@@ -99,7 +99,7 @@ template <typename T, typename RepeatedField> ...@@ -99,7 +99,7 @@ template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec, inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) { RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size()); repeated_field->Reserve(vec.size());
for (auto &elem : vec) { for (const auto &elem : vec) {
*repeated_field->Add() = elem; *repeated_field->Add() = elem;
} }
} }
...@@ -124,18 +124,23 @@ public: ...@@ -124,18 +124,23 @@ public:
VarDesc *Proto() { return &desc_; } VarDesc *Proto() { return &desc_; }
py::bytes Name() { return desc_.name(); }
void SetShape(const std::vector<int64_t> &dims) { void SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
} }
void SetDataType(int type_id) { void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(static_cast<DataType>(type_id)); desc_.mutable_lod_tensor()->set_data_type(
static_cast<enum DataType>(type_id));
} }
std::vector<int64_t> Shape() { std::vector<int64_t> Shape() {
return RepeatedToVector(desc_.lod_tensor().dims()); return RepeatedToVector(desc_.lod_tensor().dims());
} }
int DataType() { return desc_.lod_tensor().data_type(); }
private: private:
VarDesc desc_; VarDesc desc_;
}; };
...@@ -322,6 +327,22 @@ public: ...@@ -322,6 +327,22 @@ public:
return var; return var;
} }
VarDescBind *Var(py::bytes name_bytes) const {
std::string name = name_bytes;
auto it = vars_.find(name);
PADDLE_ENFORCE(
it != vars_.end(), "Can not find variable %s in current block.", name);
return it->second.get();
}
std::vector<VarDescBind *> AllVars() const {
std::vector<VarDescBind *> res;
for (const auto &p : vars_) {
res.push_back(p.second.get());
}
return res;
}
BlockDescBind *ParentBlock() const; BlockDescBind *ParentBlock() const;
OpDescBind *AppendOp() { OpDescBind *AppendOp() {
...@@ -336,6 +357,14 @@ public: ...@@ -336,6 +357,14 @@ public:
return ops_.front().get(); return ops_.front().get();
} }
std::vector<OpDescBind *> AllOps() const {
std::vector<OpDescBind *> res;
for (const auto &op : ops_) {
res.push_back(op.get());
}
return res;
}
void Sync() { void Sync() {
if (need_update_) { if (need_update_) {
auto &op_field = *this->desc_->mutable_ops(); auto &op_field = *this->desc_->mutable_ops();
...@@ -461,16 +490,26 @@ void BindBlockDesc(py::module &m) { ...@@ -461,16 +490,26 @@ void BindBlockDesc(py::module &m) {
.def("prepend_op", .def("prepend_op",
&BlockDescBind::PrependOp, &BlockDescBind::PrependOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("new_var", .def(
&BlockDescBind::NewVar, "new_var", &BlockDescBind::NewVar, py::return_value_policy::reference)
.def("var", &BlockDescBind::Var, py::return_value_policy::reference)
.def("all_vars",
&BlockDescBind::AllVars,
py::return_value_policy::reference)
.def("all_ops",
&BlockDescBind::AllOps,
py::return_value_policy::reference); py::return_value_policy::reference);
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::class_<VarDescBind>(m, "VarDesc", "") py::class_<VarDescBind>(m, "VarDesc", "")
.def("name", &VarDescBind::Name, py::return_value_policy::reference)
.def("set_shape", &VarDescBind::SetShape) .def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType) .def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape); .def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type",
&VarDescBind::DataType,
py::return_value_policy::reference);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -57,7 +57,7 @@ class TestOpDesc(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestOpDesc(unittest.TestCase):
class TestProgramDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase):
def test_instance(self): def test_instance(self):
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
del program_desc del program_desc
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.instance()
...@@ -84,7 +84,7 @@ class TestProgramDesc(unittest.TestCase): ...@@ -84,7 +84,7 @@ class TestProgramDesc(unittest.TestCase):
class TestVarDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase):
def test_shape(self): def test_shape(self):
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0) block = program_desc.block(0)
var = block.new_var('my_var') var = block.new_var('my_var')
src_shape = [3, 2, 10, 8] src_shape = [3, 2, 10, 8]
...@@ -92,6 +92,39 @@ class TestVarDesc(unittest.TestCase): ...@@ -92,6 +92,39 @@ class TestVarDesc(unittest.TestCase):
res_shape = var.shape() res_shape = var.shape()
self.assertEqual(src_shape, res_shape) self.assertEqual(src_shape, res_shape)
def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
var.set_data_type(2)
self.assertEqual(2, var.data_type)
class TestBlockDesc(unittest.TestCase):
def test_add_var(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
var1 = block.new_var("var1")
var2 = block.new_var("var2")
var3 = block.new_var("var3")
all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3]))
var2_re = block.var("var2")
self.assertEqual(var2_re, var2)
def test_add_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
op0 = block.prepend_op()
all_ops = block.all_ops()
self.assertEqual(all_ops, [op0, op1, op2])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册