Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3618b084
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3618b084
编写于
7月 13, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adaptation for ps mode.
上级
49da4e79
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
155 addition
and
24 deletion
+155
-24
mindspore/ccsrc/kernel/cpu/cpu_kernel.h
mindspore/ccsrc/kernel/cpu/cpu_kernel.h
+1
-1
mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h
mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h
+5
-3
mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
+4
-18
mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc
mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc
+92
-0
mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h
mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h
+41
-0
mindspore/communication/_comm_helper.py
mindspore/communication/_comm_helper.py
+8
-2
mindspore/communication/management.py
mindspore/communication/management.py
+4
-0
未找到文件。
mindspore/ccsrc/kernel/cpu/cpu_kernel.h
浏览文件 @
3618b084
...
...
@@ -55,7 +55,7 @@ class CPUKernel : public kernel::KernelMod {
public:
CPUKernel
()
=
default
;
~
CPUKernel
()
override
=
default
;
void
Init
(
const
CNodePtr
&
kernel_node
);
v
irtual
v
oid
Init
(
const
CNodePtr
&
kernel_node
);
virtual
void
InitKernel
(
const
CNodePtr
&
kernel_node
)
=
0
;
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
/*stream_ptr*/
)
override
{
...
...
mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h
浏览文件 @
3618b084
...
...
@@ -62,10 +62,12 @@ class CPUKernelRegistrar {
static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS>(); });
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T)
#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T)
#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##
OPNAME##_##T##_reg(#OPNAME, ATTR,
\
[]() { return std::make_shared<OPCLASS<T>>(); });
static const CPUKernelRegistrar g_cpu_kernel_##
COUNT##_##OPNAME##_##T##_reg(
\
#OPNAME, ATTR,
[]() { return std::make_shared<OPCLASS<T>>(); });
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \
...
...
mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
浏览文件 @
3618b084
...
...
@@ -46,24 +46,10 @@ void SparseApplyFtrlPSKernel::InitKernel(
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
MS_LOG
(
EXCEPTION
)
<<
"The first dimension of grad shape must be equal to indices"
;
}
/*
lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr");
if (lr_ <= 0) {
MS_LOG(EXCEPTION) << "lr should be a positive scalar";
}
l1_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l1");
if (l1_ < 0) {
MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar";
}
l2_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l2");
if (l2_ < 0) {
MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar";
}
lr_power_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr_power");
if (lr_power_ > 0) {
MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar";
}
*/
lr_
=
0.01
;
l1_
=
1e-8
;
l2_
=
1e-8
;
lr_power_
=
-
0.5
;
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
}
...
...
mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc
0 → 100644
浏览文件 @
3618b084
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/pass/replace_node_by_proxy.h"
#include <vector>
#include <memory>
#include "device/kernel_info.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
namespace
mindspore
{
namespace
opt
{
kernel
::
KernelBuildInfoPtr
ReplaceNodeByProxy
::
GenerateKernelBuildInfo
(
const
CNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
vector
<
std
::
string
>
inputs_device_format
;
std
::
vector
<
std
::
string
>
outputs_device_format
;
std
::
vector
<
TypeId
>
inputs_device_type
;
std
::
vector
<
TypeId
>
outputs_device_type
;
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
cnode
);
++
input_index
)
{
inputs_device_format
.
push_back
(
AnfAlgo
::
GetInputFormat
(
cnode
,
input_index
));
inputs_device_type
.
push_back
(
AnfAlgo
::
GetInputDeviceDataType
(
cnode
,
input_index
));
}
for
(
size_t
output_index
=
0
;
output_index
<
AnfAlgo
::
GetOutputTensorNum
(
cnode
);
++
output_index
)
{
outputs_device_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
cnode
,
output_index
));
outputs_device_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
cnode
,
output_index
));
outputs_shape
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
cnode
,
output_index
));
}
builder
.
SetFusionType
(
AnfAlgo
::
GetFusionType
(
cnode
));
builder
.
SetProcessor
(
AnfAlgo
::
GetProcessor
(
cnode
));
builder
.
SetKernelType
(
AnfAlgo
::
GetKernelType
(
cnode
));
builder
.
SetInputsFormat
(
inputs_device_format
);
builder
.
SetOutputsFormat
(
outputs_device_format
);
builder
.
SetInputsDeviceType
(
inputs_device_type
);
builder
.
SetOutputsDeviceType
(
outputs_device_type
);
return
builder
.
Build
();
}
bool
ReplaceNodeByProxy
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
func_graph
->
get_return
());
for
(
auto
node
:
node_list
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
node
)
==
kEmbeddingLookupOpName
)
{
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kEmbeddingLookupProxyOpName
);
MS_EXCEPTION_IF_NULL
(
prim
);
std
::
vector
<
AnfNodePtr
>
proxy_inputs
=
{
NewValueNode
(
prim
)};
proxy_inputs
.
insert
(
proxy_inputs
.
end
(),
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
());
AnfNodePtr
proxy_node
=
func_graph
->
NewCNode
(
proxy_inputs
);
MS_EXCEPTION_IF_NULL
(
proxy_node
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
proxy_node
->
set_kernel_info
(
kernel_info
);
AbstractBasePtrList
abstract_list
;
AnfAlgo
::
CopyNodeAttr
(
kAttrPsKey
,
cnode
,
proxy_node
);
AnfAlgo
::
CopyNodeAttr
(
"reduce_scatter_flag"
,
cnode
,
proxy_node
);
AnfAlgo
::
CopyNodeAttr
(
"offset"
,
cnode
,
proxy_node
);
abstract_list
.
push_back
(
cnode
->
abstract
());
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
MS_EXCEPTION_IF_NULL
(
abstract_tuple
);
proxy_node
->
set_abstract
(
abstract_tuple
);
auto
kernel_build_info
=
GenerateKernelBuildInfo
(
cnode
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
proxy_node
.
get
());
if
(
!
manager
->
Replace
(
cnode
,
proxy_node
))
{
MS_LOG
(
EXCEPTION
)
<<
"Replace node by proxy node failed."
;
}
}
}
return
true
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h
0 → 100644
浏览文件 @
3618b084
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
#include <utility>
#include <vector>
#include <string>
#include "pre_activate/common/pass.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/utils.h"
#include "kernel/kernel_build_info.h"
namespace
mindspore
{
namespace
opt
{
class
ReplaceNodeByProxy
:
public
Pass
{
public:
explicit
ReplaceNodeByProxy
(
const
std
::
string
&
name
)
:
Pass
(
name
)
{}
~
ReplaceNodeByProxy
()
override
=
default
;
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
kernel
::
KernelBuildInfoPtr
GenerateKernelBuildInfo
(
const
CNodePtr
&
cnode
);
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
mindspore/communication/_comm_helper.py
浏览文件 @
3618b084
...
...
@@ -14,7 +14,7 @@
# ============================================================================
"""comm_helper"""
import
os
from
._hccl_management
import
load_lib
as
hccl_load_lib
_HCCL_AVAILABLE
=
False
...
...
@@ -44,7 +44,7 @@ else:
HCCL_WORLD_COMM_GROUP
=
"hccl_world_group"
NCCL_WORLD_COMM_GROUP
=
"nccl_world_group"
MS_ROLE
=
os
.
getenv
(
"MS_ROLE"
)
class
Backend
:
"""
...
...
@@ -152,6 +152,9 @@ def _get_rank_helper(group, backend):
Integer. The local rank id of the calling process.
"""
rank_id
=
None
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
rank_id
=
0
return
rank_id
if
backend
==
Backend
.
HCCL
:
if
group
==
HCCL_WORLD_COMM_GROUP
:
rank_id
=
hccl
.
get_rank_id
()
...
...
@@ -211,6 +214,9 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size
=
None
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
size
=
1
return
size
if
backend
==
Backend
.
HCCL
:
if
group
==
HCCL_WORLD_COMM_GROUP
:
size
=
hccl
.
get_rank_size
()
...
...
mindspore/communication/management.py
浏览文件 @
3618b084
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Communication management API"""
import
os
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
._comm_helper
import
Backend
,
_get_rank_helper
,
_get_size_helper
,
\
_get_world_rank_from_group_rank_helper
,
_get_group_rank_from_world_rank_helper
,
\
...
...
@@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
DEFAULT_WORLD_COMM_GROUP
=
HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND
=
Backend
(
"hccl"
)
MS_ROLE
=
os
.
getenv
(
"MS_ROLE"
)
def
_get_group
(
group
):
...
...
@@ -58,6 +60,8 @@ def init(backend_name="hccl"):
TypeError: If backend name is not a string.
RuntimeError: If backend is invalid or distributed init fails.
"""
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
return
if
not
isinstance
(
backend_name
,
str
):
raise
TypeError
(
"Backend name must be a string, but got {}"
.
format
(
type
(
backend_name
)))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录