diff --git a/models/treebased/tdm/README.md b/models/treebased/tdm/README.md index ea53f44cb974d95b8651915143e14f1c9d9b5b60..1d0ec4ba0d6acd07cc06ebaebacf12cf254237dc 100644 --- a/models/treebased/tdm/README.md +++ b/models/treebased/tdm/README.md @@ -15,6 +15,22 @@ python -m paddlerec.run -m models/treebased/tdm/config.yaml ``` 3. 建树及自定义训练的细节可以查阅[TDM-Demo建树及训练](./gen_tree/README.md) +## 文件目录 + +```shell +. +├── config.yaml # 模型配置文件 +├── data # 训练及测试数据样例文件夹 +├── gen_tree # 建树代码及文档文件夹 +├── __init__.py # __init__.py +├── model.py # tdm组网文件 +├── README.md # tdm-demo文档 +├── tdm_evaluate_reader.py # 测试数据reader +├── tdm_reader.py # 训练数据reader +├── tdm_startup.py # tdm模型启动阶段自定义逻辑 +└── tree # 树文件样例 +``` + ## 树结构的准备 ### 名词概念 diff --git a/models/treebased/tdm/config.yaml b/models/treebased/tdm/config.yaml index c1e6f6a48b55838f5cd1743e8313c61b71f8c207..71c4ffc496f89c8b6f6e29d40738547ce5f3b309 100755 --- a/models/treebased/tdm/config.yaml +++ b/models/treebased/tdm/config.yaml @@ -18,7 +18,7 @@ workspace: "models/treebased/tdm" dataset: - name: dataset_train # name of dataset to distinguish different datasets batch_size: 2 - type: QueueDataset # or QueueDataset + type: QueueDataset # or DataLoader data_path: "{workspace}/data/train" data_converter: "{workspace}/tdm_reader.py" - name: dataset_infer # name diff --git a/models/treebased/tdm/model.py b/models/treebased/tdm/model.py index 6ad9edc991576f44af8beed1c542d5f6ef1575b6..065f086f18e0dab8a33eb4c4c3ef20595c2c804d 100755 --- a/models/treebased/tdm/model.py +++ b/models/treebased/tdm/model.py @@ -134,9 +134,14 @@ class Model(ModelBase): tree_dtype='int64', dtype='int64') + sample_nodes = [ + fluid.layers.reshape(sample_nodes[i], [-1, 1]) + for i in range(self.max_layers) + ] + # 查表得到每个节点的Embedding sample_nodes_emb = [ - fluid.embedding( + fluid.layers.embedding( input=sample_nodes[i], is_sparse=True, size=[self.node_nums, self.node_emb_size],