Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2a9c4588
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a9c4588
编写于
9月 08, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 08, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5812 Add PS context.
Merge pull request !5812 from ZPaC/master-context-for-ps
上级
39408806
87bf2a7d
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
381 addition
and
94 deletion
+381
-94
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
.../kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
+1
-2
mindspore/ccsrc/frontend/parallel/ps/common.h
mindspore/ccsrc/frontend/parallel/ps/common.h
+0
-5
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
+2
-1
mindspore/ccsrc/frontend/parallel/ps/ps_context.cc
mindspore/ccsrc/frontend/parallel/ps/ps_context.cc
+86
-0
mindspore/ccsrc/frontend/parallel/ps/ps_context.h
mindspore/ccsrc/frontend/parallel/ps/ps_context.h
+61
-0
mindspore/ccsrc/frontend/parallel/ps/util.cc
mindspore/ccsrc/frontend/parallel/ps/util.cc
+6
-29
mindspore/ccsrc/frontend/parallel/ps/util.h
mindspore/ccsrc/frontend/parallel/ps/util.h
+0
-2
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
+2
-1
mindspore/ccsrc/pipeline/jit/init.cc
mindspore/ccsrc/pipeline/jit/init.cc
+11
-3
mindspore/common/api.py
mindspore/common/api.py
+2
-2
mindspore/common/parameter.py
mindspore/common/parameter.py
+8
-2
mindspore/communication/_comm_helper.py
mindspore/communication/_comm_helper.py
+4
-5
mindspore/communication/management.py
mindspore/communication/management.py
+2
-3
mindspore/context.py
mindspore/context.py
+58
-1
mindspore/parallel/_ps_context.py
mindspore/parallel/_ps_context.py
+115
-0
mindspore/parallel/_ps_utils.py
mindspore/parallel/_ps_utils.py
+0
-23
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+2
-2
mindspore/train/model.py
mindspore/train/model.py
+3
-3
model_zoo/official/cv/resnet/train.py
model_zoo/official/cv/resnet/train.py
+1
-0
model_zoo/official/nlp/bert_thor/src/model_thor.py
model_zoo/official/nlp/bert_thor/src/model_thor.py
+0
-4
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
...ecommend/wide_and_deep/train_and_eval_parameter_server.py
+1
-0
tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
...s/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
+6
-4
tests/st/ps/full_ps/test_full_ps_lenet.py
tests/st/ps/full_ps/test_full_ps_lenet.py
+1
-0
tests/st/ps/multi_full_ps/test_multi_full_ps.py
tests/st/ps/multi_full_ps/test_multi_full_ps.py
+9
-2
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
浏览文件 @
2a9c4588
...
...
@@ -41,8 +41,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG
(
INFO
)
<<
"Init embedding lookup proxy kernel, input shape:"
<<
input_shape
<<
", indices_shape:"
<<
indices_shape
<<
", output_shape:"
<<
output_shape
;
std
::
vector
<
int
>
lens
{
SizeToInt
(
input_shape
.
size
()),
SizeToInt
(
indices_shape
.
size
()),
SizeToInt
(
output_shape
.
size
())};
const
char
*
env_role
=
getenv
(
mindspore
::
parallel
::
ps
::
kEnvRole
);
if
(
env_role
!=
nullptr
&&
strcmp
(
env_role
,
mindspore
::
parallel
::
ps
::
kEnvRoleOfWorker
)
==
0
)
{
if
(
mindspore
::
parallel
::
ps
::
Util
::
IsRoleOfWorker
())
{
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
AddEmbeddingTable
(
key_
,
input_shape
[
axis
]);
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
InitPSEmbeddingTable
(
keys
,
values
,
lens
);
}
...
...
mindspore/ccsrc/frontend/parallel/ps/common.h
浏览文件 @
2a9c4588
...
...
@@ -32,11 +32,6 @@ constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
constexpr
char
kEnvSchedulerHost
[]
=
"MS_SCHED_HOST"
;
constexpr
char
kEnvSchedulerPort
[]
=
"MS_SCHED_PORT"
;
constexpr
char
kEnvRole
[]
=
"MS_ROLE"
;
constexpr
char
kEnvRoleOfPServer
[]
=
"MS_PSERVER"
;
constexpr
char
kEnvRoleOfWorker
[]
=
"MS_WORKER"
;
constexpr
char
kEnvRoleOfScheduler
[]
=
"MS_SCHED"
;
constexpr
char
kDmlcCommType
[]
=
"DMLC_PS_VAN_TYPE"
;
constexpr
char
kDmlcInterface
[]
=
"DMLC_INTERFACE"
;
constexpr
char
kDmlcPServerNum
[]
=
"DMLC_NUM_SERVER"
;
...
...
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
浏览文件 @
2a9c4588
...
...
@@ -39,6 +39,7 @@
#include "frontend/parallel/ps/optimizer_info.h"
#include "frontend/parallel/ps/optimizer_info_builder.h"
#include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/ps_context.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
...
...
@@ -741,7 +742,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
return
;
}
Init
(
func_graph
);
Util
::
Set
RankId
(
rank_id_
);
PSContext
::
instance
()
->
SetPS
RankId
(
rank_id_
);
thread_
->
join
();
::
ps
::
Finalize
(
0
,
true
);
}
...
...
mindspore/ccsrc/frontend/parallel/ps/ps_context.cc
0 → 100644
浏览文件 @
2a9c4588
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ps/ps_context.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
namespace
mindspore
{
namespace
parallel
{
namespace
ps
{
std
::
shared_ptr
<
PSContext
>
PSContext
::
instance
()
{
static
std
::
shared_ptr
<
PSContext
>
ps_instance
=
nullptr
;
if
(
ps_instance
==
nullptr
)
{
ps_instance
.
reset
(
new
(
std
::
nothrow
)
PSContext
());
}
return
ps_instance
;
}
void
PSContext
::
SetPSEnable
(
bool
enabled
)
{
ps_enabled_
=
enabled
;
if
(
ps_enabled_
)
{
std
::
string
ms_role
=
common
::
GetEnv
(
kEnvRole
);
MS_LOG
(
INFO
)
<<
"PS mode is enabled. MS_ROLE is "
<<
ms_role
;
if
(
ms_role
==
kEnvRoleOfWorker
)
{
is_worker_
=
true
;
}
else
if
(
ms_role
==
kEnvRoleOfPServer
)
{
is_pserver_
=
true
;
}
else
if
(
ms_role
==
kEnvRoleOfScheduler
)
{
is_sched_
=
true
;
}
else
{
MS_LOG
(
WARNING
)
<<
"MS_ROLE is "
<<
ms_role
<<
", which is invalid."
;
}
}
else
{
MS_LOG
(
INFO
)
<<
"PS mode is disabled."
;
is_worker_
=
false
;
is_pserver_
=
false
;
is_sched_
=
false
;
}
}
bool
PSContext
::
is_ps_enabled
()
const
{
return
ps_enabled_
;
}
void
PSContext
::
Reset
()
{
ps_enabled_
=
false
;
is_worker_
=
false
;
is_pserver_
=
false
;
is_sched_
=
false
;
}
std
::
string
PSContext
::
ms_role
()
const
{
if
(
is_worker_
)
{
return
kEnvRoleOfWorker
;
}
else
if
(
is_pserver_
)
{
return
kEnvRoleOfPServer
;
}
else
if
(
is_sched_
)
{
return
kEnvRoleOfScheduler
;
}
else
{
return
kEnvRoleOfNotPS
;
}
}
bool
PSContext
::
is_role_worker
()
const
{
return
is_worker_
;
}
bool
PSContext
::
is_role_pserver
()
const
{
return
is_pserver_
;
}
bool
PSContext
::
is_role_sched
()
const
{
return
is_sched_
;
}
void
PSContext
::
SetPSRankId
(
int
rank_id
)
{
rank_id_
=
rank_id
;
}
int
PSContext
::
ps_rank_id
()
const
{
return
rank_id_
;
}
}
// namespace ps
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/ps/ps_context.h
0 → 100644
浏览文件 @
2a9c4588
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
#include <string>
#include <memory>
namespace
mindspore
{
namespace
parallel
{
namespace
ps
{
constexpr
char
kEnvRole
[]
=
"MS_ROLE"
;
constexpr
char
kEnvRoleOfPServer
[]
=
"MS_PSERVER"
;
constexpr
char
kEnvRoleOfWorker
[]
=
"MS_WORKER"
;
constexpr
char
kEnvRoleOfScheduler
[]
=
"MS_SCHED"
;
constexpr
char
kEnvRoleOfNotPS
[]
=
"MS_NOT_PS"
;
class
PSContext
{
public:
~
PSContext
()
=
default
;
PSContext
(
PSContext
const
&
)
=
delete
;
PSContext
&
operator
=
(
const
PSContext
&
)
=
delete
;
static
std
::
shared_ptr
<
PSContext
>
instance
();
void
SetPSEnable
(
bool
enabled
);
bool
is_ps_enabled
()
const
;
void
Reset
();
std
::
string
ms_role
()
const
;
bool
is_role_worker
()
const
;
bool
is_role_pserver
()
const
;
bool
is_role_sched
()
const
;
void
SetPSRankId
(
int
rank_id
);
int
ps_rank_id
()
const
;
private:
PSContext
()
:
ps_enabled_
(
false
),
is_worker_
(
false
),
is_pserver_
(
false
),
is_sched_
(
false
),
rank_id_
(
-
1
)
{}
bool
ps_enabled_
;
bool
is_worker_
;
bool
is_pserver_
;
bool
is_sched_
;
int
rank_id_
;
};
}
// namespace ps
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
mindspore/ccsrc/frontend/parallel/ps/util.cc
浏览文件 @
2a9c4588
...
...
@@ -16,7 +16,9 @@
#include "frontend/parallel/ps/util.h"
#include <unordered_map>
#include <vector>
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/ps_context.h"
#include "utils/ms_utils.h"
namespace
mindspore
{
...
...
@@ -45,34 +47,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizer_nodes{
{
3
,
kSparseFtrlOp
},
};
bool
Util
::
IsParamServerMode
()
{
return
IsRoleOfWorker
()
||
IsRoleOfPServer
()
||
IsRoleOfScheduler
();
}
bool
Util
::
IsParamServerMode
()
{
return
PSContext
::
instance
()
->
is_ps_enabled
();
}
bool
Util
::
IsRoleOfWorker
()
{
auto
role
=
common
::
GetEnv
(
kEnvRole
);
if
(
strcmp
(
role
.
c_str
(),
kEnvRoleOfWorker
)
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
bool
Util
::
IsRoleOfWorker
()
{
return
PSContext
::
instance
()
->
is_role_worker
();
}
bool
Util
::
IsRoleOfPServer
()
{
auto
role
=
common
::
GetEnv
(
kEnvRole
);
if
(
strcmp
(
role
.
c_str
(),
kEnvRoleOfPServer
)
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
bool
Util
::
IsRoleOfPServer
()
{
return
PSContext
::
instance
()
->
is_role_pserver
();
}
bool
Util
::
IsRoleOfScheduler
()
{
auto
role
=
common
::
GetEnv
(
kEnvRole
);
if
(
strcmp
(
role
.
c_str
(),
kEnvRoleOfScheduler
)
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
bool
Util
::
IsRoleOfScheduler
()
{
return
PSContext
::
instance
()
->
is_role_sched
();
}
void
Util
::
SetInternalEnvVar
()
{
if
(
IsParamServerMode
())
{
...
...
@@ -163,10 +144,6 @@ std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int serve
return
shard_dims
;
}
void
Util
::
SetRankId
(
int
rank_id
)
{
rank_id_
=
rank_id
;
}
int
Util
::
GetRankId
()
{
return
rank_id_
;
}
void
Util
::
ReduceSparseGradient
(
float
*
gradients
,
int
*
indices
,
const
size_t
indices_size
,
size_t
segment_size
,
const
size_t
first_dim_size
,
const
size_t
outer_dim_size
,
mindspore
::
kernel
::
SparseGradient
<
int
>
*
unique_sparse_grad
)
{
...
...
mindspore/ccsrc/frontend/parallel/ps/util.h
浏览文件 @
2a9c4588
...
...
@@ -40,8 +40,6 @@ class Util {
static
bool
is_optimizer
(
std
::
string
name
);
static
int
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
);
static
std
::
map
<
int
,
int
>
AllRankLocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
);
static
void
SetRankId
(
int
rank_id
);
static
int
GetRankId
();
static
void
ReduceSparseGradient
(
float
*
gradients
,
int
*
indices
,
const
size_t
indices_size
,
size_t
segment_size
,
const
size_t
first_dim_size
,
const
size_t
outer_dim_size
,
mindspore
::
kernel
::
SparseGradient
<
int
>
*
unique_sparse_grad
);
...
...
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
浏览文件 @
2a9c4588
...
...
@@ -27,6 +27,7 @@
#include "ps/ps.h"
#include "frontend/parallel/ps/util.h"
#include "backend/kernel_compiler/common_utils.h"
#include "frontend/parallel/ps/ps_context.h"
namespace
mindspore
{
namespace
parallel
{
...
...
@@ -43,7 +44,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
explicit
WorkerProxy
(
int
app_id
,
int
customer_id
,
int
lookup_customer_id
,
int
general_customer_id
)
:
Worker
(
app_id
,
customer_id
)
{
server_num_
=
::
ps
::
NumServers
();
Util
::
Set
RankId
(
::
ps
::
MyRank
());
PSContext
::
instance
()
->
SetPS
RankId
(
::
ps
::
MyRank
());
using
std
::
placeholders
::
_1
;
using
std
::
placeholders
::
_2
;
using
std
::
placeholders
::
_3
;
...
...
mindspore/ccsrc/pipeline/jit/init.cc
浏览文件 @
2a9c4588
...
...
@@ -36,6 +36,7 @@
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "frontend/parallel/ps/util.h"
#endif
#include "frontend/parallel/ps/ps_context.h"
namespace
py
=
pybind11
;
using
EnvInstance
=
mindspore
::
EnvInstance
;
...
...
@@ -49,6 +50,7 @@ using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
using
ParallelContext
=
mindspore
::
parallel
::
ParallelContext
;
using
CostModelContext
=
mindspore
::
parallel
::
CostModelContext
;
using
mindspore
::
MsCtxParam
;
using
PSContext
=
mindspore
::
parallel
::
ps
::
PSContext
;
// Interface with python
PYBIND11_MODULE
(
_c_expression
,
m
)
{
...
...
@@ -276,9 +278,15 @@ PYBIND11_MODULE(_c_expression, m) {
"Finalize gpu collective communication mode."
);
#endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
(
void
)
m
.
def
(
"get_ps_mode_rank"
,
&
mindspore
::
parallel
::
ps
::
Util
::
GetRankId
,
"Get Worker and PServer rank id."
);
#endif
(
void
)
py
::
class_
<
PSContext
,
std
::
shared_ptr
<
PSContext
>>
(
m
,
"PSContext"
)
.
def_static
(
"get_instance"
,
&
PSContext
::
instance
,
"Get PS context instance."
)
.
def
(
"set_ps_enable"
,
&
PSContext
::
SetPSEnable
,
"Set PS mode enabled or disabled."
)
.
def
(
"is_ps_enabled"
,
&
PSContext
::
is_ps_enabled
,
"Get PS mode enable-disable status."
)
.
def
(
"reset"
,
&
PSContext
::
Reset
,
"Reset PS context attributes."
)
.
def
(
"is_role_worker"
,
&
PSContext
::
is_role_worker
,
"Get whether the role of this process is Worker."
)
.
def
(
"is_role_pserver"
,
&
PSContext
::
is_role_pserver
,
"Get whether the role of this process is PServer."
)
.
def
(
"is_role_sched"
,
&
PSContext
::
is_role_sched
,
"Get whether the role of this process is Scheduler."
)
.
def
(
"ps_rank_id"
,
&
PSContext
::
ps_rank_id
,
"Get Worker and PServer rank id."
);
(
void
)
py
::
class_
<
OpInfoLoaderPy
,
std
::
shared_ptr
<
OpInfoLoaderPy
>>
(
m
,
"OpInfoLoaderPy"
)
.
def
(
py
::
init
())
...
...
mindspore/common/api.py
浏览文件 @
2a9c4588
...
...
@@ -15,7 +15,6 @@
# limitations under the License.
# ============================================================================
"""Providing interface methods."""
import
os
import
types
from
collections
import
OrderedDict
from
functools
import
wraps
...
...
@@ -25,6 +24,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ
from
.._c_expression
import
verify_inputs_signature
,
init_exec_dataset
,
_set_dataset_mode_config
,
init_backend
from
.tensor
import
Tensor
as
MsTensor
from
..parallel._utils
import
_get_device_num
,
_get_global_rank
,
_need_to_full
,
_to_full_tensor
from
..parallel._ps_context
import
_is_role_pserver
# store ms_function class compiled pipeline cache
ms_compile_cache
=
{}
...
...
@@ -469,7 +469,7 @@ class _Executor:
return
self
.
_executor
.
has_compiled
(
phase
)
def
__call__
(
self
,
obj
,
*
args
,
phase
=
'predict'
):
if
context
.
get_context
(
"precompile_only"
)
or
os
.
getenv
(
"MS_ROLE"
)
==
"MS_PSERVER"
:
if
context
.
get_context
(
"precompile_only"
)
or
_is_role_pserver
()
:
return
None
return
self
.
run
(
obj
,
*
args
,
phase
=
phase
)
...
...
mindspore/common/parameter.py
浏览文件 @
2a9c4588
...
...
@@ -22,6 +22,7 @@ from .tensor import Tensor, MetaTensor
from
.._checkparam
import
_check_str_by_regular
from
..parallel._tensor
import
_get_slice_index
from
..parallel._auto_parallel_context
import
auto_parallel_context
from
..parallel._ps_context
import
_is_role_worker
,
_is_role_pserver
,
_is_role_sched
__all__
=
[
'Parameter'
,
'ParameterTuple'
]
...
...
@@ -168,8 +169,13 @@ class Parameter(MetaTensor):
"""For parse check."""
def
set_param_ps
(
self
,
init_in_server
=
False
):
self
.
is_param_ps
=
True
self
.
init_in_server
=
init_in_server
if
_is_role_worker
()
or
_is_role_pserver
()
or
_is_role_sched
():
self
.
is_param_ps
=
True
self
.
init_in_server
=
init_in_server
else
:
raise
RuntimeError
(
"Must complete following two steps before calling set_param_ps:
\
1. set_ps_context(enable_ps=True)
\
2. export MS_ROLE environment variable."
)
@
property
...
...
mindspore/communication/_comm_helper.py
浏览文件 @
2a9c4588
...
...
@@ -14,7 +14,7 @@
# ============================================================================
"""comm_helper"""
import
os
from
mindspore.parallel._ps_context
import
_is_role_pserver
,
_is_role_sched
from
._hccl_management
import
load_lib
as
hccl_load_lib
_HCCL_AVAILABLE
=
False
...
...
@@ -44,7 +44,6 @@ else:
HCCL_WORLD_COMM_GROUP
=
"hccl_world_group"
NCCL_WORLD_COMM_GROUP
=
"nccl_world_group"
MS_ROLE
=
os
.
getenv
(
"MS_ROLE"
)
class
Backend
:
"""
...
...
@@ -113,7 +112,7 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def
wrapper
(
*
args
,
**
kargs
):
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
if
_is_role_pserver
()
or
_is_role_sched
(
):
return
func
(
*
args
,
**
kargs
)
group
=
None
if
"group"
in
kargs
.
keys
():
...
...
@@ -154,7 +153,7 @@ def _get_rank_helper(group, backend):
Integer. The local rank id of the calling process.
"""
rank_id
=
None
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
if
_is_role_pserver
()
or
_is_role_sched
(
):
rank_id
=
0
return
rank_id
if
backend
==
Backend
.
HCCL
:
...
...
@@ -213,7 +212,7 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size
=
None
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
if
_is_role_pserver
()
or
_is_role_sched
(
):
size
=
1
return
size
if
backend
==
Backend
.
HCCL
:
...
...
mindspore/communication/management.py
浏览文件 @
2a9c4588
...
...
@@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""Communication management API"""
import
os
from
mindspore
import
context
from
mindspore.parallel._ps_context
import
_is_role_pserver
,
_is_role_sched
from
._comm_helper
import
Backend
,
_get_rank_helper
,
_get_size_helper
,
\
_get_world_rank_from_group_rank_helper
,
_get_group_rank_from_world_rank_helper
,
\
_create_group_helper
,
_destroy_group_helper
,
HCCL_WORLD_COMM_GROUP
,
NCCL_WORLD_COMM_GROUP
,
\
...
...
@@ -29,7 +29,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
DEFAULT_WORLD_COMM_GROUP
=
HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND
=
Backend
(
"hccl"
)
MS_ROLE
=
os
.
getenv
(
"MS_ROLE"
)
def
_get_group
(
group
):
...
...
@@ -61,7 +60,7 @@ def init(backend_name=None):
RuntimeError: If device target is invalid.
RuntimeError: If backend is invalid or distributed init fails.
"""
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
if
_is_role_pserver
()
or
_is_role_sched
(
):
return
if
backend_name
is
None
:
device_target
=
context
.
get_context
(
"device_target"
)
...
...
mindspore/context.py
浏览文件 @
2a9c4588
...
...
@@ -26,9 +26,11 @@ from mindspore._c_expression import MSContext, ms_ctx_param
from
mindspore._checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
from
mindspore.parallel._ps_context
import
_set_ps_context
,
_get_ps_context
,
_reset_ps_context
__all__
=
[
'GRAPH_MODE'
,
'PYNATIVE_MODE'
,
'set_context'
,
'get_context'
,
'set_auto_parallel_context'
,
'get_auto_parallel_context'
,
'reset_auto_parallel_context'
,
'ParallelMode'
]
'get_auto_parallel_context'
,
'reset_auto_parallel_context'
,
'ParallelMode'
,
'set_ps_context'
,
'get_ps_context'
,
'reset_ps_context'
]
GRAPH_MODE
=
0
PYNATIVE_MODE
=
1
...
...
@@ -569,3 +571,58 @@ class ParallelMode:
SEMI_AUTO_PARALLEL
=
"semi_auto_parallel"
AUTO_PARALLEL
=
"auto_parallel"
MODE_LIST
=
[
STAND_ALONE
,
DATA_PARALLEL
,
HYBRID_PARALLEL
,
SEMI_AUTO_PARALLEL
,
AUTO_PARALLEL
]
@
args_type_check
(
enable_ps
=
bool
)
def
set_ps_context
(
**
kwargs
):
"""
Set parameter server training mode context.
Note:
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED represents the scheduler,
MS_WORKER represents the worker,
MS_PSERVER represents the Server
Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.
Raises:
ValueError: If input key is not the attribute in parameter server training mode context.
Examples:
>>> context.set_ps_context(enable_ps=True)
"""
_set_ps_context
(
**
kwargs
)
def
get_ps_context
(
attr_key
):
"""
Get parameter server training mode context attribute value according to the key.
Args:
attr_key (str): The key of the attribute.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
return
_get_ps_context
(
attr_key
)
def
reset_ps_context
():
"""
Reset parameter server training mode context attributes to the default values:
- enable_ps: False.
"""
_reset_ps_context
()
mindspore/parallel/_ps_context.py
0 → 100644
浏览文件 @
2a9c4588
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context for parameter server training mode"""
from
mindspore._c_expression
import
PSContext
_ps_context
=
None
def
ps_context
():
"""
Get the global _ps_context, if it is not created, create a new one.
Returns:
_ps_context, the global parameter server training mode context.
"""
global
_ps_context
if
_ps_context
is
None
:
_ps_context
=
PSContext
.
get_instance
()
return
_ps_context
_set_ps_context_func_map
=
{
"enable_ps"
:
ps_context
().
set_ps_enable
}
_get_ps_context_func_map
=
{
"enable_ps"
:
ps_context
().
is_ps_enabled
}
def
_get_ps_mode_rank
():
ps_rank
=
ps_context
().
ps_rank_id
()
if
ps_rank
==
-
1
:
raise
RuntimeError
(
"The parameter server mode training is not enabled yet."
)
return
ps_rank
def
_set_ps_context
(
**
kwargs
):
"""
Set parameter server training mode context.
Note:
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED represents the scheduler,
MS_WORKER represents the worker,
MS_PSERVER represents the Server
Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.
Raises:
ValueError: If input key is not the attribute in parameter server training mode context.
Examples:
>>> context.set_ps_context(enable_ps=True)
"""
for
key
,
value
in
kwargs
.
items
():
if
key
not
in
_set_ps_context_func_map
:
raise
ValueError
(
"Set PS context keyword %s is not recognized!"
%
key
)
set_func
=
_set_ps_context_func_map
[
key
]
set_func
(
value
)
def
_get_ps_context
(
attr_key
):
"""
Get parameter server training mode context attribute value according to the key.
Args:
attr_key (str): The key of the attribute.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
if
key
not
in
_get_ps_context_func_map
:
raise
ValueError
(
"Get PS context keyword %s is not recognized!"
%
key
)
get_func
=
_get_ps_context_func_map
[
attr_key
]
get_func
(
attr_key
)
def
_reset_ps_context
():
"""
Reset parameter server training mode context attributes to the default values:
- enable_ps: False.
"""
ps_context
().
reset
()
def
_is_role_worker
():
return
ps_context
().
is_role_worker
()
def
_is_role_pserver
():
return
ps_context
().
is_role_pserver
()
def
_is_role_sched
():
return
ps_context
().
is_role_sched
()
mindspore/parallel/_ps_utils.py
已删除
100644 → 0
浏览文件 @
39408806
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for parameter server training mode"""
from
mindspore._c_expression
import
get_ps_mode_rank
def
_get_ps_mode_rank
():
ps_rank
=
get_ps_mode_rank
()
if
ps_rank
==
-
1
:
raise
RuntimeError
(
"The parameter server mode training is not launched yet."
)
return
ps_rank
mindspore/train/callback/_checkpoint.py
浏览文件 @
2a9c4588
...
...
@@ -24,6 +24,7 @@ from mindspore import log as logger
from
mindspore._checkparam
import
check_bool
,
check_int_non_negative
from
mindspore.train._utils
import
_make_directory
from
mindspore.train.serialization
import
save_checkpoint
,
_save_graph
from
mindspore.parallel._ps_context
import
_is_role_pserver
,
_get_ps_mode_rank
from
._callback
import
Callback
,
set_cur_net
...
...
@@ -280,8 +281,7 @@ class ModelCheckpoint(Callback):
if
save_ckpt
:
cur_ckpoint_file
=
self
.
_prefix
+
"-"
+
str
(
cb_params
.
cur_epoch_num
)
+
"_"
\
+
str
(
step_num_in_epoch
)
+
".ckpt"
if
os
.
getenv
(
"MS_ROLE"
)
==
"MS_PSERVER"
:
from
mindspore.parallel._ps_utils
import
_get_ps_mode_rank
if
_is_role_pserver
():
cur_ckpoint_file
=
"PServer_"
+
str
(
_get_ps_mode_rank
())
+
"_"
+
cur_ckpoint_file
# update checkpoint file list.
self
.
_manager
.
update_ckpoint_filelist
(
self
.
_directory
,
self
.
_prefix
)
...
...
mindspore/train/model.py
浏览文件 @
2a9c4588
...
...
@@ -27,6 +27,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from
..
import
context
from
..parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
from
..parallel._ps_context
import
_is_role_pserver
,
_is_role_sched
from
..nn.metrics
import
Loss
from
..
import
nn
from
..nn.wrap.cell_wrapper
import
_VirtualDatasetCell
...
...
@@ -378,8 +379,7 @@ class Model:
cb_params
.
list_callback
=
self
.
_transform_callbacks
(
callbacks
)
cb_params
.
train_dataset_element
=
None
cb_params
.
network
=
self
.
_network
ms_role
=
os
.
getenv
(
"MS_ROLE"
)
if
ms_role
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
if
_is_role_pserver
()
or
_is_role_sched
():
epoch
=
1
# build callback list
...
...
@@ -516,7 +516,7 @@ class Model:
self
.
_loss_scale_manager
.
update_loss_scale
(
overflow
)
list_callback
.
step_end
(
run_context
)
if
os
.
getenv
(
"MS_ROLE"
)
==
"MS_PSERVER"
:
if
_is_role_pserver
()
:
os
.
_exit
(
0
)
should_stop
=
should_stop
or
run_context
.
get_stop_requested
()
if
should_stop
:
...
...
model_zoo/official/cv/resnet/train.py
浏览文件 @
2a9c4588
...
...
@@ -70,6 +70,7 @@ if __name__ == '__main__':
# init context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
target
,
save_graphs
=
False
)
context
.
set_ps_context
(
enable_ps
=
True
)
if
args_opt
.
run_distribute
:
if
target
==
"Ascend"
:
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
...
...
model_zoo/official/nlp/bert_thor/src/model_thor.py
浏览文件 @
2a9c4588
...
...
@@ -14,7 +14,6 @@
# ============================================================================
"""Model."""
import
math
import
os
from
collections.abc
import
Iterable
import
numpy
as
np
...
...
@@ -405,9 +404,6 @@ class Model:
cb_params
.
list_callback
=
self
.
_transform_callbacks
(
callbacks
)
cb_params
.
train_dataset_element
=
None
cb_params
.
network
=
self
.
_network
ms_role
=
os
.
getenv
(
"MS_ROLE"
)
if
ms_role
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
epoch
=
1
# build callback list
with
_CallbackManager
(
callbacks
)
as
list_callback
:
...
...
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
浏览文件 @
2a9c4588
...
...
@@ -118,6 +118,7 @@ if __name__ == "__main__":
wide_deep_config
.
argparse_init
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
wide_deep_config
.
device_target
)
context
.
set_ps_context
(
enable_ps
=
True
)
init
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
,
device_num
=
get_group_size
())
...
...
tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
浏览文件 @
2a9c4588
...
...
@@ -26,6 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from
mindspore.nn.optim
import
Adam
from
mindspore.ops
import
operations
as
P
from
mindspore.common.initializer
import
TruncatedNormal
from
mindspore.parallel._ps_context
import
_is_role_pserver
,
_is_role_worker
parser
=
argparse
.
ArgumentParser
(
description
=
"test_sparse_embedding"
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
)
...
...
@@ -34,6 +35,7 @@ device_target = args.device_target
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
,
enable_sparse
=
True
)
context
.
set_ps_context
(
enable_ps
=
True
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
...
...
@@ -81,7 +83,7 @@ def do_sparse_embedding(ps=False):
for
_
in
range
(
epoch
):
data
=
Tensor
(
np
.
random
.
randint
(
0
,
15
,
(
32
,
3
),
np
.
int32
))
label
=
Tensor
(
np
.
random
.
randint
(
0
,
9
,
(
32
),
np
.
int32
))
if
envs
.
get
(
"MS_ROLE"
)
==
"MS_PSERVER"
:
if
_is_role_pserver
()
:
train_network
(
data
,
label
)
sys
.
exit
()
else
:
...
...
@@ -96,10 +98,10 @@ if __name__ == "__main__":
np
.
random
.
seed
(
0
)
ps_loss
=
do_sparse_embedding
(
True
)
if
envs
.
get
(
"MS_ROLE"
)
==
"MS_WORKER"
:
envs
[
"MS_ROLE"
]
=
""
if
_is_role_worker
()
:
context
.
reset_ps_context
()
np
.
random
.
seed
(
0
)
no_ps_loss
=
do_sparse_embedding
()
envs
[
"MS_ROLE"
]
=
"MS_WORKER"
context
.
set_ps_context
(
enable_ps
=
True
)
assert
np
.
allclose
(
ps_loss
,
no_ps_loss
,
rtol
=
1.0e-6
,
atol
=
1.0e-6
)
tests/st/ps/full_ps/test_full_ps_lenet.py
浏览文件 @
2a9c4588
...
...
@@ -35,6 +35,7 @@ args, _ = parser.parse_known_args()
device_target
=
args
.
device_target
dataset_path
=
args
.
dataset_path
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
context
.
set_ps_context
(
enable_ps
=
True
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
"""weight initial for conv layer"""
...
...
tests/st/ps/multi_full_ps/test_multi_full_ps.py
浏览文件 @
2a9c4588
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import
sys
import
argparse
import
numpy
as
np
...
...
@@ -22,6 +23,7 @@ from mindspore.common.initializer import TruncatedNormal
from
mindspore
import
Tensor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.communication.management
import
init
,
get_group_size
from
mindspore.parallel._ps_context
import
_is_role_pserver
# from resnet import resnet50
parser
=
argparse
.
ArgumentParser
(
description
=
"test_ps_lenet"
)
...
...
@@ -29,6 +31,7 @@ parser.add_argument("--device_target", type=str, default="Ascend")
args
,
_
=
parser
.
parse_known_args
()
device_target
=
args
.
device_target
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
context
.
set_ps_context
(
enable_ps
=
True
)
if
device_target
==
"GPU"
:
init
()
...
...
@@ -106,6 +109,10 @@ if __name__ == "__main__":
for
_
in
range
(
epoch
):
data
=
Tensor
(
np
.
random
.
rand
(
32
,
3
,
32
,
32
).
astype
(
np
.
float32
))
label
=
Tensor
(
np
.
random
.
randint
(
0
,
9
,
(
32
)).
astype
(
np
.
int32
))
loss
=
train_network
(
data
,
label
).
asnumpy
()
losses
.
append
(
loss
)
if
_is_role_pserver
():
train_network
(
data
,
label
)
sys
.
exit
()
else
:
loss
=
train_network
(
data
,
label
).
asnumpy
()
losses
.
append
(
loss
)
print
(
losses
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录