From dbaabc94fb0b21b7bf91132eab5de954143d870b Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 7 Sep 2016 15:48:36 +0800 Subject: [PATCH] fix unitest of test_RecurrentGradientMachine, and some tiny doc update Change-Id: I028e402c964ca4f4431cbf8153bea4379dd4df70 --- doc/demo/imagenet_model/resnet_model.md | 2 +- doc/demo/rec/ml_regression.rst | 2 +- paddle/gserver/tests/sequenceGen.py | 30 +++++++++++-------- .../gserver/tests/sequence_layer_group.conf | 10 +++---- .../tests/sequence_nest_layer_group.conf | 10 +++---- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/doc/demo/imagenet_model/resnet_model.md b/doc/demo/imagenet_model/resnet_model.md index 2e5c7f3244..5403ab9f17 100644 --- a/doc/demo/imagenet_model/resnet_model.md +++ b/doc/demo/imagenet_model/resnet_model.md @@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa ### C++ Interface -First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`. +First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`. ``` train_list = 'train.list' if not is_test else None diff --git a/doc/demo/rec/ml_regression.rst b/doc/demo/rec/ml_regression.rst index 4917f873a9..0c14e4f5bb 100644 --- a/doc/demo/rec/ml_regression.rst +++ b/doc/demo/rec/ml_regression.rst @@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers * Text Convolution Pooling Layer, `text_conv_pool <../../ui/api/trainer_config_helpers/networks.html #trainer_config_helpers.networks.text_conv_pool>`_ -* Declare Python Data Sources, `define_py_data_sources +* Declare Python Data Sources, `define_py_data_sources2 <../../ui/api/trainer_config_helpers/data_sources.html>`_ Data Provider diff --git a/paddle/gserver/tests/sequenceGen.py b/paddle/gserver/tests/sequenceGen.py index dd2b90dd49..e4727e472d 100644 --- a/paddle/gserver/tests/sequenceGen.py +++ b/paddle/gserver/tests/sequenceGen.py @@ -18,27 +18,33 @@ import os import sys -from paddle.trainer.PyDataProviderWrapper import * +from paddle.trainer.PyDataProvider2 import * -@init_hook_wrapper -def hook(obj, dict_file, **kwargs): - obj.word_dict = dict_file - obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)] - obj.logger.info('dict len : %d' % (len(obj.word_dict))) +def hook(settings, dict_file, **kwargs): + settings.word_dict = dict_file + settings.input_types = [integer_value_sequence(len(settings.word_dict)), + integer_value_sequence(3)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) -@provider(use_seq=True, init_hook=hook) -def process(obj, file_name): +@provider(init_hook=hook) +def process(settings, file_name): with open(file_name, 'r') as fdata: for line in fdata: label, comment = line.strip().split('\t') label = int(''.join(label.split())) words = comment.split() - word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] + word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict] yield word_slot, [label] ## for hierarchical sequence network -@provider(use_seq=True, init_hook=hook) -def process2(obj, file_name): +def hook2(settings, dict_file, **kwargs): + settings.word_dict = dict_file + settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)), + integer_value_sub_sequence(3)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) + +@provider(init_hook=hook2) +def process2(settings, file_name): with open(file_name) as fdata: label_list = [] word_slot_list = [] @@ -47,7 +53,7 @@ def process2(obj, file_name): label,comment = line.strip().split('\t') label = int(''.join(label.split())) words = comment.split() - word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] + word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict] label_list.append([label]) word_slot_list.append(word_slot) else: diff --git a/paddle/gserver/tests/sequence_layer_group.conf b/paddle/gserver/tests/sequence_layer_group.conf index 9ad2b37628..ac031b3128 100644 --- a/paddle/gserver/tests/sequence_layer_group.conf +++ b/paddle/gserver/tests/sequence_layer_group.conf @@ -21,11 +21,11 @@ dict_file = dict() for line_count, line in enumerate(open(dict_path, "r")): dict_file[line.strip()] = line_count -define_py_data_sources(train_list='gserver/tests/Sequence/train.list', - test_list=None, - module='sequenceGen', - obj='process', - args={"dict_file":dict_file}) +define_py_data_sources2(train_list='gserver/tests/Sequence/train.list', + test_list=None, + module='sequenceGen', + obj='process', + args={"dict_file":dict_file}) settings(batch_size=5) ######################## network configure ################################ diff --git a/paddle/gserver/tests/sequence_nest_layer_group.conf b/paddle/gserver/tests/sequence_nest_layer_group.conf index 8c3a08f16c..38c60b657b 100644 --- a/paddle/gserver/tests/sequence_nest_layer_group.conf +++ b/paddle/gserver/tests/sequence_nest_layer_group.conf @@ -21,11 +21,11 @@ dict_file = dict() for line_count, line in enumerate(open(dict_path, "r")): dict_file[line.strip()] = line_count -define_py_data_sources(train_list='gserver/tests/Sequence/train.list.nest', - test_list=None, - module='sequenceGen', - obj='process2', - args={"dict_file":dict_file}) +define_py_data_sources2(train_list='gserver/tests/Sequence/train.list.nest', + test_list=None, + module='sequenceGen', + obj='process2', + args={"dict_file":dict_file}) settings(batch_size=2) ######################## network configure ################################ -- GitLab