Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5240b1f6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5240b1f6
编写于
4月 02, 2020
作者:
L
lichenever
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix refkey bug for auto parallel
上级
a44b5293
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
50 addition
and
7 deletion
+50
-7
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+19
-2
tests/ut/python/parallel/test_arithmetic.py
tests/ut/python/parallel/test_arithmetic.py
+31
-5
未找到文件。
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
5240b1f6
...
...
@@ -49,6 +49,9 @@ namespace mindspore {
namespace
parallel
{
const
std
::
set
<
std
::
string
>
COMMUNICATION_OPS
=
{
ALL_REDUCE
,
ALL_GATHER
,
ALL_TO_ALL
,
REDUCE_SCATTER
};
const
std
::
set
<
std
::
string
>
INVALID_LOSS_OPS
=
{
GET_NEXT
,
VIRTUALLOSS
};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static
std
::
map
<
AnfNodePtr
,
std
::
pair
<
AnfNodePtr
,
int
>>
g_RefMap
;
void
SetCommunicationOpGroupLabel
(
std
::
vector
<
AnfNodePtr
>
new_node_input
)
{
if
(
new_node_input
.
empty
())
{
...
...
@@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
std
::
vector
<
AnfNodePtr
>
all_inputs
=
node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
node_inputs
{
all_inputs
.
begin
()
+
1
,
all_inputs
.
end
()};
for
(
auto
&
input
:
node_inputs
)
{
size_t
inputs_size
=
all_inputs
.
size
();
for
(
size_t
i
=
1
;
i
<
inputs_size
;
++
i
)
{
Shapes
input_shapes
;
AnfNodePtr
input
=
all_inputs
[
i
];
if
(
IsValueNode
<
RefKey
>
(
input
))
{
auto
func_graph
=
node
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
std
::
vector
<
AnfNodePtr
>
parameters
=
FindParameterByRefKeyNode
(
input
,
func_graph
);
if
(
parameters
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Find parameter by ref key node failed"
;
}
std
::
pair
<
AnfNodePtr
,
int
>
node_pair
=
std
::
make_pair
(
node
,
SizeToInt
(
i
));
g_RefMap
[
parameters
[
0
]]
=
node_pair
;
input_shapes
=
GetRefKeyNodeShape
(
input
,
func_graph
);
}
else
if
(
IsValueNode
<
Tensor
>
(
input
)
||
input
->
isa
<
CNode
>
()
||
input
->
isa
<
Parameter
>
())
{
input_shapes
=
GetNodeShape
(
input
);
...
...
@@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
auto
parameters
=
root
->
parameters
();
for
(
auto
&
parameter
:
parameters
)
{
MS_EXCEPTION_IF_NULL
(
parameter
->
Shape
());
auto
iter
=
g_RefMap
.
find
(
parameter
);
if
(
iter
!=
g_RefMap
.
end
())
{
SetParallelShape
(
parameter
,
g_RefMap
[
parameter
]);
continue
;
}
std
::
pair
<
AnfNodePtr
,
int
>
res
=
FindSubGraph
(
root
,
parameter
);
if
(
res
.
first
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"Parameter "
<<
parameter
->
ToString
()
<<
" don't need to set parallel shape"
;
}
else
{
SetParallelShape
(
parameter
,
res
);
MS_LOG
(
DEBUG
)
<<
"
p
arameter "
<<
parameter
->
ToString
()
<<
" shape "
<<
parameter
->
Shape
()
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"
P
arameter "
<<
parameter
->
ToString
()
<<
" shape "
<<
parameter
->
Shape
()
->
ToString
();
}
}
g_RefMap
.
clear
();
}
bool
ParameterIsCloned
(
const
FuncGraphPtr
&
root
,
const
AnfNodePtr
&
parameter_node
)
{
...
...
tests/ut/python/parallel/test_arithmetic.py
浏览文件 @
5240b1f6
...
...
@@ -13,14 +13,13 @@
# limitations under the License.
import
numpy
as
np
from
mindspore
import
context
import
mindspore
as
ms
from
mindspore
import
Parameter
,
Tensor
,
context
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore
import
Tensor
from
tests.ut.python.ops.test_math_ops
import
VirtualLoss
import
mindspore
as
ms
from
mindspore.common.api
import
_executor
from
mindspore.ops
import
composite
as
C
from
mindspore.common.api
import
_executor
from
tests.ut.python.ops.test_math_ops
import
VirtualLoss
class
NetWithLoss
(
nn
.
Cell
):
...
...
@@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2():
y
=
Tensor
(
np
.
ones
([
32
,
1
]),
dtype
=
ms
.
float32
)
b
=
Tensor
(
np
.
ones
([
1
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
,
b
)
def
test_assign_sub
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
assign_sub
=
P
.
AssignSub
()
self
.
mul
=
P
.
Mul
()
self
.
mul_weight
=
Parameter
(
Tensor
(
np
.
full
([
128
,
32
],
0.5
,
dtype
=
np
.
float32
)),
name
=
"mul_weight"
)
self
.
assignsub_weight
=
Parameter
(
Tensor
(
np
.
full
([
128
,
32
],
1.1
,
dtype
=
np
.
float32
)),
name
=
"assignsub_weight"
)
def
construct
(
self
,
x
,
y
,
z
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
assign_sub
(
self
.
assignsub_weight
,
out
)
return
out
context
.
set_auto_parallel_context
(
device_num
=
64
,
global_rank
=
15
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
x
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
z
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
,
z
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录