From f185af8d7bbe45cf8029a5c3727e27a5d001c984 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 10 Oct 2017 13:55:01 -0700 Subject: [PATCH] Complete parameter --- python/paddle/v2/framework/graph.py | 50 ++++++++++++++++++- .../v2/framework/tests/test_parameter.py | 27 ++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_parameter.py diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index ba148854627..0f0a2847e58 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -1,13 +1,19 @@ import paddle.v2.framework.core as core import collections import numpy as np +import copy __all__ = ['Block', 'Variable', 'Program', 'Operator'] class Variable(object): - def __init__(self, block, name=None, shape=None, dtype=None, - lod_level=None): + def __init__(self, + block, + name=None, + shape=None, + dtype=None, + lod_level=None, + **kwargs): self.block = block if name is None: @@ -144,6 +150,10 @@ class Block(object): def create_var(self, *args, **kwargs): return Variable(self, *args, **kwargs) + def create_parameter(self, *args, **kwargs): + global_block = self.program.global_block() + return Parameter(global_block, *args, **kwargs) + def append_op(self, *args, **kwargs): op_desc = self.desc.append_op() op = Operator(self, op_desc, *args, **kwargs) @@ -190,5 +200,41 @@ class Program(object): self.current_block_idx = self.current_block().parent_idx +class Parameter(Variable): + def __init__(self, block, shape, dtype, **kwargs): + if shape is None or dtype is None: + raise ValueError("Parameter must set shape and dtype") + if len(shape) == 0: + raise ValueError("Parameter shape cannot be empty") + + for each in shape: + if each < 0: + raise ValueError("Parameter shape should not be related with " + "batch-size") + + Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs) + self.trainable = kwargs.get('trainable', True) + self.init_attr = kwargs.get('initialize_attr', { + 'type': 'uniform_random', + 'min': -1.0, + 'max': 1.0 + }) + + self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) + self._append_initialize_ops_() + + def _append_initialize_ops_(self): + attr = copy.deepcopy(self.init_attr) + op_type = attr.pop('type', None) + block = self.block + assert isinstance(block, Block) + shape = self.shape + attr['dims'] = shape + attr['data_type'] = int(self.data_type) + op = block.prepend_op( + type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr) + self.op = op + + # program is a global instance. g_program = Program.instance() diff --git a/python/paddle/v2/framework/tests/test_parameter.py b/python/paddle/v2/framework/tests/test_parameter.py new file mode 100644 index 00000000000..3b5d38f257e --- /dev/null +++ b/python/paddle/v2/framework/tests/test_parameter.py @@ -0,0 +1,27 @@ +import unittest +from paddle.v2.framework.graph import g_program +import paddle.v2.framework.core as core + + +class TestParameter(unittest.TestCase): + def test_param(self): + b = g_program.create_block() + param = b.create_parameter( + name='fc.w', + shape=[784, 100], + dtype='float32', + initialize_attr={ + 'type': 'uniform_random', + 'seed': 13, + 'min': -5.0, + 'max': 5.0 + }) + self.assertIsNotNone(param) + self.assertEqual('fc.w', param.name) + self.assertEqual((784, 100), param.shape) + self.assertEqual(core.DataType.FP32, param.data_type) + self.assertEqual(0, param.block.idx) + + +if __name__ == '__main__': + unittest.main() -- GitLab