提交 f185af8d 编写于 作者: Y Yu Yang

Complete parameter

上级 0c37a061
import paddle.v2.framework.core as core
import collections
import numpy as np
import copy
__all__ = ['Block', 'Variable', 'Program', 'Operator']
class Variable(object):
def __init__(self, block, name=None, shape=None, dtype=None,
lod_level=None):
def __init__(self,
block,
name=None,
shape=None,
dtype=None,
lod_level=None,
**kwargs):
self.block = block
if name is None:
......@@ -144,6 +150,10 @@ class Block(object):
def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs)
def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block()
return Parameter(global_block, *args, **kwargs)
def append_op(self, *args, **kwargs):
op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs)
......@@ -190,5 +200,41 @@ class Program(object):
self.current_block_idx = self.current_block().parent_idx
class Parameter(Variable):
def __init__(self, block, shape, dtype, **kwargs):
if shape is None or dtype is None:
raise ValueError("Parameter must set shape and dtype")
if len(shape) == 0:
raise ValueError("Parameter shape cannot be empty")
for each in shape:
if each < 0:
raise ValueError("Parameter shape should not be related with "
"batch-size")
Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
self.trainable = kwargs.get('trainable', True)
self.init_attr = kwargs.get('initialize_attr', {
'type': 'uniform_random',
'min': -1.0,
'max': 1.0
})
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self._append_initialize_ops_()
def _append_initialize_ops_(self):
attr = copy.deepcopy(self.init_attr)
op_type = attr.pop('type', None)
block = self.block
assert isinstance(block, Block)
shape = self.shape
attr['dims'] = shape
attr['data_type'] = int(self.data_type)
op = block.prepend_op(
type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr)
self.op = op
# program is a global instance.
g_program = Program.instance()
import unittest
from paddle.v2.framework.graph import g_program
import paddle.v2.framework.core as core
class TestParameter(unittest.TestCase):
def test_param(self):
b = g_program.create_block()
param = b.create_parameter(
name='fc.w',
shape=[784, 100],
dtype='float32',
initialize_attr={
'type': 'uniform_random',
'seed': 13,
'min': -5.0,
'max': 5.0
})
self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.data_type)
self.assertEqual(0, param.block.idx)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册