提交 47043fe1 编写于 作者: Q qiaolongfei

add type to datalayer

上级 cafb075e
......@@ -71,8 +71,8 @@ def main():
assert isinstance(updater, api.ParameterUpdater)
# define network
images = paddle_v2.layer.data(name='pixel', size=784)
label = paddle_v2.layer.data(name='label', size=10)
images = paddle_v2.layer.data(name='pixel', type=dp.dense_vector(784))
label = paddle_v2.layer.data(name='label', type=dp.integer_value(10))
hidden1 = paddle_v2.layer.fc(input=images, size=200)
hidden2 = paddle_v2.layer.fc(input=hidden1, size=200)
inference = paddle_v2.layer.fc(input=hidden2,
......
......@@ -67,6 +67,7 @@ paddle.v2.parameters.create, no longer exposed to users.
"""
import paddle.trainer_config_helpers as conf_helps
import paddle.trainer.PyDataProvider2 as dp
from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
......@@ -157,7 +158,37 @@ def __convert_to_v2__(method_name, name_prefix, parent_names):
return V2LayerImpl
data = __convert_to_v2__('data_layer', None, [])
"""
Some layer may need some special config, and can not use __convert_to_v2__ to convert.
So we also need to implement some special LayerV2.
"""
class DataLayerV2(Layer):
def __init__(self, name, type, **kwargs):
self.__method_name__ = 'data_layer'
assert isinstance(type, dp.InputType)
# get data_size from type.dim
args = dict()
for key in kwargs:
args[key] = kwargs[key]
args['size'] = type.dim
self.__args__ = args
super(DataLayerV2, self).__init__(name=name, parent_layers=dict())
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__args__:
args[each] = self.__args__[each]
return getattr(conf_helps, self.__method_name__)(name=self.name, **args)
data = DataLayerV2
fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input'])
max_id = __convert_to_v2__(
'maxid_layer', name_prefix='maxid_layer', parent_names=['input'])
......@@ -171,8 +202,8 @@ cross_entropy_cost = __convert_to_v2__(
parent_names=['input', 'label'])
if __name__ == '__main__':
pixel = data(name='pixel', size=784)
label = data(name='label', size=10)
pixel = data(name='pixel', type=dp.dense_vector(784))
label = data(name='label', type=dp.integer_value(10))
hidden = fc(input=pixel, size=100, act=conf_helps.SigmoidActivation())
inference = fc(input=hidden, size=10, act=conf_helps.SoftmaxActivation())
maxid = max_id(input=inference)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册