Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
49bed8cf
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
49bed8cf
编写于
4月 23, 2020
作者:
W
Webbley
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ogb graph classification task
上级
b0a20434
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
1883 addition
and
2 deletion
+1883
-2
ogb_examples/graphproppred/mol/README.md
ogb_examples/graphproppred/mol/README.md
+6
-2
ogb_examples/graphproppred/mol/args.py
ogb_examples/graphproppred/mol/args.py
+104
-0
ogb_examples/graphproppred/mol/data/__init__.py
ogb_examples/graphproppred/mol/data/__init__.py
+13
-0
ogb_examples/graphproppred/mol/data/base_dataset.py
ogb_examples/graphproppred/mol/data/base_dataset.py
+83
-0
ogb_examples/graphproppred/mol/data/dataloader.py
ogb_examples/graphproppred/mol/data/dataloader.py
+183
-0
ogb_examples/graphproppred/mol/data/splitters.py
ogb_examples/graphproppred/mol/data/splitters.py
+153
-0
ogb_examples/graphproppred/mol/main.py
ogb_examples/graphproppred/mol/main.py
+178
-0
ogb_examples/graphproppred/mol/model.py
ogb_examples/graphproppred/mol/model.py
+210
-0
ogb_examples/graphproppred/mol/mol_encoder.py
ogb_examples/graphproppred/mol/mol_encoder.py
+71
-0
ogb_examples/graphproppred/mol/monitor/train_monitor.py
ogb_examples/graphproppred/mol/monitor/train_monitor.py
+154
-0
ogb_examples/graphproppred/mol/optimization.py
ogb_examples/graphproppred/mol/optimization.py
+163
-0
ogb_examples/graphproppred/mol/utils/__init__.py
ogb_examples/graphproppred/mol/utils/__init__.py
+13
-0
ogb_examples/graphproppred/mol/utils/args.py
ogb_examples/graphproppred/mol/utils/args.py
+94
-0
ogb_examples/graphproppred/mol/utils/cards.py
ogb_examples/graphproppred/mol/utils/cards.py
+30
-0
ogb_examples/graphproppred/mol/utils/config.py
ogb_examples/graphproppred/mol/utils/config.py
+136
-0
ogb_examples/graphproppred/mol/utils/fp16.py
ogb_examples/graphproppred/mol/utils/fp16.py
+201
-0
ogb_examples/graphproppred/mol/utils/init.py
ogb_examples/graphproppred/mol/utils/init.py
+91
-0
未找到文件。
ogb_examples/graphproppred/README.md
→
ogb_examples/graphproppred/
mol/
README.md
浏览文件 @
49bed8cf
...
...
@@ -16,7 +16,11 @@ python setup.py install
```
### How to run
For example, use GPU to train model on ogbg-molhiv dataset.
For example, use GPU to train model on ogbg-molhiv dataset
and ogb-molpcba dataset
.
```
python main_pgl.py --use_cuda --dataset ogbg-molhiv
export CUDA_VISIBLE_DEVICES=1
python -u main.py --config hiv_config.yaml
export CUDA_VISIBLE_DEVICES=2
python -u main.py --config pcba_config.yaml
```
ogb_examples/graphproppred/mol/args.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
time
import
argparse
from
utils.args
import
ArgumentGroup
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"model configuration and paths."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"init_pretraining_params"
,
str
,
None
,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid."
)
model_g
.
add_arg
(
"./save_dir"
,
str
,
"./checkpoints"
,
"Path to save checkpoints."
)
model_g
.
add_arg
(
"hidden_size"
,
int
,
128
,
"hidden size."
)
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
3
,
"Number of epoches for fine-tuning."
)
train_g
.
add_arg
(
"learning_rate"
,
float
,
5e-5
,
"Learning rate used to train with warmup."
)
train_g
.
add_arg
(
"lr_scheduler"
,
str
,
"linear_warmup_decay"
,
"scheduler of learning rate."
,
choices
=
[
'linear_warmup_decay'
,
'noam_decay'
])
train_g
.
add_arg
(
"weight_decay"
,
float
,
0.01
,
"Weight decay rate for L2 regularizer."
)
train_g
.
add_arg
(
"warmup_proportion"
,
float
,
0.1
,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
10000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"validation_steps"
,
int
,
1000
,
"The steps interval to evaluate model performance."
)
train_g
.
add_arg
(
"use_dynamic_loss_scaling"
,
bool
,
True
,
"Whether to use dynamic loss scaling."
)
train_g
.
add_arg
(
"init_loss_scaling"
,
float
,
102400
,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g
.
add_arg
(
"test_save"
,
str
,
"./checkpoints/test_result"
,
"test_save"
)
train_g
.
add_arg
(
"metric"
,
str
,
"simple_accuracy"
,
"metric"
)
train_g
.
add_arg
(
"incr_every_n_steps"
,
int
,
100
,
"Increases loss scaling every n consecutive."
)
train_g
.
add_arg
(
"decr_every_n_nan_or_inf"
,
int
,
2
,
"Decreases loss scaling every n accumulated steps with nan or inf gradients."
)
train_g
.
add_arg
(
"incr_ratio"
,
float
,
2.0
,
"The multiplier to use when increasing the loss scaling."
)
train_g
.
add_arg
(
"decr_ratio"
,
float
,
0.8
,
"The less-than-one-multiplier to use when decreasing."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related."
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
log_g
.
add_arg
(
"verbose"
,
bool
,
False
,
"Whether to output verbose log."
)
log_g
.
add_arg
(
"log_dir"
,
str
,
'./logs/'
,
"Whether to output verbose log."
)
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options"
)
data_g
.
add_arg
(
"tokenizer"
,
str
,
"FullTokenizer"
,
"ATTENTION: the INPUT must be splited by Word with blank while using SentencepieceTokenizer or WordsegTokenizer"
)
data_g
.
add_arg
(
"train_set"
,
str
,
None
,
"Path to training data."
)
data_g
.
add_arg
(
"test_set"
,
str
,
None
,
"Path to test data."
)
data_g
.
add_arg
(
"dev_set"
,
str
,
None
,
"Path to validation data."
)
data_g
.
add_arg
(
"aug1_type"
,
str
,
"scheme1"
,
"augment type"
)
data_g
.
add_arg
(
"aug2_type"
,
str
,
"scheme1"
,
"augment type"
)
data_g
.
add_arg
(
"batch_size"
,
int
,
32
,
"Total examples' number in batch for training. see also --in_tokens."
)
data_g
.
add_arg
(
"predict_batch_size"
,
int
,
None
,
"Total examples' number in batch for predict. see also --in_tokens."
)
data_g
.
add_arg
(
"random_seed"
,
int
,
None
,
"Random seed."
)
data_g
.
add_arg
(
"buf_size"
,
int
,
1000
,
"Random seed."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
False
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"num_iteration_per_drop_scope"
,
int
,
10
,
"Iteration intervals to drop scope."
)
run_type_g
.
add_arg
(
"do_train"
,
bool
,
True
,
"Whether to perform training."
)
run_type_g
.
add_arg
(
"do_val"
,
bool
,
True
,
"Whether to perform evaluation on dev data set."
)
run_type_g
.
add_arg
(
"do_test"
,
bool
,
True
,
"Whether to perform evaluation on test data set."
)
run_type_g
.
add_arg
(
"metrics"
,
bool
,
True
,
"Whether to perform evaluation on test data set."
)
run_type_g
.
add_arg
(
"shuffle"
,
bool
,
True
,
""
)
run_type_g
.
add_arg
(
"for_cn"
,
bool
,
True
,
"model train for cn or for other langs."
)
run_type_g
.
add_arg
(
"num_workers"
,
int
,
1
,
"use multiprocess to generate graph"
)
run_type_g
.
add_arg
(
"output_dir"
,
str
,
None
,
"path to save model"
)
run_type_g
.
add_arg
(
"config"
,
str
,
None
,
"configure yaml file"
)
run_type_g
.
add_arg
(
"n"
,
str
,
None
,
"task name"
)
run_type_g
.
add_arg
(
"task_name"
,
str
,
None
,
"task name"
)
run_type_g
.
add_arg
(
"pretrain"
,
bool
,
False
,
"Whether do pretrian"
)
run_type_g
.
add_arg
(
"pretrain_name"
,
str
,
None
,
"pretrain task name"
)
run_type_g
.
add_arg
(
"pretrain_config"
,
str
,
None
,
"pretrain config.yaml file"
)
run_type_g
.
add_arg
(
"pretrain_model_step"
,
str
,
None
,
"pretrain model step"
)
run_type_g
.
add_arg
(
"model_type"
,
str
,
"BaseLineModel"
,
"pretrain model step"
)
run_type_g
.
add_arg
(
"num_class"
,
int
,
1
,
"number class"
)
run_type_g
.
add_arg
(
"dataset_name"
,
str
,
None
,
"finetune dataset name"
)
run_type_g
.
add_arg
(
"eval_metrics"
,
str
,
None
,
"evaluate metrics"
)
run_type_g
.
add_arg
(
"task_type"
,
str
,
None
,
"regression or classification"
)
ogb_examples/graphproppred/mol/data/__init__.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ogb_examples/graphproppred/mol/data/base_dataset.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
from
ogb.graphproppred
import
GraphPropPredDataset
import
pgl
from
pgl.utils.logger
import
log
class
BaseDataset
(
object
):
def
__init__
(
self
):
pass
def
__getitem__
(
self
,
idx
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
class
Subset
(
BaseDataset
):
r
"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def
__init__
(
self
,
dataset
,
indices
):
self
.
dataset
=
dataset
self
.
indices
=
indices
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
self
.
indices
[
idx
]]
def
__len__
(
self
):
return
len
(
self
.
indices
)
class
Dataset
(
BaseDataset
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
raw_dataset
=
GraphPropPredDataset
(
name
=
args
.
dataset_name
)
self
.
num_tasks
=
self
.
raw_dataset
.
num_tasks
self
.
eval_metrics
=
self
.
raw_dataset
.
eval_metric
self
.
task_type
=
self
.
raw_dataset
.
task_type
self
.
pgl_graph_list
=
[]
self
.
graph_label_list
=
[]
for
i
in
range
(
len
(
self
.
raw_dataset
)):
graph
,
label
=
self
.
raw_dataset
[
i
]
edges
=
list
(
zip
(
graph
[
"edge_index"
][
0
],
graph
[
"edge_index"
][
1
]))
g
=
pgl
.
graph
.
Graph
(
num_nodes
=
graph
[
"num_nodes"
],
edges
=
edges
)
if
graph
[
"edge_feat"
]
is
not
None
:
g
.
edge_feat
[
"feat"
]
=
graph
[
"edge_feat"
]
if
graph
[
"node_feat"
]
is
not
None
:
g
.
node_feat
[
"feat"
]
=
graph
[
"node_feat"
]
self
.
pgl_graph_list
.
append
(
g
)
self
.
graph_label_list
.
append
(
label
)
def
__getitem__
(
self
,
idx
):
return
self
.
pgl_graph_list
[
idx
],
self
.
graph_label_list
[
idx
]
def
__len__
(
self
):
return
len
(
slef
.
pgl_graph_list
)
def
get_idx_split
(
self
):
return
self
.
raw_dataset
.
get_idx_split
()
ogb_examples/graphproppred/mol/data/dataloader.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the graph dataloader.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
# SSL
import
torch
import
sys
import
six
from
io
import
open
import
collections
from
collections
import
namedtuple
import
numpy
as
np
import
tqdm
import
time
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
fl
import
pgl
from
pgl.utils
import
mp_reader
from
pgl.utils.logger
import
log
from
ogb.graphproppred
import
GraphPropPredDataset
def
batch_iter
(
data
,
batch_size
,
fid
,
num_workers
):
"""node_batch_iter
"""
size
=
len
(
data
)
perm
=
np
.
arange
(
size
)
np
.
random
.
shuffle
(
perm
)
start
=
0
cc
=
0
while
start
<
size
:
index
=
perm
[
start
:
start
+
batch_size
]
start
+=
batch_size
cc
+=
1
if
cc
%
num_workers
!=
fid
:
continue
yield
data
[
index
]
def
scan_batch_iter
(
data
,
batch_size
,
fid
,
num_workers
):
"""scan_batch_iter
"""
batch
=
[]
cc
=
0
for
line_example
in
data
.
scan
():
cc
+=
1
if
cc
%
num_workers
!=
fid
:
continue
batch
.
append
(
line_example
)
if
len
(
batch
)
==
batch_size
:
yield
batch
batch
=
[]
if
len
(
batch
)
>
0
:
yield
batch
class
GraphDataloader
(
object
):
"""Graph Dataloader
"""
def
__init__
(
self
,
dataset
,
graph_wrapper
,
batch_size
,
seed
=
0
,
num_workers
=
1
,
buf_size
=
1000
,
shuffle
=
True
):
self
.
shuffle
=
shuffle
self
.
seed
=
seed
self
.
num_workers
=
num_workers
self
.
buf_size
=
buf_size
self
.
batch_size
=
batch_size
self
.
dataset
=
dataset
self
.
graph_wrapper
=
graph_wrapper
def
batch_fn
(
self
,
batch_examples
):
""" batch_fn batch producer"""
graphs
=
[
b
[
0
]
for
b
in
batch_examples
]
labels
=
[
b
[
1
]
for
b
in
batch_examples
]
join_graph
=
pgl
.
graph
.
MultiGraph
(
graphs
)
labels
=
np
.
array
(
labels
)
feed_dict
=
self
.
graph_wrapper
.
to_feed
(
join_graph
)
batch_valid
=
(
labels
==
labels
).
astype
(
"float32"
)
labels
=
np
.
nan_to_num
(
labels
).
astype
(
"float32"
)
feed_dict
[
'labels'
]
=
labels
feed_dict
[
'unmask'
]
=
batch_valid
return
feed_dict
def
batch_iter
(
self
,
fid
):
"""batch_iter"""
if
self
.
shuffle
:
for
batch
in
batch_iter
(
self
,
self
.
batch_size
,
fid
,
self
.
num_workers
):
yield
batch
else
:
for
batch
in
scan_batch_iter
(
self
,
self
.
batch_size
,
fid
,
self
.
num_workers
):
yield
batch
def
__len__
(
self
):
"""__len__"""
return
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
"""__getitem__"""
if
isinstance
(
idx
,
collections
.
Iterable
):
return
[
self
[
bidx
]
for
bidx
in
idx
]
else
:
return
self
.
dataset
[
idx
]
def
__iter__
(
self
):
"""__iter__"""
def
worker
(
filter_id
):
def
func_run
():
for
batch_examples
in
self
.
batch_iter
(
filter_id
):
batch_dict
=
self
.
batch_fn
(
batch_examples
)
yield
batch_dict
return
func_run
if
self
.
num_workers
==
1
:
r
=
paddle
.
reader
.
buffered
(
worker
(
0
),
self
.
buf_size
)
else
:
worker_pool
=
[
worker
(
wid
)
for
wid
in
range
(
self
.
num_workers
)]
worker
=
mp_reader
.
multiprocess_reader
(
worker_pool
,
use_pipe
=
True
,
queue_size
=
1000
)
r
=
paddle
.
reader
.
buffered
(
worker
,
self
.
buf_size
)
for
batch
in
r
():
yield
batch
def
scan
(
self
):
"""scan"""
for
example
in
self
.
dataset
:
yield
example
if
__name__
==
"__main__"
:
from
base_dataset
import
BaseDataset
,
Subset
dataset
=
GraphPropPredDataset
(
name
=
"ogbg-molhiv"
)
splitted_index
=
dataset
.
get_idx_split
()
train_dataset
=
Subset
(
dataset
,
splitted_index
[
'train'
])
valid_dataset
=
Subset
(
dataset
,
splitted_index
[
'valid'
])
test_dataset
=
Subset
(
dataset
,
splitted_index
[
'test'
])
log
.
info
(
"Train Examples: %s"
%
len
(
train_dataset
))
log
.
info
(
"Val Examples: %s"
%
len
(
valid_dataset
))
log
.
info
(
"Test Examples: %s"
%
len
(
test_dataset
))
# train_loader = GraphDataloader(train_dataset, batch_size=3)
# for batch_data in train_loader:
# graphs, labels = batch_data
# print(labels.shape)
# time.sleep(4)
ogb_examples/graphproppred/mol/data/splitters.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
logging
from
random
import
random
import
pandas
as
pd
import
numpy
as
np
from
itertools
import
compress
import
scipy.sparse
as
sp
from
sklearn.model_selection
import
StratifiedKFold
from
sklearn.preprocessing
import
StandardScaler
from
rdkit.Chem.Scaffolds
import
MurckoScaffold
import
pgl
from
pgl.utils
import
paddle_helper
try
:
from
dataset.Dataset
import
Subset
from
dataset.Dataset
import
ChemDataset
except
:
from
Dataset
import
Subset
from
Dataset
import
ChemDataset
log
=
logging
.
getLogger
(
"logger"
)
def
random_split
(
dataset
,
args
):
total_precent
=
args
.
frac_train
+
args
.
frac_valid
+
args
.
frac_test
np
.
testing
.
assert_almost_equal
(
total_precent
,
1.0
)
length
=
len
(
dataset
)
perm
=
list
(
range
(
length
))
np
.
random
.
shuffle
(
perm
)
num_train
=
int
(
args
.
frac_train
*
length
)
num_valid
=
int
(
args
.
frac_valid
*
length
)
num_test
=
int
(
args
.
frac_test
*
length
)
train_indices
=
perm
[
0
:
num_train
]
valid_indices
=
perm
[
num_train
:(
num_train
+
num_valid
)]
test_indices
=
perm
[(
num_train
+
num_valid
):]
assert
(
len
(
train_indices
)
+
len
(
valid_indices
)
+
len
(
test_indices
)
)
==
length
train_dataset
=
Subset
(
dataset
,
train_indices
)
valid_dataset
=
Subset
(
dataset
,
valid_indices
)
test_dataset
=
Subset
(
dataset
,
test_indices
)
return
train_dataset
,
valid_dataset
,
test_dataset
def
scaffold_split
(
dataset
,
args
,
return_smiles
=
False
):
total_precent
=
args
.
frac_train
+
args
.
frac_valid
+
args
.
frac_test
np
.
testing
.
assert_almost_equal
(
total_precent
,
1.0
)
smiles_list_file
=
os
.
path
.
join
(
args
.
data_dir
,
"smiles.csv"
)
smiles_list
=
pd
.
read_csv
(
smiles_list_file
,
header
=
None
)[
0
].
tolist
()
non_null
=
np
.
ones
(
len
(
dataset
))
==
1
smiles_list
=
list
(
compress
(
enumerate
(
smiles_list
),
non_null
))
# create dict of the form {scaffold_i: [idx1, idx....]}
all_scaffolds
=
{}
for
i
,
smiles
in
smiles_list
:
scaffold
=
MurckoScaffold
.
MurckoScaffoldSmiles
(
smiles
=
smiles
,
includeChirality
=
True
)
# scaffold = generate_scaffold(smiles, include_chirality=True)
if
scaffold
not
in
all_scaffolds
:
all_scaffolds
[
scaffold
]
=
[
i
]
else
:
all_scaffolds
[
scaffold
].
append
(
i
)
# sort from largest to smallest sets
all_scaffolds
=
{
key
:
sorted
(
value
)
for
key
,
value
in
all_scaffolds
.
items
()
}
all_scaffold_sets
=
[
scaffold_set
for
(
scaffold
,
scaffold_set
)
in
sorted
(
all_scaffolds
.
items
(),
key
=
lambda
x
:
(
len
(
x
[
1
]),
x
[
1
][
0
]),
reverse
=
True
)
]
# get train, valid test indices
train_cutoff
=
args
.
frac_train
*
len
(
smiles_list
)
valid_cutoff
=
(
args
.
frac_train
+
args
.
frac_valid
)
*
len
(
smiles_list
)
train_idx
,
valid_idx
,
test_idx
=
[],
[],
[]
for
scaffold_set
in
all_scaffold_sets
:
if
len
(
train_idx
)
+
len
(
scaffold_set
)
>
train_cutoff
:
if
len
(
train_idx
)
+
len
(
valid_idx
)
+
len
(
scaffold_set
)
>
valid_cutoff
:
test_idx
.
extend
(
scaffold_set
)
else
:
valid_idx
.
extend
(
scaffold_set
)
else
:
train_idx
.
extend
(
scaffold_set
)
assert
len
(
set
(
train_idx
).
intersection
(
set
(
valid_idx
)))
==
0
assert
len
(
set
(
test_idx
).
intersection
(
set
(
valid_idx
)))
==
0
# log.info(len(scaffold_set))
# log.info(["train_idx", train_idx])
# log.info(["valid_idx", valid_idx])
# log.info(["test_idx", test_idx])
train_dataset
=
Subset
(
dataset
,
train_idx
)
valid_dataset
=
Subset
(
dataset
,
valid_idx
)
test_dataset
=
Subset
(
dataset
,
test_idx
)
if
return_smiles
:
train_smiles
=
[
smiles_list
[
i
][
1
]
for
i
in
train_idx
]
valid_smiles
=
[
smiles_list
[
i
][
1
]
for
i
in
valid_idx
]
test_smiles
=
[
smiles_list
[
i
][
1
]
for
i
in
test_idx
]
return
train_dataset
,
valid_dataset
,
test_dataset
,
(
train_smiles
,
valid_smiles
,
test_smiles
)
return
train_dataset
,
valid_dataset
,
test_dataset
if
__name__
==
"__main__"
:
file_path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
proj_path
=
os
.
path
.
join
(
file_path
,
'../'
)
sys
.
path
.
append
(
proj_path
)
from
utils.config
import
Config
from
dataset.Dataset
import
Subset
from
dataset.Dataset
import
ChemDataset
config_file
=
"./finetune_config.yaml"
args
=
Config
(
config_file
)
log
.
info
(
"loading dataset"
)
dataset
=
ChemDataset
(
args
)
train_dataset
,
valid_dataset
,
test_dataset
=
scaffold_split
(
dataset
,
args
)
log
.
info
(
"Train Examples: %s"
%
len
(
train_dataset
))
log
.
info
(
"Val Examples: %s"
%
len
(
valid_dataset
))
log
.
info
(
"Test Examples: %s"
%
len
(
test_dataset
))
import
ipdb
ipdb
.
set_trace
()
log
.
info
(
"preprocess finish"
)
ogb_examples/graphproppred/mol/main.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
# SSL
import
torch
import
os
import
re
import
time
from
random
import
random
from
functools
import
reduce
,
partial
import
numpy
as
np
import
multiprocessing
from
ogb.graphproppred
import
Evaluator
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
import
pgl
from
pgl.utils
import
paddle_helper
from
pgl.utils.logger
import
log
from
utils.args
import
print_arguments
,
check_cuda
,
prepare_logger
from
utils.init
import
init_checkpoint
,
init_pretraining_params
from
utils.config
import
Config
from
optimization
import
optimization
from
monitor.train_monitor
import
train_and_evaluate
from
args
import
parser
import
model
as
Model
from
data.base_dataset
import
Subset
,
Dataset
from
data.dataloader
import
GraphDataloader
def
main
(
args
):
log
.
info
(
'loading data'
)
dataset
=
Dataset
(
args
)
args
.
num_class
=
dataset
.
num_tasks
args
.
eval_metrics
=
dataset
.
eval_metrics
args
.
task_type
=
dataset
.
task_type
splitted_index
=
dataset
.
get_idx_split
()
train_dataset
=
Subset
(
dataset
,
splitted_index
[
'train'
])
valid_dataset
=
Subset
(
dataset
,
splitted_index
[
'valid'
])
test_dataset
=
Subset
(
dataset
,
splitted_index
[
'test'
])
log
.
info
(
"preprocess finish"
)
log
.
info
(
"Train Examples: %s"
%
len
(
train_dataset
))
log
.
info
(
"Val Examples: %s"
%
len
(
valid_dataset
))
log
.
info
(
"Test Examples: %s"
%
len
(
test_dataset
))
train_prog
=
F
.
Program
()
startup_prog
=
F
.
Program
()
if
args
.
use_cuda
:
dev_list
=
F
.
cuda_places
()
place
=
dev_list
[
0
]
dev_count
=
len
(
dev_list
)
else
:
place
=
F
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
# dev_count = args.cpu_num
log
.
info
(
"building model"
)
with
F
.
program_guard
(
train_prog
,
startup_prog
):
with
F
.
unique_name
.
guard
():
graph_model
=
getattr
(
Model
,
args
.
model_type
)(
args
,
dataset
)
train_ds
=
GraphDataloader
(
train_dataset
,
graph_model
.
graph_wrapper
,
batch_size
=
args
.
batch_size
)
num_train_examples
=
len
(
train_dataset
)
max_train_steps
=
args
.
epoch
*
num_train_examples
//
args
.
batch_size
//
dev_count
warmup_steps
=
int
(
max_train_steps
*
args
.
warmup_proportion
)
scheduled_lr
,
loss_scaling
=
optimization
(
loss
=
graph_model
.
loss
,
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
learning_rate
=
args
.
learning_rate
,
train_program
=
train_prog
,
startup_prog
=
startup_prog
,
weight_decay
=
args
.
weight_decay
,
scheduler
=
args
.
lr_scheduler
,
use_fp16
=
False
,
use_dynamic_loss_scaling
=
args
.
use_dynamic_loss_scaling
,
init_loss_scaling
=
args
.
init_loss_scaling
,
incr_every_n_steps
=
args
.
incr_every_n_steps
,
decr_every_n_nan_or_inf
=
args
.
decr_every_n_nan_or_inf
,
incr_ratio
=
args
.
incr_ratio
,
decr_ratio
=
args
.
decr_ratio
)
test_prog
=
F
.
Program
()
with
F
.
program_guard
(
test_prog
,
startup_prog
):
with
F
.
unique_name
.
guard
():
_graph_model
=
getattr
(
Model
,
args
.
model_type
)(
args
,
dataset
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
valid_ds
=
GraphDataloader
(
valid_dataset
,
graph_model
.
graph_wrapper
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
)
test_ds
=
GraphDataloader
(
test_dataset
,
graph_model
.
graph_wrapper
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
)
exe
=
F
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
for
init
in
graph_model
.
init_vars
:
init
(
place
)
for
init
in
_graph_model
.
init_vars
:
init
(
place
)
if
args
.
init_pretraining_params
is
not
None
:
init_pretraining_params
(
exe
,
args
.
init_pretraining_params
,
main_program
=
startup_prog
)
nccl2_num_trainers
=
1
nccl2_trainer_id
=
0
if
dev_count
>
1
:
exec_strategy
=
F
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
dev_count
train_exe
=
F
.
ParallelExecutor
(
use_cuda
=
args
.
use_cuda
,
loss_name
=
graph_model
.
loss
.
name
,
exec_strategy
=
exec_strategy
,
main_program
=
train_prog
,
num_trainers
=
nccl2_num_trainers
,
trainer_id
=
nccl2_trainer_id
)
test_exe
=
exe
else
:
train_exe
,
test_exe
=
exe
,
exe
evaluator
=
Evaluator
(
args
.
dataset_name
)
train_and_evaluate
(
exe
=
exe
,
train_exe
=
train_exe
,
valid_exe
=
test_exe
,
train_ds
=
train_ds
,
valid_ds
=
valid_ds
,
test_ds
=
test_ds
,
train_prog
=
train_prog
,
valid_prog
=
test_prog
,
args
=
args
,
dev_count
=
dev_count
,
evaluator
=
evaluator
,
model
=
graph_model
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
if
args
.
config
is
not
None
:
args
=
Config
(
args
.
config
,
isCreate
=
True
,
isSave
=
True
)
log
.
info
(
args
)
main
(
args
)
ogb_examples/graphproppred/mol/model.py
0 → 100644
浏览文件 @
49bed8cf
#-*- coding: utf-8 -*-
import
os
import
re
import
time
import
logging
from
random
import
random
from
functools
import
reduce
,
partial
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
import
pgl
from
pgl.graph_wrapper
import
GraphWrapper
from
pgl.layers.conv
import
gcn
,
gat
from
pgl.utils
import
paddle_helper
from
pgl.utils.logger
import
log
from
utils.args
import
print_arguments
,
check_cuda
,
prepare_logger
from
utils.init
import
init_checkpoint
,
init_pretraining_params
from
mol_encoder
import
AtomEncoder
,
BondEncoder
def
copy_send
(
src_feat
,
dst_feat
,
edge_feat
):
return
src_feat
[
"h"
]
def
mean_recv
(
feat
):
return
L
.
sequence_pool
(
feat
,
pool_type
=
"average"
)
def
sum_recv
(
feat
):
return
L
.
sequence_pool
(
feat
,
pool_type
=
"sum"
)
def
max_recv
(
feat
):
return
L
.
sequence_pool
(
feat
,
pool_type
=
"max"
)
def
unsqueeze
(
tensor
):
tensor
=
L
.
unsqueeze
(
tensor
,
axes
=-
1
)
tensor
.
stop_gradient
=
True
return
tensor
class
Metric
:
def
__init__
(
self
,
**
args
):
self
.
args
=
args
@
property
def
vars
(
self
):
values
=
[
self
.
args
[
k
]
for
k
in
self
.
args
.
keys
()]
return
values
def
parse
(
self
,
fetch_list
):
tup
=
list
(
zip
(
self
.
args
.
keys
(),
[
float
(
v
[
0
])
for
v
in
fetch_list
]))
return
dict
(
tup
)
def
gin_layer
(
gw
,
node_features
,
edge_features
,
train_eps
,
name
):
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
"""Send"""
return
src_feat
[
"h"
]
+
edge_feat
[
"h"
]
epsilon
=
L
.
create_parameter
(
shape
=
[
1
,
1
],
dtype
=
"float32"
,
attr
=
F
.
ParamAttr
(
name
=
"%s_eps"
%
name
),
default_initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
0.0
))
if
not
train_eps
:
epsilon
.
stop_gradient
=
True
msg
=
gw
.
send
(
send_func
,
nfeat_list
=
[(
"h"
,
node_features
)],
efeat_list
=
[(
"h"
,
edge_features
)])
node_feat
=
gw
.
recv
(
msg
,
"sum"
)
+
node_features
*
(
epsilon
+
1.0
)
# if apply_func is not None:
# node_feat = apply_func(node_feat, name)
return
node_feat
class
GNNModel
(
object
):
def
__init__
(
self
,
args
,
dataset
):
self
.
args
=
args
self
.
dataset
=
dataset
self
.
hidden_size
=
self
.
args
.
hidden_size
self
.
embed_dim
=
self
.
args
.
embed_dim
self
.
dropout_prob
=
self
.
args
.
dropout_rate
self
.
pool_type
=
self
.
args
.
pool_type
self
.
_init_vars
=
[]
graph_data
=
[]
g
,
label
=
self
.
dataset
[
0
]
graph_data
.
append
(
g
)
g
,
label
=
self
.
dataset
[
1
]
graph_data
.
append
(
g
)
batch_graph
=
pgl
.
graph
.
MultiGraph
(
graph_data
)
graph_data
=
batch_graph
graph_data
.
edge_feat
[
"feat"
]
=
graph_data
.
edge_feat
[
"feat"
].
astype
(
"int64"
)
graph_data
.
node_feat
[
"feat"
]
=
graph_data
.
node_feat
[
"feat"
].
astype
(
"int64"
)
self
.
graph_wrapper
=
GraphWrapper
(
name
=
"graph"
,
place
=
F
.
CPUPlace
(),
node_feat
=
graph_data
.
node_feat_info
(),
edge_feat
=
graph_data
.
edge_feat_info
())
self
.
atom_encoder
=
AtomEncoder
(
name
=
"atom"
,
emb_dim
=
self
.
embed_dim
)
self
.
bond_encoder
=
BondEncoder
(
name
=
"bond"
,
emb_dim
=
self
.
embed_dim
)
self
.
labels
=
L
.
data
(
"labels"
,
shape
=
[
None
,
self
.
args
.
num_class
],
dtype
=
"float32"
,
append_batch_size
=
False
)
self
.
unmask
=
L
.
data
(
"unmask"
,
shape
=
[
None
,
self
.
args
.
num_class
],
dtype
=
"float32"
,
append_batch_size
=
False
)
self
.
build_model
()
def
build_model
(
self
):
node_features
=
self
.
atom_encoder
(
self
.
graph_wrapper
.
node_feat
[
'feat'
])
edge_features
=
self
.
bond_encoder
(
self
.
graph_wrapper
.
edge_feat
[
'feat'
])
self
.
_enc_out
=
self
.
node_repr_encode
(
node_features
,
edge_features
)
logits
=
L
.
fc
(
self
.
_enc_out
,
self
.
args
.
num_class
,
act
=
None
,
param_attr
=
F
.
ParamAttr
(
name
=
"final_fc"
))
# L.Print(self.labels, message="labels")
# L.Print(self.unmask, message="unmask")
loss
=
L
.
sigmoid_cross_entropy_with_logits
(
x
=
logits
,
label
=
self
.
labels
)
loss
=
loss
*
self
.
unmask
self
.
loss
=
L
.
reduce_sum
(
loss
)
/
L
.
reduce_sum
(
self
.
unmask
)
self
.
pred
=
L
.
sigmoid
(
logits
)
self
.
_metrics
=
Metric
(
loss
=
self
.
loss
)
def
node_repr_encode
(
self
,
node_features
,
edge_features
):
features_list
=
[
node_features
]
for
layer
in
range
(
self
.
args
.
num_layers
):
feat
=
gin_layer
(
self
.
graph_wrapper
,
features_list
[
layer
],
edge_features
,
train_eps
=
self
.
args
.
train_eps
,
name
=
"gin_%s"
%
layer
,
)
feat
=
self
.
mlp
(
feat
,
name
=
"mlp_%s"
%
layer
)
feat
=
feat
+
features_list
[
layer
]
# residual
features_list
.
append
(
feat
)
output
=
pgl
.
layers
.
graph_pooling
(
self
.
graph_wrapper
,
features_list
[
-
1
],
self
.
args
.
pool_type
)
return
output
def
mlp
(
self
,
features
,
name
):
h
=
features
dim
=
features
.
shape
[
-
1
]
dim_list
=
[
dim
*
2
,
dim
]
for
i
in
range
(
2
):
h
=
L
.
fc
(
h
,
size
=
dim_list
[
i
],
name
=
"%s_fc_%s"
%
(
name
,
i
),
act
=
None
)
if
self
.
args
.
norm_type
==
"layer_norm"
:
log
.
info
(
"norm_type is %s"
%
self
.
args
.
norm_type
)
h
=
L
.
layer_norm
(
h
,
begin_norm_axis
=
1
,
param_attr
=
F
.
ParamAttr
(
name
=
"norm_scale_%s_%s"
%
(
name
,
i
),
initializer
=
F
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
F
.
ParamAttr
(
name
=
"norm_bias_%s_%s"
%
(
name
,
i
),
initializer
=
F
.
initializer
.
Constant
(
0.0
)),
)
else
:
log
.
info
(
"using batch_norm"
)
h
=
L
.
batch_norm
(
h
)
h
=
pgl
.
layers
.
graph_norm
(
self
.
graph_wrapper
,
h
)
h
=
L
.
relu
(
h
)
return
h
def
get_enc_output
(
self
):
return
self
.
_enc_out
@
property
def
init_vars
(
self
):
return
self
.
_init_vars
@
property
def
metrics
(
self
):
return
self
.
_metrics
ogb_examples/graphproppred/mol/mol_encoder.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MolEncoder for ogb
"""
import
paddle.fluid
as
fluid
from
ogb.utils.features
import
get_atom_feature_dims
,
get_bond_feature_dims
class
AtomEncoder
(
object
):
"""AtomEncoder for encoding node features"""
def
__init__
(
self
,
name
,
emb_dim
):
self
.
emb_dim
=
emb_dim
self
.
name
=
name
def
__call__
(
self
,
x
):
atom_feature
=
get_atom_feature_dims
()
atom_input
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
len
(
atom_feature
),
dim
=-
1
)
outputs
=
None
count
=
0
for
_x
,
_atom_input_dim
in
zip
(
atom_input
,
atom_feature
):
count
+=
1
emb
=
fluid
.
layers
.
embedding
(
_x
,
size
=
(
_atom_input_dim
,
self
.
emb_dim
),
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
name
+
'_atom_feat_%s'
%
count
))
if
outputs
is
None
:
outputs
=
emb
else
:
outputs
=
outputs
+
emb
return
outputs
class
BondEncoder
(
object
):
"""Bond for encoding edge features"""
def
__init__
(
self
,
name
,
emb_dim
):
self
.
emb_dim
=
emb_dim
self
.
name
=
name
def
__call__
(
self
,
x
):
bond_feature
=
get_bond_feature_dims
()
bond_input
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
len
(
bond_feature
),
dim
=-
1
)
outputs
=
None
count
=
0
for
_x
,
_bond_input_dim
in
zip
(
bond_input
,
bond_feature
):
count
+=
1
emb
=
fluid
.
layers
.
embedding
(
_x
,
size
=
(
_bond_input_dim
,
self
.
emb_dim
),
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
name
+
'_bond_feat_%s'
%
count
))
if
outputs
is
None
:
outputs
=
emb
else
:
outputs
=
outputs
+
emb
return
outputs
ogb_examples/graphproppred/mol/monitor/train_monitor.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tqdm
import
json
import
numpy
as
np
import
os
from
datetime
import
datetime
import
logging
from
collections
import
defaultdict
from
tensorboardX
import
SummaryWriter
import
paddle.fluid
as
F
from
pgl.utils.logger
import
log
def
multi_device
(
reader
,
dev_count
):
if
dev_count
==
1
:
for
batch
in
reader
:
yield
batch
else
:
batches
=
[]
for
batch
in
reader
:
batches
.
append
(
batch
)
if
len
(
batches
)
==
dev_count
:
yield
batches
batches
=
[]
def
evaluate
(
exe
,
loader
,
prog
,
model
,
evaluator
):
total_labels
=
[]
for
i
in
range
(
len
(
loader
.
dataset
)):
g
,
l
=
loader
.
dataset
[
i
]
total_labels
.
append
(
l
)
total_labels
=
np
.
vstack
(
total_labels
)
pred_output
=
[]
for
feed_dict
in
loader
:
ret
=
exe
.
run
(
prog
,
feed
=
feed_dict
,
fetch_list
=
model
.
pred
)
pred_output
.
append
(
ret
[
0
])
pred_output
=
np
.
vstack
(
pred_output
)
result
=
evaluator
.
eval
({
"y_true"
:
total_labels
,
"y_pred"
:
pred_output
})
return
result
def
_create_if_not_exist
(
path
):
basedir
=
os
.
path
.
dirname
(
path
)
if
not
os
.
path
.
exists
(
basedir
):
os
.
makedirs
(
basedir
)
def
train_and_evaluate
(
exe
,
train_exe
,
valid_exe
,
train_ds
,
valid_ds
,
test_ds
,
train_prog
,
valid_prog
,
args
,
model
,
evaluator
,
dev_count
=
1
):
global_step
=
0
timestamp
=
datetime
.
now
().
strftime
(
"%Hh%Mm%Ss"
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"tensorboard_log_%s"
%
timestamp
)
_create_if_not_exist
(
log_path
)
writer
=
SummaryWriter
(
log_path
)
best_valid_score
=
0.0
for
e
in
range
(
args
.
epoch
):
for
feed_dict
in
multi_device
(
train_ds
,
dev_count
):
if
dev_count
>
1
:
ret
=
train_exe
.
run
(
feed
=
feed_dict
,
fetch_list
=
model
.
metrics
.
vars
)
ret
=
[[
np
.
mean
(
v
)]
for
v
in
ret
]
else
:
ret
=
train_exe
.
run
(
train_prog
,
feed
=
feed_dict
,
fetch_list
=
model
.
metrics
.
vars
)
ret
=
model
.
metrics
.
parse
(
ret
)
if
global_step
%
args
.
train_log_step
==
0
:
writer
.
add_scalar
(
"batch_loss"
,
ret
[
'loss'
],
global_step
=
global_step
)
log
.
info
(
"epoch: %d | step: %d | loss: %.4f "
%
(
e
,
global_step
,
ret
[
'loss'
]))
global_step
+=
1
if
global_step
%
args
.
eval_step
==
0
:
valid_ret
=
evaluate
(
exe
,
valid_ds
,
valid_prog
,
model
,
evaluator
)
message
=
"valid: "
for
key
,
value
in
valid_ret
.
items
():
message
+=
"%s %.4f | "
%
(
key
,
value
)
writer
.
add_scalar
(
"eval_%s"
%
key
,
value
,
global_step
=
global_step
)
log
.
info
(
message
)
# testing
test_ret
=
evaluate
(
exe
,
test_ds
,
valid_prog
,
model
,
evaluator
)
message
=
"test: "
for
key
,
value
in
test_ret
.
items
():
message
+=
"%s %.4f | "
%
(
key
,
value
)
writer
.
add_scalar
(
"test_%s"
%
key
,
value
,
global_step
=
global_step
)
log
.
info
(
message
)
# evaluate after one epoch
valid_ret
=
evaluate
(
exe
,
valid_ds
,
valid_prog
,
model
,
evaluator
)
message
=
"epoch %s valid: "
%
e
for
key
,
value
in
valid_ret
.
items
():
message
+=
"%s %.4f | "
%
(
key
,
value
)
writer
.
add_scalar
(
"eval_%s"
%
key
,
value
,
global_step
=
global_step
)
log
.
info
(
message
)
# testing
test_ret
=
evaluate
(
exe
,
test_ds
,
valid_prog
,
model
,
evaluator
)
message
=
"epoch %s test: "
%
e
for
key
,
value
in
test_ret
.
items
():
message
+=
"%s %.4f | "
%
(
key
,
value
)
writer
.
add_scalar
(
"test_%s"
%
key
,
value
,
global_step
=
global_step
)
log
.
info
(
message
)
message
=
"epoch %s best %s result | "
%
(
e
,
args
.
eval_metrics
)
if
valid_ret
[
args
.
eval_metrics
]
>
best_valid_score
:
best_valid_score
=
valid_ret
[
args
.
eval_metrics
]
best_test_score
=
test_ret
[
args
.
eval_metrics
]
message
+=
"valid %.4f | test %.4f"
%
(
best_valid_score
,
best_test_score
)
log
.
info
(
message
)
# if global_step % args.save_step == 0:
# F.io.save_persistables(exe, os.path.join(args.save_dir, "%s" % global_step), train_prog)
writer
.
close
()
ogb_examples/graphproppred/mol/optimization.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
numpy
as
np
import
paddle.fluid
as
fluid
from
utils.fp16
import
create_master_params_grads
,
master_param_to_train_param
,
apply_dynamic_loss_scaling
def
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
):
""" Applies linear warmup of learning rate from 0 and decay to 0."""
with
fluid
.
default_main_program
().
_lr_schedule_guard
():
lr
=
fluid
.
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
0.0
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"scheduled_learning_rate"
)
global_step
=
fluid
.
layers
.
learning_rate_scheduler
.
_decay_step_counter
(
)
with
fluid
.
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
global_step
<
warmup_steps
):
warmup_lr
=
learning_rate
*
(
global_step
/
warmup_steps
)
fluid
.
layers
.
tensor
.
assign
(
warmup_lr
,
lr
)
with
switch
.
default
():
decayed_lr
=
fluid
.
layers
.
learning_rate_scheduler
.
polynomial_decay
(
learning_rate
=
learning_rate
,
decay_steps
=
num_train_steps
,
end_learning_rate
=
0.0
,
power
=
1.0
,
cycle
=
False
)
fluid
.
layers
.
tensor
.
assign
(
decayed_lr
,
lr
)
return
lr
def
optimization
(
loss
,
warmup_steps
,
num_train_steps
,
learning_rate
,
train_program
,
startup_prog
,
weight_decay
,
scheduler
=
'linear_warmup_decay'
,
use_fp16
=
False
,
use_dynamic_loss_scaling
=
False
,
init_loss_scaling
=
1.0
,
incr_every_n_steps
=
1000
,
decr_every_n_nan_or_inf
=
2
,
incr_ratio
=
2.0
,
decr_ratio
=
0.8
):
if
warmup_steps
>
0
:
if
scheduler
==
'noam_decay'
:
scheduled_lr
=
fluid
.
layers
.
learning_rate_scheduler
\
.
noam_decay
(
1
/
(
warmup_steps
*
(
learning_rate
**
2
)),
warmup_steps
)
elif
scheduler
==
'linear_warmup_decay'
:
scheduled_lr
=
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
)
else
:
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
else
:
scheduled_lr
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
learning_rate
,
dtype
=
'float32'
,
persistable
=
True
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
optimizer
.
_learning_rate_map
[
fluid
.
default_main_program
(
)]
=
scheduled_lr
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
))
def
exclude_from_weight_decay
(
name
):
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
param_list
=
dict
()
loss_scaling
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"loss_scaling"
),
shape
=
[
1
],
value
=
init_loss_scaling
,
dtype
=
'float32'
,
persistable
=
True
)
if
use_fp16
:
loss
*=
loss_scaling
param_grads
=
optimizer
.
backward
(
loss
)
master_param_grads
=
create_master_params_grads
(
param_grads
,
train_program
,
startup_prog
,
loss_scaling
)
for
param
,
_
in
master_param_grads
:
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
use_dynamic_loss_scaling
:
apply_dynamic_loss_scaling
(
loss_scaling
,
master_param_grads
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
optimizer
.
apply_gradients
(
master_param_grads
)
if
weight_decay
>
0
:
for
param
,
grad
in
master_param_grads
:
if
exclude_from_weight_decay
(
param
.
name
.
rstrip
(
".master"
)):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
master_param_to_train_param
(
master_param_grads
,
param_grads
,
train_program
)
else
:
for
param
in
train_program
.
global_block
().
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
if
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
exclude_from_weight_decay
(
param
.
name
):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
return
scheduled_lr
,
loss_scaling
ogb_examples/graphproppred/mol/utils/__init__.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ogb_examples/graphproppred/mol/utils/args.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
six
import
os
import
sys
import
argparse
import
logging
import
paddle.fluid
as
fluid
log
=
logging
.
getLogger
(
"logger"
)
def
prepare_logger
(
logger
,
debug
=
False
,
save_to_file
=
None
):
formatter
=
logging
.
Formatter
(
fmt
=
'[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:
\t
%(message)s'
)
# console_hdl = logging.StreamHandler()
# console_hdl.setFormatter(formatter)
# logger.addHandler(console_hdl)
if
save_to_file
is
not
None
:
#and not os.path.exists(save_to_file):
if
os
.
path
.
isdir
(
save_to_file
):
file_hdl
=
logging
.
FileHandler
(
os
.
path
.
join
(
save_to_file
,
'log.txt'
))
else
:
file_hdl
=
logging
.
FileHandler
(
save_to_file
)
file_hdl
.
setFormatter
(
formatter
)
logger
.
addHandler
(
file_hdl
)
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
propagate
=
False
def
str2bool
(
v
):
# because argparse does not support to parse "true, False" as python
# boolean directly
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
class
ArgumentGroup
(
object
):
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
def
add_arg
(
self
,
name
,
type
,
default
,
help
,
positional_arg
=
False
,
**
kwargs
):
prefix
=
""
if
positional_arg
else
"--"
type
=
str2bool
if
type
==
bool
else
type
self
.
_group
.
add_argument
(
prefix
+
name
,
default
=
default
,
type
=
type
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
def
print_arguments
(
args
):
log
.
info
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
log
.
info
(
'%s: %s'
%
(
arg
,
value
))
log
.
info
(
'------------------------------------------------'
)
def
check_cuda
(
use_cuda
,
err
=
\
"
\n
You can not set use_cuda = True in the model because you are using paddlepaddle-cpu.
\n
\
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.
\n
"
):
try
:
if
use_cuda
==
True
and
fluid
.
is_compiled_with_cuda
()
==
False
:
log
.
error
(
err
)
sys
.
exit
(
1
)
except
Exception
as
e
:
pass
ogb_examples/graphproppred/mol/utils/cards.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
def
get_cards
():
"""
get gpu cards number
"""
num
=
0
cards
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
''
)
if
cards
!=
''
:
num
=
len
(
cards
.
split
(
","
))
return
num
ogb_examples/graphproppred/mol/utils/config.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement a class for model configure.
"""
import
datetime
import
os
import
yaml
import
random
import
shutil
import
six
import
logging
log
=
logging
.
getLogger
(
"logger"
)
class
AttrDict
(
dict
):
"""Attr dict
"""
def
__init__
(
self
,
d
):
self
.
dict
=
d
def
__getattr__
(
self
,
attr
):
value
=
self
.
dict
[
attr
]
if
isinstance
(
value
,
dict
):
return
AttrDict
(
value
)
else
:
return
value
def
__str__
(
self
):
return
str
(
self
.
dict
)
class
Config
(
object
):
"""Implementation of Config class for model configure.
Args:
config_file(str): configure filename, which is a yaml file.
isCreate(bool): if true, create some neccessary directories to save models, log file and other outputs.
isSave(bool): if true, save config_file in order to record the configure message.
"""
def
__init__
(
self
,
config_file
,
isCreate
=
False
,
isSave
=
False
):
self
.
config_file
=
config_file
# self.config = self.get_config_from_yaml(config_file)
self
.
config
=
self
.
load_config
(
config_file
)
if
isCreate
:
self
.
create_necessary_dirs
()
if
isSave
:
self
.
save_config_file
()
def
load_config
(
self
,
config_file
):
"""Load config file"""
with
open
(
config_file
)
as
f
:
if
hasattr
(
yaml
,
'FullLoader'
):
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
else
:
config
=
yaml
.
load
(
f
)
return
config
def
create_necessary_dirs
(
self
):
"""Create some necessary directories to save some important files.
"""
self
.
config
[
'log_dir'
]
=
os
.
path
.
join
(
self
.
config
[
'log_dir'
],
self
.
config
[
'task_name'
])
self
.
config
[
'save_dir'
]
=
os
.
path
.
join
(
self
.
config
[
'save_dir'
],
self
.
config
[
'task_name'
])
self
.
config
[
'output_dir'
]
=
os
.
path
.
join
(
self
.
config
[
'output_dir'
],
self
.
config
[
'task_name'
])
self
.
make_dir
(
self
.
config
[
'log_dir'
])
self
.
make_dir
(
self
.
config
[
'save_dir'
])
self
.
make_dir
(
self
.
config
[
'output_dir'
])
def
save_config_file
(
self
):
"""Save config file so that we can know the config when we look back
"""
filename
=
self
.
config_file
.
split
(
'/'
)[
-
1
]
targetpath
=
os
.
path
.
join
(
self
.
config
[
'save_dir'
],
filename
)
try
:
shutil
.
copyfile
(
self
.
config_file
,
targetpath
)
except
shutil
.
SameFileError
:
log
.
info
(
"%s and %s are the same file, did not copy by shutil"
\
%
(
self
.
config_file
,
targetpath
))
def
make_dir
(
self
,
path
):
"""Build directory"""
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
__getitem__
(
self
,
key
):
return
self
.
config
[
key
]
def
__call__
(
self
):
"""__call__"""
return
self
.
config
def
__getattr__
(
self
,
attr
):
try
:
result
=
self
.
config
[
attr
]
except
KeyError
:
log
.
warn
(
"%s attribute is not existed, return None"
%
attr
)
result
=
None
return
result
def
__setitem__
(
self
,
key
,
value
):
self
.
config
[
key
]
=
value
def
__str__
(
self
):
return
str
(
self
.
config
)
def
pretty_print
(
self
):
log
.
info
(
"-----------------------------------------------------------------"
)
log
.
info
(
"config file: %s"
%
self
.
config_file
)
for
key
,
value
in
sorted
(
self
.
config
.
items
(),
key
=
lambda
item
:
item
[
0
]):
log
.
info
(
"%s: %s"
%
(
key
,
value
))
log
.
info
(
"-----------------------------------------------------------------"
)
ogb_examples/graphproppred/mol/utils/fp16.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
def
append_cast_op
(
i
,
o
,
prog
):
"""
Append a cast op in a given Program to cast input `i` to data type `o.dtype`.
Args:
i (Variable): The input Variable.
o (Variable): The output Variable.
prog (Program): The Program to append cast op.
"""
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
i
.
dtype
,
"out_dtype"
:
o
.
dtype
})
def
copy_to_master_param
(
p
,
block
):
v
=
block
.
vars
.
get
(
p
.
name
,
None
)
if
v
is
None
:
raise
ValueError
(
"no param name %s found!"
%
p
.
name
)
new_p
=
fluid
.
framework
.
Parameter
(
block
=
block
,
shape
=
v
.
shape
,
dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
type
=
v
.
type
,
lod_level
=
v
.
lod_level
,
stop_gradient
=
p
.
stop_gradient
,
trainable
=
p
.
trainable
,
optimize_attr
=
p
.
optimize_attr
,
regularizer
=
p
.
regularizer
,
gradient_clip_attr
=
p
.
gradient_clip_attr
,
error_clip
=
p
.
error_clip
,
name
=
v
.
name
+
".master"
)
return
new_p
def
apply_dynamic_loss_scaling
(
loss_scaling
,
master_params_grads
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
_incr_every_n_steps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
incr_every_n_steps
)
_decr_every_n_nan_or_inf
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
decr_every_n_nan_or_inf
)
_num_good_steps
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"num_good_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
_num_bad_steps
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
grads
=
[
fluid
.
layers
.
reduce_sum
(
g
)
for
[
_
,
g
]
in
master_params_grads
]
all_grads
=
fluid
.
layers
.
concat
(
grads
)
all_grads_sum
=
fluid
.
layers
.
reduce_sum
(
all_grads
)
is_overall_finite
=
fluid
.
layers
.
isfinite
(
all_grads_sum
)
update_loss_scaling
(
is_overall_finite
,
loss_scaling
,
_num_good_steps
,
_num_bad_steps
,
_incr_every_n_steps
,
_decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
pass
with
switch
.
default
():
for
_
,
g
in
master_params_grads
:
fluid
.
layers
.
assign
(
fluid
.
layers
.
zeros_like
(
g
),
g
)
def
create_master_params_grads
(
params_grads
,
main_prog
,
startup_prog
,
loss_scaling
):
master_params_grads
=
[]
for
p
,
g
in
params_grads
:
with
main_prog
.
_optimized_guard
([
p
,
g
]):
# create master parameters
master_param
=
copy_to_master_param
(
p
,
main_prog
.
global_block
())
startup_master_param
=
startup_prog
.
global_block
().
_clone_variable
(
master_param
)
startup_p
=
startup_prog
.
global_block
().
var
(
p
.
name
)
append_cast_op
(
startup_p
,
startup_master_param
,
startup_prog
)
# cast fp16 gradients to fp32 before apply gradients
if
g
.
name
.
find
(
"layer_norm"
)
>
-
1
:
scaled_g
=
g
/
loss_scaling
master_params_grads
.
append
([
p
,
scaled_g
])
continue
master_grad
=
fluid
.
layers
.
cast
(
g
,
"float32"
)
master_grad
=
master_grad
/
loss_scaling
master_params_grads
.
append
([
master_param
,
master_grad
])
return
master_params_grads
def
master_param_to_train_param
(
master_params_grads
,
params_grads
,
main_prog
):
for
idx
,
m_p_g
in
enumerate
(
master_params_grads
):
train_p
,
_
=
params_grads
[
idx
]
if
train_p
.
name
.
find
(
"layer_norm"
)
>
-
1
:
continue
with
main_prog
.
_optimized_guard
([
m_p_g
[
0
],
m_p_g
[
1
]]):
append_cast_op
(
m_p_g
[
0
],
train_p
,
main_prog
)
def
update_loss_scaling
(
is_overall_finite
,
prev_loss_scaling
,
num_good_steps
,
num_bad_steps
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwisw, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
0
)
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
should_incr_loss_scaling
=
fluid
.
layers
.
less_than
(
incr_every_n_steps
,
num_good_steps
+
1
)
with
fluid
.
layers
.
Switch
()
as
switch1
:
with
switch1
.
case
(
should_incr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
incr_ratio
loss_scaling_is_finite
=
fluid
.
layers
.
isfinite
(
new_loss_scaling
)
with
fluid
.
layers
.
Switch
()
as
switch2
:
with
switch2
.
case
(
loss_scaling_is_finite
):
fluid
.
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
with
switch2
.
default
():
pass
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch1
.
default
():
fluid
.
layers
.
increment
(
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch
.
default
():
should_decr_loss_scaling
=
fluid
.
layers
.
less_than
(
decr_every_n_nan_or_inf
,
num_bad_steps
+
1
)
with
fluid
.
layers
.
Switch
()
as
switch3
:
with
switch3
.
case
(
should_decr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
decr_ratio
static_loss_scaling
=
\
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
less_than_one
=
fluid
.
layers
.
less_than
(
new_loss_scaling
,
static_loss_scaling
)
with
fluid
.
layers
.
Switch
()
as
switch4
:
with
switch4
.
case
(
less_than_one
):
fluid
.
layers
.
assign
(
static_loss_scaling
,
prev_loss_scaling
)
with
switch4
.
default
():
fluid
.
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch3
.
default
():
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
increment
(
num_bad_steps
)
ogb_examples/graphproppred/mol/utils/init.py
0 → 100644
浏览文件 @
49bed8cf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
six
import
ast
import
copy
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
log
=
logging
.
getLogger
(
"logger"
)
def
cast_fp32_to_fp16
(
exe
,
main_program
):
log
.
info
(
"Cast parameters to float16 data format."
)
for
param
in
main_program
.
global_block
().
all_parameters
():
if
not
param
.
name
.
endswith
(
".master"
):
param_t
=
fluid
.
global_scope
().
find_var
(
param
.
name
).
get_tensor
()
data
=
np
.
array
(
param_t
)
if
param
.
name
.
startswith
(
"encoder_layer"
)
\
and
"layer_norm"
not
in
param
.
name
:
param_t
.
set
(
np
.
float16
(
data
).
view
(
np
.
uint16
),
exe
.
place
)
#load fp32
master_param_var
=
fluid
.
global_scope
().
find_var
(
param
.
name
+
".master"
)
if
master_param_var
is
not
None
:
master_param_var
.
get_tensor
().
set
(
data
,
exe
.
place
)
def
init_checkpoint
(
exe
,
init_checkpoint_path
,
main_program
,
use_fp16
=
False
):
assert
os
.
path
.
exists
(
init_checkpoint_path
),
"[%s] cann't be found."
%
init_checkpoint_path
def
existed_persitables
(
var
):
if
not
fluid
.
io
.
is_persistable
(
var
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
init_checkpoint_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
init_checkpoint_path
,
main_program
=
main_program
,
predicate
=
existed_persitables
)
log
.
info
(
"Load model from {}"
.
format
(
init_checkpoint_path
))
if
use_fp16
:
cast_fp32_to_fp16
(
exe
,
main_program
)
def
init_pretraining_params
(
exe
,
pretraining_params_path
,
main_program
,
use_fp16
=
False
):
assert
os
.
path
.
exists
(
pretraining_params_path
),
"[%s] cann't be found."
%
pretraining_params_path
def
existed_params
(
var
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
pretraining_params_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
pretraining_params_path
,
main_program
=
main_program
,
predicate
=
existed_params
)
log
.
info
(
"Load pretraining parameters from {}."
.
format
(
pretraining_params_path
))
if
use_fp16
:
cast_fp32_to_fp16
(
exe
,
main_program
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录