Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
14868eb2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
14868eb2
编写于
6月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2465 GCN rectification
Merge pull request !2465 from chentingting/gcn_rectification
上级
106f7980
a733102d
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
63 addition
and
339 deletion
+63
-339
model_zoo/gcn/README.md
model_zoo/gcn/README.md
+5
-3
model_zoo/gcn/scripts/run_process_data.sh
model_zoo/gcn/scripts/run_process_data.sh
+2
-1
model_zoo/gcn/src/dataset.py
model_zoo/gcn/src/dataset.py
+1
-1
model_zoo/gcn/src/gcn.py
model_zoo/gcn/src/gcn.py
+1
-1
model_zoo/gcn/t-SNE_visualization_on_Cora.gif
model_zoo/gcn/t-SNE_visualization_on_Cora.gif
+0
-0
model_zoo/gcn/train.py
model_zoo/gcn/train.py
+43
-9
tests/st/gnn/gcn/src/__init__.py
tests/st/gnn/gcn/src/__init__.py
+0
-0
tests/st/gnn/gcn/src/config.py
tests/st/gnn/gcn/src/config.py
+0
-23
tests/st/gnn/gcn/src/dataset.py
tests/st/gnn/gcn/src/dataset.py
+0
-60
tests/st/gnn/gcn/src/gcn.py
tests/st/gnn/gcn/src/gcn.py
+0
-163
tests/st/gnn/gcn/src/metrics.py
tests/st/gnn/gcn/src/metrics.py
+0
-68
tests/st/gnn/gcn/test_gcn.py
tests/st/gnn/gcn/test_gcn.py
+11
-10
未找到文件。
model_zoo/gcn/README.md
浏览文件 @
14868eb2
...
...
@@ -36,9 +36,9 @@ sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
>> Launch
```
#Generate dataset in mindrecord format for cora
sh run_process_data.sh cora
sh run_process_data.sh
./data
cora
#Generate dataset in mindrecord format for citeseer
sh run_process_data.sh citeseer
sh run_process_data.sh
./data
citeseer
```
## Structure
...
...
@@ -110,4 +110,6 @@ Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.
Optimization Finished!
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
...
```
\ No newline at end of file
```
model_zoo/gcn/scripts/run_process_data.sh
浏览文件 @
14868eb2
...
...
@@ -40,7 +40,8 @@ else
fi
MINDRECORD_PATH
=
`
pwd
`
/data_mr
rm
-f
$MINDRECORD_PATH
/
*
rm
-f
$MINDRECORD_PATH
/
$DATASET_NAME
rm
-f
$MINDRECORD_PATH
/
$DATASET_NAME
.db
cd
../../../example/graph_to_mindrecord
||
exit
...
...
model_zoo/gcn/src/dataset.py
浏览文件 @
14868eb2
...
...
@@ -55,7 +55,7 @@ def get_adj_features_labels(data_dir):
adj
=
adj
+
adj
.
T
.
multiply
(
adj
.
T
>
adj
)
+
sp
.
eye
(
nodes_num
)
nor_adj
=
normalize_adj
(
adj
)
nor_adj
=
np
.
array
(
nor_adj
.
todense
())
return
nor_adj
,
features
,
labels_onehot
return
nor_adj
,
features
,
labels_onehot
,
labels
def
get_mask
(
total
,
begin
,
end
):
...
...
model_zoo/gcn/src/gcn.py
浏览文件 @
14868eb2
...
...
@@ -21,7 +21,7 @@ from mindspore.ops import functional as F
from
mindspore.ops
import
operations
as
P
from
mindspore
import
Tensor
from
mindspore.nn.layer.activation
import
get_activation
from
src.metrics
import
Loss
,
Accuracy
from
model_zoo.gcn.
src.metrics
import
Loss
,
Accuracy
def
glorot
(
shape
):
...
...
model_zoo/gcn/t-SNE_visualization_on_Cora.gif
0 → 100644
浏览文件 @
14868eb2
6.3 MB
model_zoo/gcn/train.py
浏览文件 @
14868eb2
...
...
@@ -21,11 +21,25 @@ import time
import
argparse
import
numpy
as
np
from
matplotlib
import
pyplot
as
plt
from
matplotlib
import
animation
from
sklearn
import
manifold
from
mindspore
import
context
from
src.gcn
import
GCN
,
LossAccuracyWrapper
,
TrainNetWrapper
from
src.config
import
ConfigGCN
from
src.dataset
import
get_adj_features_labels
,
get_mask
from
model_zoo.gcn.src.gcn
import
GCN
,
LossAccuracyWrapper
,
TrainNetWrapper
from
model_zoo.gcn.src.config
import
ConfigGCN
from
model_zoo.gcn.src.dataset
import
get_adj_features_labels
,
get_mask
def
t_SNE
(
out_feature
,
dim
):
t_sne
=
manifold
.
TSNE
(
n_components
=
dim
,
init
=
'pca'
,
random_state
=
0
)
return
t_sne
.
fit_transform
(
out_feature
)
def
update_graph
(
i
,
data
,
scat
,
plot
):
scat
.
set_offsets
(
data
[
i
])
plt
.
title
(
't-SNE visualization of Epoch:{0}'
.
format
(
i
))
return
scat
,
plot
def
train
():
...
...
@@ -36,28 +50,39 @@ def train():
parser
.
add_argument
(
'--train_nodes_num'
,
type
=
int
,
default
=
140
,
help
=
'Nodes numbers for training'
)
parser
.
add_argument
(
'--eval_nodes_num'
,
type
=
int
,
default
=
500
,
help
=
'Nodes numbers for evaluation'
)
parser
.
add_argument
(
'--test_nodes_num'
,
type
=
int
,
default
=
1000
,
help
=
'Nodes numbers for test'
)
parser
.
add_argument
(
'--save_TSNE'
,
type
=
bool
,
default
=
False
,
help
=
'Whether to save t-SNE graph'
)
args_opt
=
parser
.
parse_args
()
np
.
random
.
seed
(
args_opt
.
seed
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
False
)
config
=
ConfigGCN
()
adj
,
feature
,
label
=
get_adj_features_labels
(
args_opt
.
data_dir
)
adj
,
feature
,
label
_onehot
,
label
=
get_adj_features_labels
(
args_opt
.
data_dir
)
nodes_num
=
label
.
shape
[
0
]
nodes_num
=
label
_onehot
.
shape
[
0
]
train_mask
=
get_mask
(
nodes_num
,
0
,
args_opt
.
train_nodes_num
)
eval_mask
=
get_mask
(
nodes_num
,
args_opt
.
train_nodes_num
,
args_opt
.
train_nodes_num
+
args_opt
.
eval_nodes_num
)
test_mask
=
get_mask
(
nodes_num
,
nodes_num
-
args_opt
.
test_nodes_num
,
nodes_num
)
class_num
=
label
.
shape
[
1
]
class_num
=
label
_onehot
.
shape
[
1
]
gcn_net
=
GCN
(
config
,
adj
,
feature
,
class_num
)
gcn_net
.
add_flags_recursive
(
fp16
=
True
)
eval_net
=
LossAccuracyWrapper
(
gcn_net
,
label
,
eval_mask
,
config
.
weight_decay
)
test_net
=
LossAccuracyWrapper
(
gcn_net
,
label
,
test_mask
,
config
.
weight_decay
)
train_net
=
TrainNetWrapper
(
gcn_net
,
label
,
train_mask
,
config
)
eval_net
=
LossAccuracyWrapper
(
gcn_net
,
label
_onehot
,
eval_mask
,
config
.
weight_decay
)
test_net
=
LossAccuracyWrapper
(
gcn_net
,
label
_onehot
,
test_mask
,
config
.
weight_decay
)
train_net
=
TrainNetWrapper
(
gcn_net
,
label
_onehot
,
train_mask
,
config
)
loss_list
=
[]
if
args_opt
.
save_TSNE
:
out_feature
=
gcn_net
()
tsne_result
=
t_SNE
(
out_feature
.
asnumpy
(),
2
)
graph_data
=
[]
graph_data
.
append
(
tsne_result
)
fig
=
plt
.
figure
()
scat
=
plt
.
scatter
(
tsne_result
[:,
0
],
tsne_result
[:,
1
],
s
=
2
,
c
=
label
,
cmap
=
'rainbow'
)
plt
.
title
(
't-SNE visualization of Epoch:0'
,
fontsize
=
'large'
,
fontweight
=
'bold'
,
verticalalignment
=
'center'
)
for
epoch
in
range
(
config
.
epochs
):
t
=
time
.
time
()
...
...
@@ -76,6 +101,11 @@ def train():
"train_acc="
,
"{:.5f}"
.
format
(
train_accuracy
),
"val_loss="
,
"{:.5f}"
.
format
(
eval_loss
),
"val_acc="
,
"{:.5f}"
.
format
(
eval_accuracy
),
"time="
,
"{:.5f}"
.
format
(
time
.
time
()
-
t
))
if
args_opt
.
save_TSNE
:
out_feature
=
gcn_net
()
tsne_result
=
t_SNE
(
out_feature
.
asnumpy
(),
2
)
graph_data
.
append
(
tsne_result
)
if
epoch
>
config
.
early_stopping
and
loss_list
[
-
1
]
>
np
.
mean
(
loss_list
[
-
(
config
.
early_stopping
+
1
):
-
1
]):
print
(
"Early stopping..."
)
break
...
...
@@ -88,6 +118,10 @@ def train():
print
(
"Test set results:"
,
"loss="
,
"{:.5f}"
.
format
(
test_loss
),
"accuracy="
,
"{:.5f}"
.
format
(
test_accuracy
),
"time="
,
"{:.5f}"
.
format
(
time
.
time
()
-
t_test
))
if
args_opt
.
save_TSNE
:
ani
=
animation
.
FuncAnimation
(
fig
,
update_graph
,
frames
=
range
(
config
.
epochs
+
1
),
fargs
=
(
graph_data
,
scat
,
plt
))
ani
.
save
(
't-SNE_visualization.gif'
,
writer
=
'imagemagick'
)
if
__name__
==
'__main__'
:
train
()
tests/st/gnn/gcn/src/__init__.py
已删除
100644 → 0
浏览文件 @
106f7980
tests/st/gnn/gcn/src/config.py
已删除
100644 → 0
浏览文件 @
106f7980
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
class
ConfigGCN
():
learning_rate
=
0.01
epochs
=
200
hidden1
=
16
dropout
=
0.0
weight_decay
=
5e-4
early_stopping
=
10
tests/st/gnn/gcn/src/dataset.py
已删除
100644 → 0
浏览文件 @
106f7980
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
scipy.sparse
as
sp
import
mindspore.dataset
as
ds
def
normalize_adj
(
adj
):
rowsum
=
np
.
array
(
adj
.
sum
(
1
))
d_inv_sqrt
=
np
.
power
(
rowsum
,
-
0.5
).
flatten
()
d_inv_sqrt
[
np
.
isinf
(
d_inv_sqrt
)]
=
0.
d_mat_inv_sqrt
=
sp
.
diags
(
d_inv_sqrt
)
return
adj
.
dot
(
d_mat_inv_sqrt
).
transpose
().
dot
(
d_mat_inv_sqrt
).
tocoo
()
def
get_adj_features_labels
(
data_dir
):
g
=
ds
.
GraphData
(
data_dir
)
nodes
=
g
.
get_all_nodes
(
0
)
nodes_list
=
nodes
.
tolist
()
row_tensor
=
g
.
get_node_feature
(
nodes_list
,
[
1
,
2
])
features
=
row_tensor
[
0
]
labels
=
row_tensor
[
1
]
nodes_num
=
labels
.
shape
[
0
]
class_num
=
labels
.
max
()
+
1
labels_onehot
=
np
.
eye
(
nodes_num
,
class_num
)[
labels
].
astype
(
np
.
float32
)
neighbor
=
g
.
get_all_neighbors
(
nodes_list
,
0
)
node_map
=
{
node_id
:
index
for
index
,
node_id
in
enumerate
(
nodes_list
)}
adj
=
np
.
zeros
([
nodes_num
,
nodes_num
],
dtype
=
np
.
float32
)
for
index
,
value
in
np
.
ndenumerate
(
neighbor
):
# The first column of neighbor is node_id, second column to last column are neighbors of the first column.
# So we only care index[1] > 1.
# If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it.
if
value
>=
0
and
index
[
1
]
>
0
:
adj
[
node_map
[
neighbor
[
index
[
0
],
0
]],
node_map
[
value
]]
=
1
adj
=
sp
.
coo_matrix
(
adj
)
adj
=
adj
+
adj
.
T
.
multiply
(
adj
.
T
>
adj
)
+
sp
.
eye
(
nodes_num
)
nor_adj
=
normalize_adj
(
adj
)
nor_adj
=
np
.
array
(
nor_adj
.
todense
())
return
nor_adj
,
features
,
labels_onehot
def
get_mask
(
total
,
begin
,
end
):
mask
=
np
.
zeros
([
total
]).
astype
(
np
.
float32
)
mask
[
begin
:
end
]
=
1
return
mask
tests/st/gnn/gcn/src/gcn.py
已删除
100644 → 0
浏览文件 @
106f7980
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
from
mindspore
import
nn
from
mindspore.common.parameter
import
ParameterTuple
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
from
mindspore
import
Tensor
from
mindspore.nn.layer.activation
import
get_activation
from
src.metrics
import
Loss
,
Accuracy
def
glorot
(
shape
):
init_range
=
np
.
sqrt
(
6.0
/
(
shape
[
0
]
+
shape
[
1
]))
initial
=
np
.
random
.
uniform
(
-
init_range
,
init_range
,
shape
).
astype
(
np
.
float32
)
return
Tensor
(
initial
)
class
GraphConvolution
(
nn
.
Cell
):
def
__init__
(
self
,
feature_in_dim
,
feature_out_dim
,
dropout_ratio
=
None
,
activation
=
None
):
super
(
GraphConvolution
,
self
).
__init__
()
self
.
in_dim
=
feature_in_dim
self
.
out_dim
=
feature_out_dim
self
.
weight_init
=
glorot
([
self
.
out_dim
,
self
.
in_dim
])
self
.
fc
=
nn
.
Dense
(
self
.
in_dim
,
self
.
out_dim
,
weight_init
=
self
.
weight_init
,
has_bias
=
False
)
self
.
dropout_ratio
=
dropout_ratio
if
self
.
dropout_ratio
is
not
None
:
self
.
dropout
=
nn
.
Dropout
(
keep_prob
=
1
-
self
.
dropout_ratio
)
self
.
dropout_flag
=
self
.
dropout_ratio
is
not
None
self
.
activation
=
get_activation
(
activation
)
self
.
activation_flag
=
self
.
activation
is
not
None
self
.
matmul
=
P
.
MatMul
()
def
construct
(
self
,
adj
,
input_feature
):
dropout
=
input_feature
if
self
.
dropout_flag
:
dropout
=
self
.
dropout
(
dropout
)
fc
=
self
.
fc
(
dropout
)
output_feature
=
self
.
matmul
(
adj
,
fc
)
if
self
.
activation_flag
:
output_feature
=
self
.
activation
(
output_feature
)
return
output_feature
class
GCN
(
nn
.
Cell
):
def
__init__
(
self
,
config
,
adj
,
feature
,
output_dim
):
super
(
GCN
,
self
).
__init__
()
self
.
adj
=
Tensor
(
adj
)
self
.
feature
=
Tensor
(
feature
)
input_dim
=
feature
.
shape
[
1
]
self
.
layer0
=
GraphConvolution
(
input_dim
,
config
.
hidden1
,
activation
=
"relu"
,
dropout_ratio
=
config
.
dropout
)
self
.
layer1
=
GraphConvolution
(
config
.
hidden1
,
output_dim
,
dropout_ratio
=
None
)
def
construct
(
self
):
output0
=
self
.
layer0
(
self
.
adj
,
self
.
feature
)
output1
=
self
.
layer1
(
self
.
adj
,
output0
)
return
output1
class
LossAccuracyWrapper
(
nn
.
Cell
):
def
__init__
(
self
,
network
,
label
,
mask
,
weight_decay
):
super
(
LossAccuracyWrapper
,
self
).
__init__
()
self
.
network
=
network
self
.
loss
=
Loss
(
label
,
mask
,
weight_decay
,
network
.
trainable_params
()[
0
])
self
.
accuracy
=
Accuracy
(
label
,
mask
)
def
construct
(
self
):
preds
=
self
.
network
()
loss
=
self
.
loss
(
preds
)
accuracy
=
self
.
accuracy
(
preds
)
return
loss
,
accuracy
class
LossWrapper
(
nn
.
Cell
):
def
__init__
(
self
,
network
,
label
,
mask
,
weight_decay
):
super
(
LossWrapper
,
self
).
__init__
()
self
.
network
=
network
self
.
loss
=
Loss
(
label
,
mask
,
weight_decay
,
network
.
trainable_params
()[
0
])
def
construct
(
self
):
preds
=
self
.
network
()
loss
=
self
.
loss
(
preds
)
return
loss
class
TrainOneStepCell
(
nn
.
Cell
):
r
"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
):
super
(
TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
.
add_flags
(
defer_inline
=
True
)
self
.
weights
=
ParameterTuple
(
network
.
trainable_params
())
self
.
optimizer
=
optimizer
self
.
grad
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
sens
=
sens
def
construct
(
self
):
weights
=
self
.
weights
loss
=
self
.
network
()
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
grads
=
self
.
grad
(
self
.
network
,
weights
)(
sens
)
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
class
TrainNetWrapper
(
nn
.
Cell
):
def
__init__
(
self
,
network
,
label
,
mask
,
config
):
super
(
TrainNetWrapper
,
self
).
__init__
(
auto_prefix
=
True
)
self
.
network
=
network
loss_net
=
LossWrapper
(
network
,
label
,
mask
,
config
.
weight_decay
)
optimizer
=
nn
.
Adam
(
loss_net
.
trainable_params
(),
learning_rate
=
config
.
learning_rate
)
self
.
loss_train_net
=
TrainOneStepCell
(
loss_net
,
optimizer
)
self
.
accuracy
=
Accuracy
(
label
,
mask
)
def
construct
(
self
):
loss
=
self
.
loss_train_net
()
accuracy
=
self
.
accuracy
(
self
.
network
())
return
loss
,
accuracy
tests/st/gnn/gcn/src/metrics.py
已删除
100644 → 0
浏览文件 @
106f7980
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore
import
nn
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
class
Loss
(
nn
.
Cell
):
def
__init__
(
self
,
label
,
mask
,
weight_decay
,
param
):
super
(
Loss
,
self
).
__init__
()
self
.
label
=
Tensor
(
label
)
self
.
mask
=
Tensor
(
mask
)
self
.
loss
=
P
.
SoftmaxCrossEntropyWithLogits
()
self
.
one
=
Tensor
(
1.0
,
mstype
.
float32
)
self
.
zero
=
Tensor
(
0.0
,
mstype
.
float32
)
self
.
mean
=
P
.
ReduceMean
()
self
.
cast
=
P
.
Cast
()
self
.
l2_loss
=
P
.
L2Loss
()
self
.
reduce_sum
=
P
.
ReduceSum
()
self
.
weight_decay
=
weight_decay
self
.
param
=
param
def
construct
(
self
,
preds
):
param
=
self
.
l2_loss
(
self
.
param
)
loss
=
self
.
weight_decay
*
param
preds
=
self
.
cast
(
preds
,
mstype
.
float32
)
loss
=
loss
+
self
.
loss
(
preds
,
self
.
label
)[
0
]
mask
=
self
.
cast
(
self
.
mask
,
mstype
.
float32
)
mask_reduce
=
self
.
mean
(
mask
)
mask
=
mask
/
mask_reduce
loss
=
loss
*
mask
loss
=
self
.
mean
(
loss
)
return
loss
class
Accuracy
(
nn
.
Cell
):
def
__init__
(
self
,
label
,
mask
):
super
(
Accuracy
,
self
).
__init__
()
self
.
label
=
Tensor
(
label
)
self
.
mask
=
Tensor
(
mask
)
self
.
equal
=
P
.
Equal
()
self
.
argmax
=
P
.
Argmax
()
self
.
cast
=
P
.
Cast
()
self
.
mean
=
P
.
ReduceMean
()
def
construct
(
self
,
preds
):
preds
=
self
.
cast
(
preds
,
mstype
.
float32
)
correct_prediction
=
self
.
equal
(
self
.
argmax
(
preds
),
self
.
argmax
(
self
.
label
))
accuracy_all
=
self
.
cast
(
correct_prediction
,
mstype
.
float32
)
mask
=
self
.
cast
(
self
.
mask
,
mstype
.
float32
)
mask_reduce
=
self
.
mean
(
mask
)
mask
=
mask
/
mask_reduce
accuracy_all
*=
mask
return
self
.
mean
(
accuracy_all
)
tests/st/gnn/gcn/test_gcn.py
浏览文件 @
14868eb2
...
...
@@ -17,9 +17,9 @@ import time
import
pytest
import
numpy
as
np
from
mindspore
import
context
from
src.gcn
import
GCN
,
LossAccuracyWrapper
,
TrainNetWrapper
from
src.config
import
ConfigGCN
from
src.dataset
import
get_adj_features_labels
,
get_mask
from
model_zoo.gcn.
src.gcn
import
GCN
,
LossAccuracyWrapper
,
TrainNetWrapper
from
model_zoo.gcn.
src.config
import
ConfigGCN
from
model_zoo.gcn.
src.dataset
import
get_adj_features_labels
,
get_mask
DATA_DIR
=
'/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
...
...
@@ -37,22 +37,23 @@ def test_gcn():
print
(
"test_gcn begin"
)
np
.
random
.
seed
(
SEED
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
Tru
e
)
device_target
=
"Ascend"
,
save_graphs
=
Fals
e
)
config
=
ConfigGCN
()
adj
,
feature
,
label
=
get_adj_features_labels
(
DATA_DIR
)
config
.
dropout
=
0.0
adj
,
feature
,
label_onehot
,
_
=
get_adj_features_labels
(
DATA_DIR
)
nodes_num
=
label
.
shape
[
0
]
nodes_num
=
label
_onehot
.
shape
[
0
]
train_mask
=
get_mask
(
nodes_num
,
0
,
TRAIN_NODE_NUM
)
eval_mask
=
get_mask
(
nodes_num
,
TRAIN_NODE_NUM
,
TRAIN_NODE_NUM
+
EVAL_NODE_NUM
)
test_mask
=
get_mask
(
nodes_num
,
nodes_num
-
TEST_NODE_NUM
,
nodes_num
)
class_num
=
label
.
shape
[
1
]
class_num
=
label
_onehot
.
shape
[
1
]
gcn_net
=
GCN
(
config
,
adj
,
feature
,
class_num
)
gcn_net
.
add_flags_recursive
(
fp16
=
True
)
eval_net
=
LossAccuracyWrapper
(
gcn_net
,
label
,
eval_mask
,
config
.
weight_decay
)
test_net
=
LossAccuracyWrapper
(
gcn_net
,
label
,
test_mask
,
config
.
weight_decay
)
train_net
=
TrainNetWrapper
(
gcn_net
,
label
,
train_mask
,
config
)
eval_net
=
LossAccuracyWrapper
(
gcn_net
,
label
_onehot
,
eval_mask
,
config
.
weight_decay
)
test_net
=
LossAccuracyWrapper
(
gcn_net
,
label
_onehot
,
test_mask
,
config
.
weight_decay
)
train_net
=
TrainNetWrapper
(
gcn_net
,
label
_onehot
,
train_mask
,
config
)
loss_list
=
[]
for
epoch
in
range
(
config
.
epochs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录