Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
c44e1271
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c44e1271
编写于
3月 31, 2020
作者:
Y
yangzhenzhang
提交者:
高东海
4月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add parallel ops for neg and batchmatmul
上级
fe1b7358
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
198 addition
and
0 deletion
+198
-0
mindspore/ccsrc/parallel/dynamic_creator.h
mindspore/ccsrc/parallel/dynamic_creator.h
+2
-0
mindspore/ccsrc/parallel/ops_info/activation_info.h
mindspore/ccsrc/parallel/ops_info/activation_info.h
+7
-0
mindspore/ccsrc/parallel/ops_info/matmul_info.h
mindspore/ccsrc/parallel/ops_info/matmul_info.h
+8
-0
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+2
-0
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+2
-0
tests/ut/python/parallel/test_batch_matmul.py
tests/ut/python/parallel/test_batch_matmul.py
+93
-0
tests/ut/python/parallel/test_neg.py
tests/ut/python/parallel/test_neg.py
+84
-0
未找到文件。
mindspore/ccsrc/parallel/dynamic_creator.h
浏览文件 @
c44e1271
...
...
@@ -123,6 +123,8 @@ REGISTER(ReLUInfo);
REGISTER
(
GatherV2Info
);
REGISTER
(
SqrtInfo
);
REGISTER
(
GetNextInfo
);
REGISTER
(
NegInfo
);
REGISTER
(
BatchMatMulInfo
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/activation_info.h
浏览文件 @
c44e1271
...
...
@@ -167,6 +167,13 @@ class SqrtInfo : public ActivationOther {
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
SqrtInfo
()
override
=
default
;
};
class
NegInfo
:
public
ActivationOther
{
public:
NegInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
NegInfo
()
override
=
default
;
};
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_
mindspore/ccsrc/parallel/ops_info/matmul_info.h
浏览文件 @
c44e1271
...
...
@@ -87,6 +87,14 @@ class MatMulInfo : public MatMul {
:
MatMul
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
MatMulInfo
()
override
=
default
;
};
class
BatchMatMulInfo
:
public
MatMul
{
public:
BatchMatMulInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
MatMul
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
BatchMatMulInfo
()
override
=
default
;
};
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
c44e1271
...
...
@@ -188,6 +188,8 @@ constexpr char SQRT[] = "Sqrt";
constexpr
char
ASSIGN
[]
=
"Assign"
;
constexpr
char
GET_NEXT
[]
=
"GetNext"
;
constexpr
char
SQUEEZE
[]
=
"Squeeze"
;
constexpr
char
Neg
[]
=
"Neg"
;
constexpr
char
BATCH_MATMUL
[]
=
"BatchMatMul"
;
// Parallel don't care
constexpr
char
TUPLE_GETITEM
[]
=
"tuple_getitem"
;
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
c44e1271
...
...
@@ -101,6 +101,8 @@ std::vector<std::string> splittable_op_ = {MATMUL,
SQRT
,
GET_NEXT
,
CAST
,
Neg
,
BATCH_MATMUL
,
SQUEEZE
};
std
::
vector
<
std
::
string
>
elementwise_op_
=
{
ACTIVATION
,
GELU
,
TANH
,
SOFTMAX
,
LOG_SOFTMAX
,
RELU
,
SQRT
,
...
...
tests/ut/python/parallel/test_batch_matmul.py
0 → 100644
浏览文件 @
c44e1271
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
_executor
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
,
batch_matmul_weight
,
transpose_b
=
False
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
batch_matmul
=
P
.
BatchMatMul
(
transpose_b
=
transpose_b
).
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
self
.
batch_matmul_weight
=
Parameter
(
batch_matmul_weight
,
"w2"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
batch_matmul
(
out
,
self
.
batch_matmul_weight
)
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w2
=
Tensor
(
np
.
ones
([
128
,
32
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
16
]),
dtype
=
ms
.
float32
)
def
compile
(
net
):
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_batch_matmul_data_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
net
=
Net
(
_w1
,
_w2
,
False
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_batch_matmul_model_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
1
,
1
,
1
),
(
1
,
1
,
1
))
strategy2
=
((
1
,
1
,
1
),
(
1
,
1
,
16
))
net
=
Net
(
_w1
,
_w2
,
False
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_batch_matmul_hybrid_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
net
=
Net
(
_w1
,
_w2
,
False
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_batch_matmul_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
net
=
Net
(
_w1
,
_w2
,
False
)
compile
(
net
)
def
test_batch_matmul_repeat_calc
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
4
),
(
2
,
2
,
4
))
strategy2
=
((
1
,
2
,
2
),
(
1
,
2
,
2
))
net
=
Net
(
_w1
,
_w2
,
False
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_batch_matmul_transpose_b
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
4
),
(
2
,
2
,
4
))
strategy2
=
((
1
,
2
,
2
),
(
1
,
2
,
2
))
net
=
Net
(
_w1
,
_w2
,
True
,
strategy1
,
strategy2
)
compile
(
net
)
tests/ut/python/parallel/test_neg.py
0 → 100644
浏览文件 @
c44e1271
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
_executor
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
neg
=
P
.
Neg
().
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
neg
(
out
)
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile
(
net
):
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_neg_data_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
16
,
1
,
1
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_neg_model_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
1
,
1
,
16
),
(
1
,
1
,
16
))
strategy2
=
((
1
,
1
,
16
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_neg_hybrid_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
4
),
(
2
,
2
,
4
))
strategy2
=
((
2
,
2
,
4
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_neg_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
net
=
Net
(
_w1
)
compile
(
net
)
def
test_neg_repeat_calc
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
4
),
(
2
,
2
,
4
))
strategy2
=
((
1
,
2
,
2
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录