提交 e6232d82 编写于 作者: D dangqingqing

testing in mnist

上级 733da9b9
...@@ -20,7 +20,8 @@ import event ...@@ -20,7 +20,8 @@ import event
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
__all__ = [ __all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'event' 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_converter'
] ]
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import collections import collections
import py_paddle.swig_paddle import py_paddle.swig_paddle as api
import numpy import numpy as np
import paddle.trainer.PyDataProvider2 as dp2 import paddle.trainer.PyDataProvider2 as dp2
__all__ = ['DataConverter'] __all__ = ['DataConverter']
...@@ -50,12 +50,12 @@ class DenseConvert(IDataConverter): ...@@ -50,12 +50,12 @@ class DenseConvert(IDataConverter):
:param data: input data :param data: input data
:type data: list | numpy array :type data: list | numpy array
:param argument: the type which paddle is acceptable :param argument: the type which paddle is acceptable
:type argument: swig_paddle.Arguments :type argument: Paddle's Arguments
""" """
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, api.Arguments)
if data.dtype != numpy.float32: if data.dtype != np.float32:
data = data.astype(numpy.float32) data = data.astype(np.float32)
m = swig_paddle.Matrix.createDenseFromNumpy(data, True, False) m = api.Matrix.createDenseFromNumpy(data, True, False)
argument.setSlotValue(self.pos, m) argument.setSlotValue(self.pos, m)
...@@ -72,17 +72,16 @@ class SparseBinaryConvert(IDataConverter): ...@@ -72,17 +72,16 @@ class SparseBinaryConvert(IDataConverter):
self.__height__ = len(data) self.__height__ = len(data)
for x in data: for x in data:
self.__rows__.append(self.__rows__[-1] + len(x)) self.__rows__.append(self.__rows__[-1] + len(x))
self__cols__ = data.flatten() self.__cols__ = data.flatten()
def convert(self, data, argument): def convert(self, data, argument):
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, api.Arguments)
fill_csr(data) fill_csr(data)
m = swig_paddle.Matrix.createSparse(self.__height__, m = api.Matrix.createSparse(self.__height__, self.input_type.dim,
self.input_type.dim,
len(self.__cols__), len(self.__cols__),
len(self.__value__) == 0) len(self.__value__) == 0)
assert isinstance(m, swig_paddle.Matrix) assert isinstance(m, api.Matrix)
m.sparseCopyFrom(self.__rows__, self.__cols__, self.__value__) m.sparseCopyFrom(self.__rows__, self.__cols__, self.__value__)
argument.setSlotValue(self.pos, m) argument.setSlotValue(self.pos, m)
...@@ -105,9 +104,9 @@ class IndexConvert(IDataConverter): ...@@ -105,9 +104,9 @@ class IndexConvert(IDataConverter):
self.__ids__ = [] self.__ids__ = []
def convert(self, data, argument): def convert(self, data, argument):
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, api.Arguments)
self.__ids__ = data.flatten() self.__ids__ = data.flatten()
ids = swig_paddle.IVector.create(self.__ids__) ids = api.IVector.create(self.__ids__)
argument.setSlotIds(self.pos, ids) argument.setSlotIds(self.pos, ids)
...@@ -135,7 +134,7 @@ class SequenceConvert(IDataConverter): ...@@ -135,7 +134,7 @@ class SequenceConvert(IDataConverter):
def convert(self, data, argument): def convert(self, data, argument):
fill_seq(data) fill_seq(data)
seq = swig_paddle.IVector.create(self.__seq__, False) seq = api.IVector.create(self.__seq__, False)
self.__setter__(argument, self.pos, seq) self.__setter__(argument, self.pos, seq)
dat = [] dat = []
...@@ -151,22 +150,21 @@ class SequenceConvert(IDataConverter): ...@@ -151,22 +150,21 @@ class SequenceConvert(IDataConverter):
class DataConverter(object): class DataConverter(object):
def __init__(self, input_mapper): def __init__(self, input):
""" """
Usege: Usege:
.. code-block:: python .. code-block:: python
inputs = [('image', dense_vector), ('label', integer_value)] inputs = [('image', dense_vector), ('label', integer_value)]
cvt = DataConverter(inputs) cvt = DataConverter(inputs)
arg = cvt.convert(minibatch_data, {'image':0, 'label':1}) arg = cvt(minibatch_data, {'image':0, 'label':1})
:param input_mapper: list of (input_name, input_type) :param input_mapper: list of (input_name, input_type)
:type input_mapper: list :type input_mapper: list
""" """
assert isinstance(self.input_types, collections.Sequence)
self.input_names = [] self.input_names = []
self.input_types = [] self.input_types = []
for each in self.input_types: for each in input:
self.input_names.append(each[0]) self.input_names.append(each[0])
self.input_types.append(each[1]) self.input_types.append(each[1])
assert isinstance(each[1], dp2.InputType) assert isinstance(each[1], dp2.InputType)
...@@ -186,16 +184,16 @@ class DataConverter(object): ...@@ -186,16 +184,16 @@ class DataConverter(object):
the feature order in argument and data is the same. the feature order in argument and data is the same.
:type input_dict: dict, like {string:integer, string, integer, ...}|None :type input_dict: dict, like {string:integer, string, integer, ...}|None
:param argument: converted data will be saved in this argument. If None, :param argument: converted data will be saved in this argument. If None,
it will create a swig_paddle.Arguments firstly. it will create a Paddle's Arguments firstly.
:param type: swig_paddle.Arguments|None :param type: swig_paddle.Arguments|None
""" """
if argument is None: if argument is None:
argument = swig_paddle.Arguments.createArguments(0) argument = api.Arguments.createArguments(0)
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, api.Arguments)
argument.resize(len(self.input_types)) argument.resize(len(self.input_types))
converts = [ converts = [
DataConverter.create_scanner(i, each_type) DataConverter.create_converter(i, each_type)
for i, each_type in enumerate(self.input_types) for i, each_type in enumerate(self.input_types)
] ]
...@@ -212,7 +210,7 @@ class DataConverter(object): ...@@ -212,7 +210,7 @@ class DataConverter(object):
return self.convert(dat, argument) return self.convert(dat, argument)
@staticmethod @staticmethod
def create_scanner(pos, each): def create_converter(pos, each):
assert isinstance(each, dp2.InputType) assert isinstance(each, dp2.InputType)
retv = None retv = None
if each.type == dp2.DataType.Dense: if each.type == dp2.DataType.Dense:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册