提交 02c609ec 编写于 作者: M MrChengmo

fix tdm cluster & update doc

上级 8fd5aeda
...@@ -15,6 +15,22 @@ python -m paddlerec.run -m models/treebased/tdm/config.yaml ...@@ -15,6 +15,22 @@ python -m paddlerec.run -m models/treebased/tdm/config.yaml
``` ```
3. 建树及自定义训练的细节可以查阅[TDM-Demo建树及训练](./gen_tree/README.md) 3. 建树及自定义训练的细节可以查阅[TDM-Demo建树及训练](./gen_tree/README.md)
## 文件目录
```shell
.
├── config.yaml # 模型配置文件
├── data # 训练及测试数据样例文件夹
├── gen_tree # 建树代码及文档文件夹
├── __init__.py # __init__.py
├── model.py # tdm组网文件
├── README.md # tdm-demo文档
├── tdm_evaluate_reader.py # 测试数据reader
├── tdm_reader.py # 训练数据reader
├── tdm_startup.py # tdm模型启动阶段自定义逻辑
└── tree # 树文件样例
```
## 树结构的准备 ## 树结构的准备
### 名词概念 ### 名词概念
......
...@@ -18,7 +18,7 @@ workspace: "models/treebased/tdm" ...@@ -18,7 +18,7 @@ workspace: "models/treebased/tdm"
dataset: dataset:
- name: dataset_train # name of dataset to distinguish different datasets - name: dataset_train # name of dataset to distinguish different datasets
batch_size: 2 batch_size: 2
type: QueueDataset # or QueueDataset type: QueueDataset # or DataLoader
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/tdm_reader.py" data_converter: "{workspace}/tdm_reader.py"
- name: dataset_infer # name - name: dataset_infer # name
......
...@@ -134,9 +134,14 @@ class Model(ModelBase): ...@@ -134,9 +134,14 @@ class Model(ModelBase):
tree_dtype='int64', tree_dtype='int64',
dtype='int64') dtype='int64')
sample_nodes = [
fluid.layers.reshape(sample_nodes[i], [-1, 1])
for i in range(self.max_layers)
]
# 查表得到每个节点的Embedding # 查表得到每个节点的Embedding
sample_nodes_emb = [ sample_nodes_emb = [
fluid.embedding( fluid.layers.embedding(
input=sample_nodes[i], input=sample_nodes[i],
is_sparse=True, is_sparse=True,
size=[self.node_nums, self.node_emb_size], size=[self.node_nums, self.node_emb_size],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册