Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
f78f4639
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f78f4639
编写于
5月 12, 2020
作者:
W
Weiyue Su
提交者:
GitHub
5月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #73 from WeiyueSu/erniesage
fix position id
上级
05276913
2839bccc
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
139 addition
and
97 deletion
+139
-97
examples/erniesage/config/erniesage_v2_cpu.yaml
examples/erniesage/config/erniesage_v2_cpu.yaml
+4
-3
examples/erniesage/config/erniesage_v2_gpu.yaml
examples/erniesage/config/erniesage_v2_gpu.yaml
+4
-3
examples/erniesage/dataset/graph_reader.py
examples/erniesage/dataset/graph_reader.py
+2
-1
examples/erniesage/infer.py
examples/erniesage/infer.py
+2
-1
examples/erniesage/job.sh
examples/erniesage/job.sh
+0
-45
examples/erniesage/learner.py
examples/erniesage/learner.py
+14
-3
examples/erniesage/local_run.sh
examples/erniesage/local_run.sh
+4
-6
examples/erniesage/models/base.py
examples/erniesage/models/base.py
+45
-3
examples/erniesage/models/ernie_model/ernie.py
examples/erniesage/models/ernie_model/ernie.py
+6
-2
examples/erniesage/models/ernie_model/transformer_encoder.py
examples/erniesage/models/ernie_model/transformer_encoder.py
+17
-24
examples/erniesage/models/erniesage_v2.py
examples/erniesage/models/erniesage_v2.py
+34
-3
examples/erniesage/models/erniesage_v3.py
examples/erniesage/models/erniesage_v3.py
+0
-1
examples/erniesage/preprocessing/dump_graph.py
examples/erniesage/preprocessing/dump_graph.py
+7
-2
未找到文件。
examples/erniesage/config/erniesage_v2_cpu.yaml
浏览文件 @
f78f4639
...
...
@@ -4,9 +4,9 @@
learner_type
:
"
cpu"
optimizer_type
:
"
adam"
lr
:
0.00005
batch_size
:
2
CPU_NUM
:
1
0
epoch
:
20
batch_size
:
4
CPU_NUM
:
1
6
epoch
:
3
log_per_step
:
1
save_per_step
:
100
output_path
:
"
./output"
...
...
@@ -31,6 +31,7 @@ final_fc: true
final_l2_norm
:
true
loss_type
:
"
hinge"
margin
:
0.3
neg_type
:
"
random_neg"
# infer config ------
infer_model
:
"
./output/last"
...
...
examples/erniesage/config/erniesage_v2_gpu.yaml
浏览文件 @
f78f4639
...
...
@@ -6,9 +6,9 @@ optimizer_type: "adam"
lr
:
0.00005
batch_size
:
32
CPU_NUM
:
10
epoch
:
20
log_per_step
:
1
save_per_step
:
100
epoch
:
3
log_per_step
:
1
0
save_per_step
:
100
0
output_path
:
"
./output"
ckpt_path
:
"
./ernie_base_ckpt"
...
...
@@ -31,6 +31,7 @@ final_fc: true
final_l2_norm
:
true
loss_type
:
"
hinge"
margin
:
0.3
neg_type
:
"
random_neg"
# infer config ------
infer_model
:
"
./output/last"
...
...
examples/erniesage/dataset/graph_reader.py
浏览文件 @
f78f4639
...
...
@@ -86,6 +86,7 @@ class GraphGenerator(BaseDataGenerator):
nodes
=
np
.
unique
(
np
.
concatenate
([
batch_src
,
batch_dst
,
batch_neg
],
0
))
subgraphs
=
graphsage_sample
(
self
.
graph
,
nodes
,
self
.
samples
,
ignore_edges
=
ignore_edges
)
#subgraphs[0].reindex_to_parrent_nodes(subgraphs[0].nodes)
feed_dict
=
{}
for
i
in
range
(
self
.
num_layers
):
feed_dict
.
update
(
self
.
graph_wrappers
[
i
].
to_feed
(
subgraphs
[
i
]))
...
...
@@ -97,7 +98,7 @@ class GraphGenerator(BaseDataGenerator):
feed_dict
[
"user_index"
]
=
np
.
array
(
sub_src_idx
,
dtype
=
"int64"
)
feed_dict
[
"item_index"
]
=
np
.
array
(
sub_dst_idx
,
dtype
=
"int64"
)
#
feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64")
feed_dict
[
"neg_item_index"
]
=
np
.
array
(
sub_neg_idx
,
dtype
=
"int64"
)
feed_dict
[
"term_ids"
]
=
self
.
term_ids
[
subgraphs
[
0
].
node_feat
[
"index"
]]
return
feed_dict
...
...
examples/erniesage/infer.py
浏览文件 @
f78f4639
...
...
@@ -72,7 +72,7 @@ def run_predict(py_reader,
for
batch_feed_dict
in
py_reader
():
batch
+=
1
batch_usr_feat
,
batch_ad_feat
,
batch_src_real_index
=
exe
.
run
(
batch_usr_feat
,
batch_ad_feat
,
_
,
batch_src_real_index
=
exe
.
run
(
program
,
feed
=
batch_feed_dict
,
fetch_list
=
model_dict
.
outputs
)
...
...
@@ -183,5 +183,6 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--conf"
,
type
=
str
,
default
=
"./config.yaml"
)
args
=
parser
.
parse_args
()
config
=
edict
(
yaml
.
load
(
open
(
args
.
conf
),
Loader
=
yaml
.
FullLoader
))
config
.
loss_type
=
"hinge"
print
(
config
)
main
(
config
)
examples/erniesage/job.sh
已删除
100644 → 0
浏览文件 @
05276913
unset
http_proxy https_proxy
set
-x
mode
=
${
1
:-
local
}
config
=
${
2
:-
"./config.yaml"
}
function
parse_yaml
{
local
prefix
=
$2
local
s
=
'[[:space:]]*'
w
=
'[a-zA-Z0-9_]*'
fs
=
$(
echo
@|tr @
'\034'
)
sed
-ne
"s|^
\(
$s
\)
:|
\1
|"
\
-e
"s|^
\(
$s
\)\(
$w
\)
$s
:
$s
[
\"
']
\(
.*
\)
[
\"
']
$s
\$
|
\1
$fs
\2
$fs
\3
|p"
\
-e
"s|^
\(
$s
\)\(
$w
\)
$s
:
$s
\(
.*
\)
$s
\$
|
\1
$fs
\2
$fs
\3
|p"
$1
|
awk
-F
$fs
'{
indent = length($1)/2;
vname[indent] = $2;
for (i in vname) {if (i > indent) {delete vname[i]}}
if (length($3) > 0) {
vn=""; for (i=0; i<indent; i++) {vn=(vn)(vname[i])("_")}
printf("%s%s%s=\"%s\"\n", "'
$prefix
'",vn, $2, $3);
}
}'
}
eval
$(
parse_yaml
$config
)
export
CPU_NUM
=
$CPU_NUM
export
FLAGS_rpc_deadline
=
3000000
export
FLAGS_rpc_retry_times
=
1000
if
[[
$async_mode
==
"True"
]]
;
then
echo
"async_mode is True"
else
export
FLAGS_communicator_send_queue_size
=
1
export
FLAGS_communicator_min_send_grad_num_before_recv
=
0
export
FLAGS_communicator_max_merge_var_num
=
1
# important!
export
FLAGS_communicator_merge_sparse_grad
=
0
fi
export
FLAGS_communicator_recv_wait_times
=
5000000
mkdir
-p
output
python ./train.py
--conf
$config
if
[[
$TRAINING_ROLE
==
"TRAINER"
]]
;
then
python ./infer.py
--conf
$config
fi
examples/erniesage/learner.py
浏览文件 @
f78f4639
...
...
@@ -26,6 +26,17 @@ from paddle.fluid.incubate.fleet.collective import fleet as cfleet
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
as
tfleet
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
tensorboardX
import
SummaryWriter
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
TrainerRuntimeConfig
# hack it!
base_get_communicator_flags
=
TrainerRuntimeConfig
.
get_communicator_flags
def
get_communicator_flags
(
self
):
flag_dict
=
base_get_communicator_flags
(
self
)
flag_dict
[
'communicator_max_merge_var_num'
]
=
str
(
1
)
flag_dict
[
'communicator_send_queue_size'
]
=
str
(
1
)
return
flag_dict
TrainerRuntimeConfig
.
get_communicator_flags
=
get_communicator_flags
class
Learner
(
object
):
...
...
@@ -132,8 +143,6 @@ class TranspilerLearner(Learner):
self
.
model
=
model
def
optimize
(
self
,
loss
,
optimizer_type
,
lr
):
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
False
log
.
info
(
'learning rate:%f'
%
lr
)
if
optimizer_type
==
"sgd"
:
optimizer
=
F
.
optimizer
.
SGD
(
learning_rate
=
lr
)
...
...
@@ -143,7 +152,8 @@ class TranspilerLearner(Learner):
else
:
raise
ValueError
(
"Unknown Optimizer %s"
%
optimizer_type
)
#create the DistributeTranspiler configure
optimizer
=
tfleet
.
distributed_optimizer
(
optimizer
,
strategy
)
self
.
strategy
=
StrategyFactory
.
create_sync_strategy
()
optimizer
=
tfleet
.
distributed_optimizer
(
optimizer
,
self
.
strategy
)
optimizer
.
minimize
(
loss
)
def
init_and_run_ps_worker
(
self
,
ckpt_path
):
...
...
@@ -193,6 +203,7 @@ class CollectiveLearner(Learner):
def
optimize
(
self
,
loss
,
optimizer_type
,
lr
):
optimizer
=
F
.
optimizer
.
Adam
(
learning_rate
=
lr
)
dist_strategy
=
DistributedStrategy
()
dist_strategy
.
enable_sequential_execution
=
True
optimizer
=
cfleet
.
distributed_optimizer
(
optimizer
,
strategy
=
dist_strategy
)
_
,
param_grads
=
optimizer
.
minimize
(
loss
,
F
.
default_startup_program
())
...
...
examples/erniesage/local_run.sh
浏览文件 @
f78f4639
...
...
@@ -36,7 +36,7 @@ transpiler_local_train(){
for
((
i
=
0
;
i<
${
PADDLE_PSERVERS_NUM
}
;
i++
))
do
echo
"start ps server:
${
i
}
"
TRAINING_ROLE
=
"PSERVER"
PADDLE_TRAINER_ID
=
${
i
}
sh job.sh
local
$config
\
TRAINING_ROLE
=
"PSERVER"
PADDLE_TRAINER_ID
=
${
i
}
python ./train.py
--conf
$config
\
&>
$BASE
/pserver.
$i
.log &
echo
$!
>>
job_id
done
...
...
@@ -44,13 +44,12 @@ transpiler_local_train(){
for
((
j
=
0
;
j<
${
PADDLE_TRAINERS_NUM
}
;
j++
))
do
echo
"start ps work:
${
j
}
"
TRAINING_ROLE
=
"TRAINER"
PADDLE_TRAINER_ID
=
${
j
}
sh job.sh
local
$config
\
echo
$!
>>
job_id
TRAINING_ROLE
=
"TRAINER"
PADDLE_TRAINER_ID
=
${
j
}
python ./train.py
--conf
$config
TRAINING_ROLE
=
"TRAINER"
PADDLE_TRAINER_ID
=
${
j
}
python ./infer.py
--conf
$config
done
}
collective_local_train
(){
export
PATH
=
./python27-gcc482-gpu/bin/:
$PATH
echo
`
which python
`
python
-m
paddle.distributed.launch train.py
--conf
$config
python
-m
paddle.distributed.launch infer.py
--conf
$config
...
...
@@ -58,8 +57,7 @@ collective_local_train(){
eval
$(
parse_yaml
$config
)
python3 ./preprocessing/dump_graph.py
-i
$input_data
-o
$graph_path
--encoding
$encoding
\
-l
$max_seqlen
--vocab_file
$ernie_vocab_file
python ./preprocessing/dump_graph.py
-i
$input_data
-o
$graph_path
--encoding
$encoding
-l
$max_seqlen
--vocab_file
$ernie_vocab_file
if
[[
$learner_type
==
"cpu"
]]
;
then
transpiler_local_train
...
...
examples/erniesage/models/base.py
浏览文件 @
f78f4639
...
...
@@ -129,7 +129,9 @@ class BaseNet(object):
"user_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
item_index
=
L
.
data
(
"item_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
return
[
user_index
,
item_index
]
neg_item_index
=
L
.
data
(
"neg_item_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
return
[
user_index
,
item_index
,
neg_item_index
]
def
build_embedding
(
self
,
graph_wrappers
,
inputs
=
None
):
num_embed
=
int
(
np
.
load
(
os
.
path
.
join
(
self
.
config
.
graph_path
,
"num_nodes.npy"
)))
...
...
@@ -177,18 +179,58 @@ class BaseNet(object):
outputs
.
append
(
src_real_index
)
return
inputs
,
outputs
def
all_gather
(
X
):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
trainer_num
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"0"
))
if
trainer_num
==
1
:
copy_X
=
X
*
1
copy_X
.
stop_gradients
=
True
return
copy_X
Xs
=
[]
for
i
in
range
(
trainer_num
):
copy_X
=
X
*
1
copy_X
=
L
.
collective
.
_broadcast
(
copy_X
,
i
,
True
)
copy_X
.
stop_gradients
=
True
Xs
.
append
(
copy_X
)
if
len
(
Xs
)
>
1
:
Xs
=
L
.
concat
(
Xs
,
0
)
Xs
.
stop_gradients
=
True
else
:
Xs
=
Xs
[
0
]
return
Xs
class
BaseLoss
(
object
):
def
__init__
(
self
,
config
):
self
.
config
=
config
def
__call__
(
self
,
outputs
):
user_feat
,
item_feat
=
outputs
[
0
],
outputs
[
1
]
user_feat
,
item_feat
,
neg_item_feat
=
outputs
[
0
],
outputs
[
1
],
outputs
[
2
]
loss_type
=
self
.
config
.
loss_type
if
self
.
config
.
neg_type
==
"batch_neg"
:
neg_item_feat
=
item_feat
# Calc Loss
if
self
.
config
.
loss_type
==
"hinge"
:
pos
=
L
.
reduce_sum
(
user_feat
*
item_feat
,
-
1
,
keep_dim
=
True
)
# [B, 1]
neg
=
L
.
matmul
(
user_feat
,
item_feat
,
transpose_y
=
True
)
# [B, B]
neg
=
L
.
matmul
(
user_feat
,
neg_
item_feat
,
transpose_y
=
True
)
# [B, B]
loss
=
L
.
reduce_mean
(
L
.
relu
(
neg
-
pos
+
self
.
config
.
margin
))
elif
self
.
config
.
loss_type
==
"all_hinge"
:
pos
=
L
.
reduce_sum
(
user_feat
*
item_feat
,
-
1
,
keep_dim
=
True
)
# [B, 1]
all_pos
=
all_gather
(
pos
)
# [B * n, 1]
all_neg_item_feat
=
all_gather
(
neg_item_feat
)
# [B * n, 1]
all_user_feat
=
all_gather
(
user_feat
)
# [B * n, 1]
neg1
=
L
.
matmul
(
user_feat
,
all_neg_item_feat
,
transpose_y
=
True
)
# [B, B * n]
neg2
=
L
.
matmul
(
all_user_feat
,
neg_item_feat
,
transpose_y
=
True
)
# [B *n, B]
loss1
=
L
.
reduce_mean
(
L
.
relu
(
neg1
-
pos
+
self
.
config
.
margin
))
loss2
=
L
.
reduce_mean
(
L
.
relu
(
neg2
-
all_pos
+
self
.
config
.
margin
))
#loss = (loss1 + loss2) / 2
loss
=
loss1
+
loss2
elif
self
.
config
.
loss_type
==
"softmax"
:
pass
# TODO
...
...
examples/erniesage/models/ernie_model/ernie.py
浏览文件 @
f78f4639
...
...
@@ -59,6 +59,8 @@ class ErnieModel(object):
def
__init__
(
self
,
src_ids
,
sentence_ids
,
position_ids
=
None
,
input_mask
=
None
,
task_ids
=
None
,
config
=
None
,
weight_sharing
=
True
,
...
...
@@ -66,8 +68,10 @@ class ErnieModel(object):
name
=
""
):
self
.
_set_config
(
config
,
name
,
weight_sharing
)
input_mask
=
self
.
_build_input_mask
(
src_ids
)
position_ids
=
self
.
_build_position_ids
(
src_ids
)
if
position_ids
is
None
:
position_ids
=
self
.
_build_position_ids
(
src_ids
)
if
input_mask
is
None
:
input_mask
=
self
.
_build_input_mask
(
src_ids
)
self
.
_build_model
(
src_ids
,
position_ids
,
sentence_ids
,
task_ids
,
input_mask
)
self
.
_debug_summary
(
input_mask
)
...
...
examples/erniesage/models/ernie_model/transformer_encoder.py
浏览文件 @
f78f4639
...
...
@@ -19,8 +19,6 @@ from contextlib import contextmanager
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
L
import
paddle.fluid.layers
as
layers
#import propeller.paddle as propeller
#from propeller import log
#determin this at the begining
to_3d
=
lambda
a
:
a
# will change later
...
...
@@ -85,7 +83,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped
=
layers
.
reshape
(
x
=
x
,
shape
=
[
0
,
0
,
n_head
,
hidden_size
//
n_head
],
inplace
=
True
)
x
=
x
,
shape
=
[
0
,
0
,
n_head
,
hidden_size
//
n_head
],
inplace
=
True
)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...
...
@@ -262,7 +260,6 @@ def encoder_layer(enc_input,
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
#L.Print(L.reduce_mean(enc_input), message='1')
attn_output
,
ctx_multiheads_attn
=
multi_head_attention
(
pre_process_layer
(
enc_input
,
...
...
@@ -279,7 +276,6 @@ def encoder_layer(enc_input,
attention_dropout
,
param_initializer
=
param_initializer
,
name
=
name
+
'_multi_head_att'
)
#L.Print(L.reduce_mean(attn_output), message='1')
attn_output
=
post_process_layer
(
enc_input
,
attn_output
,
...
...
@@ -287,7 +283,6 @@ def encoder_layer(enc_input,
prepostprocess_dropout
,
name
=
name
+
'_post_att'
)
#L.Print(L.reduce_mean(attn_output), message='2')
ffd_output
=
positionwise_feed_forward
(
pre_process_layer
(
attn_output
,
...
...
@@ -300,14 +295,12 @@ def encoder_layer(enc_input,
hidden_act
,
param_initializer
=
param_initializer
,
name
=
name
+
'_ffn'
)
#L.Print(L.reduce_mean(ffd_output), message='3')
ret
=
post_process_layer
(
attn_output
,
ffd_output
,
postprocess_cmd
,
prepostprocess_dropout
,
name
=
name
+
'_post_ffn'
)
#L.Print(L.reduce_mean(ret), message='4')
return
ret
,
ctx_multiheads_attn
,
ffd_output
...
...
@@ -374,7 +367,7 @@ def encoder(enc_input,
encoder_layer.
"""
#
global to_2d, to_3d #, batch, seqlen, dynamic_dim
global
to_2d
,
to_3d
#, batch, seqlen, dynamic_dim
d_shape
=
L
.
shape
(
input_mask
)
pad_idx
=
build_pad_idx
(
input_mask
)
attn_bias
=
build_attn_bias
(
input_mask
,
n_head
,
enc_input
.
dtype
)
...
...
@@ -391,14 +384,14 @@ def encoder(enc_input,
# if attn_bias.dtype != enc_input.dtype:
# attn_bias = L.cast(attn_bias, enc_input.dtype)
#
def to_2d(t_3d):
#
t_2d = L.gather_nd(t_3d, pad_idx)
#
return t_2d
def
to_2d
(
t_3d
):
t_2d
=
L
.
gather_nd
(
t_3d
,
pad_idx
)
return
t_2d
#
def to_3d(t_2d):
#
t_3d = L.scatter_nd(
#
pad_idx, t_2d, shape=[d_shape[0], d_shape[1], d_model])
#
return t_3d
def
to_3d
(
t_2d
):
t_3d
=
L
.
scatter_nd
(
pad_idx
,
t_2d
,
shape
=
[
d_shape
[
0
],
d_shape
[
1
],
d_model
])
return
t_3d
enc_input
=
to_2d
(
enc_input
)
all_hidden
=
[]
...
...
@@ -456,7 +449,7 @@ def graph_encoder(enc_input,
encoder_layer.
"""
#
global to_2d, to_3d #, batch, seqlen, dynamic_dim
global
to_2d
,
to_3d
#, batch, seqlen, dynamic_dim
d_shape
=
L
.
shape
(
input_mask
)
pad_idx
=
build_pad_idx
(
input_mask
)
attn_bias
=
build_graph_attn_bias
(
input_mask
,
n_head
,
enc_input
.
dtype
,
slot_seqlen
)
...
...
@@ -474,14 +467,14 @@ def graph_encoder(enc_input,
# if attn_bias.dtype != enc_input.dtype:
# attn_bias = L.cast(attn_bias, enc_input.dtype)
#
def to_2d(t_3d):
#
t_2d = L.gather_nd(t_3d, pad_idx)
#
return t_2d
def
to_2d
(
t_3d
):
t_2d
=
L
.
gather_nd
(
t_3d
,
pad_idx
)
return
t_2d
#
def to_3d(t_2d):
#
t_3d = L.scatter_nd(
#
pad_idx, t_2d, shape=[d_shape[0], d_shape[1], d_model])
#
return t_3d
def
to_3d
(
t_2d
):
t_3d
=
L
.
scatter_nd
(
pad_idx
,
t_2d
,
shape
=
[
d_shape
[
0
],
d_shape
[
1
],
d_model
])
return
t_3d
enc_input
=
to_2d
(
enc_input
)
all_hidden
=
[]
...
...
examples/erniesage/models/erniesage_v2.py
浏览文件 @
f78f4639
...
...
@@ -3,8 +3,6 @@ import paddle.fluid as F
import
paddle.fluid.layers
as
L
from
models.base
import
BaseNet
,
BaseGNNModel
from
models.ernie_model.ernie
import
ErnieModel
from
models.ernie_model.ernie
import
ErnieGraphModel
from
models.ernie_model.ernie
import
ErnieConfig
class
ErnieSageV2
(
BaseNet
):
...
...
@@ -16,19 +14,52 @@ class ErnieSageV2(BaseNet):
return
inputs
+
[
term_ids
]
def
gnn_layer
(
self
,
gw
,
feature
,
hidden_size
,
act
,
initializer
,
learning_rate
,
name
):
def
build_position_ids
(
src_ids
,
dst_ids
):
src_shape
=
L
.
shape
(
src_ids
)
src_batch
=
src_shape
[
0
]
src_seqlen
=
src_shape
[
1
]
dst_seqlen
=
src_seqlen
-
1
# without cls
src_position_ids
=
L
.
reshape
(
L
.
range
(
0
,
src_seqlen
,
1
,
dtype
=
'int32'
),
[
1
,
src_seqlen
,
1
],
inplace
=
True
)
# [1, slot_seqlen, 1]
src_position_ids
=
L
.
expand
(
src_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen * num_b, 1]
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
input_mask
=
L
.
cast
(
L
.
equal
(
src_ids
,
zero
),
"int32"
)
# assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len
=
L
.
reduce_sum
(
input_mask
,
1
)
# [B, 1, 1]
dst_position_ids
=
L
.
reshape
(
L
.
range
(
src_seqlen
,
src_seqlen
+
dst_seqlen
,
1
,
dtype
=
'int32'
),
[
1
,
dst_seqlen
,
1
],
inplace
=
True
)
# [1, slot_seqlen, 1]
dst_position_ids
=
L
.
expand
(
dst_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen, 1]
dst_position_ids
=
dst_position_ids
-
src_pad_len
# [B, slot_seqlen, 1]
position_ids
=
L
.
concat
([
src_position_ids
,
dst_position_ids
],
1
)
position_ids
=
L
.
cast
(
position_ids
,
'int64'
)
position_ids
.
stop_gradient
=
True
return
position_ids
def
ernie_send
(
src_feat
,
dst_feat
,
edge_feat
):
"""doc"""
# input_ids
cls
=
L
.
fill_constant_batch_size_like
(
src_feat
[
"term_ids"
],
[
-
1
,
1
,
1
],
"int64"
,
1
)
src_ids
=
L
.
concat
([
cls
,
src_feat
[
"term_ids"
]],
1
)
dst_ids
=
dst_feat
[
"term_ids"
]
# sent_ids
sent_ids
=
L
.
concat
([
L
.
zeros_like
(
src_ids
),
L
.
ones_like
(
dst_ids
)],
1
)
term_ids
=
L
.
concat
([
src_ids
,
dst_ids
],
1
)
# position_ids
position_ids
=
build_position_ids
(
src_ids
,
dst_ids
)
term_ids
.
stop_gradient
=
True
sent_ids
.
stop_gradient
=
True
ernie
=
ErnieModel
(
term_ids
,
sent_ids
,
term_ids
,
sent_ids
,
position_ids
,
config
=
self
.
config
.
ernie_config
)
feature
=
ernie
.
get_pooled_output
()
return
feature
...
...
examples/erniesage/models/erniesage_v3.py
浏览文件 @
f78f4639
...
...
@@ -18,7 +18,6 @@ import paddle.fluid.layers as L
from
models.base
import
BaseNet
,
BaseGNNModel
from
models.ernie_model.ernie
import
ErnieModel
from
models.ernie_model.ernie
import
ErnieGraphModel
from
models.ernie_model.ernie
import
ErnieConfig
from
models.message_passing
import
copy_send
...
...
examples/erniesage/preprocessing/dump_graph.py
浏览文件 @
f78f4639
...
...
@@ -53,6 +53,7 @@ def dump_graph(args):
term_file
=
io
.
open
(
os
.
path
.
join
(
args
.
outpath
,
"terms.txt"
),
"w"
,
encoding
=
args
.
encoding
)
terms
=
[]
count
=
0
item_distribution
=
[]
with
io
.
open
(
args
.
inpath
,
encoding
=
args
.
encoding
)
as
f
:
edges
=
[]
...
...
@@ -66,6 +67,7 @@ def dump_graph(args):
str2id
[
s
]
=
count
count
+=
1
term_file
.
write
(
str
(
col_idx
)
+
"
\t
"
+
col
+
"
\n
"
)
item_distribution
.
append
(
0
)
slots
.
append
(
str2id
[
s
])
...
...
@@ -74,6 +76,7 @@ def dump_graph(args):
neg_samples
.
append
(
slots
[
2
:])
edges
.
append
((
src
,
dst
))
edges
.
append
((
dst
,
src
))
item_distribution
[
dst
]
+=
1
term_file
.
close
()
edges
=
np
.
array
(
edges
,
dtype
=
"int64"
)
...
...
@@ -82,12 +85,14 @@ def dump_graph(args):
log
.
info
(
"building graph..."
)
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
num_nodes
,
edges
=
edges
)
indegree
=
graph
.
indegree
()
graph
.
indegree
()
graph
.
outdegree
()
graph
.
dump
(
args
.
outpath
)
# dump alias sample table
sqrt_indegree
=
np
.
sqrt
(
indegree
)
distribution
=
1.
*
sqrt_indegree
/
sqrt_indegree
.
sum
()
item_distribution
=
np
.
array
(
item_distribution
)
item_distribution
=
np
.
sqrt
(
item_distribution
)
distribution
=
1.
*
item_distribution
/
item_distribution
.
sum
()
alias
,
events
=
alias_sample_build_table
(
distribution
)
np
.
save
(
os
.
path
.
join
(
args
.
outpath
,
"alias.npy"
),
alias
)
np
.
save
(
os
.
path
.
join
(
args
.
outpath
,
"events.npy"
),
events
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录