Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
348c67ac
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
348c67ac
编写于
8月 11, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PS support multi server.
上级
cde69647
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
549 addition
and
59 deletion
+549
-59
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
+5
-3
mindspore/ccsrc/frontend/parallel/ps/worker.h
mindspore/ccsrc/frontend/parallel/ps/worker.h
+4
-4
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
+203
-16
model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh
...icial/cv/resnet/scripts/run_parameter_server_train_gpu.sh
+23
-19
tests/st/ps/multi_full_ps/entry.py
tests/st/ps/multi_full_ps/entry.py
+14
-11
tests/st/ps/multi_full_ps/resnet.py
tests/st/ps/multi_full_ps/resnet.py
+283
-0
tests/st/ps/multi_full_ps/shell_run_test.sh
tests/st/ps/multi_full_ps/shell_run_test.sh
+11
-5
tests/st/ps/multi_full_ps/test_multi_full_ps.py
tests/st/ps/multi_full_ps/test_multi_full_ps.py
+6
-1
未找到文件。
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
浏览文件 @
348c67ac
...
...
@@ -257,6 +257,7 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
ps_
->
mutex
());
const
Key
&
key
=
req_data
.
keys
[
0
];
MS_LOG
(
INFO
)
<<
"Initializing embedding table for key:"
<<
key
;
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
shapes
=
std
::
make_shared
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
();
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
input_shape
=
std
::
make_shared
<
std
::
vector
<
size_t
>>
();
...
...
@@ -348,6 +349,8 @@ void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_
}
weight_key_to_optims_
[
key
]
=
Util
::
optimizer_name
(
optim_id
);
weight_key_to_optim_op_
[
key
]
=
Util
::
optimizer_node_name
(
optim_id
);
MS_LOG
(
INFO
)
<<
"Initializing optimizer id for key:"
<<
key
<<
", optimizer name:"
<<
weight_key_to_optims_
[
key
]
<<
", optimizer op name:"
<<
weight_key_to_optim_op_
[
key
];
}
template
<
typename
T
>
...
...
@@ -355,7 +358,7 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
InputsShapePtr
inputs_shape
=
std
::
make_shared
<
InputsShape
>
();
int
val_idx
=
0
;
const
Key
&
key
=
keys
[
0
];
MS_LOG
(
INFO
)
<<
"Initializing optimizer inputs shape for key:"
<<
key
;
if
(
optim_inputs_shape_
.
count
(
key
)
==
0
)
{
optim_inputs_shape_
[
key
]
=
inputs_shape
;
}
...
...
@@ -413,7 +416,7 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
template
<
typename
T
>
void
ParameterServer
<
T
>::
InitWeight
(
const
Key
&
key
,
const
WeightPtr
&
weight
)
{
MS_LOG
(
INFO
)
<<
"Initializing weight for key "
<<
key
;
MS_LOG
(
INFO
)
<<
"Initializing weight for key "
<<
key
<<
", server rank "
<<
rank_id_
;
if
((
weights_
.
count
(
key
)
==
0
)
||
(
is_embedding_
[
key
]
&&
weights_
.
count
(
key
)
!=
0
))
{
weights_
[
key
]
=
weight
;
tokens_
[
key
]
=
0
;
...
...
@@ -432,7 +435,6 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
template
<
typename
T
>
void
ParameterServer
<
T
>::
InitEmbeddingTable
(
const
Key
&
key
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
MS_LOG
(
INFO
)
<<
"Initializing embedding table for key "
<<
key
;
std
::
shared_ptr
<
PServerKernel
>
lookup
=
std
::
make_shared
<
kernel
::
ps
::
EmbeddingLookUpPSKernel
>
(
rank_id_
,
pserver_num_
);
lookup
->
InitKernel
(
shapes
);
embedding_lookup_ops_
[
key
]
=
lookup
;
...
...
mindspore/ccsrc/frontend/parallel/ps/worker.h
浏览文件 @
348c67ac
...
...
@@ -89,7 +89,7 @@ void Worker<T>::Run() {
if
(
!::
ps
::
IsWorker
())
{
MS_LOG
(
EXCEPTION
)
<<
"The role is not worker."
;
}
kv_worker_
=
std
::
make_shared
<
WorkerProxy
<
T
>>
(
0
,
0
,
1
);
kv_worker_
=
std
::
make_shared
<
WorkerProxy
<
T
>>
(
0
,
0
,
1
,
2
);
running_
=
true
;
}
...
...
@@ -121,7 +121,7 @@ void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) {
while
(
!
kv_worker_
->
IsReadyForPull
(
key
))
{
continue
;
}
kv_worker_
->
Wait
(
kv_worker_
->
ZPull
({
key
},
&
variables
)
);
kv_worker_
->
PullData
({
key
},
&
variables
);
auto
ret
=
memcpy_s
(
dev_addr
,
size
,
variables
.
data
(),
size
);
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
...
...
@@ -149,7 +149,7 @@ void Worker<T>::InitPSParamData(const std::vector<size_t> &keys, void *origin_ad
::
ps
::
SArray
<::
ps
::
Key
>
key
(
keys
);
::
ps
::
SArray
<
int
>
lens
;
lens
.
push_back
(
addr
.
size
());
kv_worker_
->
Wait
(
kv_worker_
->
ZPush
(
key
,
addr
,
lens
,
kInitWeightsCmd
)
);
kv_worker_
->
PushData
(
key
,
addr
,
lens
,
kInitWeightsCmd
);
init_keys_
[
key
[
0
]]
=
true
;
}
...
...
@@ -269,7 +269,6 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
}
template
<
typename
T
>
// Initialize parameters and optimizer kernels of Parameter Server.
void
Worker
<
T
>::
InitPSParamAndOptim
(
const
std
::
string
&
param_name
,
tensor
::
TensorPtr
tensor
)
{
void
*
param_data
=
tensor
->
data_c
();
size_t
param_size
=
LongToSize
(
tensor
->
data
().
nbytes
());
...
...
@@ -290,6 +289,7 @@ void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, tensor::Tenso
if
(
!
init
)
{
MS_LOG
(
INFO
)
<<
"Init paramter and optimizer in parameter server side for "
<<
param_name
<<
", whether init in server: "
<<
init_in_server
;
kv_worker_
->
AddKeyToServerId
(
param_key
);
if
(
!
init_in_server
)
{
InitPSParamData
({
param_key
},
param_data
,
param_size
);
}
...
...
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
浏览文件 @
348c67ac
...
...
@@ -38,19 +38,26 @@ class WorkerProxy : public ::ps::KVWorker<T> {
using
Slicer
=
std
::
function
<
void
(
int
ts
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
ranges
,
SlicedKVs
*
sliced
)
>
;
using
::
ps
::
SimpleApp
::
obj_
;
explicit
WorkerProxy
(
int
app_id
,
int
customer_id
,
int
lookup_customer_id
)
:
Worker
(
app_id
,
customer_id
)
{
explicit
WorkerProxy
(
int
app_id
,
int
customer_id
,
int
lookup_customer_id
,
int
general_customer_id
)
:
Worker
(
app_id
,
customer_id
)
{
server_num_
=
::
ps
::
NumServers
();
using
std
::
placeholders
::
_1
;
using
std
::
placeholders
::
_2
;
using
std
::
placeholders
::
_3
;
using
std
::
placeholders
::
_4
;
lookup_customer_
=
std
::
unique_ptr
<::
ps
::
Customer
>
(
new
::
ps
::
Customer
(
app_id
,
lookup_customer_id
,
std
::
bind
(
&
WorkerProxy
<
T
>::
ProcessLookupResult
,
this
,
_1
)));
general_customer_
=
std
::
unique_ptr
<::
ps
::
Customer
>
(
new
::
ps
::
Customer
(
app_id
,
general_customer_id
,
std
::
bind
(
&
WorkerProxy
<
T
>::
ProcessResponse
,
this
,
_1
)));
lookup_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
LookupIdSlicer
,
this
,
_1
,
_2
,
_3
,
_4
);
broadcast_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
BroadcastSlicer
,
this
,
_1
,
_2
,
_3
,
_4
);
round_robin_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
RoundRobinSlicer
,
this
,
_1
,
_2
,
_3
,
_4
);
worker_init_embedding_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
WorkerInitEmbeddingSlicer
,
this
,
_1
,
_2
,
_3
,
_4
);
}
~
WorkerProxy
()
override
=
default
;
void
AddEmbeddingTable
(
const
::
ps
::
Key
&
key
,
const
size_t
&
row_count
);
void
AddKeyToServerId
(
const
::
ps
::
Key
&
key
);
void
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
outs
,
int
cmd
=
0
,
const
Callback
&
cb
=
nullptr
,
int
priority
=
0
);
...
...
@@ -60,37 +67,54 @@ class WorkerProxy : public ::ps::KVWorker<T> {
bool
IsReadyForPull
(
const
Key
&
key
);
void
PushData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
=
{},
int
cmd
=
0
,
int
priority
=
0
);
void
PullData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
::
ps
::
SArray
<
T
>
*
vals
,
::
ps
::
SArray
<
int
>
*
lens
=
nullptr
,
int
cmd
=
0
,
int
priority
=
0
);
void
Finalize
();
private:
template
<
typename
C
>
int
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
C
*
vals
,
int
cmd
,
const
Callback
&
cb
);
int
AddGeneralRspCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
::
ps
::
SArray
<
T
>
*
vals
,
::
ps
::
SArray
<
int
>
*
lens
,
int
cmd
,
const
Callback
&
cb
);
void
LookupIdSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
BroadcastSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
RoundRobinSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
WorkerInitEmbeddingSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
ProcessLookupResult
(
const
::
ps
::
Message
&
msg
);
void
ProcessResponse
(
const
::
ps
::
Message
&
msg
);
void
Send
(
::
ps
::
Customer
*
customer
,
int
timestamp
,
bool
push
,
bool
pull
,
int
cmd
,
const
::
ps
::
KVPairs
<
T
>
&
kvs
,
const
Slicer
&
slicer
);
void
AddKeyByHashMod
(
const
::
ps
::
Key
&
key
);
int
server_num_
;
std
::
unique_ptr
<::
ps
::
Customer
>
lookup_customer_
;
std
::
unique_ptr
<::
ps
::
Customer
>
general_customer_
;
std
::
unordered_map
<::
ps
::
Key
,
std
::
shared_ptr
<
std
::
vector
<::
ps
::
Range
>>>
embedding_table_ranges_
;
std
::
unordered_map
<
int
,
std
::
vector
<::
ps
::
KVPairs
<
T
>>>
lookup_results_
;
std
::
unordered_map
<
int
,
::
ps
::
KVPairs
<
T
>>
gathered_response_
;
std
::
mutex
mutex_
;
Slicer
lookup_slicer_
;
Slicer
broadcast_slicer_
;
Slicer
round_robin_slicer_
;
Slicer
worker_init_embedding_slicer_
;
std
::
unordered_map
<
int
,
Callback
>
lookup_callbacks_
;
std
::
unordered_map
<
int
,
Callback
>
general_callbacks_
;
std
::
unordered_map
<
int
,
int
>
expected_result_count_
;
std
::
unordered_map
<::
ps
::
Key
,
int
>
key_to_server_id_
;
std
::
unordered_map
<::
ps
::
Key
,
size_t
>
embedding_row_cnt_
;
};
template
<
typename
T
>
void
WorkerProxy
<
T
>::
AddEmbeddingTable
(
const
::
ps
::
Key
&
key
,
const
size_t
&
row_count
)
{
uint64_t
begin
=
0
;
uint64_t
end
=
0
;
int
server_num
=
::
ps
::
NumServers
();
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
int
local_row_cnt
=
Util
::
LocalShard
(
row_count
,
i
,
server_num
);
for
(
int
i
=
0
;
i
<
server_num_
;
i
++
)
{
int
local_row_cnt
=
Util
::
LocalShard
(
row_count
,
i
,
server_num_
);
if
(
i
==
0
)
{
end
=
local_row_cnt
-
1
;
}
else
{
...
...
@@ -103,6 +127,21 @@ void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_c
}
embedding_table_ranges_
[
key
]
->
push_back
(
range
);
}
embedding_row_cnt_
[
key
]
=
row_count
;
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
AddKeyByHashMod
(
const
::
ps
::
Key
&
key
)
{
if
(
server_num_
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Server number is invalid:0"
;
}
key_to_server_id_
[
key
]
=
static_cast
<
int
>
(
key
%
server_num_
);
MS_LOG
(
INFO
)
<<
"The server id of key "
<<
key
<<
" is "
<<
key_to_server_id_
[
key
];
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
AddKeyToServerId
(
const
::
ps
::
Key
&
key
)
{
AddKeyByHashMod
(
key
);
}
template
<
typename
T
>
...
...
@@ -116,9 +155,8 @@ void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const
kvs
.
priority
=
priority
;
expected_result_count_
[
ts
]
=
0
;
Send
(
lookup_customer_
.
get
(),
ts
,
true
,
true
,
cmd
,
kvs
,
lookup_slicer_
);
int
server_num
=
::
ps
::
NumServers
();
int
expect_rt_count
=
expected_result_count_
[
ts
];
lookup_customer_
->
AddResponse
(
ts
,
server_num
-
expect_rt_count
);
lookup_customer_
->
AddResponse
(
ts
,
server_num
_
-
expect_rt_count
);
lookup_customer_
->
WaitRequest
(
ts
);
expected_result_count_
.
erase
(
ts
);
}
...
...
@@ -139,7 +177,7 @@ int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons
template
<
typename
T
>
bool
WorkerProxy
<
T
>::
IsReadyForPush
(
const
Key
&
key
)
{
::
ps
::
SArray
<
T
>
result
(
1
,
0
);
this
->
Wait
(
this
->
ZPull
({
key
},
&
result
,
nullptr
,
kCheckReadyForPushCmd
)
);
PullData
({
key
},
&
result
,
nullptr
,
kCheckReadyForPushCmd
);
if
(
result
[
0
]
>
0
)
{
return
true
;
}
else
{
...
...
@@ -150,7 +188,7 @@ bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
template
<
typename
T
>
bool
WorkerProxy
<
T
>::
IsReadyForPull
(
const
Key
&
key
)
{
::
ps
::
SArray
<
T
>
result
(
1
,
0
);
this
->
Wait
(
this
->
ZPull
({
key
},
&
result
,
nullptr
,
kCheckReadyForPullCmd
)
);
PullData
({
key
},
&
result
,
nullptr
,
kCheckReadyForPullCmd
);
if
(
result
[
0
]
>
0
)
{
return
true
;
}
else
{
...
...
@@ -161,14 +199,43 @@ bool WorkerProxy<T>::IsReadyForPull(const Key &key) {
template
<
typename
T
>
void
WorkerProxy
<
T
>::
PushData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
,
int
cmd
,
int
priority
)
{
int
ts
=
obj_
->
NewRequest
(
::
ps
::
kServerGroup
);
int
ts
=
AddGeneralRspCB
(
keys
,
nullptr
,
nullptr
,
cmd
,
nullptr
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
vals
=
vals
;
kvs
.
lens
=
lens
;
kvs
.
priority
=
priority
;
Send
(
obj_
,
ts
,
true
,
false
,
cmd
,
kvs
,
broadcast_slicer_
);
obj_
->
WaitRequest
(
ts
);
if
(
embedding_table_ranges_
.
count
(
keys
[
0
]))
{
if
(
cmd
==
kInitWeightsCmd
)
{
Send
(
general_customer_
.
get
(),
ts
,
true
,
false
,
cmd
,
kvs
,
worker_init_embedding_slicer_
);
}
else
{
Send
(
general_customer_
.
get
(),
ts
,
true
,
false
,
cmd
,
kvs
,
broadcast_slicer_
);
}
}
else
{
Send
(
general_customer_
.
get
(),
ts
,
true
,
false
,
cmd
,
kvs
,
round_robin_slicer_
);
}
if
(
expected_result_count_
[
ts
]
<
server_num_
)
{
general_customer_
->
AddResponse
(
ts
,
server_num_
-
expected_result_count_
[
ts
]);
}
general_customer_
->
WaitRequest
(
ts
);
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
PullData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
::
ps
::
SArray
<
T
>
*
vals
,
::
ps
::
SArray
<
int
>
*
lens
,
int
cmd
,
int
priority
)
{
int
ts
=
AddGeneralRspCB
(
keys
,
vals
,
lens
,
cmd
,
nullptr
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
priority
=
priority
;
if
(
embedding_table_ranges_
.
count
(
keys
[
0
]))
{
Send
(
general_customer_
.
get
(),
ts
,
false
,
true
,
cmd
,
kvs
,
broadcast_slicer_
);
}
else
{
Send
(
general_customer_
.
get
(),
ts
,
false
,
true
,
cmd
,
kvs
,
round_robin_slicer_
);
}
if
(
expected_result_count_
[
ts
]
<
server_num_
)
{
general_customer_
->
AddResponse
(
ts
,
server_num_
-
expected_result_count_
[
ts
]);
}
general_customer_
->
WaitRequest
(
ts
);
}
template
<
typename
T
>
...
...
@@ -192,8 +259,13 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
auto
&
kvs
=
lookup_results_
[
ts
];
mutex_
.
unlock
();
auto
&
s
=
kvs
[
0
];
*
lookup_result
=
s
.
vals
;
::
ps
::
SArray
<
T
>
result
(
kvs
[
0
].
vals
.
size
(),
0
);
for
(
auto
k
:
kvs
)
{
for
(
size_t
i
=
0
;
i
<
k
.
vals
.
size
();
i
++
)
{
result
[
i
]
+=
k
.
vals
[
i
];
}
}
*
lookup_result
=
result
;
mutex_
.
lock
();
lookup_results_
.
erase
(
ts
);
...
...
@@ -204,6 +276,31 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
return
ts
;
}
template
<
typename
T
>
int
WorkerProxy
<
T
>::
AddGeneralRspCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
::
ps
::
SArray
<
T
>
*
vals
,
::
ps
::
SArray
<
int
>
*
lens
,
int
cmd
,
const
Callback
&
cb
)
{
int
ts
=
general_customer_
->
NewRequest
(
::
ps
::
kServerGroup
);
const
auto
&
callback
=
[
this
,
ts
,
keys
,
vals
,
lens
,
cb
]()
mutable
{
mutex_
.
lock
();
auto
&
kvs
=
gathered_response_
[
ts
];
mutex_
.
unlock
();
*
vals
=
kvs
.
vals
;
if
(
lens
)
{
*
lens
=
kvs
.
lens
;
}
mutex_
.
lock
();
gathered_response_
.
erase
(
ts
);
mutex_
.
unlock
();
if
(
cb
)
{
cb
();
}
};
general_callbacks_
[
ts
]
=
callback
;
return
ts
;
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
LookupIdSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
...
...
@@ -236,11 +333,70 @@ void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send,
template
<
typename
T
>
void
WorkerProxy
<
T
>::
BroadcastSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
auto
server_num
=
::
ps
::
Postoffice
::
Get
()
->
num_servers
();
sliced
->
resize
(
server_num
);
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
sliced
->
resize
(
server_num_
);
for
(
int
i
=
0
;
i
<
server_num_
;
i
++
)
{
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
send
;
expected_result_count_
[
timestamp
]
+=
1
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
RoundRobinSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
sliced
->
resize
(
server_num_
);
auto
keys
=
send
.
keys
;
auto
vals
=
send
.
vals
;
auto
lens
=
send
.
lens
;
int
server_id
,
len
;
::
ps
::
Key
param_key
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
i
++
)
{
param_key
=
keys
[
i
];
server_id
=
key_to_server_id_
[
param_key
];
if
(
!
sliced
->
at
(
server_id
).
first
)
{
sliced
->
at
(
server_id
).
first
=
true
;
expected_result_count_
[
timestamp
]
+=
1
;
}
::
ps
::
KVPairs
<
T
>
&
server_kv_pairs
=
sliced
->
at
(
server_id
).
second
;
server_kv_pairs
.
keys
.
push_back
(
param_key
);
if
(
vals
.
empty
())
{
continue
;
}
len
=
lens
[
i
];
int
offset
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
i
,
0
);
auto
val_begin
=
vals
.
begin
()
+
offset
;
auto
val_end
=
val_begin
+
len
;
for
(
auto
iter
=
val_begin
;
iter
!=
val_end
;
iter
++
)
{
server_kv_pairs
.
vals
.
push_back
(
*
iter
);
}
server_kv_pairs
.
lens
.
push_back
(
len
);
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
WorkerInitEmbeddingSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
sliced
->
resize
(
server_num_
);
auto
keys
=
send
.
keys
;
auto
vals
=
send
.
vals
;
auto
lens
=
send
.
lens
;
size_t
col_cnt
=
lens
[
0
]
/
embedding_row_cnt_
[
keys
[
0
]];
const
std
::
vector
<::
ps
::
Range
>
&
ranges
=
*
(
embedding_table_ranges_
[
keys
[
0
]]);
for
(
size_t
i
=
0
;
i
<
ranges
.
size
();
i
++
)
{
size_t
offset_begin
=
ranges
[
i
].
begin
()
*
col_cnt
;
size_t
offset_end
=
(
ranges
[
i
].
end
()
+
1
)
*
col_cnt
;
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
vals
=
vals
.
segment
(
offset_begin
,
offset_end
);
kvs
.
lens
.
push_back
(
offset_end
-
offset_begin
);
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
kvs
;
}
}
...
...
@@ -266,6 +422,37 @@ void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
ProcessResponse
(
const
::
ps
::
Message
&
msg
)
{
int
ts
=
msg
.
meta
.
timestamp
;
if
(
msg
.
meta
.
pull
)
{
CHECK_GE
(
msg
.
data
.
size
(),
(
size_t
)
2
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
msg
.
data
[
0
];
kvs
.
vals
=
msg
.
data
[
1
];
if
(
msg
.
data
.
size
()
>
(
size_t
)
2
)
{
kvs
.
lens
=
msg
.
data
[
2
];
}
mutex_
.
lock
();
for
(
auto
key
:
kvs
.
keys
)
{
gathered_response_
[
ts
].
keys
.
push_back
(
key
);
}
for
(
auto
val
:
kvs
.
vals
)
{
gathered_response_
[
ts
].
vals
.
push_back
(
val
);
}
for
(
auto
len
:
kvs
.
lens
)
{
gathered_response_
[
ts
].
lens
.
push_back
(
len
);
}
mutex_
.
unlock
();
if
(
general_customer_
->
NumResponse
(
ts
)
+
1
==
server_num_
)
{
const
auto
&
cb
=
general_callbacks_
[
ts
];
cb
();
general_callbacks_
.
erase
(
ts
);
}
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
Send
(
::
ps
::
Customer
*
customer
,
int
timestamp
,
bool
push
,
bool
pull
,
int
cmd
,
const
::
ps
::
KVPairs
<
T
>
&
kvs
,
const
Slicer
&
slicer
)
{
...
...
model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh
浏览文件 @
348c67ac
...
...
@@ -99,27 +99,31 @@ then
fi
cd
..
export
MS_ROLE
=
MS_PSERVER
rm
-rf
./server
mkdir
./server
cp
../
*
.py ./server
cp
*
.sh ./server
cp
-r
../src ./server
cd
./server
||
exit
if
[
$#
==
3
]
then
mpirun
--allow-run-as-root
-n
1
\
python train.py
--net
=
$1
--dataset
=
$2
--run_distribute
=
True
\
--device_num
=
$DEVICE_NUM
--device_target
=
"GPU"
--dataset_path
=
$PATH1
--parameter_server
=
True &> server.log &
fi
if
[
$#
==
4
]
then
mpirun
--allow-run-as-root
-n
1
\
for
((
i
=
0
;
i<
$MS_SERVER_NUM
;
i++
))
;
do
rm
-rf
./server_
$i
mkdir
./server_
$i
cp
../
*
.py ./server_
$i
cp
*
.sh ./server_
$i
cp
-r
../src ./server_
$i
cd
./server_
$i
||
exit
if
[
$#
==
3
]
then
mpirun
--allow-run-as-root
-n
1
\
python train.py
--net
=
$1
--dataset
=
$2
--run_distribute
=
True
\
--device_num
=
$DEVICE_NUM
--device_target
=
"GPU"
--dataset_path
=
$PATH1
--parameter_server
=
True
--pre_trained
=
$PATH2
&> server.log &
fi
cd
..
--device_num
=
$DEVICE_NUM
--device_target
=
"GPU"
--dataset_path
=
$PATH1
--parameter_server
=
True &> server_
$i
.log &
fi
if
[
$#
==
4
]
then
mpirun
--allow-run-as-root
-n
1
\
python train.py
--net
=
$1
--dataset
=
$2
--run_distribute
=
True
\
--device_num
=
$DEVICE_NUM
--device_target
=
"GPU"
--dataset_path
=
$PATH1
--parameter_server
=
True
--pre_trained
=
$PATH2
&> server_
$i
.log &
fi
cd
..
done
export
MS_ROLE
=
MS_WORKER
rm
-rf
./worker
...
...
tests/st/ps/multi_
worker_
full_ps/entry.py
→
tests/st/ps/multi_full_ps/entry.py
浏览文件 @
348c67ac
...
...
@@ -14,19 +14,22 @@
# ============================================================================
import
os
# @pytest.mark.level0
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
# @pytest.mark.env_single
def
test_multi_worker_full_ps_ascend_lenet
():
return_code
=
os
.
system
(
"bash shell_run_test.sh Ascend 8 1 127.0.0.1 8088"
)
def
test_ps_ascend_multi_worker_multi_server
():
return_code
=
os
.
system
(
"bash shell_run_test.sh Ascend 8 8 127.0.0.1 8088"
)
assert
return_code
==
0
# @pytest.mark.level0
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
# @pytest.mark.env_onecard
def
test_full_ps_ascend_lenet
():
def
test_ps_ascend
():
return_code
=
os
.
system
(
"bash shell_run_test.sh Ascend 1 1 127.0.0.1 8088"
)
assert
return_code
==
0
def
test_ps_gpu_multi_worker_multi_server
():
return_code
=
os
.
system
(
"bash shell_run_test.sh GPU 8 8 127.0.0.1 8088"
)
assert
return_code
==
0
def
test_ps_gpu
():
return_code
=
os
.
system
(
"bash shell_run_test.sh GPU 1 1 127.0.0.1 8088"
)
assert
return_code
==
0
tests/st/ps/multi_full_ps/resnet.py
0 → 100755
浏览文件 @
348c67ac
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore.common.tensor
import
Tensor
def
_weight_variable
(
shape
,
factor
=
0.01
):
init_value
=
np
.
random
.
randn
(
*
shape
).
astype
(
np
.
float32
)
*
factor
return
Tensor
(
init_value
)
def
_conv3x3
(
in_channel
,
out_channel
,
stride
=
1
):
weight_shape
=
(
out_channel
,
in_channel
,
3
,
3
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'same'
,
weight_init
=
weight
)
def
_conv1x1
(
in_channel
,
out_channel
,
stride
=
1
):
weight_shape
=
(
out_channel
,
in_channel
,
1
,
1
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'same'
,
weight_init
=
weight
)
def
_conv7x7
(
in_channel
,
out_channel
,
stride
=
1
):
weight_shape
=
(
out_channel
,
in_channel
,
7
,
7
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
=
7
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'same'
,
weight_init
=
weight
)
def
_bn
(
channel
):
return
nn
.
BatchNorm2d
(
channel
,
eps
=
1e-4
,
momentum
=
0.9
,
gamma_init
=
1
,
beta_init
=
0
,
moving_mean_init
=
0
,
moving_var_init
=
1
)
def
_bn_last
(
channel
):
return
nn
.
BatchNorm2d
(
channel
,
eps
=
1e-4
,
momentum
=
0.9
,
gamma_init
=
0
,
beta_init
=
0
,
moving_mean_init
=
0
,
moving_var_init
=
1
)
def
_fc
(
in_channel
,
out_channel
):
weight_shape
=
(
out_channel
,
in_channel
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Dense
(
in_channel
,
out_channel
,
has_bias
=
True
,
weight_init
=
weight
,
bias_init
=
0
)
class
ResidualBlock
(
nn
.
Cell
):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion
=
4
def
__init__
(
self
,
in_channel
,
out_channel
,
stride
=
1
):
super
(
ResidualBlock
,
self
).
__init__
()
channel
=
out_channel
//
self
.
expansion
self
.
conv1
=
_conv1x1
(
in_channel
,
channel
,
stride
=
1
)
self
.
bn1
=
_bn
(
channel
)
self
.
conv2
=
_conv3x3
(
channel
,
channel
,
stride
=
stride
)
self
.
bn2
=
_bn
(
channel
)
self
.
conv3
=
_conv1x1
(
channel
,
out_channel
,
stride
=
1
)
self
.
bn3
=
_bn_last
(
out_channel
)
self
.
relu
=
nn
.
ReLU
()
self
.
down_sample
=
False
if
stride
!=
1
or
in_channel
!=
out_channel
:
self
.
down_sample
=
True
self
.
down_sample_layer
=
None
if
self
.
down_sample
:
self
.
down_sample_layer
=
nn
.
SequentialCell
([
_conv1x1
(
in_channel
,
out_channel
,
stride
),
_bn
(
out_channel
)])
self
.
add
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
down_sample
:
identity
=
self
.
down_sample_layer
(
identity
)
out
=
self
.
add
(
out
,
identity
)
out
=
self
.
relu
(
out
)
return
out
class
ResNet
(
nn
.
Cell
):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def
__init__
(
self
,
block
,
layer_nums
,
in_channels
,
out_channels
,
strides
,
num_classes
):
super
(
ResNet
,
self
).
__init__
()
if
not
len
(
layer_nums
)
==
len
(
in_channels
)
==
len
(
out_channels
)
==
4
:
raise
ValueError
(
"the length of layer_num, in_channels, out_channels list must be 4!"
)
self
.
conv1
=
_conv7x7
(
3
,
64
,
stride
=
2
)
self
.
bn1
=
_bn
(
64
)
self
.
relu
=
P
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
pad_mode
=
"same"
)
self
.
layer1
=
self
.
_make_layer
(
block
,
layer_nums
[
0
],
in_channel
=
in_channels
[
0
],
out_channel
=
out_channels
[
0
],
stride
=
strides
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
layer_nums
[
1
],
in_channel
=
in_channels
[
1
],
out_channel
=
out_channels
[
1
],
stride
=
strides
[
1
])
self
.
layer3
=
self
.
_make_layer
(
block
,
layer_nums
[
2
],
in_channel
=
in_channels
[
2
],
out_channel
=
out_channels
[
2
],
stride
=
strides
[
2
])
self
.
layer4
=
self
.
_make_layer
(
block
,
layer_nums
[
3
],
in_channel
=
in_channels
[
3
],
out_channel
=
out_channels
[
3
],
stride
=
strides
[
3
])
self
.
mean
=
P
.
ReduceMean
(
keep_dims
=
True
)
self
.
flatten
=
nn
.
Flatten
()
self
.
end_point
=
_fc
(
out_channels
[
3
],
num_classes
)
def
_make_layer
(
self
,
block
,
layer_num
,
in_channel
,
out_channel
,
stride
):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers
=
[]
resnet_block
=
block
(
in_channel
,
out_channel
,
stride
=
stride
)
layers
.
append
(
resnet_block
)
for
_
in
range
(
1
,
layer_num
):
resnet_block
=
block
(
out_channel
,
out_channel
,
stride
=
1
)
layers
.
append
(
resnet_block
)
return
nn
.
SequentialCell
(
layers
)
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
c1
=
self
.
maxpool
(
x
)
c2
=
self
.
layer1
(
c1
)
c3
=
self
.
layer2
(
c2
)
c4
=
self
.
layer3
(
c3
)
c5
=
self
.
layer4
(
c4
)
out
=
self
.
mean
(
c5
,
(
2
,
3
))
out
=
self
.
flatten
(
out
)
out
=
self
.
end_point
(
out
)
return
out
def
resnet50
(
class_num
=
10
):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return
ResNet
(
ResidualBlock
,
[
3
,
4
,
6
,
3
],
[
64
,
256
,
512
,
1024
],
[
256
,
512
,
1024
,
2048
],
[
1
,
2
,
2
,
2
],
class_num
)
def
resnet101
(
class_num
=
1001
):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return
ResNet
(
ResidualBlock
,
[
3
,
4
,
23
,
3
],
[
64
,
256
,
512
,
1024
],
[
256
,
512
,
1024
,
2048
],
[
1
,
2
,
2
,
2
],
class_num
)
tests/st/ps/multi_
worker_
full_ps/shell_run_test.sh
→
tests/st/ps/multi_full_ps/shell_run_test.sh
浏览文件 @
348c67ac
...
...
@@ -30,9 +30,7 @@ do
rm
-rf
${
execute_path
}
/sched_
$i
/
mkdir
${
execute_path
}
/sched_
$i
/
cd
${
execute_path
}
/sched_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
${
self_path
}
/../test_multi_worker_full_ps_lenet.py
--device_target
=
$DEVICE_TARGET
&
python
${
self_path
}
/../test_multi_full_ps.py
--device_target
=
$DEVICE_TARGET
&
done
export
MS_ROLE
=
MS_PSERVER
...
...
@@ -43,10 +41,11 @@ do
cd
${
execute_path
}
/server_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
${
self_path
}
/../test_multi_
worker_full_ps_lenet
.py
--device_target
=
$DEVICE_TARGET
&
python
${
self_path
}
/../test_multi_
full_ps
.py
--device_target
=
$DEVICE_TARGET
&
done
export
MS_ROLE
=
MS_WORKER
if
[
$DEVICE_TARGET
==
"Ascend"
]
;
then
for
((
i
=
0
;
i<
$MS_WORKER_NUM
;
i++
))
;
do
rm
-rf
${
execute_path
}
/worker_
$i
/
...
...
@@ -54,8 +53,15 @@ do
cd
${
execute_path
}
/worker_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
${
self_path
}
/../test_multi_
worker_full_ps_lenet
.py
--device_target
=
$DEVICE_TARGET
&
python
${
self_path
}
/../test_multi_
full_ps
.py
--device_target
=
$DEVICE_TARGET
&
done
fi
if
[
$DEVICE_TARGET
==
"GPU"
]
;
then
rm
-rf
${
execute_path
}
/worker/
mkdir
${
execute_path
}
/worker/
cd
${
execute_path
}
/worker/
||
exit
mpirun
-n
$MS_WORKER_NUM
python
${
self_path
}
/../test_multi_full_ps.py
--device_target
=
$DEVICE_TARGET
&
fi
wait
$!
exit
$?
tests/st/ps/multi_
worker_full_ps/test_multi_worker_full_ps_lenet
.py
→
tests/st/ps/multi_
full_ps/test_multi_full_ps
.py
浏览文件 @
348c67ac
...
...
@@ -21,12 +21,16 @@ import mindspore.nn as nn
from
mindspore.common.initializer
import
TruncatedNormal
from
mindspore
import
Tensor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.communication.management
import
init
,
get_group_size
# from resnet import resnet50
parser
=
argparse
.
ArgumentParser
(
description
=
"test_ps_lenet"
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
)
args
,
_
=
parser
.
parse_known_args
()
device_target
=
args
.
device_target
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
if
device_target
==
"GPU"
:
init
(
'nccl'
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
...
...
@@ -94,7 +98,8 @@ if __name__ == "__main__":
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
if
device_target
==
"GPU"
:
context
.
set_auto_parallel_context
(
parallel_mode
=
"data_parallel"
,
mirror_mean
=
True
,
device_num
=
get_group_size
())
net_with_criterion
=
WithLossCell
(
network
,
criterion
)
train_network
=
TrainOneStepCell
(
net_with_criterion
,
net_opt
)
train_network
.
set_train
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录