test_net.py 826 字节
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
import unittest


class TestNet(unittest.TestCase):
    def test_net_all(self):
        net = core.Net.create()
        op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
        net.add_op(op1)

        net2 = core.Net.create()
        net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
        net2.complete_add_op(True)
        net.add_op(net2)
        net.complete_add_op(True)
        expected = '''naive_net:
    Op(add_two), inputs:(X, Y), outputs:(Out).
    naive_net:
        fc:
            Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
            Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
'''
        self.assertEqual(expected, str(net))


if __name__ == '__main__':
    unittest.main()