Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
729d5e33
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
729d5e33
编写于
5月 15, 2020
作者:
Z
zhangwenhui03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add esmm infer
上级
52d3c0a5
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
83 addition
and
4 deletion
+83
-4
models/multitask/esmm/config.yaml
models/multitask/esmm/config.yaml
+6
-0
models/multitask/esmm/esmm_infer_reader.py
models/multitask/esmm/esmm_infer_reader.py
+63
-0
models/multitask/esmm/model.py
models/multitask/esmm/model.py
+14
-4
未找到文件。
models/multitask/esmm/config.yaml
浏览文件 @
729d5e33
...
...
@@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate
:
reader
:
batch_size
:
1
class
:
"
{workspace}/esmm_infer_reader.py"
test_data_path
:
"
{workspace}/data/train"
train
:
trainer
:
# for cluster training
...
...
models/multitask/esmm/esmm_infer_reader.py
0 → 100644
浏览文件 @
729d5e33
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddlerec.core.reader
import
Reader
from
paddlerec.core.utils
import
envs
from
collections
import
defaultdict
import
numpy
as
np
class
EvaluateReader
(
Reader
):
def
init
(
self
):
all_field_id
=
[
'101'
,
'109_14'
,
'110_14'
,
'127_14'
,
'150_14'
,
'121'
,
'122'
,
'124'
,
'125'
,
'126'
,
'127'
,
'128'
,
'129'
,
'205'
,
'206'
,
'207'
,
'210'
,
'216'
,
'508'
,
'509'
,
'702'
,
'853'
,
'301'
]
self
.
all_field_id_dict
=
defaultdict
(
int
)
for
i
,
field_id
in
enumerate
(
all_field_id
):
self
.
all_field_id_dict
[
field_id
]
=
[
False
,
i
]
def
generate_sample
(
self
,
line
):
"""
Read the data line by line and process it as a dictionary
"""
def
reader
():
"""
This function needs to be implemented by the user, based on data format
"""
features
=
line
.
strip
().
split
(
','
)
ctr
=
int
(
features
[
1
])
cvr
=
int
(
features
[
2
])
padding
=
0
output
=
[(
field_id
,[])
for
field_id
in
self
.
all_field_id_dict
]
for
elem
in
features
[
4
:]:
field_id
,
feat_id
=
elem
.
strip
().
split
(
':'
)
if
field_id
not
in
self
.
all_field_id_dict
:
continue
self
.
all_field_id_dict
[
field_id
][
0
]
=
True
index
=
self
.
all_field_id_dict
[
field_id
][
1
]
output
[
index
][
1
].
append
(
int
(
feat_id
))
for
field_id
in
self
.
all_field_id_dict
:
visited
,
index
=
self
.
all_field_id_dict
[
field_id
]
if
visited
:
self
.
all_field_id_dict
[
field_id
][
0
]
=
False
else
:
output
[
index
][
1
].
append
(
padding
)
output
.
append
((
'ctr'
,
[
ctr
]))
output
.
append
((
'cvr'
,
[
cvr
]))
yield
output
return
reader
models/multitask/esmm/model.py
浏览文件 @
729d5e33
...
...
@@ -53,7 +53,7 @@ class Model(ModelBase):
return
inputs
def
net
(
self
,
inputs
):
def
net
(
self
,
inputs
,
is_infer
=
False
):
vocab_size
=
envs
.
get_global_env
(
"hyper_parameters.vocab_size"
,
None
,
self
.
_namespace
)
embed_size
=
envs
.
get_global_env
(
"hyper_parameters.embed_size"
,
None
,
self
.
_namespace
)
...
...
@@ -90,13 +90,20 @@ class Model(ModelBase):
ctcvr_prop_one
=
fluid
.
layers
.
elementwise_mul
(
ctr_prop_one
,
cvr_prop_one
)
ctcvr_prop
=
fluid
.
layers
.
concat
(
input
=
[
1
-
ctcvr_prop_one
,
ctcvr_prop_one
],
axis
=
1
)
auc_ctr
,
batch_auc_ctr
,
auc_states_ctr
=
fluid
.
layers
.
auc
(
input
=
ctr_out
,
label
=
ctr_clk
)
auc_ctcvr
,
batch_auc_ctcvr
,
auc_states_ctcvr
=
fluid
.
layers
.
auc
(
input
=
ctcvr_prop
,
label
=
ctcvr_buy
)
if
is_infer
:
self
.
_infer_results
[
"AUC_ctr"
]
=
auc_ctr
self
.
_infer_results
[
"AUC_ctcvr"
]
=
auc_ctcvr
return
loss_ctr
=
fluid
.
layers
.
cross_entropy
(
input
=
ctr_out
,
label
=
ctr_clk
)
loss_ctcvr
=
fluid
.
layers
.
cross_entropy
(
input
=
ctcvr_prop
,
label
=
ctcvr_buy
)
cost
=
loss_ctr
+
loss_ctcvr
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
auc_ctr
,
batch_auc_ctr
,
auc_states_ctr
=
fluid
.
layers
.
auc
(
input
=
ctr_out
,
label
=
ctr_clk
)
auc_ctcvr
,
batch_auc_ctcvr
,
auc_states_ctcvr
=
fluid
.
layers
.
auc
(
input
=
ctcvr_prop
,
label
=
ctcvr_buy
)
self
.
_cost
=
avg_cost
self
.
_metrics
[
"AUC_ctr"
]
=
auc_ctr
...
...
@@ -111,4 +118,7 @@ class Model(ModelBase):
def
infer_net
(
self
):
pass
self
.
_infer_data_var
=
self
.
input_data
()
self
.
_infer_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
self
.
_infer_data_var
,
capacity
=
64
,
use_double_buffer
=
False
,
iterable
=
False
)
self
.
net
(
self
.
_infer_data_var
,
is_infer
=
True
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录