提交 576e7f47 编写于 作者: D dangqingqing

Support variable-dimension for convolution operation.

上级 dc530a71
......@@ -103,7 +103,7 @@ def stacked_lstm_net(input_dim,
if __name__ == '__main__':
# init
paddle.init(use_gpu=False)
paddle.init(use_gpu=False, log_clipping=True)
#data
print 'load dictionary...'
......@@ -131,6 +131,7 @@ if __name__ == '__main__':
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3,
gradient_clipping_threshold=0.003,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5))
......
......@@ -17,6 +17,7 @@ import collections
import swig_paddle
import numpy
import itertools
from functools import reduce
__all__ = ['DataProviderConverter']
......@@ -59,12 +60,14 @@ class IScanner(object):
"""
pass
def finish_pre_scan(self, argument):
def finish_pre_scan(self, argument, dat=None):
"""
Finish first scan pass. Allocate the memory.
:param argument: Output arguments object.
:type argument: swig_paddle.Arguments
:param dat: Output arguments object.
:type dat: The Python object, numpy.array or List.
:return:
"""
pass
......@@ -95,17 +98,27 @@ class DenseScanner(IScanner):
def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos)
self.__mat__ = None
self.__shape__ = None
self.__height__ = 0
def pre_scan(self, dat):
self.__height__ += 1
def finish_pre_scan(self, argument):
def finish_pre_scan(self, argument, dat=None):
self.__shape__ = numpy.array(dat).shape
if len(self.__shape__) > 3:
raise ValueError("The dimension of input is greater than 3.")
dim = reduce(lambda x, y: x * y, self.__shape__)
if len(self.__shape__) == 1:
assert dim == self.input_type.dim
self.__mat__ = numpy.ndarray(
shape=(self.__height__, self.input_type.dim), dtype=numpy.float32)
shape=(self.__height__, dim), dtype=numpy.float32)
self.__height__ = 0
def scan(self, dat):
if isinstance(dat, numpy.ndarray):
assert self.__shape__ == dat.shape
dat = dat.flatten()
self.__mat__[self.__height__] = dat
self.__height__ += 1
......@@ -116,6 +129,13 @@ class DenseScanner(IScanner):
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True,
self.data_in_gpu)
argument.setSlotValue(self.pos, m)
if len(self.__shape__) > 1:
# The last-two dimenstions are the frame height and width.
# For example, the layout is CHW for 3-D feature of image.
# The H and W are the fram height and width.
h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w)
class SparseBinaryScanner(IScanner):
......@@ -166,7 +186,7 @@ class IndexScanner(IScanner):
def pre_scan(self, dat):
self.__idx__ += 1
def finish_pre_scan(self, argument):
def finish_pre_scan(self, argument, dat=None):
self.__ids__ = [0] * self.__idx__
self.__idx__ = 0
......@@ -191,8 +211,8 @@ class SequenceScanner(IScanner):
for each in dat:
self.__inner_scanner__.pre_scan(each)
def finish_pre_scan(self, argument):
self.__inner_scanner__.finish_pre_scan(argument)
def finish_pre_scan(self, argument, dat=None):
self.__inner_scanner__.finish_pre_scan(argument, dat)
def scan(self, dat):
self.__seq__.append(self.__seq__[-1] + self.get_size(dat))
......@@ -233,8 +253,11 @@ class DataProviderConverter(object):
for each_step, scanner in itertools.izip(each_sample, scanners):
scanner.pre_scan(each_step)
for scanner in scanners:
scanner.finish_pre_scan(argument)
# Some scanners, like dense scanner, pre-allocate memory for mini-batch
# in finish_pre_scan function. The dat[0] is used to calculate the size
# of input data.
for scanner, each_feature in itertools.izip(scanners, dat[0]):
scanner.finish_pre_scan(argument, each_feature)
for each_sample in dat:
for each_step, scanner in itertools.izip(each_sample, scanners):
......
......@@ -72,9 +72,16 @@ class InputType(object):
def dense_slot(dim, seq_type=SequenceType.NO_SEQUENCE):
"""
Dense Vector. It means the input feature is dense float vector. For example,
if the input is an image with 28*28 pixels, the input of Paddle neural
network should be a dense vector with dimension 784.
Dense Array. It means the input feature is dense array with float type.
For example, if the input is an image with 28*28 pixels, the input of
Paddle neural network could be a dense vector with dimension 784 or a
numpy array with shape (28, 28).
For the 2-D convolution operation, each sample in one mini-batch must have
the similarly size in PaddlePaddle now. But, it supports variable-dimension
feature across mini-batch. For the variable-dimension, the param dim is not
used. While the data reader must yield numpy array and the data feeder will
set the data shape correctly.
:param dim: dimension of this vector.
:type dim: int
......@@ -135,6 +142,10 @@ sparse_binary_vector = sparse_non_value_slot
sparse_vector = sparse_value_slot
integer_value = index_slot
# dense_array can be used for variable-length input feature.
# Each feature is not a vector, but a multi-dimensional array.
dense_array = dense_slot
def dense_vector_sequence(dim):
"""
......
......@@ -16,7 +16,8 @@ import paddle.trainer.PyDataProvider2 as pydp2
import_list = [
nm for nm in dir(pydp2)
if '_' in nm and nm[0] != '_' and ('value' in nm or 'vector' in nm)
if '_' in nm and nm[0] != '_' and ('value' in nm or 'vector' in nm or
'array' in nm)
]
import_list.extend(['InputType'])
......
......@@ -233,6 +233,30 @@ class DataFeederTest(unittest.TestCase):
self.assertEqual(out_sparse.getSparseRowCols(i), data[i][1])
self.assertEqual(out_index[i], data[i][0])
def test_dense_set_shape(self):
# test 2-D data
def gen_data(batch_size, shape):
data = []
for i in xrange(batch_size):
each_sample = []
each_sample.append(np.random.random(shape))
data.append(each_sample)
return data
feeder = DataFeeder([('image', data_type.dense_array(2352))],
{'image': 0})
arg = feeder(gen_data(32, (3, 28, 28)))
h = arg.getSlotFrameHeight(0)
w = arg.getSlotFrameWidth(0)
self.assertEqual(h, 28)
self.assertEqual(w, 28)
arg = feeder(gen_data(32, (3, 30, 32)))
h = arg.getSlotFrameHeight(0)
w = arg.getSlotFrameWidth(0)
self.assertEqual(h, 30)
self.assertEqual(w, 32)
if __name__ == '__main__':
api.initPaddle("--use_gpu=0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册