提交 dbaabc94 编写于 作者: L Luo Tao 提交者: Yu Yang

fix unitest of test_RecurrentGradientMachine, and some tiny doc update

Change-Id: I028e402c964ca4f4431cbf8153bea4379dd4df70
上级 d6d91223
...@@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa ...@@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa
### C++ Interface ### C++ Interface
First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`. First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`.
``` ```
train_list = 'train.list' if not is_test else None train_list = 'train.list' if not is_test else None
......
...@@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers ...@@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers
* Text Convolution Pooling Layer, `text_conv_pool * Text Convolution Pooling Layer, `text_conv_pool
<../../ui/api/trainer_config_helpers/networks.html <../../ui/api/trainer_config_helpers/networks.html
#trainer_config_helpers.networks.text_conv_pool>`_ #trainer_config_helpers.networks.text_conv_pool>`_
* Declare Python Data Sources, `define_py_data_sources * Declare Python Data Sources, `define_py_data_sources2
<../../ui/api/trainer_config_helpers/data_sources.html>`_ <../../ui/api/trainer_config_helpers/data_sources.html>`_
Data Provider Data Provider
......
...@@ -18,27 +18,33 @@ ...@@ -18,27 +18,33 @@
import os import os
import sys import sys
from paddle.trainer.PyDataProviderWrapper import * from paddle.trainer.PyDataProvider2 import *
@init_hook_wrapper def hook(settings, dict_file, **kwargs):
def hook(obj, dict_file, **kwargs): settings.word_dict = dict_file
obj.word_dict = dict_file settings.input_types = [integer_value_sequence(len(settings.word_dict)),
obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)] integer_value_sequence(3)]
obj.logger.info('dict len : %d' % (len(obj.word_dict))) settings.logger.info('dict len : %d' % (len(settings.word_dict)))
@provider(use_seq=True, init_hook=hook) @provider(init_hook=hook)
def process(obj, file_name): def process(settings, file_name):
with open(file_name, 'r') as fdata: with open(file_name, 'r') as fdata:
for line in fdata: for line in fdata:
label, comment = line.strip().split('\t') label, comment = line.strip().split('\t')
label = int(''.join(label.split())) label = int(''.join(label.split()))
words = comment.split() words = comment.split()
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
yield word_slot, [label] yield word_slot, [label]
## for hierarchical sequence network ## for hierarchical sequence network
@provider(use_seq=True, init_hook=hook) def hook2(settings, dict_file, **kwargs):
def process2(obj, file_name): settings.word_dict = dict_file
settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)),
integer_value_sub_sequence(3)]
settings.logger.info('dict len : %d' % (len(settings.word_dict)))
@provider(init_hook=hook2)
def process2(settings, file_name):
with open(file_name) as fdata: with open(file_name) as fdata:
label_list = [] label_list = []
word_slot_list = [] word_slot_list = []
...@@ -47,7 +53,7 @@ def process2(obj, file_name): ...@@ -47,7 +53,7 @@ def process2(obj, file_name):
label,comment = line.strip().split('\t') label,comment = line.strip().split('\t')
label = int(''.join(label.split())) label = int(''.join(label.split()))
words = comment.split() words = comment.split()
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
label_list.append([label]) label_list.append([label])
word_slot_list.append(word_slot) word_slot_list.append(word_slot)
else: else:
......
...@@ -21,7 +21,7 @@ dict_file = dict() ...@@ -21,7 +21,7 @@ dict_file = dict()
for line_count, line in enumerate(open(dict_path, "r")): for line_count, line in enumerate(open(dict_path, "r")):
dict_file[line.strip()] = line_count dict_file[line.strip()] = line_count
define_py_data_sources(train_list='gserver/tests/Sequence/train.list', define_py_data_sources2(train_list='gserver/tests/Sequence/train.list',
test_list=None, test_list=None,
module='sequenceGen', module='sequenceGen',
obj='process', obj='process',
......
...@@ -21,7 +21,7 @@ dict_file = dict() ...@@ -21,7 +21,7 @@ dict_file = dict()
for line_count, line in enumerate(open(dict_path, "r")): for line_count, line in enumerate(open(dict_path, "r")):
dict_file[line.strip()] = line_count dict_file[line.strip()] = line_count
define_py_data_sources(train_list='gserver/tests/Sequence/train.list.nest', define_py_data_sources2(train_list='gserver/tests/Sequence/train.list.nest',
test_list=None, test_list=None,
module='sequenceGen', module='sequenceGen',
obj='process2', obj='process2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册