test_parameter.py 801 字节
Newer Older
Y
Yu Yang 已提交
1
import unittest
2
from paddle.v2.framework.framework import g_main_program
Y
Yu Yang 已提交
3 4 5 6 7
import paddle.v2.framework.core as core


class TestParameter(unittest.TestCase):
    def test_param(self):
8
        b = g_main_program.create_block()
Y
Yu Yang 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
        param = b.create_parameter(
            name='fc.w',
            shape=[784, 100],
            dtype='float32',
            initialize_attr={
                'type': 'uniform_random',
                'seed': 13,
                'min': -5.0,
                'max': 5.0
            })
        self.assertIsNotNone(param)
        self.assertEqual('fc.w', param.name)
        self.assertEqual((784, 100), param.shape)
        self.assertEqual(core.DataType.FP32, param.data_type)
        self.assertEqual(0, param.block.idx)


if __name__ == '__main__':
    unittest.main()