提交 dbaabc94 编写于 作者: L Luo Tao 提交者: Yu Yang

fix unitest of test_RecurrentGradientMachine, and some tiny doc update

Change-Id: I028e402c964ca4f4431cbf8153bea4379dd4df70
上级 d6d91223
......@@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa
### C++ Interface
First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`.
First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`.
```
train_list = 'train.list' if not is_test else None
......
......@@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers
* Text Convolution Pooling Layer, `text_conv_pool
<../../ui/api/trainer_config_helpers/networks.html
#trainer_config_helpers.networks.text_conv_pool>`_
* Declare Python Data Sources, `define_py_data_sources
* Declare Python Data Sources, `define_py_data_sources2
<../../ui/api/trainer_config_helpers/data_sources.html>`_
Data Provider
......
......@@ -18,27 +18,33 @@
import os
import sys
from paddle.trainer.PyDataProviderWrapper import *
from paddle.trainer.PyDataProvider2 import *
@init_hook_wrapper
def hook(obj, dict_file, **kwargs):
obj.word_dict = dict_file
obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)]
obj.logger.info('dict len : %d' % (len(obj.word_dict)))
def hook(settings, dict_file, **kwargs):
settings.word_dict = dict_file
settings.input_types = [integer_value_sequence(len(settings.word_dict)),
integer_value_sequence(3)]
settings.logger.info('dict len : %d' % (len(settings.word_dict)))
@provider(use_seq=True, init_hook=hook)
def process(obj, file_name):
@provider(init_hook=hook)
def process(settings, file_name):
with open(file_name, 'r') as fdata:
for line in fdata:
label, comment = line.strip().split('\t')
label = int(''.join(label.split()))
words = comment.split()
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict]
word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
yield word_slot, [label]
## for hierarchical sequence network
@provider(use_seq=True, init_hook=hook)
def process2(obj, file_name):
def hook2(settings, dict_file, **kwargs):
settings.word_dict = dict_file
settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)),
integer_value_sub_sequence(3)]
settings.logger.info('dict len : %d' % (len(settings.word_dict)))
@provider(init_hook=hook2)
def process2(settings, file_name):
with open(file_name) as fdata:
label_list = []
word_slot_list = []
......@@ -47,7 +53,7 @@ def process2(obj, file_name):
label,comment = line.strip().split('\t')
label = int(''.join(label.split()))
words = comment.split()
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict]
word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
label_list.append([label])
word_slot_list.append(word_slot)
else:
......
......@@ -21,7 +21,7 @@ dict_file = dict()
for line_count, line in enumerate(open(dict_path, "r")):
dict_file[line.strip()] = line_count
define_py_data_sources(train_list='gserver/tests/Sequence/train.list',
define_py_data_sources2(train_list='gserver/tests/Sequence/train.list',
test_list=None,
module='sequenceGen',
obj='process',
......
......@@ -21,7 +21,7 @@ dict_file = dict()
for line_count, line in enumerate(open(dict_path, "r")):
dict_file[line.strip()] = line_count
define_py_data_sources(train_list='gserver/tests/Sequence/train.list.nest',
define_py_data_sources2(train_list='gserver/tests/Sequence/train.list.nest',
test_list=None,
module='sequenceGen',
obj='process2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册