diff --git a/doc/demo/imagenet_model/resnet_model.md b/doc/demo/imagenet_model/resnet_model.md index 2e5c7f324434d896aaf0e3232f976e5ae6ed7f03..5403ab9f17d2399fee878d0f3c512cb166aba06f 100644 --- a/doc/demo/imagenet_model/resnet_model.md +++ b/doc/demo/imagenet_model/resnet_model.md @@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa ### C++ Interface -First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`. +First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`. ``` train_list = 'train.list' if not is_test else None diff --git a/doc/demo/rec/ml_regression.rst b/doc/demo/rec/ml_regression.rst index 4917f873a934dc93e8618627473dbe99644b10b5..0c14e4f5bb7f815a06c0c756b1a6e6ef9099fd66 100644 --- a/doc/demo/rec/ml_regression.rst +++ b/doc/demo/rec/ml_regression.rst @@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers * Text Convolution Pooling Layer, `text_conv_pool <../../ui/api/trainer_config_helpers/networks.html #trainer_config_helpers.networks.text_conv_pool>`_ -* Declare Python Data Sources, `define_py_data_sources +* Declare Python Data Sources, `define_py_data_sources2 <../../ui/api/trainer_config_helpers/data_sources.html>`_ Data Provider diff --git a/paddle/gserver/tests/sequenceGen.py b/paddle/gserver/tests/sequenceGen.py index dd2b90dd4986cf049ae4514d308fbbf549371f90..e4727e472d446b48e6001968841bfc178e34ec0c 100644 --- a/paddle/gserver/tests/sequenceGen.py +++ b/paddle/gserver/tests/sequenceGen.py @@ -18,27 +18,33 @@ import os import sys -from paddle.trainer.PyDataProviderWrapper import * +from paddle.trainer.PyDataProvider2 import * -@init_hook_wrapper -def hook(obj, dict_file, **kwargs): - obj.word_dict = dict_file - obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)] - obj.logger.info('dict len : %d' % (len(obj.word_dict))) +def hook(settings, dict_file, **kwargs): + settings.word_dict = dict_file + settings.input_types = [integer_value_sequence(len(settings.word_dict)), + integer_value_sequence(3)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) -@provider(use_seq=True, init_hook=hook) -def process(obj, file_name): +@provider(init_hook=hook) +def process(settings, file_name): with open(file_name, 'r') as fdata: for line in fdata: label, comment = line.strip().split('\t') label = int(''.join(label.split())) words = comment.split() - word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] + word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict] yield word_slot, [label] ## for hierarchical sequence network -@provider(use_seq=True, init_hook=hook) -def process2(obj, file_name): +def hook2(settings, dict_file, **kwargs): + settings.word_dict = dict_file + settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)), + integer_value_sub_sequence(3)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) + +@provider(init_hook=hook2) +def process2(settings, file_name): with open(file_name) as fdata: label_list = [] word_slot_list = [] @@ -47,7 +53,7 @@ def process2(obj, file_name): label,comment = line.strip().split('\t') label = int(''.join(label.split())) words = comment.split() - word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] + word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict] label_list.append([label]) word_slot_list.append(word_slot) else: diff --git a/paddle/gserver/tests/sequence_layer_group.conf b/paddle/gserver/tests/sequence_layer_group.conf index 9ad2b3762845faf9e5deac2362b971e0abcbd9a7..ac031b31280df297246c1ea2e279fc2c595bd8b7 100644 --- a/paddle/gserver/tests/sequence_layer_group.conf +++ b/paddle/gserver/tests/sequence_layer_group.conf @@ -21,11 +21,11 @@ dict_file = dict() for line_count, line in enumerate(open(dict_path, "r")): dict_file[line.strip()] = line_count -define_py_data_sources(train_list='gserver/tests/Sequence/train.list', - test_list=None, - module='sequenceGen', - obj='process', - args={"dict_file":dict_file}) +define_py_data_sources2(train_list='gserver/tests/Sequence/train.list', + test_list=None, + module='sequenceGen', + obj='process', + args={"dict_file":dict_file}) settings(batch_size=5) ######################## network configure ################################ diff --git a/paddle/gserver/tests/sequence_nest_layer_group.conf b/paddle/gserver/tests/sequence_nest_layer_group.conf index 8c3a08f16cd1cc593ffe48839bddd12852a41931..38c60b657b969f9fbcf46a00c542fa100da5a877 100644 --- a/paddle/gserver/tests/sequence_nest_layer_group.conf +++ b/paddle/gserver/tests/sequence_nest_layer_group.conf @@ -21,11 +21,11 @@ dict_file = dict() for line_count, line in enumerate(open(dict_path, "r")): dict_file[line.strip()] = line_count -define_py_data_sources(train_list='gserver/tests/Sequence/train.list.nest', - test_list=None, - module='sequenceGen', - obj='process2', - args={"dict_file":dict_file}) +define_py_data_sources2(train_list='gserver/tests/Sequence/train.list.nest', + test_list=None, + module='sequenceGen', + obj='process2', + args={"dict_file":dict_file}) settings(batch_size=2) ######################## network configure ################################