graph_test.py 3.4 KB
Newer Older
D
daminglu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
import unittest
import graph
import json


class GraphTest(unittest.TestCase):
    def setUp(self):
        self.mock_dir = "./mock"

    def test_graph_edges_squeezenet(self):
        json_str = graph.load_model(self.mock_dir + '/squeezenet_model.pb')
        json_obj = json.loads(json_str)

        # 126 edges + 66 nodes (out-edge of each node is counted twice)
        self.assertEqual(len(json_obj['edges']), 126 + 66)

        # label_0: (in-edge)
        # {u'source': u'data_0', u'target': u'node_0', u'label': u'label_0'}
        self.assertEqual(json_obj['edges'][0]['source'], 'data_0')
        self.assertEqual(json_obj['edges'][0]['target'], 'node_0')
        self.assertEqual(json_obj['edges'][0]['label'], 'label_0')

        # label_50: (in-edge)
        # {u'source': u'fire3/concat_1', u'target': u'node_17', u'label': u'label_50'}
        self.assertEqual(json_obj['edges'][50]['source'], 'fire3/concat_1')
        self.assertEqual(json_obj['edges'][50]['target'], 'node_17')
        self.assertEqual(json_obj['edges'][50]['label'], 'label_50')

        # label_100: (in-edge)
        # {u'source': u'fire6/squeeze1x1_1', u'target': u'node_34', u'label': u'label_100'}
        self.assertEqual(json_obj['edges'][100]['source'], 'fire6/squeeze1x1_1')
        self.assertEqual(json_obj['edges'][100]['target'], 'node_34')
        self.assertEqual(json_obj['edges'][100]['label'], 'label_100')

        # label_111: (out-edge)
        # {u'source': u'node_37', u'target': u'fire6/expand3x3_1', u'label': u'label_111'}
        self.assertEqual(json_obj['edges'][111]['source'], 'node_37')
        self.assertEqual(json_obj['edges'][111]['target'], 'fire6/expand3x3_1')
        self.assertEqual(json_obj['edges'][111]['label'], 'label_111')

    def test_graph_edges_inception_v1(self):
        json_str = graph.load_model(self.mock_dir + '/inception_v1_model.pb')
        json_obj = json.loads(json_str)

        # 286 edges + 143 nodes (out-edge of each node is counted twice)
        self.assertEqual(len(json_obj['edges']), 286 + 143)

        # label_0: (in-edge)
        # {u'source': u'data_0', u'target': u'node_0', u'label': u'label_0'}
        self.assertEqual(json_obj['edges'][0]['source'], 'data_0')
        self.assertEqual(json_obj['edges'][0]['target'], 'node_0')
        self.assertEqual(json_obj['edges'][0]['label'], 'label_0')

        # label_50: (in-edge)
        # {u'source': u'inception_3a/5x5_reduce_2', u'target': u'node_18', u'label': u'label_50'}
        self.assertEqual(json_obj['edges'][50]['source'], 'inception_3a/5x5_reduce_2')
        self.assertEqual(json_obj['edges'][50]['target'], 'node_18')
        self.assertEqual(json_obj['edges'][50]['label'], 'label_50')

        # label_100: (out-edge)
        # {u'source': u'node_34', u'target': u'inception_3b/pool_1', u'label': u'label_100'}
        self.assertEqual(json_obj['edges'][100]['source'], 'node_34')
        self.assertEqual(json_obj['edges'][100]['target'], 'inception_3b/pool_1')
        self.assertEqual(json_obj['edges'][100]['label'], 'label_100')

        # label_420: (out-edge)
        # {u'source': u'node_139', u'target': u'pool5/7x7_s1_2', u'label': u'label_420'}
        self.assertEqual(json_obj['edges'][420]['source'], 'node_139')
        self.assertEqual(json_obj['edges'][420]['target'], 'pool5/7x7_s1_2')
        self.assertEqual(json_obj['edges'][420]['label'], 'label_420')


if __name__ == '__main__':
    unittest.main()