未验证 提交 c5382b8d 编写于 作者: Webbley's avatar Webbley 提交者: GitHub

Merge pull request #55 from Liwb5/develop

update mol task for ogb
......@@ -18,11 +18,16 @@ python setup.py install
### How to run
For example, use GPU to train model on ogbg-molhiv dataset and ogb-molpcba dataset.
```
export CUDA_VISIBLE_DEVICES=1
python -u main.py --config hiv_config.yaml
CUDA_VISIBLE_DEVICES=1 python -u main.py --config hiv_config.yaml --use_cuda
export CUDA_VISIBLE_DEVICES=2
python -u main.py --config pcba_config.yaml
CUDA_VISIBLE_DEVICES=2 python -u main.py --config pcba_config.yaml --use_cuda
```
If you want to use CPU to train model, environment variables `CPU_NUM` should be specified and should be in the range of 1 to N, where N is the total CPU number on your machine.
```
CPU_NUM=1 python -u main.py --config hiv_config.yaml
CPU_NUM=1 python -u main.py --config pcba_config.yaml
```
### Experiment results
......
......@@ -25,6 +25,7 @@ from utils.args import ArgumentGroup
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--use_cuda', action='store_true')
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
......@@ -80,7 +81,6 @@ data_g.add_arg("random_seed", int, None, "Random seed.")
data_g.add_arg("buf_size", int, 1000, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation on dev data set.")
......
task_name: hiv
use_cuda: True
seed: 15391
dataset_name: ogbg-molhiv
eval_metrics: null
......
......@@ -171,8 +171,10 @@ def main(args):
if __name__ == "__main__":
args = parser.parse_args()
if args.config is not None:
args = Config(args.config, isCreate=True, isSave=True)
config = Config(args.config, isCreate=True, isSave=True)
log.info(args)
config['use_cuda'] = args.use_cuda
main(args)
log.info(config)
main(config)
task_name: pcba
use_cuda: True
seed: 28994
dataset_name: ogbg-molpcba
eval_metrics: null
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册