diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index c5c4513572bc59c2f30ee8d599743b9aa1011930..3dc8bfd6e04fb3da2b217f84126fac8101102945 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -42,7 +42,7 @@ class TranspileTrainer(Trainer): namespace = "train.reader" class_name = "TrainReader" else: - dataloader = self.model._infer_data_loader + readerdataloader = self.model._infer_data_loader namespace = "evaluate.reader" class_name = "EvaluateReader" @@ -58,6 +58,16 @@ class TranspileTrainer(Trainer): dataloader.set_sample_list_generator(reader) else: dataloader.set_sample_generator(reader, batch_size) + + debug_mode = envs.get_global_env("reader_debug_mode", False, namespace) + if debug_mode: + print("--- DataLoader Debug Mode Begin , show pre 10 data ---") + for idx, line in enumerate(reader()): + print(line) + if idx >= 9: + break + print("--- DataLoader Debug Mode End , show pre 10 data ---") + exit(0) return dataloader def _get_dataset(self, state="TRAIN"): @@ -98,6 +108,16 @@ class TranspileTrainer(Trainer): ] dataset.set_filelist(file_list) + + debug_mode = envs.get_global_env("reader_debug_mode", False, namespace) + if debug_mode: + print( + "--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0])) + os.system("cat {} | {} | head -10".format(file_list[0], pipe_cmd)) + print( + "--- Dataset Debug Mode End , show pre 10 data of {}---".format(file_list[0])) + exit(0) + return dataset def save(self, epoch_id, namespace, is_fleet=False): @@ -116,23 +136,28 @@ class TranspileTrainer(Trainer): if not need_save(epoch_id, save_interval, False): return - + # print("save inference model is not supported now.") # return - feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace) - fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace) + feed_varnames = envs.get_global_env( + "save.inference.feed_varnames", None, namespace) + fetch_varnames = envs.get_global_env( + "save.inference.fetch_varnames", None, namespace) if feed_varnames is None or fetch_varnames is None: return - fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames] - dirname = envs.get_global_env("save.inference.dirname", None, namespace) + fetch_vars = [fluid.default_main_program().global_block().vars[varname] + for varname in fetch_varnames] + dirname = envs.get_global_env( + "save.inference.dirname", None, namespace) assert dirname is not None dirname = os.path.join(dirname, str(epoch_id)) if is_fleet: - fleet.save_inference_model(self._exe, dirname, feed_varnames, fetch_vars) + fleet.save_inference_model( + self._exe, dirname, feed_varnames, fetch_vars) else: fluid.io.save_inference_model( dirname, feed_varnames, fetch_vars, self._exe) diff --git a/doc/custom_dataset_reader.md b/doc/custom_dataset_reader.md index 19e2eaae2b32d19dff06e6a5fc71c18ff4fb27a1..c6dba95100908d741437f4003119c83a072eba89 100644 --- a/doc/custom_dataset_reader.md +++ b/doc/custom_dataset_reader.md @@ -1,105 +1,235 @@ # PaddleRec 自定义数据集及Reader -## dataset数据读取 -为了能高速运行CTR模型的训练,我们使用`dataset`API进行高性能的IO,dataset是为多线程及全异步方式量身打造的数据读取方式,每个数据读取线程会与一个训练线程耦合,形成了多生产者-多消费者的模式,会极大的加速我们的模型训练。 +## 数据集及reader配置简介 -如何在我们的训练中引入dataset读取方式呢?无需变更数据格式,只需在我们的训练代码中加入以下内容,便可达到媲美二进制读取的高效率,以下是一个比较完整的流程: +以`ctr-dnn`模型举例: -### 引入dataset +```yaml +reader: + batch_size: 2 + class: "{workspace}/../criteo_reader.py" + train_data_path: "{workspace}/data/train" + reader_debug_mode: False +``` +有以上4个需要重点关注的配置选项: + +- batch_size: 网络进行小批量训练的一组数据的大小 +- class: 指定数据处理及读取的`reader` python文件 +- train_data_path: 训练数据所在地址 +- reader_debug_mode: 测试reader语法,及输出是否符合预期的debug模式的开关 + +## 自定义数据集 + +PaddleRec支持模型自定义数据集,在model.config.yaml文件中的reader部分,通过`train_data_path`指定数据读取路径。 + +关于数据的tips + +- PaddleRec 面向的是推荐与搜索领域,数据以文本格式为主 +- Dataset模式支持读取文本数据压缩后的`.gz`格式 +- Dataset模式下,训练线程与数据读取线程的关系强相关,为了多线程充分利用,`强烈建议将文件拆成多个小文件`,尤其是在分布式训练场景下,可以均衡各个节点的数据量。 + +## 自定义Reader + +数据集准备就绪后,需要适当修改或重写一个新的reader以适配数据集或新组网。 + +我们以`ctr-dnn`网络举例`reader`的正确打开方式,网络文件位于`models/rank/dnn`。 + +### Criteo数据集格式 + +CTR-DNN训练及测试数据集选用[Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)所用的Criteo数据集。该数据集包括两部分:训练集和测试集。训练集包含一段时间内Criteo的部分流量,测试集则对应训练数据后一天的广告点击流量。 +每一行数据格式如下所示: +```bash +