提交 e8cad5a1 编写于 作者: Q qiaolongfei

add more unit test for test_append_backward

上级 2e554693
...@@ -117,7 +117,7 @@ void BindProgramDesc(py::module &m) { ...@@ -117,7 +117,7 @@ void BindProgramDesc(py::module &m) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("backward", .def("append_backward",
[](ProgramDescBind &program_desc, [](ProgramDescBind &program_desc,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
AppendBackward(program_desc, no_grad_vars); AppendBackward(program_desc, no_grad_vars);
......
...@@ -33,20 +33,33 @@ class TestProgram(unittest.TestCase): ...@@ -33,20 +33,33 @@ class TestProgram(unittest.TestCase):
self.assertEqual(1, b.idx) self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx) self.assertEqual(0, b.parent_idx)
def test_backward(self): def test_append_backward(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
mul_op_desc = block.append_op()
mul_op_desc.set_type("mul")
mul_op_desc.set_input("X", ["x1"])
mul_op_desc.set_input("Y", ["y1"])
mul_op_desc.set_output("Out", ["out1"])
sum_op_desc = block.append_op() sum_op_desc = block.append_op()
sum_op_desc.set_type("sum") sum_op_desc.set_type("elementwise_add")
sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_input("X", ["out1"])
sum_op_desc.set_output("Out", ["out"]) sum_op_desc.set_input("Y", ["b1"])
sum_op_desc.set_output("Out", ["out2"])
self.assertEqual(len(block.all_ops()), 1) expect_ops = [
prog.backward(set()) "mul", "elementwise_add", "elementwise_add_grad", "mul_grad"
self.assertEqual(len(block.all_ops()), 3) ]
actual_ops = []
prog.append_backward(set())
for op in block.all_ops():
actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册