From 89cfb39fe450faeb71fc00d13a6421824a43b646 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 20 Feb 2017 16:44:58 +0800 Subject: [PATCH] use data_type of data layer --- demo/mnist/api_train.py | 2 +- demo/mnist/api_train_v2.py | 4 ++-- python/paddle/v2/layer.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index ae00e9f54c..c6993eda38 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -99,7 +99,7 @@ def main(): # DataProvider Converter is a utility convert Python Object to Paddle C++ # Input. The input format is as same as Paddle's DataProvider. converter = DataProviderConverter( - input_types=[dp.dense_vector(784), dp.integer_value(10)]) + input_types=[images.data_type, label.data_type]) train_file = './data/raw_data/train' test_file = './data/raw_data/t10k' diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index b46cf1c870..6bbc3a4ab3 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -51,8 +51,8 @@ def main(): batch_size=32, # batch size should be refactor in Data reader data_types={ # data_types will be removed, It should be in # network topology - 'pixel': dense_vector(784), - 'label': integer_value(10) + 'pixel': images.data_type, + 'label': label.data_type }) diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 511b3e7457..f2bc47946b 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -168,15 +168,15 @@ class DataLayerV2(Layer): def __init__(self, name, data_type, **kwargs): assert isinstance(data_type, dp.InputType) + self.data_type = data_type self.__method_name__ = 'data_layer' self.__kwargs__ = kwargs - self.__data_size__ = data_type.dim super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) def to_proto_impl(self, **kwargs): args = dict() - args['size'] = self.__data_size__ + args['size'] = self.data_type.dim for each in kwargs: args[each] = kwargs[each] for each in self.__kwargs__: -- GitLab