Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3aeb91ee
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3aeb91ee
编写于
5月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1443 Add parallel operator for Sigmoid
Merge pull request !1443 from yangzhenzhang/add-sigmoid-op
上级
650a45b2
7c237620
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
82 addition
and
73 deletion
+82
-73
mindspore/ccsrc/parallel/dynamic_creator.h
mindspore/ccsrc/parallel/dynamic_creator.h
+1
-0
mindspore/ccsrc/parallel/ops_info/activation_info.h
mindspore/ccsrc/parallel/ops_info/activation_info.h
+8
-0
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+18
-73
tests/ut/python/parallel/test_auto_parallel_activation.py
tests/ut/python/parallel/test_auto_parallel_activation.py
+55
-0
未找到文件。
mindspore/ccsrc/parallel/dynamic_creator.h
浏览文件 @
3aeb91ee
...
...
@@ -122,6 +122,7 @@ REGISTER(AssignSubInfo);
REGISTER
(
ReLUInfo
);
REGISTER
(
GatherV2Info
);
REGISTER
(
SqrtInfo
);
REGISTER
(
SigmoidInfo
);
REGISTER
(
GetNextInfo
);
REGISTER
(
NegInfo
);
REGISTER
(
BatchMatMulInfo
);
...
...
mindspore/ccsrc/parallel/ops_info/activation_info.h
浏览文件 @
3aeb91ee
...
...
@@ -211,6 +211,14 @@ class SquareInfo : public ActivationOther {
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
SquareInfo
()
override
=
default
;
};
class
SigmoidInfo
:
public
ActivationOther
{
public:
SigmoidInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
SigmoidInfo
()
override
=
default
;
};
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
3aeb91ee
...
...
@@ -48,74 +48,6 @@
namespace
mindspore
{
namespace
parallel
{
// splittable_op_ will continuously be updated
std
::
vector
<
std
::
string
>
splittable_op_
=
{
MATMUL
,
GELU
,
TANH
,
SOFTMAX
,
LOG_SOFTMAX
,
ACTIVATION
,
PRELU
,
FLOORDIV
,
L2_NORMALIZE
,
TRANSPOSE
,
RESHAPE
,
TENSOR_ADD
,
SUB
,
MUL
,
DIV
,
GREATER
,
MAXPOOL
,
MAXPOOLV2
,
VIRTUAL_DATA_SET
,
SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
RELU
,
ONEHOT
,
DROPOUT_DO_MASK
,
REDUCE_MAX
,
REDUCE_MIN
,
ARGMAXWITHVALUE
,
ARGMINWITHVALUE
,
REDUCE_SUM
,
CONV2D
,
FUSE_BATCH_NORM
,
POOLING
,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
SIGMOID_CROSS_ENTROPY_WITH_LOGITS
,
MAX_POOL_WITH_ARGMAX
,
SIMPLE_MEAN
,
FLATTEN
,
BATCH_NORM
,
LAYER_NORM
,
BIAS_ADD
,
ASSIGN_SUB
,
COS
,
ACOS
,
EXP
,
LOG
,
REDUCE_MEAN
,
REAL_DIV
,
SIGMOID
,
POW
,
MAXIMUM
,
MINIMUM
,
EQUAL
,
NOT_EQUAL
,
LOGICALNOT
,
GATHERV2
,
STRIDEDSLICE
,
SQRT
,
GET_NEXT
,
CAST
,
NEG
,
SQUARE
,
BATCH_MATMUL
,
EXPAND_DIMS
,
SQUEEZE
};
std
::
vector
<
std
::
string
>
elementwise_op_
=
{
ACTIVATION
,
GELU
,
TANH
,
SOFTMAX
,
LOG_SOFTMAX
,
RELU
,
SQRT
,
CAST
,
POW
,
EXP
,
LOG
,
COS
,
ACOS
,
LOGICALNOT
,
NEG
,
SQUARE
};
bool
StepAutoParallel
(
const
FuncGraphPtr
&
root
,
const
opt
::
OptimizerPtr
&
)
{
MS_EXCEPTION_IF_NULL
(
root
);
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
...
...
@@ -314,14 +246,27 @@ std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
}
bool
IsElementWiseOperator
(
const
std
::
string
&
op_name
)
{
auto
iter
=
std
::
find
(
elementwise_op_
.
begin
(),
elementwise_op_
.
end
(),
op_name
);
return
(
iter
!=
elementwise_op_
.
end
());
static
const
std
::
set
<
std
::
string
>
elementwise_op
=
{
ACTIVATION
,
GELU
,
TANH
,
SOFTMAX
,
LOG_SOFTMAX
,
RELU
,
SQRT
,
CAST
,
POW
,
EXP
,
LOG
,
COS
,
ACOS
,
LOGICALNOT
,
NEG
,
SQUARE
,
SIGMOID
};
auto
iter
=
elementwise_op
.
find
(
op_name
);
return
(
iter
!=
elementwise_op
.
end
());
}
bool
IsSplittableOperator
(
const
std
::
string
&
op_name
)
{
std
::
vector
<
std
::
string
>::
iterator
iter
;
iter
=
std
::
find
(
splittable_op_
.
begin
(),
splittable_op_
.
end
(),
op_name
);
return
(
iter
!=
splittable_op_
.
end
());
// clang-format off
static
const
std
::
set
<
std
::
string
>
splittable_op
=
{
MATMUL
,
TRANSPOSE
,
GELU
,
TANH
,
SOFTMAX
,
SUB
,
MUL
,
DIV
,
RESHAPE
,
GREATER
,
LOG_SOFTMAX
,
ACTIVATION
,
PRELU
,
FLOORDIV
,
L2_NORMALIZE
,
TENSOR_ADD
,
MAXPOOL
,
MAXPOOLV2
,
VIRTUAL_DATA_SET
,
RELU
,
ONEHOT
,
DROPOUT_DO_MASK
,
REDUCE_MAX
,
REDUCE_MIN
,
ARGMAXWITHVALUE
,
ARGMINWITHVALUE
,
REDUCE_SUM
,
CONV2D
,
FUSE_BATCH_NORM
,
POOLING
,
MAX_POOL_WITH_ARGMAX
,
SIMPLE_MEAN
,
FLATTEN
,
BATCH_NORM
,
LAYER_NORM
,
BIAS_ADD
,
ASSIGN_SUB
,
COS
,
ACOS
,
EXP
,
LOG
,
REDUCE_MEAN
,
REAL_DIV
,
SIGMOID
,
POW
,
MAXIMUM
,
MINIMUM
,
EQUAL
,
NOT_EQUAL
,
LOGICALNOT
,
GATHERV2
,
SQRT
,
STRIDEDSLICE
,
GET_NEXT
,
CAST
,
NEG
,
SQUARE
,
BATCH_MATMUL
,
EXPAND_DIMS
,
SQUEEZE
,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
SIGMOID_CROSS_ENTROPY_WITH_LOGITS
,
SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
};
// clang-format on
auto
iter
=
splittable_op
.
find
(
op_name
);
return
(
iter
!=
splittable_op
.
end
());
}
bool
IsAutoParallelCareNode
(
const
CNodePtr
&
cnode
)
{
...
...
tests/ut/python/parallel/test_auto_parallel_activation.py
0 → 100644
浏览文件 @
3aeb91ee
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
sigmoid
=
P
.
Sigmoid
().
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
sigmoid
(
out
)
return
out
_x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_auto_parallel_activation
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
4
,
4
),
(
4
,
4
))
strategy2
=
None
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile_net
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录