Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
eeede168
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eeede168
编写于
8月 24, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
wide_and_deep merge ckpt in eval
上级
816ed8d8
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
71 addition
and
19 deletion
+71
-19
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
+1
-1
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+22
-10
mindspore/ccsrc/frontend/parallel/step_parallel.h
mindspore/ccsrc/frontend/parallel/step_parallel.h
+1
-1
model_zoo/official/recommend/wide_and_deep/README.md
model_zoo/official/recommend/wide_and_deep/README.md
+5
-1
model_zoo/official/recommend/wide_and_deep/eval.py
model_zoo/official/recommend/wide_and_deep/eval.py
+24
-3
model_zoo/official/recommend/wide_and_deep/src/callbacks.py
model_zoo/official/recommend/wide_and_deep/src/callbacks.py
+2
-1
model_zoo/official/recommend/wide_and_deep/src/config.py
model_zoo/official/recommend/wide_and_deep/src/config.py
+4
-0
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
...zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
+5
-0
model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py
...l/recommend/wide_and_deep/train_and_eval_auto_parallel.py
+3
-2
tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py
...odel_zoo_tests/wide_and_deep/python_file_for_ci/config.py
+4
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
浏览文件 @
eeede168
...
...
@@ -343,7 +343,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
std
::
string
strategy_key_name
=
""
;
auto
param_names
=
NodeParameterName
(
cnode
);
if
(
!
param_names
.
empty
())
{
strategy_key_name
=
param_names
[
0
].
first
;
strategy_key_name
=
p
rim
->
name
()
+
"_"
+
p
aram_names
[
0
].
first
;
}
bool
load_strategy_from_ckpt
=
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
->
find
(
strategy_key_name
)
!=
stra_map
->
end
();
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
eeede168
...
...
@@ -1528,7 +1528,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
std
::
string
strategy_key_name
=
""
;
auto
param_names
=
NodeParameterName
(
cnode
);
if
(
!
param_names
.
empty
())
{
strategy_key_name
=
param_names
[
0
].
first
;
strategy_key_name
=
p
rim
->
name
()
+
"_"
+
p
aram_names
[
0
].
first
;
}
bool
load_strategy_from_ckpt
=
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
.
find
(
strategy_key_name
)
!=
stra_map
.
end
();
...
...
@@ -2219,9 +2219,23 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
auto
input
=
node_inputs
[
i
];
if
(
input
->
isa
<
Parameter
>
())
{
auto
input_parameter
=
input
->
cast
<
ParameterPtr
>
();
if
(
input_parameter
->
has_default
())
{
if
(
ParameterRequireGrad
(
input_parameter
))
{
param_names
.
push_back
({
input_parameter
->
name
(),
i
});
if
(
input_parameter
->
has_default
()
&&
ParameterRequireGrad
(
input_parameter
))
{
param_names
.
push_back
({
input_parameter
->
name
(),
i
});
}
}
else
if
(
input
->
isa
<
CNode
>
())
{
CNodePtr
cnode
=
input
->
cast
<
CNodePtr
>
();
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
param_names
;
}
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
PrimitivePtr
prim
=
prim_anf_node
->
value
()
->
cast
<
PrimitivePtr
>
();
if
(
prim
->
name
()
==
CAST
&&
cnode
->
inputs
().
size
()
>=
1
)
{
auto
cast_input
=
cnode
->
inputs
()[
1
];
if
(
cast_input
->
isa
<
Parameter
>
())
{
auto
cast_input_parameter
=
cast_input
->
cast
<
ParameterPtr
>
();
if
(
cast_input_parameter
->
has_default
()
&&
ParameterRequireGrad
(
cast_input_parameter
))
{
param_names
.
push_back
({
cast_input_parameter
->
name
(),
i
});
}
}
}
}
...
...
@@ -2229,14 +2243,11 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
return
param_names
;
}
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
void
CheckpointStrategy
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
MS_LOG
(
DEBUG
)
<<
"Save strategy to checkpoint begin"
;
StrategyMap
stra_map
;
TensorInfoMap
tensor_info_map
;
ManualShapeMap
manual_shape_map
;
auto
ret
=
func_graph
->
get_return
();
auto
all_nodes
=
DeepScopedGraphSearch
(
ret
);
for
(
auto
&
node
:
all_nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
...
...
@@ -2258,7 +2269,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
std
::
vector
<
TensorInfo
>
input_tensor_info
=
operator_info
->
inputs_tensor_info
();
StrategyPtr
strategyPtr
=
operator_info
->
strategy
();
MS_EXCEPTION_IF_NULL
(
node
->
scope
());
stra_map
[
param_name
]
=
strategyPtr
;
std
::
string
stratey_key_name
=
prim
->
name
()
+
"_"
+
param_name
;
stra_map
[
stratey_key_name
]
=
strategyPtr
;
for
(
auto
param_name_pair
:
param_names
)
{
if
(
param_name_pair
.
second
-
1
>=
UintToInt
(
input_tensor_info
.
size
()))
{
continue
;
...
...
@@ -2552,7 +2564,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// save strategy as checkpoint for multi-train
if
(
StrategyCheckpoint
::
GetInstance
().
SaveCheckPointOn
())
{
CheckpointStrategy
(
root
);
CheckpointStrategy
(
all_nodes
);
}
HandleSymbolicKeyInstance
(
root
,
all_nodes
);
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.h
浏览文件 @
eeede168
...
...
@@ -136,7 +136,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
std
::
vector
<
std
::
pair
<
std
::
string
,
int
>>
NodeParameterName
(
const
CNodePtr
&
node
);
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
);
void
CheckpointStrategy
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
);
// main step of Parallel
bool
StepParallel
(
const
FuncGraphPtr
&
func_graph
,
const
opt
::
OptimizerPtr
&
optimizer
);
...
...
model_zoo/official/recommend/wide_and_deep/README.md
浏览文件 @
eeede168
...
...
@@ -152,7 +152,11 @@ optional arguments:
--keep_prob The keep rate in dropout layer.(Default:1.0)
--dropout_flag Enable dropout.(Default:0)
--output_path Deprecated
--ckpt_path The location of the checkpoint file.(Defalut:./checkpoints/)
--ckpt_path The location of the checkpoint file. If the checkpoint file
is a slice of weight, multiple checkpoint files need to be
transferred. Use ';' to separate them and sort them in sequence
like "./checkpoints/0.ckpt;./checkpoints/1.ckpt".
(Defalut:./checkpoints/)
--eval_file_name Eval output file.(Default:eval.og)
--loss_file_name Loss output file.(Default:loss.log)
--host_device_mix Enable host device mode or not.(Default:0)
...
...
model_zoo/official/recommend/wide_and_deep/eval.py
浏览文件 @
eeede168
...
...
@@ -18,7 +18,8 @@
import
os
from
mindspore
import
Model
,
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
,
\
build_searched_strategy
,
merge_sliced_parameter
from
src.wide_and_deep
import
PredictWithSigmoid
,
TrainStepWrap
,
NetWithLossClass
,
WideDeepModel
from
src.callbacks
import
LossCallBack
,
EvalCallBack
...
...
@@ -81,8 +82,28 @@ def test_eval(config):
net_builder
=
ModelBuilder
()
train_net
,
eval_net
=
net_builder
.
get_net
(
config
)
param_dict
=
load_checkpoint
(
config
.
ckpt_path
)
ckpt_path
=
config
.
ckpt_path
if
";"
in
ckpt_path
:
ckpt_paths
=
ckpt_path
.
split
(
';'
)
param_list_dict
=
{}
strategy
=
build_searched_strategy
(
config
.
stra_ckpt
)
for
slice_path
in
ckpt_paths
:
param_slice_dict
=
load_checkpoint
(
slice_path
)
for
key
,
value
in
param_slice_dict
.
items
():
if
'optimizer'
in
key
:
continue
if
key
not
in
param_list_dict
:
param_list_dict
[
key
]
=
[]
param_list_dict
[
key
].
append
(
value
)
param_dict
=
{}
for
key
,
value
in
param_list_dict
.
items
():
if
key
in
strategy
:
merged_parameter
=
merge_sliced_parameter
(
value
,
strategy
)
else
:
merged_parameter
=
merge_sliced_parameter
(
value
)
param_dict
[
key
]
=
merged_parameter
else
:
param_dict
=
load_checkpoint
(
ckpt_path
)
load_param_into_net
(
eval_net
,
param_dict
)
auc_metric
=
AUCMetric
()
...
...
model_zoo/official/recommend/wide_and_deep/src/callbacks.py
浏览文件 @
eeede168
...
...
@@ -97,6 +97,7 @@ class EvalCallBack(Callback):
self
.
eval_file_name
=
config
.
eval_file_name
self
.
eval_values
=
[]
self
.
host_device_mix
=
host_device_mix
self
.
config
=
config
def
epoch_end
(
self
,
run_context
):
"""
...
...
@@ -106,7 +107,7 @@ class EvalCallBack(Callback):
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
if
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
context
.
set_auto_parallel_context
(
strategy_ckpt_save_file
=
""
,
strategy_ckpt_load_file
=
"./strategy_train.ckpt"
)
strategy_ckpt_load_file
=
self
.
config
.
stra_ckpt
)
rank_id
=
0
if
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
,
ParallelMode
.
DATA_PARALLEL
):
...
...
model_zoo/official/recommend/wide_and_deep/src/config.py
浏览文件 @
eeede168
...
...
@@ -39,6 +39,8 @@ def argparse_init():
parser
.
add_argument
(
"--dropout_flag"
,
type
=
int
,
default
=
0
,
help
=
"Enable dropout"
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./output/"
)
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
default
=
"./checkpoints/"
,
help
=
"The location of the checkpoint file."
)
parser
.
add_argument
(
"--stra_ckpt"
,
type
=
str
,
default
=
"./checkpoints/strategy.ckpt"
,
help
=
"The strategy checkpoint file."
)
parser
.
add_argument
(
"--eval_file_name"
,
type
=
str
,
default
=
"eval.log"
,
help
=
"Eval output file."
)
parser
.
add_argument
(
"--loss_file_name"
,
type
=
str
,
default
=
"loss.log"
,
help
=
"Loss output file."
)
parser
.
add_argument
(
"--host_device_mix"
,
type
=
int
,
default
=
0
,
help
=
"Enable host device mode or not"
)
...
...
@@ -75,6 +77,7 @@ class WideDeepConfig():
self
.
eval_file_name
=
"eval.log"
self
.
loss_file_name
=
"loss.log"
self
.
ckpt_path
=
"./checkpoints/"
self
.
stra_ckpt
=
'./checkpoints/strategy.ckpt'
self
.
host_device_mix
=
0
self
.
dataset_type
=
"tfrecord"
self
.
parameter_server
=
0
...
...
@@ -107,6 +110,7 @@ class WideDeepConfig():
self
.
eval_file_name
=
args
.
eval_file_name
self
.
loss_file_name
=
args
.
loss_file_name
self
.
ckpt_path
=
args
.
ckpt_path
self
.
stra_ckpt
=
args
.
stra_ckpt
self
.
host_device_mix
=
args
.
host_device_mix
self
.
dataset_type
=
args
.
dataset_type
self
.
parameter_server
=
args
.
parameter_server
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
浏览文件 @
eeede168
...
...
@@ -203,6 +203,7 @@ class WideDeepModel(nn.Cell):
self
.
dense_layer_1
.
dropout
.
dropout_do_mask
.
set_strategy
(((
1
,
get_group_size
()),))
self
.
dense_layer_1
.
dropout
.
dropout
.
set_strategy
(((
1
,
get_group_size
()),))
self
.
dense_layer_1
.
matmul
.
set_strategy
(((
1
,
get_group_size
()),
(
get_group_size
(),
1
)))
self
.
dense_layer_1
.
matmul
.
add_prim_attr
(
"field_size"
,
config
.
field_size
)
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
,
slice_mode
=
nn
.
EmbeddingLookUpSplitMode
.
TABLE_COLUMN_SLICE
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
,
...
...
@@ -211,6 +212,10 @@ class WideDeepModel(nn.Cell):
self
.
deep_reshape
.
add_prim_attr
(
"skip_redistribution"
,
True
)
self
.
reduce_sum
.
add_prim_attr
(
"cross_batch"
,
True
)
self
.
embedding_table
=
self
.
deep_embeddinglookup
.
embedding_table
elif
host_device_mix
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
)
self
.
embedding_table
=
self
.
deep_embeddinglookup
.
embedding_table
elif
parameter_server
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
)
...
...
model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py
浏览文件 @
eeede168
...
...
@@ -111,10 +111,11 @@ def train_and_eval(config):
eval_callback
=
EvalCallBack
(
model
,
ds_eval
,
auc_metric
,
config
,
host_device_mix
=
host_device_mix
)
callback
=
LossCallBack
(
config
=
config
,
per_print_times
=
20
)
ckptconfig
=
CheckpointConfig
(
save_checkpoint_steps
=
ds_train
.
get_dataset_size
(),
keep_checkpoint_max
=
5
)
ckptconfig
=
CheckpointConfig
(
save_checkpoint_steps
=
ds_train
.
get_dataset_size
()
*
epochs
,
keep_checkpoint_max
=
5
,
integrated_save
=
False
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
'widedeep_train'
,
directory
=
config
.
ckpt_path
,
config
=
ckptconfig
)
context
.
set_auto_parallel_context
(
strategy_ckpt_save_file
=
"./strategy_train.ckpt"
)
context
.
set_auto_parallel_context
(
strategy_ckpt_save_file
=
config
.
stra_ckpt
)
callback_list
=
[
TimeMonitor
(
ds_train
.
get_dataset_size
()),
eval_callback
,
callback
]
if
not
host_device_mix
:
callback_list
.
append
(
ckpoint_cb
)
...
...
tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py
浏览文件 @
eeede168
...
...
@@ -30,6 +30,8 @@ def argparse_init():
parser
.
add_argument
(
"--deep_layer_dim"
,
type
=
int
,
nargs
=
'+'
,
default
=
[
1024
,
512
,
256
,
128
])
parser
.
add_argument
(
"--deep_layer_act"
,
type
=
str
,
default
=
'relu'
)
parser
.
add_argument
(
"--keep_prob"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--stra_ckpt"
,
type
=
str
,
default
=
"./strategy_train.ckpt"
,
help
=
"The strategy checkpoint file."
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./output/"
)
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
default
=
"./checkpoints/"
)
...
...
@@ -63,6 +65,7 @@ class WideDeepConfig():
self
.
eval_file_name
=
"eval.log"
self
.
loss_file_name
=
"loss.log"
self
.
ckpt_path
=
"./checkpoints/"
self
.
stra_ckpt
=
"./strategy_train.ckpt"
def
argparse_init
(
self
):
"""
...
...
@@ -90,3 +93,4 @@ class WideDeepConfig():
self
.
eval_file_name
=
args
.
eval_file_name
self
.
loss_file_name
=
args
.
loss_file_name
self
.
ckpt_path
=
args
.
ckpt_path
self
.
stra_ckpt
=
args
.
stra_ckpt
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录