Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
ff710dde
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ff710dde
编写于
5月 12, 2020
作者:
Z
zhaojichen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize SoftmaxCrossEntropWithLogits and momentum
上级
98112d1a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
20 addition
and
7 deletion
+20
-7
mindspore/nn/loss/loss.py
mindspore/nn/loss/loss.py
+13
-3
mindspore/nn/optim/momentum.py
mindspore/nn/optim/momentum.py
+5
-2
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+2
-2
未找到文件。
mindspore/nn/loss/loss.py
浏览文件 @
ff710dde
...
...
@@ -18,6 +18,8 @@ from mindspore.common.tensor import Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.nn.cell
import
Cell
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
...
import
context
...
...
@@ -215,6 +217,8 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
sparse (bool): Specifies whether labels use sparse format or not. Default: False.
reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None,
do not reduction. Default: None.
smooth_factor (float): Label smoothing factor. It is a optional input. Default: 0.
num_classes (int): The number of classes in the task. It is a optional input Default: 2.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
...
...
@@ -235,14 +239,20 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
def
__init__
(
self
,
is_grad
=
True
,
sparse
=
False
,
reduction
=
None
):
reduction
=
None
,
smooth_factor
=
0
,
num_classes
=
2
):
super
(
SoftmaxCrossEntropyWithLogits
,
self
).
__init__
(
reduction
)
self
.
is_grad
=
is_grad
self
.
sparse
=
sparse
validator
.
check_integer
(
"num_classes"
,
num_classes
,
1
,
Rel
.
GT
,
self
.
cls_name
)
validator
.
check_number_range
(
"smooth_factor"
,
smooth_factor
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
cls_name
)
self
.
smooth_factor
=
smooth_factor
self
.
num_classes
=
num_classes
self
.
softmax_cross_entropy
=
P
.
SoftmaxCrossEntropyWithLogits
()
self
.
one_hot
=
P
.
OneHot
()
self
.
on_value
=
Tensor
(
1.0
,
mstype
.
float32
)
self
.
off_value
=
Tensor
(
0.0
,
mstype
.
float32
)
self
.
on_value
=
Tensor
(
1.0
-
self
.
smooth_factor
,
mstype
.
float32
)
self
.
off_value
=
Tensor
(
1.0
*
self
.
smooth_factor
/
(
self
.
num_classes
-
1
)
,
mstype
.
float32
)
self
.
is_cpugpu
=
context
.
get_context
(
'device_target'
)
in
[
"CPU"
,
"GPU"
]
if
self
.
is_cpugpu
:
...
...
mindspore/nn/optim/momentum.py
浏览文件 @
ff710dde
...
...
@@ -17,6 +17,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.tensor
import
Tensor
import
mindspore.common.dtype
as
mstype
from
mindspore._checkparam
import
check_bool
from
.optimizer
import
Optimizer
momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
...
...
@@ -67,6 +68,7 @@ class Momentum(Optimizer):
momentum (float): Hyperparameter of type float, means momentum for the moving average.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
use_nesterov (bool): Enable Nesterov momentum. Default: False.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
...
...
@@ -95,15 +97,16 @@ class Momentum(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def
__init__
(
self
,
params
,
learning_rate
,
momentum
,
weight_decay
=
0.0
,
loss_scale
=
1.0
):
def
__init__
(
self
,
params
,
learning_rate
,
momentum
,
weight_decay
=
0.0
,
loss_scale
=
1.0
,
use_nesterov
=
False
):
super
(
Momentum
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
raise
ValueError
(
"momentum should be at least 0.0, but got momentum {}"
.
format
(
momentum
))
self
.
momentum
=
Parameter
(
Tensor
(
momentum
,
mstype
.
float32
),
name
=
"momentum"
)
self
.
params
=
self
.
parameters
self
.
use_nesterov
=
check_bool
(
use_nesterov
)
self
.
moments
=
self
.
params
.
clone
(
prefix
=
"moments"
,
init
=
'zeros'
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyMomentum
()
self
.
opt
=
P
.
ApplyMomentum
(
use_nesterov
=
self
.
use_nesterov
)
def
construct
(
self
,
gradients
):
params
=
self
.
params
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
ff710dde
...
...
@@ -1757,8 +1757,8 @@ class LayerNorm(Primitive):
- **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`.
The shape is :math:`(N, C)`.
- **
updated_gamma
** (Tensor) - Tensor of shape :math:`(C,)`.
- **
updated_beta
** (Tensor) - Tensor of shape :math:`(C,)`.
- **
mean
** (Tensor) - Tensor of shape :math:`(C,)`.
- **
variance
** (Tensor) - Tensor of shape :math:`(C,)`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录