Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
b109e6f6
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b109e6f6
编写于
7月 25, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add parameter server model_zoo case and CI test cases.
上级
7f6f140d
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
479 addition
and
18 deletion
+479
-18
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
+2
-1
model_zoo/official/recommend/wide_and_deep/README.md
model_zoo/official/recommend/wide_and_deep/README.md
+26
-15
model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train.sh
...ommend/wide_and_deep/script/run_parameter_server_train.sh
+64
-0
model_zoo/official/recommend/wide_and_deep/src/config.py
model_zoo/official/recommend/wide_and_deep/src/config.py
+3
-0
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
...zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
+27
-2
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
...ecommend/wide_and_deep/train_and_eval_parameter_server.py
+129
-0
tests/st/ps/full_ps/run_full_ps_lenet.sh
tests/st/ps/full_ps/run_full_ps_lenet.sh
+61
-0
tests/st/ps/full_ps/test_full_ps_lenet.py
tests/st/ps/full_ps/test_full_ps_lenet.py
+137
-0
tests/st/ps/full_ps/test_run.py
tests/st/ps/full_ps/test_run.py
+30
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
浏览文件 @
b109e6f6
...
...
@@ -344,6 +344,7 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
template
<
typename
T
>
void
ParameterServer
<
T
>::
InitWeight
(
const
Key
&
key
,
const
WeightPtr
&
weight
)
{
MS_LOG
(
INFO
)
<<
"Initializing weight for key "
<<
key
;
if
(
weights_
.
count
(
key
)
==
0
)
{
weights_
[
key
]
=
weight
;
}
...
...
@@ -360,7 +361,7 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
template
<
typename
T
>
void
ParameterServer
<
T
>::
InitEmbeddingTable
(
const
Key
&
key
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
// Init embedding lookup kernel
MS_LOG
(
INFO
)
<<
"Initializing embedding table for key "
<<
key
;
std
::
shared_ptr
<
PServerKernel
>
lookup
=
std
::
make_shared
<
kernel
::
ps
::
EmbeddingLookUpPSKernel
>
(
rank_id_
,
pserver_num_
);
lookup
->
InitKernel
(
shapes
);
embedding_lookup_ops_
[
key
]
=
lookup
;
...
...
model_zoo/official/recommend/wide_and_deep/README.md
浏览文件 @
b109e6f6
...
...
@@ -24,22 +24,24 @@ The common used benchmark datasets are used for model training and evaluation.
The entire code structure is as following:
```
|--- wide_and_deep/
train_and_eval.py "Entrance of Wide&Deep model training and evaluation"
eval.py "Entrance of Wide&Deep model evaluation"
train.py "Entrance of Wide&Deep model training"
train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation"
train_and_eval.py
"Entrance of Wide&Deep model training and evaluation"
eval.py
"Entrance of Wide&Deep model evaluation"
train.py
"Entrance of Wide&Deep model training"
train_and_eval_multinpu.py
"Entrance of Wide&Deep model data parallel training and evaluation"
train_and_eval_auto_parallel.py
|--- src/ "Entrance of training and evaluation"
config.py "Parameters configuration"
dataset.py "Dataset loader class"
process_data.py "Process dataset"
preprocess_data.py "Pre_process dataset"
wide_and_deep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class"
|--- script/ "Run shell dir"
run_multinpu_train.sh "Run data parallel"
run_auto_parallel_train.sh "Run auto parallel"
train_and_eval_parameter_server.py "Entrance of Wide&Deep model parameter server training and evaluation"
|--- src/ "Entrance of training and evaluation"
config.py "Parameters configuration"
dataset.py "Dataset loader class"
process_data.py "Process dataset"
preprocess_data.py "Pre_process dataset"
wide_and_deep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class"
|--- script/ "Run shell dir"
run_multinpu_train.sh "Run data parallel"
run_auto_parallel_train.sh "Run auto parallel"
run_parameter_server_train.sh "Run parameter server"
```
### Train and evaluate model
...
...
@@ -110,6 +112,15 @@ bash start_cluster.sh CLUSTER_CONFIG_PATH EPOCH_SIZE VOCAB_SIZE EMB_DIM
DATASET ENV_SH RANK_TABLE_FILE MODE
```
To train and evaluate the model in parameter server mode, command as follows:'''
```
# SERVER_NUM is the number of parameter servers for this task.
# SCHED_HOST is the IP address of scheduler.
# SCHED_PORT is the port of scheduler.
# The number of workers is the same as RANK_SIZE.
bash run_parameter_server_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE SERVER_NUM SCHED_HOST SCHED_PORT
```
To evaluate the model, command as follows:
```
python eval.py
...
...
model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train.sh
0 → 100644
浏览文件 @
b109e6f6
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path
=
$(
pwd
)
script_self
=
$(
readlink
-f
"
$0
"
)
self_path
=
$(
dirname
"
${
script_self
}
"
)
export
RANK_SIZE
=
$1
export
EPOCH_SIZE
=
$2
export
DATASET
=
$3
export
RANK_TABLE_FILE
=
$4
export
MINDSPORE_HCCL_CONFIG_PATH
=
$4
export
MS_COMM_TYPE
=
zmq
export
MS_SCHED_NUM
=
1
export
MS_WORKER_NUM
=
$RANK_SIZE
export
MS_SERVER_NUM
=
$5
export
MS_SCHED_HOST
=
$6
export
MS_SCHED_PORT
=
$7
export
MS_ROLE
=
MS_SCHED
for
((
i
=
0
;
i<1
;
i++
))
;
do
rm
-rf
${
execute_path
}
/sched_
$i
/
mkdir
${
execute_path
}
/sched_
$i
/
cd
${
execute_path
}
/sched_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../train_and_eval_parameter_server.py
--data_path
=
$DATASET
--epochs
=
$EPOCH_SIZE
--parameter_server
=
1
>
sched_
$i
.log 2>&1 &
done
export
MS_ROLE
=
MS_PSERVER
for
((
i
=
0
;
i<
$MS_SERVER_NUM
;
i++
))
;
do
rm
-rf
${
execute_path
}
/server_
$i
/
mkdir
${
execute_path
}
/server_
$i
/
cd
${
execute_path
}
/server_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../train_and_eval_parameter_server.py
--data_path
=
$DATASET
--epochs
=
$EPOCH_SIZE
--parameter_server
=
1
>
server_
$i
.log 2>&1 &
done
export
MS_ROLE
=
MS_WORKER
for
((
i
=
0
;
i<
$MS_WORKER_NUM
;
i++
))
;
do
rm
-rf
${
execute_path
}
/worker_
$i
/
mkdir
${
execute_path
}
/worker_
$i
/
cd
${
execute_path
}
/worker_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../train_and_eval_parameter_server.py
--data_path
=
$DATASET
--epochs
=
$EPOCH_SIZE
--parameter_server
=
1
>
worker_
$i
.log 2>&1 &
done
model_zoo/official/recommend/wide_and_deep/src/config.py
浏览文件 @
b109e6f6
...
...
@@ -40,6 +40,7 @@ def argparse_init():
parser
.
add_argument
(
"--loss_file_name"
,
type
=
str
,
default
=
"loss.log"
)
parser
.
add_argument
(
"--host_device_mix"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--dataset_type"
,
type
=
str
,
default
=
"tfrecord"
)
parser
.
add_argument
(
"--parameter_server"
,
type
=
int
,
default
=
0
)
return
parser
...
...
@@ -72,6 +73,7 @@ class WideDeepConfig():
self
.
ckpt_path
=
"./checkpoints/"
self
.
host_device_mix
=
0
self
.
dataset_type
=
"tfrecord"
self
.
parameter_server
=
0
def
argparse_init
(
self
):
"""
...
...
@@ -103,3 +105,4 @@ class WideDeepConfig():
self
.
ckpt_path
=
args
.
ckpt_path
self
.
host_device_mix
=
args
.
host_device_mix
self
.
dataset_type
=
args
.
dataset_type
self
.
parameter_server
=
args
.
parameter_server
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
浏览文件 @
b109e6f6
...
...
@@ -108,6 +108,9 @@ class DenseLayer(nn.Cell):
return
act_func
def
construct
(
self
,
x
):
'''
Construct Dense layer
'''
if
self
.
training
and
self
.
drop_out
:
x
=
self
.
dropout
(
x
)
if
self
.
convert_dtype
:
...
...
@@ -138,6 +141,7 @@ class WideDeepModel(nn.Cell):
super
(
WideDeepModel
,
self
).
__init__
()
self
.
batch_size
=
config
.
batch_size
host_device_mix
=
bool
(
config
.
host_device_mix
)
parameter_server
=
bool
(
config
.
parameter_server
)
parallel_mode
=
_get_parallel_mode
()
is_auto_parallel
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
if
is_auto_parallel
:
...
...
@@ -164,6 +168,9 @@ class WideDeepModel(nn.Cell):
self
.
wide_w
=
var_map
[
"Wide_w"
]
self
.
wide_b
=
var_map
[
"Wide_b"
]
self
.
embedding_table
=
var_map
[
"V_l2"
]
if
parameter_server
:
self
.
wide_w
.
set_param_ps
()
self
.
embedding_table
.
set_param_ps
()
self
.
dense_layer_1
=
DenseLayer
(
self
.
all_dim_list
[
0
],
self
.
all_dim_list
[
1
],
self
.
weight_bias_init
,
...
...
@@ -209,6 +216,9 @@ class WideDeepModel(nn.Cell):
self
.
deep_mul
.
set_strategy
(((
1
,
1
,
get_group_size
()),
(
1
,
1
,
1
)))
self
.
deep_reshape
.
add_prim_attr
(
"skip_redistribution"
,
True
)
self
.
reduce_sum
.
add_prim_attr
(
"cross_batch"
,
True
)
elif
parameter_server
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
()
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
()
else
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
target
=
'DEVICE'
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
target
=
'DEVICE'
)
...
...
@@ -249,9 +259,10 @@ class NetWithLossClass(nn.Cell):
def
__init__
(
self
,
network
,
config
):
super
(
NetWithLossClass
,
self
).
__init__
(
auto_prefix
=
False
)
host_device_mix
=
bool
(
config
.
host_device_mix
)
parameter_server
=
bool
(
config
.
parameter_server
)
parallel_mode
=
_get_parallel_mode
()
is_auto_parallel
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
self
.
no_l2loss
=
host_device_mix
and
is_auto_parallel
self
.
no_l2loss
=
(
is_auto_parallel
if
host_device_mix
else
parameter_server
)
self
.
network
=
network
self
.
l2_coef
=
config
.
l2_coef
self
.
loss
=
P
.
SigmoidCrossEntropyWithLogits
()
...
...
@@ -262,6 +273,9 @@ class NetWithLossClass(nn.Cell):
self
.
reduceSum_false
=
P
.
ReduceSum
(
keep_dims
=
False
)
def
construct
(
self
,
batch_ids
,
batch_wts
,
label
):
'''
Construct NetWithLossClass
'''
predict
,
embedding_table
=
self
.
network
(
batch_ids
,
batch_wts
)
log_loss
=
self
.
loss
(
predict
,
label
)
wide_loss
=
self
.
reduceMean_false
(
log_loss
)
...
...
@@ -294,9 +308,10 @@ class TrainStepWrap(nn.Cell):
network (Cell): The training network. Note that loss function should have been added.
sens (Number): The adjust parameter. Default: 1024.0
host_device_mix (Bool): Whether run in host and device mix mode. Default: False
parameter_server (Bool): Whether run in parameter server mode. Default: False
"""
def
__init__
(
self
,
network
,
sens
=
1024.0
,
host_device_mix
=
False
):
def
__init__
(
self
,
network
,
sens
=
1024.0
,
host_device_mix
=
False
,
parameter_server
=
False
):
super
(
TrainStepWrap
,
self
).
__init__
()
parallel_mode
=
_get_parallel_mode
()
is_auto_parallel
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
...
...
@@ -320,6 +335,13 @@ class TrainStepWrap(nn.Cell):
l1
=
1e-8
,
l2
=
1e-8
,
initial_accum
=
1.0
,
loss_scale
=
sens
)
self
.
optimizer_w
.
sparse_opt
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
optimizer_d
.
sparse_opt
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
elif
parameter_server
:
self
.
optimizer_d
=
Adam
(
self
.
weights_d
,
learning_rate
=
3.5e-4
,
eps
=
1e-8
,
loss_scale
=
sens
)
self
.
optimizer_w
=
FTRL
(
learning_rate
=
5e-2
,
params
=
self
.
weights_w
,
l1
=
1e-8
,
l2
=
1e-8
,
initial_accum
=
1.0
,
loss_scale
=
sens
)
self
.
optimizer_w
.
sparse_opt
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
optimizer_d
.
sparse_opt
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
else
:
self
.
optimizer_d
=
Adam
(
self
.
weights_d
,
learning_rate
=
3.5e-4
,
eps
=
1e-8
,
loss_scale
=
sens
)
...
...
@@ -347,6 +369,9 @@ class TrainStepWrap(nn.Cell):
self
.
grad_reducer_d
=
DistributedGradReducer
(
self
.
optimizer_d
.
parameters
,
mean
,
degree
)
def
construct
(
self
,
batch_ids
,
batch_wts
,
label
):
'''
Construct wide and deep model
'''
weights_w
=
self
.
weights_w
weights_d
=
self
.
weights_d
loss_w
,
loss_d
=
self
.
network
(
batch_ids
,
batch_wts
,
label
)
...
...
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
0 → 100644
浏览文件 @
b109e6f6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import
os
import
sys
import
numpy
as
np
from
mindspore
import
Model
,
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
TimeMonitor
from
mindspore.train
import
ParallelMode
from
mindspore.communication.management
import
get_rank
,
get_group_size
,
init
from
src.wide_and_deep
import
PredictWithSigmoid
,
TrainStepWrap
,
NetWithLossClass
,
WideDeepModel
from
src.callbacks
import
LossCallBack
,
EvalCallBack
from
src.datasets
import
create_dataset
,
DataType
from
src.metrics
import
AUCMetric
from
src.config
import
WideDeepConfig
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
context
.
set_context
(
enable_sparse
=
True
)
def
get_WideDeep_net
(
config
):
"""
Get network of wide&deep model.
"""
WideDeep_net
=
WideDeepModel
(
config
)
loss_net
=
NetWithLossClass
(
WideDeep_net
,
config
)
train_net
=
TrainStepWrap
(
loss_net
,
parameter_server
=
bool
(
config
.
parameter_server
))
eval_net
=
PredictWithSigmoid
(
WideDeep_net
)
return
train_net
,
eval_net
class
ModelBuilder
():
"""
ModelBuilder
"""
def
__init__
(
self
):
pass
def
get_hook
(
self
):
pass
def
get_train_hook
(
self
):
hooks
=
[]
callback
=
LossCallBack
()
hooks
.
append
(
callback
)
if
int
(
os
.
getenv
(
'DEVICE_ID'
))
==
0
:
pass
return
hooks
def
get_net
(
self
,
config
):
return
get_WideDeep_net
(
config
)
def
train_and_eval
(
config
):
"""
test_train_eval
"""
np
.
random
.
seed
(
1000
)
data_path
=
config
.
data_path
batch_size
=
config
.
batch_size
epochs
=
config
.
epochs
if
config
.
dataset_type
==
"tfrecord"
:
dataset_type
=
DataType
.
TFRECORD
elif
config
.
dataset_type
==
"mindrecord"
:
dataset_type
=
DataType
.
MINDRECORD
else
:
dataset_type
=
DataType
.
H5
parameter_server
=
bool
(
config
.
parameter_server
)
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
(),
data_type
=
dataset_type
)
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
(),
data_type
=
dataset_type
)
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
net_builder
=
ModelBuilder
()
train_net
,
eval_net
=
net_builder
.
get_net
(
config
)
train_net
.
set_train
()
auc_metric
=
AUCMetric
()
model
=
Model
(
train_net
,
eval_network
=
eval_net
,
metrics
=
{
"auc"
:
auc_metric
})
eval_callback
=
EvalCallBack
(
model
,
ds_eval
,
auc_metric
,
config
)
callback
=
LossCallBack
(
config
=
config
)
ckptconfig
=
CheckpointConfig
(
save_checkpoint_steps
=
ds_train
.
get_dataset_size
(),
keep_checkpoint_max
=
5
)
if
config
.
device_target
==
"Ascend"
:
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
'widedeep_train'
,
directory
=
config
.
ckpt_path
,
config
=
ckptconfig
)
elif
config
.
device_target
==
"GPU"
:
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
'widedeep_train_'
+
str
(
get_rank
()),
directory
=
config
.
ckpt_path
,
config
=
ckptconfig
)
model
.
train
(
epochs
,
ds_train
,
callbacks
=
[
TimeMonitor
(
ds_train
.
get_dataset_size
()),
eval_callback
,
callback
,
ckpoint_cb
],
dataset_sink_mode
=
(
not
parameter_server
))
if
__name__
==
"__main__"
:
wide_deep_config
=
WideDeepConfig
()
wide_deep_config
.
argparse_init
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
wide_deep_config
.
device_target
)
if
wide_deep_config
.
device_target
==
"Ascend"
:
init
(
"hccl"
)
elif
wide_deep_config
.
device_target
==
"GPU"
:
init
(
"nccl"
)
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
device_num
=
get_group_size
())
train_and_eval
(
wide_deep_config
)
tests/st/ps/full_ps/run_full_ps_lenet.sh
0 → 100644
浏览文件 @
b109e6f6
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path
=
$(
pwd
)
self_path
=
$(
dirname
"
${
script_self
}
"
)
export
MS_COMM_TYPE
=
zmq
export
MS_SCHED_NUM
=
1
DEVICE_TARGET
=
$1
export
MS_WORKER_NUM
=
$2
export
MS_SERVER_NUM
=
$3
export
MS_SCHED_HOST
=
$4
export
MS_SCHED_PORT
=
$5
export
MS_ROLE
=
MS_SCHED
for
((
i
=
0
;
i<1
;
i++
))
;
do
rm
-rf
${
execute_path
}
/sched_
$i
/
mkdir
${
execute_path
}
/sched_
$i
/
cd
${
execute_path
}
/sched_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../test_full_ps_lenet.py
--device_target
=
$DEVICE_TARGET
&
done
export
MS_ROLE
=
MS_PSERVER
for
((
i
=
0
;
i<
$MS_SERVER_NUM
;
i++
))
;
do
rm
-rf
${
execute_path
}
/server_
$i
/
mkdir
${
execute_path
}
/server_
$i
/
cd
${
execute_path
}
/server_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../test_full_ps_lenet.py
--device_target
=
$DEVICE_TARGET
&
done
export
MS_ROLE
=
MS_WORKER
for
((
i
=
0
;
i<
$MS_WORKER_NUM
;
i++
))
;
do
rm
-rf
${
execute_path
}
/worker_
$i
/
mkdir
${
execute_path
}
/worker_
$i
/
cd
${
execute_path
}
/worker_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../test_full_ps_lenet.py
--device_target
=
$DEVICE_TARGET
&
done
wait
$!
exit
$?
tests/st/ps/full_ps/test_full_ps_lenet.py
0 → 100644
浏览文件 @
b109e6f6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
os
import
argparse
import
mindspore.context
as
context
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.dataset.transforms.vision.c_transforms
as
CV
import
mindspore.nn
as
nn
from
mindspore.common
import
dtype
as
mstype
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train
import
Model
from
mindspore.train.callback
import
LossMonitor
from
mindspore.common.initializer
import
TruncatedNormal
parser
=
argparse
.
ArgumentParser
(
description
=
'test_ps_lenet'
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
)
parser
.
add_argument
(
"--dataset_path"
,
type
=
str
,
default
=
"/home/workspace/mindspore_dataset/mnist"
)
args
,
_
=
parser
.
parse_known_args
()
device_target
=
args
.
device_target
dataset_path
=
args
.
dataset_path
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
"""weight initial for conv layer"""
weight
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"valid"
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
"""weight initial for fc layer"""
weight
=
weight_variable
()
bias
=
weight_variable
()
return
nn
.
Dense
(
input_channels
,
out_channels
,
weight
,
bias
)
def
weight_variable
():
"""weight initial"""
return
TruncatedNormal
(
0.02
)
class
LeNet5
(
nn
.
Cell
):
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
conv
(
channel
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
def
create_dataset
(
data_path
,
batch_size
=
32
,
repeat_size
=
1
,
num_parallel_workers
=
1
):
"""
create dataset for train or test
"""
# define dataset
mnist_ds
=
ds
.
MnistDataset
(
data_path
)
resize_height
,
resize_width
=
32
,
32
rescale
=
1.0
/
255.0
shift
=
0.0
rescale_nml
=
1
/
0.3081
shift_nml
=
-
1
*
0.1307
/
0.3081
# define map operations
resize_op
=
CV
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
Inter
.
LINEAR
)
# Bilinear mode
rescale_nml_op
=
CV
.
Rescale
(
rescale_nml
,
shift_nml
)
rescale_op
=
CV
.
Rescale
(
rescale
,
shift
)
hwc2chw_op
=
CV
.
HWC2CHW
()
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
# apply map operations on images
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
resize_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
rescale_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
rescale_nml_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
hwc2chw_op
,
num_parallel_workers
=
num_parallel_workers
)
# apply DatasetOps
buffer_size
=
10000
mnist_ds
=
mnist_ds
.
shuffle
(
buffer_size
=
buffer_size
)
# 10000 as in LeNet train script
mnist_ds
=
mnist_ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
mnist_ds
=
mnist_ds
.
repeat
(
repeat_size
)
return
mnist_ds
if
__name__
==
"__main__"
:
network
=
LeNet5
(
10
)
network
.
set_param_ps
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
ds_train
=
create_dataset
(
os
.
path
.
join
(
dataset_path
,
"train"
),
32
,
1
)
model
.
train
(
1
,
ds_train
,
callbacks
=
[
LossMonitor
()],
dataset_sink_mode
=
False
)
ds_eval
=
create_dataset
(
os
.
path
.
join
(
dataset_path
,
"test"
),
32
,
1
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
print
(
"Accuracy:"
,
acc
[
'Accuracy'
])
assert
acc
[
'Accuracy'
]
>
0.93
tests/st/ps/full_ps/test_run.py
0 → 100644
浏览文件 @
b109e6f6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
os
# @pytest.mark.level0
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
# @pytest.mark.env_onecard
def
test_full_ps_ascend_lenet
():
return_code
=
os
.
system
(
"bash run_full_ps_lenet.sh Ascend 1 1 127.0.0.1 8088"
)
assert
return_code
==
0
# @pytest.mark.level0
# @pytest.mark.platform_x86_gpu_training
# @pytest.mark.env_onecard
def
test_full_ps_gpu_lenet
():
return_code
=
os
.
system
(
"bash run_full_ps_lenet.sh GPU 1 1 127.0.0.1 8088"
)
assert
return_code
==
0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录