From 02c609ec0af59834838c58055bbe5a90c5800096 Mon Sep 17 00:00:00 2001 From: MrChengmo Date: Fri, 25 Sep 2020 20:00:54 +0800 Subject: [PATCH] fix tdm cluster & update doc --- models/treebased/tdm/README.md | 16 ++++++++++++++++ models/treebased/tdm/config.yaml | 2 +- models/treebased/tdm/model.py | 7 ++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/models/treebased/tdm/README.md b/models/treebased/tdm/README.md index ea53f44c..1d0ec4ba 100644 --- a/models/treebased/tdm/README.md +++ b/models/treebased/tdm/README.md @@ -15,6 +15,22 @@ python -m paddlerec.run -m models/treebased/tdm/config.yaml ``` 3. 建树及自定义训练的细节可以查阅[TDM-Demo建树及训练](./gen_tree/README.md) +## 文件目录 + +```shell +. +├── config.yaml # 模型配置文件 +├── data # 训练及测试数据样例文件夹 +├── gen_tree # 建树代码及文档文件夹 +├── __init__.py # __init__.py +├── model.py # tdm组网文件 +├── README.md # tdm-demo文档 +├── tdm_evaluate_reader.py # 测试数据reader +├── tdm_reader.py # 训练数据reader +├── tdm_startup.py # tdm模型启动阶段自定义逻辑 +└── tree # 树文件样例 +``` + ## 树结构的准备 ### 名词概念 diff --git a/models/treebased/tdm/config.yaml b/models/treebased/tdm/config.yaml index c1e6f6a4..71c4ffc4 100755 --- a/models/treebased/tdm/config.yaml +++ b/models/treebased/tdm/config.yaml @@ -18,7 +18,7 @@ workspace: "models/treebased/tdm" dataset: - name: dataset_train # name of dataset to distinguish different datasets batch_size: 2 - type: QueueDataset # or QueueDataset + type: QueueDataset # or DataLoader data_path: "{workspace}/data/train" data_converter: "{workspace}/tdm_reader.py" - name: dataset_infer # name diff --git a/models/treebased/tdm/model.py b/models/treebased/tdm/model.py index 6ad9edc9..065f086f 100755 --- a/models/treebased/tdm/model.py +++ b/models/treebased/tdm/model.py @@ -134,9 +134,14 @@ class Model(ModelBase): tree_dtype='int64', dtype='int64') + sample_nodes = [ + fluid.layers.reshape(sample_nodes[i], [-1, 1]) + for i in range(self.max_layers) + ] + # 查表得到每个节点的Embedding sample_nodes_emb = [ - fluid.embedding( + fluid.layers.embedding( input=sample_nodes[i], is_sparse=True, size=[self.node_nums, self.node_emb_size], -- GitLab