diff --git a/demo/seqToseq/dataprovider.py b/demo/seqToseq/dataprovider.py index c5da1b7685f47fda337921c7c60ac1497b9e48bb..0614d54aaf8bb64f8a4d4a8af364fb993fd8cc7e 100755 --- a/demo/seqToseq/dataprovider.py +++ b/demo/seqToseq/dataprovider.py @@ -19,27 +19,38 @@ START = "" END = "" -def hook(settings, src_dict, trg_dict, file_list, **kwargs): +def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list, + **kwargs): # job_mode = 1: training mode # job_mode = 0: generating mode - settings.job_mode = trg_dict is not None - settings.src_dict = src_dict + settings.job_mode = not is_generating + settings.src_dict = dict() + for line_count, line in enumerate(open(src_dict_path, "r")): + settings.src_dict[line.strip()] = line_count + settings.trg_dict = dict() + for line_count, line in enumerate(open(trg_dict_path, "r")): + settings.trg_dict[line.strip()] = line_count + settings.logger.info("src dict len : %d" % (len(settings.src_dict))) settings.sample_count = 0 if settings.job_mode: - settings.trg_dict = trg_dict - settings.slots = [ + settings.slots = { + 'source_language_word': integer_value_sequence(len(settings.src_dict)), + 'target_language_word': integer_value_sequence(len(settings.trg_dict)), + 'target_language_next_word': integer_value_sequence(len(settings.trg_dict)) - ] + } settings.logger.info("trg dict len : %d" % (len(settings.trg_dict))) else: - settings.slots = [ + settings.slots = { + 'source_language_word': integer_value_sequence(len(settings.src_dict)), + 'sent_id': integer_value_sequence(len(open(file_list[0], "r").readlines())) - ] + } def _get_ids(s, dictionary): @@ -69,6 +80,10 @@ def process(settings, file_name): continue trg_ids_next = trg_ids + [settings.trg_dict[END]] trg_ids = [settings.trg_dict[START]] + trg_ids - yield src_ids, trg_ids, trg_ids_next + yield { + 'source_language_word': src_ids, + 'target_language_word': trg_ids, + 'target_language_next_word': trg_ids_next + } else: - yield src_ids, [line_count] + yield {'source_language_word': src_ids, 'sent_id': [line_count]} diff --git a/demo/seqToseq/seqToseq_net.py b/demo/seqToseq/seqToseq_net.py index ad5e3339c1461de06732eb62aca9e8323eea707b..fc9db05ba706ee6eff6eb0ce0885a645ebd76340 100644 --- a/demo/seqToseq/seqToseq_net.py +++ b/demo/seqToseq/seqToseq_net.py @@ -37,17 +37,10 @@ def seq_to_seq_data(data_dir, """ src_lang_dict = os.path.join(data_dir, 'src.dict') trg_lang_dict = os.path.join(data_dir, 'trg.dict') - src_dict = dict() - for line_count, line in enumerate(open(src_lang_dict, "r")): - src_dict[line.strip()] = line_count - trg_dict = dict() - for line_count, line in enumerate(open(trg_lang_dict, "r")): - trg_dict[line.strip()] = line_count if is_generating: train_list = None test_list = os.path.join(data_dir, gen_list) - trg_dict = None else: train_list = os.path.join(data_dir, train_list) test_list = os.path.join(data_dir, test_list) @@ -57,8 +50,11 @@ def seq_to_seq_data(data_dir, test_list, module="dataprovider", obj="process", - args={"src_dict": src_dict, - "trg_dict": trg_dict}) + args={ + "src_dict_path": src_lang_dict, + "trg_dict_path": trg_lang_dict, + "is_generating": is_generating + }) return { "src_dict_path": src_lang_dict, diff --git a/doc_cn/faq/index.rst b/doc_cn/faq/index.rst index 73f7a0e06a7db0253318510e50a7efc17bd22aa6..f611255aaccd54f079c04dd509454bfd08af1307 100644 --- a/doc_cn/faq/index.rst +++ b/doc_cn/faq/index.rst @@ -214,3 +214,41 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字 cmake .. -DPYTHON_EXECUTABLE= -DPYTHON_LIBRARY= -DPYTHON_INCLUDE_DIR= 用户需要指定本机上Python的路径:````, ````, ```` + +10. A protocol message was rejected because it was too big +---------------------------------------------------------- + +如果在训练NLP相关模型时,出现以下错误: + +.. code-block:: bash + + [libprotobuf ERROR google/protobuf/io/coded_stream.cc:171] A protocol message was rejected because it was too big (more than 67108864 bytes). To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h. + F1205 14:59:50.295174 14703 TrainerConfigHelper.cpp:59] Check failed: m->conf.ParseFromString(configProtoStr) + +可能的原因是:传给dataprovider的某一个args过大,一般是由于直接传递大字典导致的。错误的define_py_data_sources2类似: + +.. code-block:: python + + src_dict = dict() + for line_count, line in enumerate(open(src_dict_path, "r")): + src_dict[line.strip()] = line_count + + define_py_data_sources2( + train_list, + test_list, + module="dataprovider", + obj="process", + args={"src_dict": src_dict}) + +解决方案是:将字典的地址作为args传给dataprovider,然后在dataprovider里面根据该地址加载字典。即define_py_data_sources2应改为: + +.. code-block:: python + + define_py_data_sources2( + train_list, + test_list, + module="dataprovider", + obj="process", + args={"src_dict_path": src_dict_path}) + +完整源码可参考 `seqToseq `_ 示例。 \ No newline at end of file