提交 f93af824 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #730 from luotao1/demo

fix protobuf size limit of seq2seq demo
......@@ -19,27 +19,44 @@ START = "<s>"
END = "<e>"
def hook(settings, src_dict, trg_dict, file_list, **kwargs):
def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list,
**kwargs):
# job_mode = 1: training mode
# job_mode = 0: generating mode
settings.job_mode = trg_dict is not None
settings.src_dict = src_dict
settings.job_mode = not is_generating
settings.src_dict = dict()
with open(src_dict_path, "r") as fin:
settings.src_dict = {
line.strip(): line_count
for line_count, line in enumerate(fin)
}
settings.trg_dict = dict()
with open(trg_dict_path, "r") as fin:
settings.trg_dict = {
line.strip(): line_count
for line_count, line in enumerate(fin)
}
settings.logger.info("src dict len : %d" % (len(settings.src_dict)))
settings.sample_count = 0
if settings.job_mode:
settings.trg_dict = trg_dict
settings.slots = [
settings.slots = {
'source_language_word':
integer_value_sequence(len(settings.src_dict)),
'target_language_word':
integer_value_sequence(len(settings.trg_dict)),
'target_language_next_word':
integer_value_sequence(len(settings.trg_dict))
]
}
settings.logger.info("trg dict len : %d" % (len(settings.trg_dict)))
else:
settings.slots = [
settings.slots = {
'source_language_word':
integer_value_sequence(len(settings.src_dict)),
'sent_id':
integer_value_sequence(len(open(file_list[0], "r").readlines()))
]
}
def _get_ids(s, dictionary):
......@@ -69,6 +86,10 @@ def process(settings, file_name):
continue
trg_ids_next = trg_ids + [settings.trg_dict[END]]
trg_ids = [settings.trg_dict[START]] + trg_ids
yield src_ids, trg_ids, trg_ids_next
yield {
'source_language_word': src_ids,
'target_language_word': trg_ids,
'target_language_next_word': trg_ids_next
}
else:
yield src_ids, [line_count]
yield {'source_language_word': src_ids, 'sent_id': [line_count]}
......@@ -37,17 +37,10 @@ def seq_to_seq_data(data_dir,
"""
src_lang_dict = os.path.join(data_dir, 'src.dict')
trg_lang_dict = os.path.join(data_dir, 'trg.dict')
src_dict = dict()
for line_count, line in enumerate(open(src_lang_dict, "r")):
src_dict[line.strip()] = line_count
trg_dict = dict()
for line_count, line in enumerate(open(trg_lang_dict, "r")):
trg_dict[line.strip()] = line_count
if is_generating:
train_list = None
test_list = os.path.join(data_dir, gen_list)
trg_dict = None
else:
train_list = os.path.join(data_dir, train_list)
test_list = os.path.join(data_dir, test_list)
......@@ -57,8 +50,11 @@ def seq_to_seq_data(data_dir,
test_list,
module="dataprovider",
obj="process",
args={"src_dict": src_dict,
"trg_dict": trg_dict})
args={
"src_dict_path": src_lang_dict,
"trg_dict_path": trg_lang_dict,
"is_generating": is_generating
})
return {
"src_dict_path": src_lang_dict,
......
......@@ -214,3 +214,41 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字
cmake .. -DPYTHON_EXECUTABLE=<exc_path> -DPYTHON_LIBRARY=<lib_path> -DPYTHON_INCLUDE_DIR=<inc_path>
用户需要指定本机上Python的路径:``<exc_path>``, ``<lib_path>``, ``<inc_path>``
10. A protocol message was rejected because it was too big
----------------------------------------------------------
如果在训练NLP相关模型时,出现以下错误:
.. code-block:: bash
[libprotobuf ERROR google/protobuf/io/coded_stream.cc:171] A protocol message was rejected because it was too big (more than 67108864 bytes). To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
F1205 14:59:50.295174 14703 TrainerConfigHelper.cpp:59] Check failed: m->conf.ParseFromString(configProtoStr)
可能的原因是:传给dataprovider的某一个args过大,一般是由于直接传递大字典导致的。错误的define_py_data_sources2类似:
.. code-block:: python
src_dict = dict()
for line_count, line in enumerate(open(src_dict_path, "r")):
src_dict[line.strip()] = line_count
define_py_data_sources2(
train_list,
test_list,
module="dataprovider",
obj="process",
args={"src_dict": src_dict})
解决方案是:将字典的地址作为args传给dataprovider,然后在dataprovider里面根据该地址加载字典。即define_py_data_sources2应改为:
.. code-block:: python
define_py_data_sources2(
train_list,
test_list,
module="dataprovider",
obj="process",
args={"src_dict_path": src_dict_path})
完整源码可参考 `seqToseq <https://github.com/PaddlePaddle/Paddle/tree/develop/demo/seqToseq>`_ 示例。
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册