Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2ce0537a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ce0537a
编写于
8月 20, 2020
作者:
H
huangxinjing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add evaluation
上级
7a3b6667
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
97 addition
and
0 deletion
+97
-0
model_zoo/official/recommend/wide_and_deep_multitable/eval.py
...l_zoo/official/recommend/wide_and_deep_multitable/eval.py
+97
-0
未找到文件。
model_zoo/official/recommend/wide_and_deep_multitable/eval.py
0 → 100644
浏览文件 @
2ce0537a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" training_and_evaluating """
import
os
import
sys
from
mindspore
import
Model
,
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
src.wide_and_deep
import
PredictWithSigmoid
,
TrainStepWrap
,
NetWithLossClass
,
WideDeepModel
from
src.callbacks
import
LossCallBack
,
EvalCallBack
from
src.datasets
import
create_dataset
,
compute_emb_dim
from
src.metrics
import
AUCMetric
from
src.config
import
WideDeepConfig
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
def
get_WideDeep_net
(
config
):
"""
Get network of wide&deep model.
"""
WideDeep_net
=
WideDeepModel
(
config
)
loss_net
=
NetWithLossClass
(
WideDeep_net
,
config
)
train_net
=
TrainStepWrap
(
loss_net
,
config
)
eval_net
=
PredictWithSigmoid
(
WideDeep_net
)
return
train_net
,
eval_net
class
ModelBuilder
():
"""
ModelBuilder.
"""
def
__init__
(
self
):
pass
def
get_hook
(
self
):
pass
def
get_train_hook
(
self
):
hooks
=
[]
callback
=
LossCallBack
()
hooks
.
append
(
callback
)
if
int
(
os
.
getenv
(
'DEVICE_ID'
))
==
0
:
pass
return
hooks
def
get_net
(
self
,
config
):
return
get_WideDeep_net
(
config
)
def
train_and_eval
(
config
):
"""
train_and_eval.
"""
data_path
=
config
.
data_path
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
config
.
batch_size
,
is_tf_dataset
=
config
.
is_tf_dataset
)
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
net_builder
=
ModelBuilder
()
train_net
,
eval_net
=
net_builder
.
get_net
(
config
)
param_dict
=
load_checkpoint
(
config
.
ckpt_path
)
load_param_into_net
(
eval_net
,
param_dict
)
auc_metric
=
AUCMetric
()
model
=
Model
(
train_net
,
eval_network
=
eval_net
,
metrics
=
{
"auc"
:
auc_metric
})
eval_callback
=
EvalCallBack
(
model
,
ds_eval
,
auc_metric
,
config
)
model
.
eval
(
ds_eval
,
callbacks
=
eval_callback
)
if
__name__
==
"__main__"
:
wide_and_deep_config
=
WideDeepConfig
()
wide_and_deep_config
.
argparse_init
()
compute_emb_dim
(
wide_and_deep_config
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Davinci"
)
train_and_eval
(
wide_and_deep_config
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录