提交 98b071f8 编写于 作者: W Webbley

update distribute_metapath2vec

上级 7eb8941c
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
[metapath2vec](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) is a algorithm framework for representation learning in heterogeneous networks which contains multiple types of nodes and links. Given a heterogeneous graph, metapath2vec algorithm first generates meta-path-based random walks and then use skipgram model to train a language model. Based on PGL, we reproduce metapath2vec algorithm in distributed mode. [metapath2vec](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) is a algorithm framework for representation learning in heterogeneous networks which contains multiple types of nodes and links. Given a heterogeneous graph, metapath2vec algorithm first generates meta-path-based random walks and then use skipgram model to train a language model. Based on PGL, we reproduce metapath2vec algorithm in distributed mode.
## Datasets ### Datasets
DBLP: The dataset contains 14376 papers (P), 20 conferences (C), 14475 authors (A), and 8920 terms (T). There are 33791 nodes in this dataset. DBLP: The dataset contains 14376 papers (P), 20 conferences (C), 14475 authors (A), and 8920 terms (T). There are 33791 nodes in this dataset.
You can dowload datasets from [here](https://github.com/librahu/HIN-Datasets-for-Recommendation-and-Network-Embedding) You can dowload datasets from [here](https://github.com/librahu/HIN-Datasets-for-Recommendation-and-Network-Embedding)
We use the ```DBLP``` dataset for example. After downloading the dataset, put them, let's say, in ```./data/DBLP/``` . We use the ```DBLP``` dataset for example. After downloading the dataset, put them, let's say, in ```./data/DBLP/``` .
## Dependencies ### Dependencies
- paddlepaddle>=1.6 - paddlepaddle>=1.6
- pgl>=1.0.0 - pgl>=1.0.0
## How to run ### How to run
Before training, run the below command to do data preprocessing. Before training, run the below command to do data preprocessing.
```sh ```sh
python data_process.py --data_path ./data/DBLP --output_path ./data/data_processed python data_process.py --data_path ./data/DBLP --output_path ./data/data_processed
...@@ -30,11 +30,21 @@ python multi_class.py --dataset ./data/data_processed/author_label.txt --ckpt_pa ...@@ -30,11 +30,21 @@ python multi_class.py --dataset ./data/data_processed/author_label.txt --ckpt_pa
``` ```
### Model Selection
Actually, There are 3 models in this example, they are ```metapath2vec```, ```metapath2vec++``` and ```multi_metapath2vec++```. You can select different models by modifying ```config.yaml```.
## Hyperparameters In order to run ```metapath2vec++``` model, you can easily rewrite the hyper parameter of **neg_sample_type** to **m2v_plus**, then ```metapath2vec++``` model will be selected.
```multi-metapath2vec++``` means that you are not only use a single metapath, instead, you can use several metapaths at the same time to train the model. For example, you might want to use ```c2p-p2a-a2p-p2c``` and ```p2a-a2p``` simultaneously. Then you can rewrite the below hyper parameters in ```config.yaml```.
- **neg_sample_type**: "m2v_plus"
- **walk_mode**: "multi_m2v"
- **meta_path**: "c2p-p2a-a2p-p2c;p2a-a2p"
- **first_node_type**: "c;p"
### Hyperparameters
All the hyper parameters are saved in ```config.yaml``` file. So before training, you can open the config.yaml to modify the hyper parameters as you like. All the hyper parameters are saved in ```config.yaml``` file. So before training, you can open the config.yaml to modify the hyper parameters as you like.
Some important hyper parameters in config.yaml: Some important hyper parameters in ```config.yaml```:
- **edge_path**: the directory of graph data that you want to load - **edge_path**: the directory of graph data that you want to load
- **lr**: learning rate - **lr**: learning rate
- **neg_num**: number of negative samples. - **neg_num**: number of negative samples.
......
...@@ -31,7 +31,7 @@ is_distributed: False ...@@ -31,7 +31,7 @@ is_distributed: False
# trainging config # trainging config
epochs: 10 epochs: 10
optimizer: "sgd" optimizer: "sgd"
lr: 1.0 lr: 0.1
warm_start_from_dir: null warm_start_from_dir: null
walkpath_files: "None" walkpath_files: "None"
train_files: "None" train_files: "None"
......
...@@ -87,9 +87,12 @@ class NodeGenerator(object): ...@@ -87,9 +87,12 @@ class NodeGenerator(object):
idx = cc % num_n_type idx = cc % num_n_type
n_type = n_type_list[idx] n_type = n_type_list[idx]
try: try:
nodes = node_generators[n_type].next() nodes = next(node_generators[n_type])
except StopIteration as e: except StopIteration as e:
log.info("exception when iteration") log.info("node type of %s iteration finished in one epoch" %
(n_type))
node_generators[n_type] = \
self.graph.node_batch_iter(self.batch_size, n_type=n_type)
break break
yield (nodes, idx) yield (nodes, idx)
cc += 1 cc += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册