Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
bccff22a
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bccff22a
编写于
9月 24, 2020
作者:
Webbley
提交者:
GitHub
9月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5 from PaddlePaddle/main
merge
上级
9029259d
7e5da5f5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
65 addition
and
61 deletion
+65
-61
ogb_examples/nodeproppred/unimp/main_arxiv.py
ogb_examples/nodeproppred/unimp/main_arxiv.py
+13
-17
ogb_examples/nodeproppred/unimp/main_product.py
ogb_examples/nodeproppred/unimp/main_product.py
+16
-20
ogb_examples/nodeproppred/unimp/main_protein.py
ogb_examples/nodeproppred/unimp/main_protein.py
+13
-12
pgl/utils/mp_reader.py
pgl/utils/mp_reader.py
+23
-12
未找到文件。
ogb_examples/nodeproppred/unimp/main_arxiv.py
浏览文件 @
bccff22a
...
@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv')
...
@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv')
def
get_config
():
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
##
基本模型参数
##
model_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
...
@@ -28,7 +28,7 @@ def get_config():
...
@@ -28,7 +28,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
## label
embedding模型参数
## label
_embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
...
@@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
...
@@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
build up training program
exe
.
run
(
start_program
)
exe
.
run
(
start_program
)
max_acc
=
0
#
最佳
test_acc
max_acc
=
0
#
best
test_acc
max_step
=
0
#
最佳test_acc 对应step
max_step
=
0
#
step for best test_acc
max_val_acc
=
0
#
最佳
val_acc
max_val_acc
=
0
#
best
val_acc
max_cor_acc
=
0
#
最佳val_acc对应test
_acc
max_cor_acc
=
0
#
test_acc for best val
_acc
max_cor_step
=
0
#
最佳val_acc对应step
max_cor_step
=
0
#
step for best val_acc
#
训练循环
#
training loop
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
#
运行训练器
#
start training
if
parser
.
use_label_e
:
if
parser
.
use_label_e
:
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
...
@@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program,
...
@@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program,
# print(loss[1][0])
# print(loss[1][0])
loss
=
loss
[
0
]
loss
=
loss
[
0
]
#
测试结果
#
eval result
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
train_acc
,
valid_acc
,
test_acc
=
result
...
@@ -191,11 +191,7 @@ if __name__ == '__main__':
...
@@ -191,11 +191,7 @@ if __name__ == '__main__':
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
model
.
train_program
()
model
.
train_program
()
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#optimizer
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#训练优化函数
adam_optimizer
.
minimize
(
model
.
avg_cost
)
adam_optimizer
.
minimize
(
model
.
avg_cost
)
exe
=
F
.
Executor
(
place
)
exe
=
F
.
Executor
(
place
)
...
@@ -206,4 +202,4 @@ if __name__ == '__main__':
...
@@ -206,4 +202,4 @@ if __name__ == '__main__':
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_prog
,
model
,
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_prog
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_i
,
wf
)
graph
,
label
,
split_idx
,
exe
,
run_i
,
wf
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
close
()
wf
.
close
()
\ No newline at end of file
ogb_examples/nodeproppred/unimp/main_product.py
浏览文件 @
bccff22a
...
@@ -22,14 +22,14 @@ evaluator = Evaluator(name='ogbn-products')
...
@@ -22,14 +22,14 @@ evaluator = Evaluator(name='ogbn-products')
def
get_config
():
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
##
采样参数
##
data_sampling_arg
data_group
=
parser
.
add_argument_group
(
'data_arg'
)
data_group
=
parser
.
add_argument_group
(
'data_arg'
)
data_group
.
add_argument
(
'--batch_size'
,
default
=
1500
,
type
=
int
)
data_group
.
add_argument
(
'--batch_size'
,
default
=
1500
,
type
=
int
)
data_group
.
add_argument
(
'--num_workers'
,
default
=
12
,
type
=
int
)
data_group
.
add_argument
(
'--num_workers'
,
default
=
12
,
type
=
int
)
data_group
.
add_argument
(
'--sizes'
,
default
=
[
10
,
10
,
10
],
type
=
int
,
nargs
=
'+'
)
data_group
.
add_argument
(
'--sizes'
,
default
=
[
10
,
10
,
10
],
type
=
int
,
nargs
=
'+'
)
data_group
.
add_argument
(
'--buf_size'
,
default
=
1000
,
type
=
int
)
data_group
.
add_argument
(
'--buf_size'
,
default
=
1000
,
type
=
int
)
##
基本模型参数
##
model_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
...
@@ -37,7 +37,7 @@ def get_config():
...
@@ -37,7 +37,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
## label
embedding模型参数
## label
_embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
...
@@ -113,19 +113,19 @@ def eval_test(parser, test_p_list, model, test_exe, dataset, split_idx):
...
@@ -113,19 +113,19 @@ def eval_test(parser, test_p_list, model, test_exe, dataset, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_p_list
,
def
train_loop
(
parser
,
start_program
,
main_program
,
test_p_list
,
model
,
feat_init
,
place
,
dataset
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
model
,
feat_init
,
place
,
dataset
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
build up training program
exe
.
run
(
start_program
)
exe
.
run
(
start_program
)
feat_init
(
place
)
feat_init
(
place
)
max_acc
=
0
#
最佳
test_acc
max_acc
=
0
#
best
test_acc
max_step
=
0
#
最佳test_acc 对应step
max_step
=
0
#
step for best test_acc
max_val_acc
=
0
#
最佳
val_acc
max_val_acc
=
0
#
best
val_acc
max_cor_acc
=
0
#
最佳val_acc对应test
_acc
max_cor_acc
=
0
#
test_acc for best val
_acc
max_cor_step
=
0
#
最佳val_acc对应step
max_cor_step
=
0
#
step for best val_acc
#
训练循环
#
training loop
for
epoch_id
in
range
(
parser
.
epochs
):
for
epoch_id
in
range
(
parser
.
epochs
):
#
运行训练器
#
start training
if
parser
.
use_label_e
:
if
parser
.
use_label_e
:
train_idx_temp
=
copy
.
deepcopy
(
split_idx
[
'train'
])
train_idx_temp
=
copy
.
deepcopy
(
split_idx
[
'train'
])
...
@@ -158,8 +158,7 @@ def train_loop(parser, start_program, main_program, test_p_list,
...
@@ -158,8 +158,7 @@ def train_loop(parser, start_program, main_program, test_p_list,
print
(
'acc: '
,
(
acc_num
/
unlabel_idx
.
shape
[
0
])
*
100
)
print
(
'acc: '
,
(
acc_num
/
unlabel_idx
.
shape
[
0
])
*
100
)
#测试结果
#eval result
# total=0.0
if
(
epoch_id
+
1
)
>=
50
and
(
epoch_id
+
1
)
%
10
==
0
:
if
(
epoch_id
+
1
)
>=
50
and
(
epoch_id
+
1
)
%
10
==
0
:
result
=
eval_test
(
parser
,
test_p_list
,
model
,
exe
,
dataset
,
split_idx
)
result
=
eval_test
(
parser
,
test_p_list
,
model
,
exe
,
dataset
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
train_acc
,
valid_acc
,
test_acc
=
result
...
@@ -242,17 +241,14 @@ if __name__ == '__main__':
...
@@ -242,17 +241,14 @@ if __name__ == '__main__':
# test_prog=train_prog.clone(for_test=True)
# test_prog=train_prog.clone(for_test=True)
model
.
train_program
()
model
.
train_program
()
# ave_loss = train_program(pred_output)#训练程序
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#optimizer
# lr, global_step= linear_warmup_decay(0.01, 50, 500)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#训练优化函数
adam_optimizer
.
minimize
(
model
.
avg_cost
)
adam_optimizer
.
minimize
(
model
.
avg_cost
)
test_p_list
=
[]
test_p_list
=
[]
with
F
.
unique_name
.
guard
():
with
F
.
unique_name
.
guard
():
##
input层
##
build up eval program
test_p
=
F
.
Program
()
test_p
=
F
.
Program
()
with
F
.
program_guard
(
test_p
,
):
with
F
.
program_guard
(
test_p
,
):
gw_test
=
pgl
.
graph_wrapper
.
GraphWrapper
(
gw_test
=
pgl
.
graph_wrapper
.
GraphWrapper
(
...
@@ -281,7 +277,7 @@ if __name__ == '__main__':
...
@@ -281,7 +277,7 @@ if __name__ == '__main__':
with
F
.
program_guard
(
test_p
,
):
with
F
.
program_guard
(
test_p
,
):
gw_test
=
pgl
.
graph_wrapper
.
GraphWrapper
(
gw_test
=
pgl
.
graph_wrapper
.
GraphWrapper
(
name
=
"product_"
+
str
(
0
))
name
=
"product_"
+
str
(
0
))
# feature_batch=model.get_batch_feature(label_feature, test=True)
# 把图在CPU存起
# feature_batch=model.get_batch_feature(label_feature, test=True)
feature_batch
=
F
.
data
(
'hidden_node_feat'
,
feature_batch
=
F
.
data
(
'hidden_node_feat'
,
shape
=
[
None
,
model
.
num_heads
*
model
.
hidden_size
],
shape
=
[
None
,
model
.
num_heads
*
model
.
hidden_size
],
dtype
=
'float32'
)
dtype
=
'float32'
)
...
@@ -322,4 +318,4 @@ if __name__ == '__main__':
...
@@ -322,4 +318,4 @@ if __name__ == '__main__':
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_p_list
,
model
,
feat_init
,
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_p_list
,
model
,
feat_init
,
place
,
dataset
,
split_idx
,
exe
,
run_i
,
wf
)
place
,
dataset
,
split_idx
,
exe
,
run_i
,
wf
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
close
()
wf
.
close
()
\ No newline at end of file
ogb_examples/nodeproppred/unimp/main_protein.py
浏览文件 @
bccff22a
...
@@ -23,7 +23,7 @@ evaluator = Evaluator(name='ogbn-proteins')
...
@@ -23,7 +23,7 @@ evaluator = Evaluator(name='ogbn-proteins')
def
get_config
():
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
##
基本模型参数
##
model_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
7
,
type
=
int
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
7
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
64
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
64
,
type
=
int
)
...
@@ -31,7 +31,7 @@ def get_config():
...
@@ -31,7 +31,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.1
,
type
=
float
)
model_group
.
add_argument
(
'--dropout'
,
default
=
0.1
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
## label
embedding模型参数
## label
_embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.5
,
type
=
float
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.5
,
type
=
float
)
...
@@ -90,15 +90,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
...
@@ -90,15 +90,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
build up training program
exe
.
run
(
start_program
)
exe
.
run
(
start_program
)
max_acc
=
0
# 最佳test_acc
max_acc
=
0
# best test_acc
max_step
=
0
# 最佳test_acc 对应step
max_step
=
0
# step for best test_acc
max_val_acc
=
0
# 最佳val_acc
max_val_acc
=
0
# best val_acc
max_cor_acc
=
0
# 最佳val_acc对应test_acc
max_cor_acc
=
0
# test_acc for best val_acc
max_cor_step
=
0
# 最佳val_acc对应step
max_cor_step
=
0
# step for best val_acc
#训练循环
#training loop
graph
.
node_feat
[
"label"
]
=
label
graph
.
node_feat
[
"label"
]
=
label
graph
.
node_feat
[
"nid"
]
=
np
.
arange
(
0
,
graph
.
num_nodes
)
graph
.
node_feat
[
"nid"
]
=
np
.
arange
(
0
,
graph
.
num_nodes
)
...
@@ -112,7 +113,7 @@ def train_loop(parser, start_program, main_program, test_program,
...
@@ -112,7 +113,7 @@ def train_loop(parser, start_program, main_program, test_program,
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
for
subgraph
in
random_partition
(
num_clusters
=
9
,
graph
=
graph
,
shuffle
=
True
):
for
subgraph
in
random_partition
(
num_clusters
=
9
,
graph
=
graph
,
shuffle
=
True
):
#
运行训练器
#
start training
if
parser
.
use_label_e
:
if
parser
.
use_label_e
:
feed_dict
=
model
.
gw
.
to_feed
(
subgraph
)
feed_dict
=
model
.
gw
.
to_feed
(
subgraph
)
sub_idx
=
set
(
subgraph
.
node_feat
[
"nid"
])
sub_idx
=
set
(
subgraph
.
node_feat
[
"nid"
])
...
@@ -139,7 +140,7 @@ def train_loop(parser, start_program, main_program, test_program,
...
@@ -139,7 +140,7 @@ def train_loop(parser, start_program, main_program, test_program,
fetch_list
=
[
model
.
avg_cost
])
fetch_list
=
[
model
.
avg_cost
])
loss
=
loss
[
0
]
loss
=
loss
[
0
]
#
测试结果
#
eval result
if
(
epoch_id
+
1
)
>
parser
.
epochs
*
0.9
:
if
(
epoch_id
+
1
)
>
parser
.
epochs
*
0.9
:
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
train_acc
,
valid_acc
,
test_acc
=
result
...
@@ -221,7 +222,7 @@ if __name__ == '__main__':
...
@@ -221,7 +222,7 @@ if __name__ == '__main__':
model
.
train_program
()
model
.
train_program
()
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#
训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#
optimizer
adam_optimizer
.
minimize
(
model
.
avg_cost
)
adam_optimizer
.
minimize
(
model
.
avg_cost
)
exe
=
F
.
Executor
(
place
)
exe
=
F
.
Executor
(
place
)
...
...
pgl/utils/mp_reader.py
浏览文件 @
bccff22a
...
@@ -27,24 +27,34 @@ import time
...
@@ -27,24 +27,34 @@ import time
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
multiprocessing
import
Queue
from
multiprocessing
import
Queue
import
threading
import
threading
from
collections
import
namedtuple
_np_serialized_data
=
namedtuple
(
"_np_serialized_data"
,
[
"value"
,
"shape"
,
"dtype"
])
def
serialize_data
(
data
):
def
serialize_data
(
data
):
"""serialize_data"""
"""serialize_data"""
if
data
is
None
:
if
data
is
None
:
return
None
return
None
return
numpy_serialize_data
(
data
)
#, ensure_ascii=False)
return
numpy_serialize_data
(
data
)
#, ensure_ascii=False)
def
index_iter
(
data
):
"""return indexing iter"""
if
isinstance
(
data
,
list
):
return
range
(
len
(
data
))
elif
isinstance
(
data
,
dict
):
return
data
.
keys
()
def
numpy_serialize_data
(
data
):
def
numpy_serialize_data
(
data
):
"""serialize_data"""
"""serialize_data"""
ret_data
=
{}
ret_data
=
copy
.
deepcopy
(
data
)
for
key
in
data
:
if
isinstance
(
data
[
key
],
np
.
ndarray
):
if
isinstance
(
ret_data
,
(
dict
,
list
)
):
ret_data
[
key
]
=
(
data
[
key
].
tobytes
(),
list
(
data
[
key
].
shape
),
for
key
in
index_iter
(
ret_data
):
"%s"
%
data
[
key
].
dtype
)
if
isinstance
(
ret_data
[
key
],
np
.
ndarray
):
else
:
ret_data
[
key
]
=
_np_serialized_data
(
value
=
ret_data
[
key
].
tobytes
(),
ret_data
[
key
]
=
data
[
key
]
shape
=
list
(
ret_data
[
key
].
shape
),
dtype
=
"%s"
%
ret_data
[
key
].
dtype
)
return
ret_data
return
ret_data
...
@@ -52,11 +62,12 @@ def numpy_deserialize_data(data):
...
@@ -52,11 +62,12 @@ def numpy_deserialize_data(data):
"""deserialize_data"""
"""deserialize_data"""
if
data
is
None
:
if
data
is
None
:
return
None
return
None
for
key
in
data
:
if
isinstance
(
data
[
key
],
tuple
):
if
isinstance
(
data
,
(
dict
,
list
)):
value
=
np
.
frombuffer
(
for
key
in
index_iter
(
data
):
data
[
key
][
0
],
dtype
=
data
[
key
][
2
]).
reshape
(
data
[
key
][
1
])
if
isinstance
(
data
[
key
],
_np_serialized_data
):
data
[
key
]
=
value
data
[
key
]
=
np
.
frombuffer
(
buffer
=
data
[
key
].
value
,
dtype
=
data
[
key
].
dtype
).
reshape
(
data
[
key
].
shape
)
return
data
return
data
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录