提交 fa72e544 编写于 作者: Y Yu Yang 提交者: Yang Yang(Tony)

Python API for StaticRNN (#4991)

上级 23bf6b2c
...@@ -113,6 +113,10 @@ class Variable(object): ...@@ -113,6 +113,10 @@ class Variable(object):
def lod_level(self): def lod_level(self):
return self.desc.lod_level() return self.desc.lod_level()
@property
def type(self):
return self.desc.type()
@staticmethod @staticmethod
def _unique_var_name_(): def _unique_var_name_():
uid = core.unique_integer() # unique during whole process. uid = core.unique_integer() # unique during whole process.
......
from paddle.v2.framework.framework import Variable, OpProtoHolder, g_program, g_init_program
import paddle.v2.framework.core as core
import copy import copy
import itertools import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Variable, g_program, \
g_init_program
def unique_name(prefix): def unique_name(prefix):
uid = core.unique_integer() # unique during whole process. uid = core.unique_integer() # unique during whole process.
...@@ -130,6 +133,9 @@ class LayerHelper(object): ...@@ -130,6 +133,9 @@ class LayerHelper(object):
dtype=dtype, dtype=dtype,
persistable=False) persistable=False)
def create_variable(self, *args, **kwargs):
return self.program.current_block().create_var(*args, **kwargs)
def create_global_variable(self, *args, **kwargs): def create_global_variable(self, *args, **kwargs):
return self.program.global_block().create_var( return self.program.global_block().create_var(
*args, persistable=False, **kwargs) *args, persistable=False, **kwargs)
......
from paddle.v2.framework.layer_helper import LayerHelper from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
import re import re
__all__ = [ __all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat' 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN'
] ]
...@@ -26,7 +27,9 @@ def fc(input, ...@@ -26,7 +27,9 @@ def fc(input,
mul_results = [] mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params(): for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape input_shape = input_var.shape
param_shape = list(input_shape[num_flatten_dims:]) + [size] param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter( w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype) attr=param_attr, shape=param_shape, dtype=dtype)
...@@ -38,10 +41,8 @@ def fc(input, ...@@ -38,10 +41,8 @@ def fc(input,
"Y": w, "Y": w,
}, },
outputs={"Out": tmp}, outputs={"Out": tmp},
attrs={ attrs={'x_num_col_dims': num_flatten_dims,
'x_num_col_dims': num_flatten_dims, 'y_num_col_dims': 1})
'y_num_col_dims': len(input_shape) - num_flatten_dims
})
mul_results.append(tmp) mul_results.append(tmp)
# sum # sum
...@@ -273,3 +274,170 @@ def pool2d(input, ...@@ -273,3 +274,170 @@ def pool2d(input,
}) })
return pool_out return pool_out
class BlockGuard(object):
"""
BlockGuard used to create sub-block in program by using Python `with`
keyword.
"""
def __init__(self, program):
if not isinstance(program, Program):
raise TypeError("BlockGuard takes a program")
self.program = program
def __enter__(self):
self.program.create_block()
def __exit__(self, exc_type, exc_val, exc_tb):
self.program.rollback()
if exc_type is not None:
return False # re-raise exception
return True
class StaticRNNGuard(BlockGuard):
def __init__(self, rnn):
if not isinstance(rnn, StaticRNN):
raise TypeError("StaticRNNGuard takes an StaticRNN")
super(StaticRNNGuard, self).__init__(rnn.helper.program)
self.rnn = rnn
def __enter__(self):
self.rnn.status = StaticRNN.IN_RNN_BLOCK
return super(StaticRNNGuard, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
class StaticRNNMemoryLink(object):
"""
:param init: the initial variable for Memory
:type init: Variable
:param pre_mem: the memory variable in previous time step
:type pre_mem: Variable
:param mem: the memory variable in current time step
:type mem: Variable
"""
def __init__(self, init, pre_mem, mem=None):
self.init = init
self.pre_mem = pre_mem
self.mem = mem
class StaticRNN(object):
BEFORE_RNN_BLOCK = 0
IN_RNN_BLOCK = 1
AFTER_RNN_BLOCK = 2
def __init__(self, name=None, program=None):
self.helper = LayerHelper("static_rnn", name=name, program=program)
self.memories = {} # memory map, from pre_mem.name --> MemoryLink
self.inputs = [] # input variable list in current block
self.outputs = [] # output variable list in parent block
self.status = StaticRNN.BEFORE_RNN_BLOCK # status flag.
# sequence length, since it is a static RNN, sequence length are fixed.
self.seq_len = None
def step(self):
return StaticRNNGuard(self)
def _assert_in_rnn_block_(self, method):
if self.status != StaticRNN.IN_RNN_BLOCK:
raise ValueError("You must invoke {0} in rnn block".format(method))
def memory(self, init=None, shape=None, dtype=None, init_value=0):
self._assert_in_rnn_block_('memory')
if init is None:
if shape is None or dtype is None:
raise ValueError(
"if init is None, memory at least need shape and dtype")
parent_block = self.parent_block()
var_name = unique_name("@".join([self.helper.name, "memory_boot"]))
boot_var = parent_block.create_var(
name=var_name, shape=shape, dtype=dtype, persistable=False)
parent_block.append_op(
type="fill_constant",
inputs={},
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': boot_var.shape,
'data_type': boot_var.data_type
})
return self.memory(init=boot_var)
else:
pre_mem = self.helper.create_variable(
name=unique_name("@".join([self.helper.name, "mem"])),
dtype=init.data_type,
shape=init.shape)
self.memories[pre_mem.name] = StaticRNNMemoryLink(
init=init, pre_mem=pre_mem)
return pre_mem
def step_input(self, x):
self._assert_in_rnn_block_('step_input')
if not isinstance(x, Variable):
raise TypeError("step input takes a Variable")
if self.seq_len is None:
self.seq_len = x.shape[1]
elif self.seq_len != x.shape[1]:
raise ValueError("Static RNN only take fix seq_len input")
ipt = self.helper.create_variable(
name=x.name,
dtype=x.data_type,
shape=[-1] + list(x.shape[2:]),
type=x.type)
self.inputs.append(ipt)
return ipt
def step_output(self, o):
self._assert_in_rnn_block_('step_output')
if not isinstance(o, Variable):
raise TypeError("step output takes a Variable")
out_var = self.parent_block().create_var(
name=o.name,
shape=[-1, self.seq_len] + list(o.shape[1:]),
dtype=o.data_type)
self.outputs.append(out_var)
def output(self, *outputs):
for each in outputs:
self.step_output(each)
def update_memory(self, mem, var):
if not isinstance(mem, Variable) or not isinstance(var, Variable):
raise TypeError("update memory should take variables")
self.memories[mem.name].mem = var
def parent_block(self):
prog = self.helper.program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def complete_rnn_op(self):
# TODO(yuyang18): Create RNN Op here.
# Implement this method after RNN op complete.
pass
import unittest
from paddle.v2.framework.layers import *
from paddle.v2.framework.framework import g_program
class TestRNN(unittest.TestCase):
def test_rnn(self):
img = data(
shape=[
80, # sequence length
22, # image height
22
], # image width
data_type='float32',
name='image')
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
self.assertEqual((-1, 80, 100), hidden.shape)
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
self.assertEqual((-1, 80, 100), hidden.shape)
rnn = StaticRNN()
with rnn.step():
hidden = rnn.step_input(hidden)
self.assertEqual((-1, 100), hidden.shape)
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
self.assertEqual((-1, 32), rnn_out.shape)
rnn.update_memory(memory, rnn_out)
rnn.output(rnn_out)
out = rnn()
self.assertEqual((-1, 80, 32), out.shape)
print g_program
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册