From 47043fe1d3a511e0c88625ff4382adb1fea0feb8 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 20 Feb 2017 16:15:22 +0800 Subject: [PATCH] add type to datalayer --- demo/mnist/api_train.py | 4 ++-- python/paddle/v2/layer.py | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index fe39f0bd23..77e6fa8900 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -71,8 +71,8 @@ def main(): assert isinstance(updater, api.ParameterUpdater) # define network - images = paddle_v2.layer.data(name='pixel', size=784) - label = paddle_v2.layer.data(name='label', size=10) + images = paddle_v2.layer.data(name='pixel', type=dp.dense_vector(784)) + label = paddle_v2.layer.data(name='label', type=dp.integer_value(10)) hidden1 = paddle_v2.layer.fc(input=images, size=200) hidden2 = paddle_v2.layer.fc(input=hidden1, size=200) inference = paddle_v2.layer.fc(input=hidden2, diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 0ce4ecd569..e006b78922 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -67,6 +67,7 @@ paddle.v2.parameters.create, no longer exposed to users. """ import paddle.trainer_config_helpers as conf_helps +import paddle.trainer.PyDataProvider2 as dp from paddle.trainer_config_helpers.config_parser_utils import \ parse_network_config as __parse__ from paddle.trainer_config_helpers.default_decorators import wrap_name_default @@ -157,7 +158,37 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): return V2LayerImpl -data = __convert_to_v2__('data_layer', None, []) +""" +Some layer may need some special config, and can not use __convert_to_v2__ to convert. +So we also need to implement some special LayerV2. +""" + + +class DataLayerV2(Layer): + def __init__(self, name, type, **kwargs): + self.__method_name__ = 'data_layer' + + assert isinstance(type, dp.InputType) + + # get data_size from type.dim + args = dict() + for key in kwargs: + args[key] = kwargs[key] + args['size'] = type.dim + self.__args__ = args + + super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) + + def to_proto_impl(self, **kwargs): + args = dict() + for each in kwargs: + args[each] = kwargs[each] + for each in self.__args__: + args[each] = self.__args__[each] + return getattr(conf_helps, self.__method_name__)(name=self.name, **args) + + +data = DataLayerV2 fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input']) max_id = __convert_to_v2__( 'maxid_layer', name_prefix='maxid_layer', parent_names=['input']) @@ -171,8 +202,8 @@ cross_entropy_cost = __convert_to_v2__( parent_names=['input', 'label']) if __name__ == '__main__': - pixel = data(name='pixel', size=784) - label = data(name='label', size=10) + pixel = data(name='pixel', type=dp.dense_vector(784)) + label = data(name='label', type=dp.integer_value(10)) hidden = fc(input=pixel, size=100, act=conf_helps.SigmoidActivation()) inference = fc(input=hidden, size=10, act=conf_helps.SoftmaxActivation()) maxid = max_id(input=inference) -- GitLab