Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9aa84b3d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9aa84b3d
编写于
7月 29, 2020
作者:
Y
yangzhenzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add strided slice op
上级
1b699234
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
550 addition
and
0 deletion
+550
-0
mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h
...csrc/frontend/parallel/auto_parallel/operator_costmodel.h
+2
-0
mindspore/ccsrc/frontend/parallel/dynamic_creator.h
mindspore/ccsrc/frontend/parallel/dynamic_creator.h
+1
-0
mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h
...re/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h
+1
-0
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
+5
-0
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
...re/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
+305
-0
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h
...ore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h
+72
-0
tests/ut/python/parallel/test_stridedslice.py
tests/ut/python/parallel/test_stridedslice.py
+164
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h
浏览文件 @
9aa84b3d
...
...
@@ -170,6 +170,8 @@ class ActivationCost : public OperatorCost {
using
ActivationCostPtr
=
std
::
shared_ptr
<
ActivationCost
>
;
using
TransposeCost
=
ActivationCost
;
using
TransposeCostPtr
=
std
::
shared_ptr
<
TransposeCost
>
;
using
StridedSliceCost
=
ActivationCost
;
using
StridedSliceCostPtr
=
std
::
shared_ptr
<
StridedSliceCost
>
;
class
SoftmaxCost
:
public
OperatorCost
{
public:
...
...
mindspore/ccsrc/frontend/parallel/dynamic_creator.h
浏览文件 @
9aa84b3d
...
...
@@ -134,6 +134,7 @@ REGISTER(SquareInfo);
REGISTER
(
GatherV2PInfo
);
REGISTER
(
EmbeddingLookupInfo
);
REGISTER
(
TileInfo
);
REGISTER
(
StridedSliceInfo
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h
浏览文件 @
9aa84b3d
...
...
@@ -38,5 +38,6 @@
#include "frontend/parallel/ops_info/virtual_dataset_info.h"
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
#include "frontend/parallel/ops_info/tile_info.h"
#include "frontend/parallel/ops_info/strided_slice_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
浏览文件 @
9aa84b3d
...
...
@@ -29,6 +29,11 @@ constexpr int32_t NO_SPLIT_STRATEGY = 1;
constexpr
int32_t
SPLIT_FLAG
=
1
;
constexpr
int32_t
NO_SPLIT_FLAG
=
0
;
constexpr
size_t
MATMUL_ATTRS_SIZE
=
2
;
constexpr
size_t
STRIDED_SLICE_ATTRS_SIZE
=
5
;
constexpr
size_t
STRIDED_SLICE_INPUTS_SIZE
=
4
;
constexpr
size_t
STRIDED_SLICE_BEGIN_INDEX
=
1
;
constexpr
size_t
STRIDED_SLICE_END_INDEX
=
2
;
constexpr
size_t
STRIDED_SLICE_STRIDES_INDEX
=
3
;
constexpr
size_t
MATMUL_INPUTS_SIZE
=
2
;
constexpr
size_t
MATMUL_OUTPUTS_SIZE
=
1
;
constexpr
size_t
ACTIVATION_ATTR_SIZE
=
1
;
...
...
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
0 → 100644
浏览文件 @
9aa84b3d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/strided_slice_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace
mindspore
{
namespace
parallel
{
Status
StridedSliceInfo
::
GetMask
(
const
std
::
string
&
mask_name
,
int32_t
*
mask_value
)
{
if
(
mask_value
==
nullptr
)
{
return
FAILED
;
}
auto
mask_iter
=
attrs_
.
find
(
mask_name
);
if
(
mask_iter
!=
attrs_
.
end
())
{
MS_EXCEPTION_IF_NULL
(
mask_iter
->
second
);
if
(
mask_iter
->
second
->
isa
<
Int32Imm
>
())
{
*
mask_value
=
mask_iter
->
second
->
cast
<
Int32ImmPtr
>
()
->
value
();
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The value of "
<<
mask_name
<<
" is not int"
;
return
FAILED
;
}
}
return
SUCCESS
;
}
Status
GetInput
(
const
ValuePtr
&
input_value
,
std
::
vector
<
int32_t
>
*
input
)
{
MS_EXCEPTION_IF_NULL
(
input_value
);
ValueTuplePtr
value_tuple
=
input_value
->
cast
<
ValueTuplePtr
>
();
if
(
value_tuple
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Input value must be ValueTuplePtr."
;
return
FAILED
;
}
for
(
auto
&
element
:
value_tuple
->
value
())
{
MS_EXCEPTION_IF_NULL
(
element
);
if
(
element
->
isa
<
Int32Imm
>
())
{
int32_t
value
=
element
->
cast
<
Int32ImmPtr
>
()
->
value
();
input
->
push_back
(
value
);
}
else
{
MS_LOG
(
ERROR
)
<<
"The value must be int32"
;
return
FAILED
;
}
}
return
SUCCESS
;
}
Status
StridedSliceInfo
::
GetAttrs
()
{
if
(
attrs_
.
size
()
<
STRIDED_SLICE_ATTRS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of attrs small than "
<<
STRIDED_SLICE_ATTRS_SIZE
;
return
FAILED
;
}
if
((
GetMask
(
BEGIN_MASK
,
&
begin_mask_
)
!=
SUCCESS
)
||
(
GetMask
(
END_MASK
,
&
end_mask_
)
!=
SUCCESS
)
||
(
GetMask
(
ELLIPSIS_MASK
,
&
ellipsis_mask_
)
!=
SUCCESS
)
||
(
GetMask
(
NEW_AXIS_MASK
,
&
new_axis_mask_
)
!=
SUCCESS
)
||
(
GetMask
(
SHRINK_AXIS_MASK
,
&
shrink_axis_mask_
)
!=
SUCCESS
))
{
return
FAILED
;
}
has_mask_
=
((
begin_mask_
!=
0
)
||
(
end_mask_
!=
0
)
||
(
ellipsis_mask_
!=
0
)
||
(
new_axis_mask_
!=
0
)
||
(
shrink_axis_mask_
!=
0
));
if
(
input_value_
.
size
()
!=
STRIDED_SLICE_INPUTS_SIZE
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of input value must be "
<<
STRIDED_SLICE_INPUTS_SIZE
<<
", but got "
<<
input_value_
.
size
();
return
FAILED
;
}
if
((
GetInput
(
input_value_
[
STRIDED_SLICE_BEGIN_INDEX
],
&
begin_
)
!=
SUCCESS
)
||
(
GetInput
(
input_value_
[
STRIDED_SLICE_END_INDEX
],
&
end_
)
!=
SUCCESS
)
||
(
GetInput
(
input_value_
[
STRIDED_SLICE_STRIDES_INDEX
],
&
strides_
)
!=
SUCCESS
))
{
return
FAILED
;
}
return
SUCCESS
;
}
Status
StridedSliceInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
MS_EXCEPTION_IF_NULL
(
strategy
);
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
return
FAILED
;
}
std
::
vector
<
Dimensions
>
stra
=
strategy
->
GetInputDim
();
if
(
stra
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The strategy is empty"
;
return
FAILED
;
}
Dimensions
strategy_value
=
stra
[
0
];
bool
has_split
=
std
::
any_of
(
strategy_value
.
begin
(),
strategy_value
.
end
(),
[](
int32_t
v
)
{
return
v
>
1
;
});
if
(
has_split
&&
has_mask_
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": When there is a mask, the input is not supported to be split"
;
return
FAILED
;
}
if
(
strategy_value
.
size
()
<
strides_
.
size
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of strategy must be larger or equal to the size of strides"
;
return
FAILED
;
}
for
(
size_t
i
=
0
;
i
<
strides_
.
size
();
++
i
)
{
if
((
strides_
[
i
]
!=
1
)
&&
(
strategy_value
[
i
]
>
1
))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": When a certain dimension is split, now does not support that the stride is not 1"
;
return
FAILED
;
}
}
if
((
begin_
.
size
()
!=
end_
.
size
())
||
(
begin_
.
size
()
!=
strides_
.
size
()))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of begin "
<<
begin_
.
size
()
<<
", end "
<<
end_
.
size
()
<<
" and strides "
<<
strides_
.
size
()
<<
" must be equal"
;
return
FAILED
;
}
for
(
size_t
i
=
0
;
i
<
begin_
.
size
();
++
i
)
{
bool
no_fully_fetch
=
((
begin_
[
i
]
!=
0
)
||
(
end_
[
i
]
<
inputs_shape_
[
0
][
i
]));
if
(
no_fully_fetch
&&
(
strategy_value
[
i
]
!=
1
))
{
MS_LOG
(
ERROR
)
<<
name_
<<
"When a dimension is not fully fetched, the dimension can not be split now"
;
return
FAILED
;
}
}
return
SUCCESS
;
}
Status
StridedSliceInfo
::
InferDevMatrixShape
()
{
MS_EXCEPTION_IF_NULL
(
strategy_
);
std
::
vector
<
Dimensions
>
stra
=
strategy_
->
GetInputDim
();
if
(
stra
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
"The strategy is empty"
;
return
FAILED
;
}
dev_matrix_shape_
=
stra
[
0
];
return
SUCCESS
;
}
Status
StridedSliceInfo
::
InferTensorMap
()
{
TensorMap
tensor_map
;
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
"The inputs shape is empty"
;
return
FAILED
;
}
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
int32_t
size
=
SizeToInt
(
inputs_shape_
[
0
].
size
());
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
tensor_map
.
push_back
(
size
-
i
-
1
);
}
inputs_tensor_map_
.
push_back
(
tensor_map
);
outputs_tensor_map_
.
push_back
(
tensor_map
);
return
SUCCESS
;
}
Status
StridedSliceInfo
::
InferMirrorOps
()
{
mirror_ops_
.
clear
();
if
(
inputs_tensor_map_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs tensor map is empty"
;
return
FAILED
;
}
Shape
input_tensor_map
=
inputs_tensor_map_
[
0
];
std
::
vector
<
Group
>
group
;
if
(
CreateGroupByTensorMap
(
input_tensor_map
,
&
group
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Create group for input failed."
;
return
FAILED
;
}
if
(
group
.
empty
())
{
MS_LOG
(
INFO
)
<<
name_
<<
": The mirror group is empty."
;
return
SUCCESS
;
}
OperatorVector
input_op
,
begin_op
,
end_op
,
strides_op
;
input_op
=
CreateMirrorOps
(
group
[
0
].
name
(),
group
[
0
].
GetDevNum
());
mirror_ops_
.
push_back
(
input_op
);
mirror_ops_
.
push_back
(
begin_op
);
mirror_ops_
.
push_back
(
end_op
);
mirror_ops_
.
push_back
(
strides_op
);
return
SUCCESS
;
}
Status
StridedSliceInfo
::
InferTensorInfo
()
{
if
(
inputs_shape_
.
empty
()
||
outputs_shape_
.
empty
()
||
inputs_tensor_map_
.
empty
()
||
outputs_tensor_map_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid args"
;
return
FAILED
;
}
// infer tensor layout
TensorLayout
input_layout
,
output_layout
;
if
(
input_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
[
0
],
inputs_shape_
[
0
])
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer input tensor layout failed."
;
return
FAILED
;
}
if
(
output_layout
.
InitFromVector
(
dev_matrix_shape_
,
outputs_tensor_map_
[
0
],
outputs_shape_
[
0
])
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer output tensor layout failed."
;
return
FAILED
;
}
TensorInfo
input_tensor_info
(
input_layout
);
TensorInfo
output_tensor_info
(
output_layout
);
inputs_tensor_info_
.
push_back
(
input_tensor_info
);
outputs_tensor_info_
.
push_back
(
output_tensor_info
);
return
SUCCESS
;
}
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
StridedSliceInfo
::
GenerateBatchStrategies
()
{
split_flag_list_
=
{
true
};
return
GenerateBatchStrategiesBySplitFlag
(
inputs_shape_
,
split_flag_list_
);
}
Status
StridedSliceInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
StridedSliceInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
InferAttrs
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer attrs failed"
;
return
FAILED
;
}
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs shape is empty"
;
return
FAILED
;
}
Shape
input_split
(
inputs_shape_
[
0
].
size
(),
1
);
if
(
has_mask_
)
{
for
(
size_t
i
=
0
;
i
<
inputs_shape_
[
0
].
size
();
++
i
)
{
input_split
[
i
]
=
0
;
}
}
else
{
for
(
size_t
i
=
0
;
i
<
begin_
.
size
();
++
i
)
{
bool
no_fully_fetch
=
((
begin_
[
i
]
!=
0
)
||
(
end_
[
i
]
<
inputs_shape_
[
0
][
i
]));
if
(
no_fully_fetch
||
(
strides_
[
i
]
!=
1
))
{
input_split
[
i
]
=
0
;
}
}
}
Shapes
splittable_inputs
=
{
input_split
};
std
::
vector
<
StrategyPtr
>
sp_vector
;
is_auto_parallel_
=
true
;
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
inputs_shape_
,
splittable_inputs
,
&
sp_vector
)
!=
SUCCESS
)
{
return
FAILED
;
}
size_t
success
=
0
;
for
(
auto
&
sp
:
sp_vector
)
{
PrintStrategy
(
sp
);
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
success
++
;
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated "
<<
success
<<
" strategy."
;
PrintStrategy
(
sp
);
}
}
return
SUCCESS
;
}
Status
StridedSliceInfo
::
Init
(
const
StrategyPtr
&
strategy
)
{
if
(
InitWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init failed."
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": Init success."
;
return
SUCCESS
;
}
Status
StridedSliceInfo
::
InitForCostModel
(
const
StrategyPtr
&
strategy
)
{
if
(
InitForCostModelWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init for cost model failed."
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": Init for cost model success."
;
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h
0 → 100644
浏览文件 @
9aa84b3d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace
mindspore
{
namespace
parallel
{
class
StridedSliceInfo
:
public
OperatorInfo
{
public:
StridedSliceInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
StridedSliceCost
>
(
false
))
{}
~
StridedSliceInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
)
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
protected:
Status
GetAttrs
()
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
InferMirrorOps
()
override
;
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
Status
InferTensorInfo
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferTensorMap
()
override
;
Status
GetMask
(
const
std
::
string
&
mask_name
,
int32_t
*
mask_value
);
private:
std
::
vector
<
int32_t
>
begin_
;
std
::
vector
<
int32_t
>
end_
;
std
::
vector
<
int32_t
>
strides_
;
int32_t
begin_mask_
=
0
;
int32_t
end_mask_
=
0
;
int32_t
ellipsis_mask_
=
0
;
int32_t
new_axis_mask_
=
0
;
int32_t
shrink_axis_mask_
=
0
;
bool
has_mask_
=
false
;
};
using
StridedSliceInfoPtr
=
std
::
shared_ptr
<
StridedSliceInfo
>
;
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_
tests/ut/python/parallel/test_stridedslice.py
0 → 100644
浏览文件 @
9aa84b3d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
class
Net
(
Cell
):
def
__init__
(
self
,
weight
,
w2
,
begin
,
end
,
strides
,
strategy1
=
None
,
strategy2
=
None
,
is_parameter
=
True
,
mask
=
0
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
strided_slice
=
P
.
StridedSlice
(
begin_mask
=
mask
).
set_strategy
(
strategy2
)
if
is_parameter
:
self
.
weight
=
Parameter
(
weight
,
"w1"
)
else
:
self
.
weight
=
weight
self
.
mul2
=
P
.
Mul
()
self
.
weight2
=
Parameter
(
w2
,
"w2"
)
self
.
begin
=
begin
self
.
end
=
end
self
.
strides
=
strides
def
construct
(
self
,
x
,
b
):
out
=
self
.
strided_slice
(
self
.
weight
,
self
.
begin
,
self
.
end
,
self
.
strides
)
out
=
self
.
mul
(
x
,
out
)
out
=
self
.
mul2
(
out
,
self
.
weight2
)
return
out
class
Net2
(
Cell
):
def
__init__
(
self
,
weight2
,
begin
,
end
,
strides
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
strided_slice
=
P
.
StridedSlice
().
set_strategy
(
strategy2
)
self
.
weight2
=
Parameter
(
weight2
,
"w2"
)
self
.
begin
=
begin
self
.
end
=
end
self
.
strides
=
strides
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
weight2
)
out
=
self
.
strided_slice
(
out
,
self
.
begin
,
self
.
end
,
self
.
strides
)
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
1
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
256
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w2
=
Tensor
(
np
.
ones
([
128
,
64
,
1
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
context
.
set_context
(
save_graphs
=
True
)
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_stridedslice_no_fully_fetch_split_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
((
2
,
2
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
32
),
(
1
,
1
,
1
),
strategy1
,
strategy2
,
is_parameter
=
True
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_stridedslice_strides_no_1_split_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
((
1
,
2
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
32
),
(
1
,
1
,
2
),
strategy1
,
strategy2
,
is_parameter
=
True
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_stridedslice_mask_no_0_split_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
((
1
,
2
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
32
),
(
1
,
1
,
1
),
strategy1
,
strategy2
,
is_parameter
=
True
,
mask
=
1
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_stridedslice_begin_size_smaller
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
4
,
1
),
(
1
,
4
,
2
))
strategy2
=
((
1
,
4
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
),
(
128
,
64
),
(
1
,
1
),
strategy1
,
strategy2
,
is_parameter
=
True
)
compile_net
(
net
)
def
test_stridedslice_parameter
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
4
,
1
),
(
1
,
4
,
2
))
strategy2
=
((
1
,
4
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
32
),
(
1
,
1
,
1
),
strategy1
,
strategy2
,
is_parameter
=
True
)
compile_net
(
net
)
def
test_stridedslice_tensor
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
4
,
1
),
(
1
,
4
,
2
))
strategy2
=
((
1
,
4
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
32
),
(
1
,
1
,
1
),
strategy1
,
strategy2
,
is_parameter
=
False
)
compile_net
(
net
)
def
test_stridedslice_parameter_no_full_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
4
,
1
),
(
1
,
4
,
2
))
strategy2
=
((
1
,
2
,
2
),)
net
=
Net
(
_w1
,
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
32
),
(
1
,
1
,
1
),
strategy1
,
strategy2
,
is_parameter
=
True
)
compile_net
(
net
)
def
test_stridedslice_output
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
8
,
1
),
(
1
,
8
,
1
))
strategy2
=
((
1
,
8
,
1
),)
net
=
Net2
(
_w2
,
(
0
,
0
,
0
),
(
64
,
64
,
1
),
(
1
,
1
,
1
),
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_stridedslice_output_no_full_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
8
,
1
),
(
1
,
8
,
1
))
strategy2
=
((
1
,
4
,
1
),)
net
=
Net2
(
_w2
,
(
0
,
0
,
0
),
(
64
,
64
,
1
),
(
1
,
1
,
1
),
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_stridedslice_no_strategy
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
8
,
1
),
(
1
,
8
,
1
))
strategy2
=
None
net
=
Net2
(
_w2
,
(
0
,
0
,
0
),
(
128
,
64
,
1
),
(
1
,
1
,
1
),
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_stridedslice_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
net
=
Net2
(
_w2
,
(
0
,
0
,
0
),
(
32
,
64
,
1
),
(
1
,
1
,
1
))
compile_net
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录