提交 02c609ec 编写于 作者: M MrChengmo

fix tdm cluster & update doc

上级 8fd5aeda
......@@ -15,6 +15,22 @@ python -m paddlerec.run -m models/treebased/tdm/config.yaml
```
3. 建树及自定义训练的细节可以查阅[TDM-Demo建树及训练](./gen_tree/README.md)
## 文件目录
```shell
.
├── config.yaml # 模型配置文件
├── data # 训练及测试数据样例文件夹
├── gen_tree # 建树代码及文档文件夹
├── __init__.py # __init__.py
├── model.py # tdm组网文件
├── README.md # tdm-demo文档
├── tdm_evaluate_reader.py # 测试数据reader
├── tdm_reader.py # 训练数据reader
├── tdm_startup.py # tdm模型启动阶段自定义逻辑
└── tree # 树文件样例
```
## 树结构的准备
### 名词概念
......
......@@ -18,7 +18,7 @@ workspace: "models/treebased/tdm"
dataset:
- name: dataset_train # name of dataset to distinguish different datasets
batch_size: 2
type: QueueDataset # or QueueDataset
type: QueueDataset # or DataLoader
data_path: "{workspace}/data/train"
data_converter: "{workspace}/tdm_reader.py"
- name: dataset_infer # name
......
......@@ -134,9 +134,14 @@ class Model(ModelBase):
tree_dtype='int64',
dtype='int64')
sample_nodes = [
fluid.layers.reshape(sample_nodes[i], [-1, 1])
for i in range(self.max_layers)
]
# 查表得到每个节点的Embedding
sample_nodes_emb = [
fluid.embedding(
fluid.layers.embedding(
input=sample_nodes[i],
is_sparse=True,
size=[self.node_nums, self.node_emb_size],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册