提交 5419f16b 编写于 作者: F fengjiayi

Add unittests

上级 f9f910a3
......@@ -99,7 +99,7 @@ template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size());
for (auto &elem : vec) {
for (const auto &elem : vec) {
*repeated_field->Add() = elem;
}
}
......@@ -124,18 +124,23 @@ public:
VarDesc *Proto() { return &desc_; }
py::bytes Name() { return desc_.name(); }
void SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
}
void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(static_cast<DataType>(type_id));
desc_.mutable_lod_tensor()->set_data_type(
static_cast<enum DataType>(type_id));
}
std::vector<int64_t> Shape() {
return RepeatedToVector(desc_.lod_tensor().dims());
}
int DataType() { return desc_.lod_tensor().data_type(); }
private:
VarDesc desc_;
};
......@@ -322,6 +327,22 @@ public:
return var;
}
VarDescBind *Var(py::bytes name_bytes) const {
std::string name = name_bytes;
auto it = vars_.find(name);
PADDLE_ENFORCE(
it != vars_.end(), "Can not find variable %s in current block.", name);
return it->second.get();
}
std::vector<VarDescBind *> AllVars() const {
std::vector<VarDescBind *> res;
for (const auto &p : vars_) {
res.push_back(p.second.get());
}
return res;
}
BlockDescBind *ParentBlock() const;
OpDescBind *AppendOp() {
......@@ -336,6 +357,14 @@ public:
return ops_.front().get();
}
std::vector<OpDescBind *> AllOps() const {
std::vector<OpDescBind *> res;
for (const auto &op : ops_) {
res.push_back(op.get());
}
return res;
}
void Sync() {
if (need_update_) {
auto &op_field = *this->desc_->mutable_ops();
......@@ -461,16 +490,26 @@ void BindBlockDesc(py::module &m) {
.def("prepend_op",
&BlockDescBind::PrependOp,
py::return_value_policy::reference)
.def("new_var",
&BlockDescBind::NewVar,
.def(
"new_var", &BlockDescBind::NewVar, py::return_value_policy::reference)
.def("var", &BlockDescBind::Var, py::return_value_policy::reference)
.def("all_vars",
&BlockDescBind::AllVars,
py::return_value_policy::reference)
.def("all_ops",
&BlockDescBind::AllOps,
py::return_value_policy::reference);
}
void BindVarDsec(py::module &m) {
py::class_<VarDescBind>(m, "VarDesc", "")
.def("name", &VarDescBind::Name, py::return_value_policy::reference)
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape);
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type",
&VarDescBind::DataType,
py::return_value_policy::reference);
}
void BindOpDesc(py::module &m) {
......
......@@ -57,7 +57,7 @@ class TestOpDesc(unittest.TestCase):
class TestProgramDesc(unittest.TestCase):
def test_instance(self):
program_desc = core.ProgramDesc.instance()
program_desc = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(program_desc)
del program_desc
program_desc = core.ProgramDesc.instance()
......@@ -84,7 +84,7 @@ class TestProgramDesc(unittest.TestCase):
class TestVarDesc(unittest.TestCase):
def test_shape(self):
program_desc = core.ProgramDesc.instance()
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
src_shape = [3, 2, 10, 8]
......@@ -92,6 +92,39 @@ class TestVarDesc(unittest.TestCase):
res_shape = var.shape()
self.assertEqual(src_shape, res_shape)
def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
var.set_data_type(2)
self.assertEqual(2, var.data_type)
class TestBlockDesc(unittest.TestCase):
def test_add_var(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
var1 = block.new_var("var1")
var2 = block.new_var("var2")
var3 = block.new_var("var3")
all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3]))
var2_re = block.var("var2")
self.assertEqual(var2_re, var2)
def test_add_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
op0 = block.prepend_op()
all_ops = block.all_ops()
self.assertEqual(all_ops, [op0, op1, op2])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册