提交 8ac9e9a8 编写于 作者: L Luo Tao

fix protobuf size limit of seq2seq demo

上级 9ffa434b
...@@ -19,27 +19,38 @@ START = "<s>" ...@@ -19,27 +19,38 @@ START = "<s>"
END = "<e>" END = "<e>"
def hook(settings, src_dict, trg_dict, file_list, **kwargs): def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list,
**kwargs):
# job_mode = 1: training mode # job_mode = 1: training mode
# job_mode = 0: generating mode # job_mode = 0: generating mode
settings.job_mode = trg_dict is not None settings.job_mode = not is_generating
settings.src_dict = src_dict settings.src_dict = dict()
for line_count, line in enumerate(open(src_dict_path, "r")):
settings.src_dict[line.strip()] = line_count
settings.trg_dict = dict()
for line_count, line in enumerate(open(trg_dict_path, "r")):
settings.trg_dict[line.strip()] = line_count
settings.logger.info("src dict len : %d" % (len(settings.src_dict))) settings.logger.info("src dict len : %d" % (len(settings.src_dict)))
settings.sample_count = 0 settings.sample_count = 0
if settings.job_mode: if settings.job_mode:
settings.trg_dict = trg_dict settings.slots = {
settings.slots = [ 'source_language_word':
integer_value_sequence(len(settings.src_dict)), integer_value_sequence(len(settings.src_dict)),
'target_language_word':
integer_value_sequence(len(settings.trg_dict)), integer_value_sequence(len(settings.trg_dict)),
'target_language_next_word':
integer_value_sequence(len(settings.trg_dict)) integer_value_sequence(len(settings.trg_dict))
] }
settings.logger.info("trg dict len : %d" % (len(settings.trg_dict))) settings.logger.info("trg dict len : %d" % (len(settings.trg_dict)))
else: else:
settings.slots = [ settings.slots = {
'source_language_word':
integer_value_sequence(len(settings.src_dict)), integer_value_sequence(len(settings.src_dict)),
'sent_id':
integer_value_sequence(len(open(file_list[0], "r").readlines())) integer_value_sequence(len(open(file_list[0], "r").readlines()))
] }
def _get_ids(s, dictionary): def _get_ids(s, dictionary):
...@@ -69,6 +80,10 @@ def process(settings, file_name): ...@@ -69,6 +80,10 @@ def process(settings, file_name):
continue continue
trg_ids_next = trg_ids + [settings.trg_dict[END]] trg_ids_next = trg_ids + [settings.trg_dict[END]]
trg_ids = [settings.trg_dict[START]] + trg_ids trg_ids = [settings.trg_dict[START]] + trg_ids
yield src_ids, trg_ids, trg_ids_next yield {
'source_language_word': src_ids,
'target_language_word': trg_ids,
'target_language_next_word': trg_ids_next
}
else: else:
yield src_ids, [line_count] yield {'source_language_word': src_ids, 'sent_id': [line_count]}
...@@ -37,17 +37,10 @@ def seq_to_seq_data(data_dir, ...@@ -37,17 +37,10 @@ def seq_to_seq_data(data_dir,
""" """
src_lang_dict = os.path.join(data_dir, 'src.dict') src_lang_dict = os.path.join(data_dir, 'src.dict')
trg_lang_dict = os.path.join(data_dir, 'trg.dict') trg_lang_dict = os.path.join(data_dir, 'trg.dict')
src_dict = dict()
for line_count, line in enumerate(open(src_lang_dict, "r")):
src_dict[line.strip()] = line_count
trg_dict = dict()
for line_count, line in enumerate(open(trg_lang_dict, "r")):
trg_dict[line.strip()] = line_count
if is_generating: if is_generating:
train_list = None train_list = None
test_list = os.path.join(data_dir, gen_list) test_list = os.path.join(data_dir, gen_list)
trg_dict = None
else: else:
train_list = os.path.join(data_dir, train_list) train_list = os.path.join(data_dir, train_list)
test_list = os.path.join(data_dir, test_list) test_list = os.path.join(data_dir, test_list)
...@@ -57,8 +50,11 @@ def seq_to_seq_data(data_dir, ...@@ -57,8 +50,11 @@ def seq_to_seq_data(data_dir,
test_list, test_list,
module="dataprovider", module="dataprovider",
obj="process", obj="process",
args={"src_dict": src_dict, args={
"trg_dict": trg_dict}) "src_dict_path": src_lang_dict,
"trg_dict_path": trg_lang_dict,
"is_generating": is_generating
})
return { return {
"src_dict_path": src_lang_dict, "src_dict_path": src_lang_dict,
......
...@@ -214,3 +214,41 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字 ...@@ -214,3 +214,41 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字
cmake .. -DPYTHON_EXECUTABLE=<exc_path> -DPYTHON_LIBRARY=<lib_path> -DPYTHON_INCLUDE_DIR=<inc_path> cmake .. -DPYTHON_EXECUTABLE=<exc_path> -DPYTHON_LIBRARY=<lib_path> -DPYTHON_INCLUDE_DIR=<inc_path>
用户需要指定本机上Python的路径:``<exc_path>``, ``<lib_path>``, ``<inc_path>`` 用户需要指定本机上Python的路径:``<exc_path>``, ``<lib_path>``, ``<inc_path>``
10. A protocol message was rejected because it was too big
----------------------------------------------------------
如果在训练NLP相关模型时,出现以下错误:
.. code-block:: bash
[libprotobuf ERROR google/protobuf/io/coded_stream.cc:171] A protocol message was rejected because it was too big (more than 67108864 bytes). To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
F1205 14:59:50.295174 14703 TrainerConfigHelper.cpp:59] Check failed: m->conf.ParseFromString(configProtoStr)
可能的原因是:传给dataprovider的某一个args过大,一般是由于直接传递大字典导致的。错误的define_py_data_sources2类似:
.. code-block:: python
src_dict = dict()
for line_count, line in enumerate(open(src_dict_path, "r")):
src_dict[line.strip()] = line_count
define_py_data_sources2(
train_list,
test_list,
module="dataprovider",
obj="process",
args={"src_dict": src_dict})
解决方案是:将字典的地址作为args传给dataprovider,然后在dataprovider里面根据该地址加载字典。即define_py_data_sources2应改为:
.. code-block:: python
define_py_data_sources2(
train_list,
test_list,
module="dataprovider",
obj="process",
args={"src_dict_path": src_dict_path})
完整源码可参考 `seqToseq <https://github.com/PaddlePaddle/Paddle/tree/develop/demo/seqToseq>`_ 示例。
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册