Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2ab211ae
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ab211ae
编写于
4月 23, 2020
作者:
L
lichenever
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support reshape parameter
上级
b48d663c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
104 addition
and
3 deletion
+104
-3
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+24
-1
mindspore/context.py
mindspore/context.py
+5
-2
tests/ut/python/parallel/test_reshape_parameter.py
tests/ut/python/parallel/test_reshape_parameter.py
+75
-0
未找到文件。
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
2ab211ae
...
...
@@ -1523,9 +1523,32 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return
nullptr
;
}
std
::
shared_ptr
<
TensorLayout
>
CreateParameterLayout
(
const
AnfNodePtr
&
node
)
{
// Create DataParallel tensor layout for parameter(support WideDeep).
CheckGlobalDeviceManager
();
int32_t
dev_num
=
SizeToInt
(
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
());
TensorLayout
input_tensor_layout
;
// create input_shape
Shapes
inputs_shape
=
GetNodeShape
(
node
);
Shape
input_shape_array
=
inputs_shape
[
0
];
if
(
input_shape_array
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Don't support reshape a scalar parameter."
;
}
// create tensor_map
size_t
shape_size
=
input_shape_array
.
size
();
TensorMap
input_tensor_map_array
(
SizeToInt
(
shape_size
)
-
1
,
-
1
);
input_tensor_map_array
.
insert
(
input_tensor_map_array
.
begin
(),
0
);
// create dev_matrix
Shape
dev_matrix_array
=
{
dev_num
};
if
(
input_tensor_layout
.
InitFromVector
(
dev_matrix_array
,
input_tensor_map_array
,
input_shape_array
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Create tensor layout for parameter failed."
;
}
return
std
::
make_shared
<
TensorLayout
>
(
input_tensor_layout
);
}
std
::
shared_ptr
<
TensorLayout
>
FindPrevLayout
(
const
AnfNodePtr
&
node
)
{
if
(
node
->
isa
<
Parameter
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: parameter before reshape is not supported temporary"
;
return
CreateParameterLayout
(
node
)
;
}
if
(
!
node
->
isa
<
CNode
>
())
{
return
nullptr
;
...
...
mindspore/context.py
浏览文件 @
2ab211ae
...
...
@@ -415,8 +415,11 @@ def set_auto_parallel_context(**kwargs):
Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support mirror_mean. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
cast_before_mirror. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
...
...
tests/ut/python/parallel/test_reshape_parameter.py
0 → 100644
浏览文件 @
2ab211ae
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.common.api
import
_executor
from
tests.ut.python.ops.test_math_ops
import
VirtualLoss
import
numpy
as
np
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
return
C
.
grad_all
(
self
.
network
)(
x
,
y
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy
):
super
().
__init__
()
self
.
reshape
=
P
.
Reshape
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy
)
self
.
relu
=
P
.
ReLU
()
def
construct
(
self
,
x
,
y
):
out
=
self
.
reshape
(
x
,
(
10000
,
36
,
1
))
out
=
self
.
mul
(
out
,
y
)
out
=
self
.
relu
(
out
)
return
out
def
test_reshape_parameter_data_parallel
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy
=
((
8
,
1
,
1
),
(
8
,
1
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy
)))
x
=
Tensor
(
np
.
ones
([
10000
,
36
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
10000
,
36
,
1
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
def
test_reshape_parameter_model_parallel
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy
)))
x
=
Tensor
(
np
.
ones
([
10000
,
36
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
10000
,
36
,
1
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录