提交 e6232d82 编写于 作者: D dangqingqing

testing in mnist

上级 733da9b9
......@@ -20,7 +20,8 @@ import event
import py_paddle.swig_paddle as api
__all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'event'
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_converter'
]
......
......@@ -13,8 +13,8 @@
# limitations under the License.
import collections
import py_paddle.swig_paddle
import numpy
import py_paddle.swig_paddle as api
import numpy as np
import paddle.trainer.PyDataProvider2 as dp2
__all__ = ['DataConverter']
......@@ -50,12 +50,12 @@ class DenseConvert(IDataConverter):
:param data: input data
:type data: list | numpy array
:param argument: the type which paddle is acceptable
:type argument: swig_paddle.Arguments
:type argument: Paddle's Arguments
"""
assert isinstance(argument, swig_paddle.Arguments)
if data.dtype != numpy.float32:
data = data.astype(numpy.float32)
m = swig_paddle.Matrix.createDenseFromNumpy(data, True, False)
assert isinstance(argument, api.Arguments)
if data.dtype != np.float32:
data = data.astype(np.float32)
m = api.Matrix.createDenseFromNumpy(data, True, False)
argument.setSlotValue(self.pos, m)
......@@ -72,17 +72,16 @@ class SparseBinaryConvert(IDataConverter):
self.__height__ = len(data)
for x in data:
self.__rows__.append(self.__rows__[-1] + len(x))
self__cols__ = data.flatten()
self.__cols__ = data.flatten()
def convert(self, data, argument):
assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(argument, api.Arguments)
fill_csr(data)
m = swig_paddle.Matrix.createSparse(self.__height__,
self.input_type.dim,
len(self.__cols__),
len(self.__value__) == 0)
assert isinstance(m, swig_paddle.Matrix)
m = api.Matrix.createSparse(self.__height__, self.input_type.dim,
len(self.__cols__),
len(self.__value__) == 0)
assert isinstance(m, api.Matrix)
m.sparseCopyFrom(self.__rows__, self.__cols__, self.__value__)
argument.setSlotValue(self.pos, m)
......@@ -105,9 +104,9 @@ class IndexConvert(IDataConverter):
self.__ids__ = []
def convert(self, data, argument):
assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(argument, api.Arguments)
self.__ids__ = data.flatten()
ids = swig_paddle.IVector.create(self.__ids__)
ids = api.IVector.create(self.__ids__)
argument.setSlotIds(self.pos, ids)
......@@ -135,7 +134,7 @@ class SequenceConvert(IDataConverter):
def convert(self, data, argument):
fill_seq(data)
seq = swig_paddle.IVector.create(self.__seq__, False)
seq = api.IVector.create(self.__seq__, False)
self.__setter__(argument, self.pos, seq)
dat = []
......@@ -151,22 +150,21 @@ class SequenceConvert(IDataConverter):
class DataConverter(object):
def __init__(self, input_mapper):
def __init__(self, input):
"""
Usege:
.. code-block:: python
inputs = [('image', dense_vector), ('label', integer_value)]
cvt = DataConverter(inputs)
arg = cvt.convert(minibatch_data, {'image':0, 'label':1})
arg = cvt(minibatch_data, {'image':0, 'label':1})
:param input_mapper: list of (input_name, input_type)
:type input_mapper: list
"""
assert isinstance(self.input_types, collections.Sequence)
self.input_names = []
self.input_types = []
for each in self.input_types:
for each in input:
self.input_names.append(each[0])
self.input_types.append(each[1])
assert isinstance(each[1], dp2.InputType)
......@@ -186,16 +184,16 @@ class DataConverter(object):
the feature order in argument and data is the same.
:type input_dict: dict, like {string:integer, string, integer, ...}|None
:param argument: converted data will be saved in this argument. If None,
it will create a swig_paddle.Arguments firstly.
it will create a Paddle's Arguments firstly.
:param type: swig_paddle.Arguments|None
"""
if argument is None:
argument = swig_paddle.Arguments.createArguments(0)
assert isinstance(argument, swig_paddle.Arguments)
argument = api.Arguments.createArguments(0)
assert isinstance(argument, api.Arguments)
argument.resize(len(self.input_types))
converts = [
DataConverter.create_scanner(i, each_type)
DataConverter.create_converter(i, each_type)
for i, each_type in enumerate(self.input_types)
]
......@@ -212,7 +210,7 @@ class DataConverter(object):
return self.convert(dat, argument)
@staticmethod
def create_scanner(pos, each):
def create_converter(pos, each):
assert isinstance(each, dp2.InputType)
retv = None
if each.type == dp2.DataType.Dense:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册