Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7bc2cee3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7bc2cee3
编写于
4月 13, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
!167 add_squeeze_distributed_op
Merge pull request !167 from lichen/add_squeeze_distributed_op
上级
478f43ab
32cd280c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
261 addition
and
7 deletion
+261
-7
mindspore/ccsrc/parallel/dynamic_creator.h
mindspore/ccsrc/parallel/dynamic_creator.h
+1
-0
mindspore/ccsrc/parallel/ops_info/activation_info.cc
mindspore/ccsrc/parallel/ops_info/activation_info.cc
+156
-0
mindspore/ccsrc/parallel/ops_info/activation_info.h
mindspore/ccsrc/parallel/ops_info/activation_info.h
+20
-1
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/comparison_function_info.h
mindspore/ccsrc/parallel/ops_info/comparison_function_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/onehot_info.h
mindspore/ccsrc/parallel/ops_info/onehot_info.h
+1
-1
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+2
-3
tests/ut/python/parallel/test_squeeze_info.py
tests/ut/python/parallel/test_squeeze_info.py
+79
-0
未找到文件。
mindspore/ccsrc/parallel/dynamic_creator.h
浏览文件 @
7bc2cee3
...
...
@@ -125,6 +125,7 @@ REGISTER(GetNextInfo);
REGISTER
(
NegInfo
);
REGISTER
(
BatchMatMulInfo
);
REGISTER
(
ExpandDimsInfo
);
REGISTER
(
SqueezeInfo
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/activation_info.cc
浏览文件 @
7bc2cee3
...
...
@@ -19,6 +19,7 @@
#include <algorithm>
#include <memory>
#include <vector>
#include <utility>
#include "ir/value.h"
#include "parallel/auto_parallel/costmodel.h"
...
...
@@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() {
MS_LOG
(
INFO
)
<<
name_
<<
": Create mirror ops success, the group name is "
<<
group
[
0
].
name
();
return
SUCCESS
;
}
Status
SqueezeInfo
::
InferAxis
(
const
ValueTuplePtr
&
value_tuple
)
{
std
::
vector
<
int32_t
>
axis
;
auto
axis_list
=
value_tuple
->
value
();
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs shape is empty"
;
return
FAILED
;
}
Shape
input_shape
=
inputs_shape_
.
at
(
0
);
size_t
input_size
=
input_shape
.
size
();
// if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1.
if
(
axis_list
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
input_size
;
++
i
)
{
if
(
input_shape
[
i
]
==
1
)
{
axis
.
push_back
(
i
);
}
}
axis_
=
MakeValue
(
axis
)
->
cast
<
ValueTuplePtr
>
();
return
SUCCESS
;
}
// convert negative axis to positive.
for
(
auto
&
dim
:
axis_list
)
{
if
(
!
dim
->
isa
<
Int32Imm
>
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The type of axis is not int"
;
return
FAILED
;
}
int32_t
dim_value
=
GetValue
<
int32_t
>
(
dim
);
int32_t
positive_value
=
(
dim_value
<
0
)
?
(
dim_value
+
SizeToInt
(
input_size
))
:
dim_value
;
axis
.
push_back
(
positive_value
);
}
axis_
=
MakeValue
(
axis
)
->
cast
<
ValueTuplePtr
>
();
return
SUCCESS
;
}
Status
SqueezeInfo
::
GetAttrs
()
{
auto
iter
=
attrs_
.
find
(
AXIS
);
if
(
iter
==
attrs_
.
end
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Can't find axis attribute."
;
return
FAILED
;
}
MS_EXCEPTION_IF_NULL
(
iter
->
second
);
auto
value_tuple
=
iter
->
second
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_tuple
);
InferAxis
(
value_tuple
);
attrs_
[
AXIS
]
=
axis_
;
return
SUCCESS
;
}
Status
SqueezeInfo
::
InferReplaceOps
(
const
StrategyPtr
&
strategy
)
{
Attr
attr
=
std
::
make_pair
(
AXIS
,
axis_
);
OperatorAttrs
attrs
=
{
attr
};
OperatorParams
params
;
OperatorArgs
args
=
std
::
make_pair
(
attrs
,
params
);
replace_op_
=
{
std
::
make_pair
(
SQUEEZE
,
args
)};
return
SUCCESS
;
}
Status
SqueezeInfo
::
InferTensorMap
()
{
// for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
std
::
vector
<
int32_t
>
input_tensor_map
,
output_tensor_map
;
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs shape is empty"
;
return
FAILED
;
}
size_t
size
=
inputs_shape_
[
0
].
size
();
std
::
vector
<
int32_t
>
axis
=
GetValue
<
const
std
::
vector
<
int
>>
(
axis_
);
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
size_t
index
=
size
-
i
-
1
;
auto
iter
=
std
::
find
(
axis
.
begin
(),
axis
.
end
(),
SizeToInt
(
i
));
if
(
iter
==
axis
.
end
())
{
output_tensor_map
.
push_back
(
SizeToInt
(
index
));
}
input_tensor_map
.
push_back
(
SizeToInt
(
index
));
}
inputs_tensor_map_
.
push_back
(
input_tensor_map
);
outputs_tensor_map_
.
push_back
(
output_tensor_map
);
MS_LOG
(
INFO
)
<<
name_
<<
": The tensor map of input is "
<<
ShapeToString
(
input_tensor_map
)
<<
", and the tensor map of output is "
<<
ShapeToString
(
output_tensor_map
);
return
SUCCESS
;
}
Status
SqueezeInfo
::
InferTensorInfo
()
{
if
(
inputs_shape_
.
empty
()
||
outputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The shape of inputs or outputs is empty"
;
return
FAILED
;
}
if
(
inputs_tensor_map_
.
empty
()
||
outputs_tensor_map_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The tensor map of inputs or outputs is empty"
;
return
FAILED
;
}
Shape
input_shape
=
inputs_shape_
[
0
];
Shape
output_shape
=
outputs_shape_
[
0
];
// infer slice shape
Shapes
inputs_slice_shape
,
outputs_slice_shape
;
Strategys
inputs_strategy
=
strategy_
->
GetInputDim
();
Dimensions
output_strategy
;
std
::
vector
<
int32_t
>
axis
=
GetValue
<
const
std
::
vector
<
int
>>
(
axis_
);
for
(
size_t
i
=
0
;
i
<
inputs_shape_
[
0
].
size
();
++
i
)
{
auto
iter
=
std
::
find
(
axis
.
begin
(),
axis
.
end
(),
SizeToInt
(
i
));
if
(
iter
==
axis
.
end
())
{
output_strategy
.
push_back
(
inputs_strategy
[
0
].
at
(
i
));
}
}
Strategys
outputs_strategy
=
{
output_strategy
};
if
(
InferSliceShape
(
inputs_strategy
,
outputs_strategy
,
&
inputs_slice_shape
,
&
outputs_slice_shape
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer slice shape failed"
;
return
FAILED
;
}
if
(
inputs_slice_shape
.
empty
()
||
outputs_slice_shape
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The slice shape of inputs or outputs is empty"
;
return
FAILED
;
}
Shape
input_slice_shape
=
inputs_slice_shape
[
0
];
Shape
output_slice_shape
=
outputs_slice_shape
[
0
];
// infer tensor layout
TensorLayout
input_tensor_layout
,
output_tensor_layout
;
if
(
input_tensor_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
[
0
],
input_shape
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init tensor layout for input failed"
;
return
FAILED
;
}
if
(
output_tensor_layout
.
InitFromVector
(
dev_matrix_shape_
,
outputs_tensor_map_
[
0
],
output_shape
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init tensor layout for output failed"
;
return
FAILED
;
}
TensorInfo
input_tensor_info
(
input_tensor_layout
,
input_shape
,
input_slice_shape
);
TensorInfo
output_tensor_info
(
output_tensor_layout
,
output_shape
,
output_slice_shape
);
inputs_tensor_info_
.
push_back
(
input_tensor_info
);
outputs_tensor_info_
.
push_back
(
output_tensor_info
);
return
SUCCESS
;
}
Status
SqueezeInfo
::
Init
(
const
StrategyPtr
&
strategy
)
{
if
(
InitWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Init failed."
;
}
if
(
InferReplaceOps
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Infer replace ops failed"
;
}
MS_LOG
(
INFO
)
<<
name_
<<
" : Init success."
;
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/ops_info/activation_info.h
浏览文件 @
7bc2cee3
...
...
@@ -184,6 +184,25 @@ class ExpandDimsInfo : public ActivationOther {
Strategys
inputs_strategy_
;
Strategys
outputs_strategy_
;
};
class
SqueezeInfo
:
public
ActivationOther
{
public:
SqueezeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
SqueezeInfo
()
override
=
default
;
protected:
Status
InferAxis
(
const
ValueTuplePtr
&
value_tuple
);
Status
GetAttrs
()
override
;
Status
InferReplaceOps
(
const
StrategyPtr
&
strategy
);
Status
InferTensorMap
()
override
;
Status
InferTensorInfo
()
override
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
private:
ValueTuplePtr
axis_
;
};
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_
OPTIMIZER_OPS_INFO_PARALLEL
_ACTIVATION_INFO_H_
#endif // MINDSPORE_CCSRC_
PARALLEL_OPS_INFO
_ACTIVATION_INFO_H_
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
浏览文件 @
7bc2cee3
...
...
@@ -123,4 +123,4 @@ class AssignSubInfo : public ArithmeticBase {
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_
OPTIMIZER_OPS_INFO_PARALLEL
_ARITHMETIC_INFO_H_
#endif // MINDSPORE_CCSRC_
PARALLEL_OPS_INFO
_ARITHMETIC_INFO_H_
mindspore/ccsrc/parallel/ops_info/comparison_function_info.h
浏览文件 @
7bc2cee3
...
...
@@ -53,4 +53,4 @@ class MaximumInfo : public ArithmeticBase {
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_
OPTIMIZER_OPS_INFO_PARALLEL
_COMPARISON_FUNCTION_INFO_H_
#endif // MINDSPORE_CCSRC_
PARALLEL_OPS_INFO
_COMPARISON_FUNCTION_INFO_H_
mindspore/ccsrc/parallel/ops_info/onehot_info.h
浏览文件 @
7bc2cee3
...
...
@@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo {
};
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_
OPTIMIZER_OPS_INFO_PARALLEL
_ONEHOT_INFO_H_
#endif // MINDSPORE_CCSRC_
PARALLEL_OPS_INFO
_ONEHOT_INFO_H_
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
7bc2cee3
...
...
@@ -47,8 +47,8 @@ using mindspore::tensor::Tensor;
namespace
mindspore
{
namespace
parallel
{
const
std
::
set
<
std
::
string
>
COMMUNICATION_OPS
=
{
ALL_REDUCE
,
ALL_GATHER
,
ALL_TO_ALL
,
REDUCE_SCATTER
};
const
std
::
set
<
std
::
string
>
INVALID_LOSS_OPS
=
{
GET_NEXT
,
VIRTUALLOSS
};
static
const
std
::
set
<
std
::
string
>
COMMUNICATION_OPS
=
{
ALL_REDUCE
,
ALL_GATHER
,
ALL_TO_ALL
,
REDUCE_SCATTER
};
static
const
std
::
set
<
std
::
string
>
INVALID_LOSS_OPS
=
{
GET_NEXT
,
VIRTUALLOSS
};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static
std
::
map
<
AnfNodePtr
,
std
::
pair
<
AnfNodePtr
,
int
>>
g_RefMap
;
...
...
@@ -1840,7 +1840,6 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector<AnfNodePt
if
(
cnode
==
loss_cnode
)
{
is_loss_cnode
=
true
;
}
// insert forward ops
InsertForwardOps
(
distribute_operator
,
cnode
);
...
...
tests/ut/python/parallel/test_squeeze_info.py
0 → 100644
浏览文件 @
7bc2cee3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
_executor
class
Net
(
Cell
):
def
__init__
(
self
,
strategy1
=
None
,
strategy2
=
None
,
axis
=
()):
super
().
__init__
()
self
.
squeeze
=
P
.
Squeeze
(
axis
=
axis
).
set_strategy
(
strategy1
)
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy2
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
squeeze
(
x
)
out
=
self
.
mul
(
out
,
b
)
return
out
_x
=
Tensor
(
np
.
ones
([
64
,
1
,
32
,
1
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile
(
net
):
_executor
.
compile
(
net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_squeeze_data_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
,
1
),
)
strategy2
=
((
16
,
1
),
(
16
,
1
))
net
=
Net
(
strategy1
,
strategy2
)
compile
(
net
)
def
test_squeeze_model_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
1
,
1
,
16
,
1
),
)
strategy2
=
((
1
,
16
),
(
1
,
16
))
net
=
Net
(
strategy1
,
strategy2
)
compile
(
net
)
def
test_squeeze_specified_axis
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
4
,
1
,
4
,
1
),
)
strategy2
=
((
8
,
2
),
(
8
,
2
))
net
=
Net
(
strategy1
,
strategy2
,
(
1
,
3
))
compile
(
net
)
def
test_squeeze_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
net
=
Net
()
compile
(
net
)
def
test_squeeze_repeat_calc
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
1
,
1
,
8
,
1
),
)
strategy2
=
((
2
,
8
),
(
2
,
8
))
net
=
Net
(
strategy1
,
strategy2
)
compile
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录