Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
8e6ea314
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e6ea314
编写于
5月 11, 2020
作者:
S
suweiyue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix position id
上级
cbf4a1a3
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
101 addition
and
19 deletion
+101
-19
examples/erniesage/config/erniesage_v2_gpu.yaml
examples/erniesage/config/erniesage_v2_gpu.yaml
+4
-3
examples/erniesage/dataset/graph_reader.py
examples/erniesage/dataset/graph_reader.py
+2
-1
examples/erniesage/infer.py
examples/erniesage/infer.py
+1
-1
examples/erniesage/learner.py
examples/erniesage/learner.py
+1
-0
examples/erniesage/local_run.sh
examples/erniesage/local_run.sh
+1
-3
examples/erniesage/models/base.py
examples/erniesage/models/base.py
+45
-3
examples/erniesage/models/ernie_model/ernie.py
examples/erniesage/models/ernie_model/ernie.py
+6
-2
examples/erniesage/models/erniesage_v2.py
examples/erniesage/models/erniesage_v2.py
+34
-3
examples/erniesage/models/erniesage_v3.py
examples/erniesage/models/erniesage_v3.py
+0
-1
examples/erniesage/preprocessing/dump_graph.py
examples/erniesage/preprocessing/dump_graph.py
+7
-2
未找到文件。
examples/erniesage/config/erniesage_v2_gpu.yaml
浏览文件 @
8e6ea314
...
@@ -6,9 +6,9 @@ optimizer_type: "adam"
...
@@ -6,9 +6,9 @@ optimizer_type: "adam"
lr
:
0.00005
lr
:
0.00005
batch_size
:
32
batch_size
:
32
CPU_NUM
:
10
CPU_NUM
:
10
epoch
:
20
epoch
:
3
log_per_step
:
1
log_per_step
:
1
0
save_per_step
:
100
save_per_step
:
100
0
output_path
:
"
./output"
output_path
:
"
./output"
ckpt_path
:
"
./ernie_base_ckpt"
ckpt_path
:
"
./ernie_base_ckpt"
...
@@ -31,6 +31,7 @@ final_fc: true
...
@@ -31,6 +31,7 @@ final_fc: true
final_l2_norm
:
true
final_l2_norm
:
true
loss_type
:
"
hinge"
loss_type
:
"
hinge"
margin
:
0.3
margin
:
0.3
neg_type
:
"
random_neg"
# infer config ------
# infer config ------
infer_model
:
"
./output/last"
infer_model
:
"
./output/last"
...
...
examples/erniesage/dataset/graph_reader.py
浏览文件 @
8e6ea314
...
@@ -86,6 +86,7 @@ class GraphGenerator(BaseDataGenerator):
...
@@ -86,6 +86,7 @@ class GraphGenerator(BaseDataGenerator):
nodes
=
np
.
unique
(
np
.
concatenate
([
batch_src
,
batch_dst
,
batch_neg
],
0
))
nodes
=
np
.
unique
(
np
.
concatenate
([
batch_src
,
batch_dst
,
batch_neg
],
0
))
subgraphs
=
graphsage_sample
(
self
.
graph
,
nodes
,
self
.
samples
,
ignore_edges
=
ignore_edges
)
subgraphs
=
graphsage_sample
(
self
.
graph
,
nodes
,
self
.
samples
,
ignore_edges
=
ignore_edges
)
#subgraphs[0].reindex_to_parrent_nodes(subgraphs[0].nodes)
feed_dict
=
{}
feed_dict
=
{}
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
feed_dict
.
update
(
self
.
graph_wrappers
[
i
].
to_feed
(
subgraphs
[
i
]))
feed_dict
.
update
(
self
.
graph_wrappers
[
i
].
to_feed
(
subgraphs
[
i
]))
...
@@ -97,7 +98,7 @@ class GraphGenerator(BaseDataGenerator):
...
@@ -97,7 +98,7 @@ class GraphGenerator(BaseDataGenerator):
feed_dict
[
"user_index"
]
=
np
.
array
(
sub_src_idx
,
dtype
=
"int64"
)
feed_dict
[
"user_index"
]
=
np
.
array
(
sub_src_idx
,
dtype
=
"int64"
)
feed_dict
[
"item_index"
]
=
np
.
array
(
sub_dst_idx
,
dtype
=
"int64"
)
feed_dict
[
"item_index"
]
=
np
.
array
(
sub_dst_idx
,
dtype
=
"int64"
)
#
feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64")
feed_dict
[
"neg_item_index"
]
=
np
.
array
(
sub_neg_idx
,
dtype
=
"int64"
)
feed_dict
[
"term_ids"
]
=
self
.
term_ids
[
subgraphs
[
0
].
node_feat
[
"index"
]]
feed_dict
[
"term_ids"
]
=
self
.
term_ids
[
subgraphs
[
0
].
node_feat
[
"index"
]]
return
feed_dict
return
feed_dict
...
...
examples/erniesage/infer.py
浏览文件 @
8e6ea314
...
@@ -72,7 +72,7 @@ def run_predict(py_reader,
...
@@ -72,7 +72,7 @@ def run_predict(py_reader,
for
batch_feed_dict
in
py_reader
():
for
batch_feed_dict
in
py_reader
():
batch
+=
1
batch
+=
1
batch_usr_feat
,
batch_ad_feat
,
batch_src_real_index
=
exe
.
run
(
batch_usr_feat
,
batch_ad_feat
,
_
,
batch_src_real_index
=
exe
.
run
(
program
,
program
,
feed
=
batch_feed_dict
,
feed
=
batch_feed_dict
,
fetch_list
=
model_dict
.
outputs
)
fetch_list
=
model_dict
.
outputs
)
...
...
examples/erniesage/learner.py
浏览文件 @
8e6ea314
...
@@ -193,6 +193,7 @@ class CollectiveLearner(Learner):
...
@@ -193,6 +193,7 @@ class CollectiveLearner(Learner):
def
optimize
(
self
,
loss
,
optimizer_type
,
lr
):
def
optimize
(
self
,
loss
,
optimizer_type
,
lr
):
optimizer
=
F
.
optimizer
.
Adam
(
learning_rate
=
lr
)
optimizer
=
F
.
optimizer
.
Adam
(
learning_rate
=
lr
)
dist_strategy
=
DistributedStrategy
()
dist_strategy
=
DistributedStrategy
()
dist_strategy
.
enable_sequential_execution
=
True
optimizer
=
cfleet
.
distributed_optimizer
(
optimizer
,
strategy
=
dist_strategy
)
optimizer
=
cfleet
.
distributed_optimizer
(
optimizer
,
strategy
=
dist_strategy
)
_
,
param_grads
=
optimizer
.
minimize
(
loss
,
F
.
default_startup_program
())
_
,
param_grads
=
optimizer
.
minimize
(
loss
,
F
.
default_startup_program
())
...
...
examples/erniesage/local_run.sh
浏览文件 @
8e6ea314
...
@@ -50,7 +50,6 @@ transpiler_local_train(){
...
@@ -50,7 +50,6 @@ transpiler_local_train(){
}
}
collective_local_train
(){
collective_local_train
(){
export
PATH
=
./python27-gcc482-gpu/bin/:
$PATH
echo
`
which python
`
echo
`
which python
`
python
-m
paddle.distributed.launch train.py
--conf
$config
python
-m
paddle.distributed.launch train.py
--conf
$config
python
-m
paddle.distributed.launch infer.py
--conf
$config
python
-m
paddle.distributed.launch infer.py
--conf
$config
...
@@ -58,8 +57,7 @@ collective_local_train(){
...
@@ -58,8 +57,7 @@ collective_local_train(){
eval
$(
parse_yaml
$config
)
eval
$(
parse_yaml
$config
)
python3 ./preprocessing/dump_graph.py
-i
$input_data
-o
$graph_path
--encoding
$encoding
\
python ./preprocessing/dump_graph.py
-i
$input_data
-o
$graph_path
--encoding
$encoding
-l
$max_seqlen
--vocab_file
$ernie_vocab_file
-l
$max_seqlen
--vocab_file
$ernie_vocab_file
if
[[
$learner_type
==
"cpu"
]]
;
then
if
[[
$learner_type
==
"cpu"
]]
;
then
transpiler_local_train
transpiler_local_train
...
...
examples/erniesage/models/base.py
浏览文件 @
8e6ea314
...
@@ -129,7 +129,9 @@ class BaseNet(object):
...
@@ -129,7 +129,9 @@ class BaseNet(object):
"user_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
"user_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
item_index
=
L
.
data
(
item_index
=
L
.
data
(
"item_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
"item_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
return
[
user_index
,
item_index
]
neg_item_index
=
L
.
data
(
"neg_item_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
return
[
user_index
,
item_index
,
neg_item_index
]
def
build_embedding
(
self
,
graph_wrappers
,
inputs
=
None
):
def
build_embedding
(
self
,
graph_wrappers
,
inputs
=
None
):
num_embed
=
int
(
np
.
load
(
os
.
path
.
join
(
self
.
config
.
graph_path
,
"num_nodes.npy"
)))
num_embed
=
int
(
np
.
load
(
os
.
path
.
join
(
self
.
config
.
graph_path
,
"num_nodes.npy"
)))
...
@@ -177,18 +179,58 @@ class BaseNet(object):
...
@@ -177,18 +179,58 @@ class BaseNet(object):
outputs
.
append
(
src_real_index
)
outputs
.
append
(
src_real_index
)
return
inputs
,
outputs
return
inputs
,
outputs
def
all_gather
(
X
):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
trainer_num
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"0"
))
if
trainer_num
==
1
:
copy_X
=
X
*
1
copy_X
.
stop_gradients
=
True
return
copy_X
Xs
=
[]
for
i
in
range
(
trainer_num
):
copy_X
=
X
*
1
copy_X
=
L
.
collective
.
_broadcast
(
copy_X
,
i
,
True
)
copy_X
.
stop_gradients
=
True
Xs
.
append
(
copy_X
)
if
len
(
Xs
)
>
1
:
Xs
=
L
.
concat
(
Xs
,
0
)
Xs
.
stop_gradients
=
True
else
:
Xs
=
Xs
[
0
]
return
Xs
class
BaseLoss
(
object
):
class
BaseLoss
(
object
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
def
__call__
(
self
,
outputs
):
def
__call__
(
self
,
outputs
):
user_feat
,
item_feat
=
outputs
[
0
],
outputs
[
1
]
user_feat
,
item_feat
,
neg_item_feat
=
outputs
[
0
],
outputs
[
1
],
outputs
[
2
]
loss_type
=
self
.
config
.
loss_type
loss_type
=
self
.
config
.
loss_type
if
self
.
config
.
neg_type
==
"batch_neg"
:
neg_item_feat
=
item_feat
# Calc Loss
# Calc Loss
if
self
.
config
.
loss_type
==
"hinge"
:
if
self
.
config
.
loss_type
==
"hinge"
:
pos
=
L
.
reduce_sum
(
user_feat
*
item_feat
,
-
1
,
keep_dim
=
True
)
# [B, 1]
pos
=
L
.
reduce_sum
(
user_feat
*
item_feat
,
-
1
,
keep_dim
=
True
)
# [B, 1]
neg
=
L
.
matmul
(
user_feat
,
item_feat
,
transpose_y
=
True
)
# [B, B]
neg
=
L
.
matmul
(
user_feat
,
neg_
item_feat
,
transpose_y
=
True
)
# [B, B]
loss
=
L
.
reduce_mean
(
L
.
relu
(
neg
-
pos
+
self
.
config
.
margin
))
loss
=
L
.
reduce_mean
(
L
.
relu
(
neg
-
pos
+
self
.
config
.
margin
))
elif
self
.
config
.
loss_type
==
"all_hinge"
:
pos
=
L
.
reduce_sum
(
user_feat
*
item_feat
,
-
1
,
keep_dim
=
True
)
# [B, 1]
all_pos
=
all_gather
(
pos
)
# [B * n, 1]
all_neg_item_feat
=
all_gather
(
neg_item_feat
)
# [B * n, 1]
all_user_feat
=
all_gather
(
user_feat
)
# [B * n, 1]
neg1
=
L
.
matmul
(
user_feat
,
all_neg_item_feat
,
transpose_y
=
True
)
# [B, B * n]
neg2
=
L
.
matmul
(
all_user_feat
,
neg_item_feat
,
transpose_y
=
True
)
# [B *n, B]
loss1
=
L
.
reduce_mean
(
L
.
relu
(
neg1
-
pos
+
self
.
config
.
margin
))
loss2
=
L
.
reduce_mean
(
L
.
relu
(
neg2
-
all_pos
+
self
.
config
.
margin
))
#loss = (loss1 + loss2) / 2
loss
=
loss1
+
loss2
elif
self
.
config
.
loss_type
==
"softmax"
:
elif
self
.
config
.
loss_type
==
"softmax"
:
pass
pass
# TODO
# TODO
...
...
examples/erniesage/models/ernie_model/ernie.py
浏览文件 @
8e6ea314
...
@@ -59,6 +59,8 @@ class ErnieModel(object):
...
@@ -59,6 +59,8 @@ class ErnieModel(object):
def
__init__
(
self
,
def
__init__
(
self
,
src_ids
,
src_ids
,
sentence_ids
,
sentence_ids
,
position_ids
=
None
,
input_mask
=
None
,
task_ids
=
None
,
task_ids
=
None
,
config
=
None
,
config
=
None
,
weight_sharing
=
True
,
weight_sharing
=
True
,
...
@@ -66,8 +68,10 @@ class ErnieModel(object):
...
@@ -66,8 +68,10 @@ class ErnieModel(object):
name
=
""
):
name
=
""
):
self
.
_set_config
(
config
,
name
,
weight_sharing
)
self
.
_set_config
(
config
,
name
,
weight_sharing
)
input_mask
=
self
.
_build_input_mask
(
src_ids
)
if
position_ids
is
None
:
position_ids
=
self
.
_build_position_ids
(
src_ids
)
position_ids
=
self
.
_build_position_ids
(
src_ids
)
if
input_mask
is
None
:
input_mask
=
self
.
_build_input_mask
(
src_ids
)
self
.
_build_model
(
src_ids
,
position_ids
,
sentence_ids
,
task_ids
,
self
.
_build_model
(
src_ids
,
position_ids
,
sentence_ids
,
task_ids
,
input_mask
)
input_mask
)
self
.
_debug_summary
(
input_mask
)
self
.
_debug_summary
(
input_mask
)
...
...
examples/erniesage/models/erniesage_v2.py
浏览文件 @
8e6ea314
...
@@ -3,8 +3,6 @@ import paddle.fluid as F
...
@@ -3,8 +3,6 @@ import paddle.fluid as F
import
paddle.fluid.layers
as
L
import
paddle.fluid.layers
as
L
from
models.base
import
BaseNet
,
BaseGNNModel
from
models.base
import
BaseNet
,
BaseGNNModel
from
models.ernie_model.ernie
import
ErnieModel
from
models.ernie_model.ernie
import
ErnieModel
from
models.ernie_model.ernie
import
ErnieGraphModel
from
models.ernie_model.ernie
import
ErnieConfig
class
ErnieSageV2
(
BaseNet
):
class
ErnieSageV2
(
BaseNet
):
...
@@ -16,19 +14,52 @@ class ErnieSageV2(BaseNet):
...
@@ -16,19 +14,52 @@ class ErnieSageV2(BaseNet):
return
inputs
+
[
term_ids
]
return
inputs
+
[
term_ids
]
def
gnn_layer
(
self
,
gw
,
feature
,
hidden_size
,
act
,
initializer
,
learning_rate
,
name
):
def
gnn_layer
(
self
,
gw
,
feature
,
hidden_size
,
act
,
initializer
,
learning_rate
,
name
):
def
build_position_ids
(
src_ids
,
dst_ids
):
src_shape
=
L
.
shape
(
src_ids
)
src_batch
=
src_shape
[
0
]
src_seqlen
=
src_shape
[
1
]
dst_seqlen
=
src_seqlen
-
1
# without cls
src_position_ids
=
L
.
reshape
(
L
.
range
(
0
,
src_seqlen
,
1
,
dtype
=
'int32'
),
[
1
,
src_seqlen
,
1
],
inplace
=
True
)
# [1, slot_seqlen, 1]
src_position_ids
=
L
.
expand
(
src_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen * num_b, 1]
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
input_mask
=
L
.
cast
(
L
.
equal
(
src_ids
,
zero
),
"int32"
)
# assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len
=
L
.
reduce_sum
(
input_mask
,
1
)
# [B, 1, 1]
dst_position_ids
=
L
.
reshape
(
L
.
range
(
src_seqlen
,
src_seqlen
+
dst_seqlen
,
1
,
dtype
=
'int32'
),
[
1
,
dst_seqlen
,
1
],
inplace
=
True
)
# [1, slot_seqlen, 1]
dst_position_ids
=
L
.
expand
(
dst_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen, 1]
dst_position_ids
=
dst_position_ids
-
src_pad_len
# [B, slot_seqlen, 1]
position_ids
=
L
.
concat
([
src_position_ids
,
dst_position_ids
],
1
)
position_ids
=
L
.
cast
(
position_ids
,
'int64'
)
position_ids
.
stop_gradient
=
True
return
position_ids
def
ernie_send
(
src_feat
,
dst_feat
,
edge_feat
):
def
ernie_send
(
src_feat
,
dst_feat
,
edge_feat
):
"""doc"""
"""doc"""
# input_ids
cls
=
L
.
fill_constant_batch_size_like
(
src_feat
[
"term_ids"
],
[
-
1
,
1
,
1
],
"int64"
,
1
)
cls
=
L
.
fill_constant_batch_size_like
(
src_feat
[
"term_ids"
],
[
-
1
,
1
,
1
],
"int64"
,
1
)
src_ids
=
L
.
concat
([
cls
,
src_feat
[
"term_ids"
]],
1
)
src_ids
=
L
.
concat
([
cls
,
src_feat
[
"term_ids"
]],
1
)
dst_ids
=
dst_feat
[
"term_ids"
]
dst_ids
=
dst_feat
[
"term_ids"
]
# sent_ids
sent_ids
=
L
.
concat
([
L
.
zeros_like
(
src_ids
),
L
.
ones_like
(
dst_ids
)],
1
)
sent_ids
=
L
.
concat
([
L
.
zeros_like
(
src_ids
),
L
.
ones_like
(
dst_ids
)],
1
)
term_ids
=
L
.
concat
([
src_ids
,
dst_ids
],
1
)
term_ids
=
L
.
concat
([
src_ids
,
dst_ids
],
1
)
# position_ids
position_ids
=
build_position_ids
(
src_ids
,
dst_ids
)
term_ids
.
stop_gradient
=
True
term_ids
.
stop_gradient
=
True
sent_ids
.
stop_gradient
=
True
sent_ids
.
stop_gradient
=
True
ernie
=
ErnieModel
(
ernie
=
ErnieModel
(
term_ids
,
sent_ids
,
term_ids
,
sent_ids
,
position_ids
,
config
=
self
.
config
.
ernie_config
)
config
=
self
.
config
.
ernie_config
)
feature
=
ernie
.
get_pooled_output
()
feature
=
ernie
.
get_pooled_output
()
return
feature
return
feature
...
...
examples/erniesage/models/erniesage_v3.py
浏览文件 @
8e6ea314
...
@@ -18,7 +18,6 @@ import paddle.fluid.layers as L
...
@@ -18,7 +18,6 @@ import paddle.fluid.layers as L
from
models.base
import
BaseNet
,
BaseGNNModel
from
models.base
import
BaseNet
,
BaseGNNModel
from
models.ernie_model.ernie
import
ErnieModel
from
models.ernie_model.ernie
import
ErnieModel
from
models.ernie_model.ernie
import
ErnieGraphModel
from
models.ernie_model.ernie
import
ErnieGraphModel
from
models.ernie_model.ernie
import
ErnieConfig
from
models.message_passing
import
copy_send
from
models.message_passing
import
copy_send
...
...
examples/erniesage/preprocessing/dump_graph.py
浏览文件 @
8e6ea314
...
@@ -53,6 +53,7 @@ def dump_graph(args):
...
@@ -53,6 +53,7 @@ def dump_graph(args):
term_file
=
io
.
open
(
os
.
path
.
join
(
args
.
outpath
,
"terms.txt"
),
"w"
,
encoding
=
args
.
encoding
)
term_file
=
io
.
open
(
os
.
path
.
join
(
args
.
outpath
,
"terms.txt"
),
"w"
,
encoding
=
args
.
encoding
)
terms
=
[]
terms
=
[]
count
=
0
count
=
0
item_distribution
=
[]
with
io
.
open
(
args
.
inpath
,
encoding
=
args
.
encoding
)
as
f
:
with
io
.
open
(
args
.
inpath
,
encoding
=
args
.
encoding
)
as
f
:
edges
=
[]
edges
=
[]
...
@@ -66,6 +67,7 @@ def dump_graph(args):
...
@@ -66,6 +67,7 @@ def dump_graph(args):
str2id
[
s
]
=
count
str2id
[
s
]
=
count
count
+=
1
count
+=
1
term_file
.
write
(
str
(
col_idx
)
+
"
\t
"
+
col
+
"
\n
"
)
term_file
.
write
(
str
(
col_idx
)
+
"
\t
"
+
col
+
"
\n
"
)
item_distribution
.
append
(
0
)
slots
.
append
(
str2id
[
s
])
slots
.
append
(
str2id
[
s
])
...
@@ -74,6 +76,7 @@ def dump_graph(args):
...
@@ -74,6 +76,7 @@ def dump_graph(args):
neg_samples
.
append
(
slots
[
2
:])
neg_samples
.
append
(
slots
[
2
:])
edges
.
append
((
src
,
dst
))
edges
.
append
((
src
,
dst
))
edges
.
append
((
dst
,
src
))
edges
.
append
((
dst
,
src
))
item_distribution
[
dst
]
+=
1
term_file
.
close
()
term_file
.
close
()
edges
=
np
.
array
(
edges
,
dtype
=
"int64"
)
edges
=
np
.
array
(
edges
,
dtype
=
"int64"
)
...
@@ -82,12 +85,14 @@ def dump_graph(args):
...
@@ -82,12 +85,14 @@ def dump_graph(args):
log
.
info
(
"building graph..."
)
log
.
info
(
"building graph..."
)
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
num_nodes
,
edges
=
edges
)
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
num_nodes
,
edges
=
edges
)
indegree
=
graph
.
indegree
()
indegree
=
graph
.
indegree
()
graph
.
indegree
()
graph
.
outdegree
()
graph
.
outdegree
()
graph
.
dump
(
args
.
outpath
)
graph
.
dump
(
args
.
outpath
)
# dump alias sample table
# dump alias sample table
sqrt_indegree
=
np
.
sqrt
(
indegree
)
item_distribution
=
np
.
array
(
item_distribution
)
distribution
=
1.
*
sqrt_indegree
/
sqrt_indegree
.
sum
()
item_distribution
=
np
.
sqrt
(
item_distribution
)
distribution
=
1.
*
item_distribution
/
item_distribution
.
sum
()
alias
,
events
=
alias_sample_build_table
(
distribution
)
alias
,
events
=
alias_sample_build_table
(
distribution
)
np
.
save
(
os
.
path
.
join
(
args
.
outpath
,
"alias.npy"
),
alias
)
np
.
save
(
os
.
path
.
join
(
args
.
outpath
,
"alias.npy"
),
alias
)
np
.
save
(
os
.
path
.
join
(
args
.
outpath
,
"events.npy"
),
events
)
np
.
save
(
os
.
path
.
join
(
args
.
outpath
,
"events.npy"
),
events
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录