Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7f1816c2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7f1816c2
编写于
4月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!382 Add parallel operator for SigmoidCrossEntropyWithLogits
Merge pull request !382 from yangzhenzhang/sigmoidloss
上级
8f6b941a
57cd9f81
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
95 addition
and
0 deletion
+95
-0
mindspore/ccsrc/parallel/dynamic_creator.h
mindspore/ccsrc/parallel/dynamic_creator.h
+1
-0
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
+9
-0
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+1
-0
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+1
-0
tests/ut/python/parallel/test_sigmoid_cross_entropy_with_logits.py
...python/parallel/test_sigmoid_cross_entropy_with_logits.py
+83
-0
未找到文件。
mindspore/ccsrc/parallel/dynamic_creator.h
浏览文件 @
7f1816c2
...
...
@@ -127,6 +127,7 @@ REGISTER(NegInfo);
REGISTER
(
BatchMatMulInfo
);
REGISTER
(
ExpandDimsInfo
);
REGISTER
(
SqueezeInfo
);
REGISTER
(
SigmoidCrossEntropyWithLogitsInfo
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
浏览文件 @
7f1816c2
...
...
@@ -120,6 +120,15 @@ class AssignSubInfo : public ArithmeticBase {
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
))
{}
~
AssignSubInfo
()
override
=
default
;
};
// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label.
class
SigmoidCrossEntropyWithLogitsInfo
:
public
ArithmeticBase
{
public:
SigmoidCrossEntropyWithLogitsInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
))
{}
~
SigmoidCrossEntropyWithLogitsInfo
()
override
=
default
;
};
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
7f1816c2
...
...
@@ -138,6 +138,7 @@ constexpr char ALL_GATHER[] = "AllGather";
constexpr
char
REDUCE_SCATTER
[]
=
"ReduceScatter"
;
constexpr
char
CONCAT
[]
=
"Concat"
;
constexpr
char
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
[]
=
"SoftmaxCrossEntropyWithLogits"
;
constexpr
char
SIGMOID_CROSS_ENTROPY_WITH_LOGITS
[]
=
"SigmoidCrossEntropyWithLogits"
;
constexpr
char
MATMUL
[]
=
"MatMul"
;
constexpr
char
GELU
[]
=
"Gelu"
;
constexpr
char
TANH
[]
=
"Tanh"
;
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
7f1816c2
...
...
@@ -78,6 +78,7 @@ std::vector<std::string> splittable_op_ = {MATMUL,
FUSE_BATCH_NORM
,
POOLING
,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
SIGMOID_CROSS_ENTROPY_WITH_LOGITS
,
MAX_POOL_WITH_ARGMAX
,
SIMPLE_MEAN
,
FLATTEN
,
...
...
tests/ut/python/parallel/test_sigmoid_cross_entropy_with_logits.py
0 → 100644
浏览文件 @
7f1816c2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
_executor
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
loss
=
P
.
SigmoidCrossEntropyWithLogits
().
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
loss
(
out
,
b
)
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
def
compile
(
net
):
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_sigmoid_cross_entropy_with_logits_data_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
),
(
16
,
1
))
strategy2
=
((
16
,
1
),
(
16
,
1
))
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_sigmoid_cross_entropy_with_logits_model_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
1
,
16
),
(
1
,
16
))
strategy2
=
((
1
,
16
),
(
1
,
16
))
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_sigmoid_cross_entropy_with_logits_hybrid_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
8
),
(
2
,
8
))
strategy2
=
((
2
,
8
),
(
2
,
8
))
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_sigmoid_cross_entropy_with_logits_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
net
=
Net
(
_w1
)
compile
(
net
)
def
test_sigmoid_cross_entropy_with_logits_repeat_calc
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
8
),
(
2
,
8
))
strategy2
=
((
2
,
2
),
(
2
,
2
))
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录