Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
72d2fc74
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
72d2fc74
编写于
8月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4050 fix unsortedgrad clipbynorm boundingboxdecode
Merge pull request !4050 from fangzehua/unsortedgrad
上级
ae6f9524
c16a22c6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
20 addition
and
33 deletion
+20
-33
mindspore/nn/layer/basic.py
mindspore/nn/layer/basic.py
+8
-5
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+0
-10
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+0
-14
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+12
-4
未找到文件。
mindspore/nn/layer/basic.py
浏览文件 @
72d2fc74
...
...
@@ -250,6 +250,10 @@ def _is_equal_one(x):
return
False
return
bool
(
x
.
asnumpy
().
mean
()
==
1.0
)
@
constexpr
def
_dtype_check
(
x_dtype
):
if
x_dtype
not
in
[
mstype
.
float32
,
mstype
.
float16
]:
raise
TypeError
(
"The input type must be float32 or float16."
)
class
ClipByNorm
(
Cell
):
r
"""
...
...
@@ -264,12 +268,11 @@ class ClipByNorm(Cell):
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
Inputs:
- **input** (Tensor) - Tensor of shape N-D.
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)` and of
the same type as the input Tensor.
- **input** (Tensor) - Tensor of shape N-D. The type should be float32 or float16.
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
Outputs:
Tensor, clipped tensor with the same shape as the input.
Tensor, clipped tensor with the same shape as the input
, whose type is float32
.
Examples:
>>> net = nn.ClipByNorm()
...
...
@@ -300,10 +303,10 @@ class ClipByNorm(Cell):
l2sum
=
self
.
cast
(
self
.
reduce_sum
(
mul_x
),
mstype
.
float32
)
cond
=
self
.
greater_
(
l2sum
,
0
)
ones_
=
self
.
fill
(
self
.
dtype
(
cond
),
self
.
shape
(
cond
),
1.0
)
l2sum_safe
=
self
.
select_
(
cond
,
l2sum
,
self
.
cast
(
ones_
,
self
.
dtype
(
l2sum
)))
l2norm
=
self
.
select_
(
cond
,
self
.
sqrt
(
l2sum_safe
),
l2sum
)
_dtype_check
(
self
.
dtype
(
x
))
if
_is_equal_one
(
clip_norm
):
intermediate
=
x
else
:
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
72d2fc74
...
...
@@ -827,13 +827,3 @@ def get_bprop_unique(self):
dx
=
op
(
dout
,
out
)
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
P
.
UnsortedSegmentSum
)
def
get_bprop_unsorted_segment_sum
(
self
):
"""Generate bprop for UnsortedSegmentSum"""
op
=
G
.
UnsortedSegmentSumGrad
()
def
bprop
(
x
,
segment_ids
,
num_segments
,
out
,
dout
):
dx
=
op
(
dout
,
segment_ids
)
return
(
dx
,
zeros_like
(
segment_ids
),
zeros_like
(
num_segments
))
return
bprop
mindspore/ops/operations/_grad_ops.py
浏览文件 @
72d2fc74
...
...
@@ -502,20 +502,6 @@ class UniqueGrad(Primitive):
raise
NotImplementedError
class
UnsortedSegmentSumGrad
(
PrimitiveWithInfer
):
"""Gradients of UnsortedSegmentSum operation."""
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'grads'
,
'ids'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
grads
,
ids
):
return
ids
+
grads
[
len
(
ids
):]
def
infer_dtype
(
self
,
grads
,
ids
):
return
grads
class
BNTrainingReduceGrad
(
PrimitiveWithInfer
):
"""Gradients of FusedBatchNorm operation."""
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
72d2fc74
...
...
@@ -93,8 +93,12 @@ class BoundingBoxEncode(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
means
=
(
0.0
,
0.0
,
0.0
,
0.0
),
stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
validator
.
check_value_type
(
'means'
,
means
,
[
tuple
],
self
.
name
)
validator
.
check_value_type
(
'stds'
,
stds
,
[
tuple
],
self
.
name
)
validator
.
check_value_type
(
'means'
,
means
,
[
tuple
,
list
],
self
.
name
)
validator
.
check_value_type
(
'stds'
,
stds
,
[
tuple
,
list
],
self
.
name
)
for
i
,
value
in
enumerate
(
means
):
validator
.
check_value_type
(
"means[%d]"
%
i
,
value
,
[
float
],
self
.
name
)
for
i
,
value
in
enumerate
(
stds
):
validator
.
check_value_type
(
"stds[%d]"
%
i
,
value
,
[
float
],
self
.
name
)
validator
.
check_integer
(
"means len"
,
len
(
means
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"stds len"
,
len
(
stds
),
4
,
Rel
.
EQ
,
self
.
name
)
...
...
@@ -143,8 +147,12 @@ class BoundingBoxDecode(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
max_shape
,
means
=
(
0.0
,
0.0
,
0.0
,
0.0
),
stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
wh_ratio_clip
=
0.016
):
validator
.
check_value_type
(
'means'
,
means
,
[
tuple
],
self
.
name
)
validator
.
check_value_type
(
'stds'
,
stds
,
[
tuple
],
self
.
name
)
validator
.
check_value_type
(
'means'
,
means
,
[
tuple
,
list
],
self
.
name
)
validator
.
check_value_type
(
'stds'
,
stds
,
[
tuple
,
list
],
self
.
name
)
for
i
,
value
in
enumerate
(
means
):
validator
.
check_value_type
(
"means[%d]"
%
i
,
value
,
[
float
],
self
.
name
)
for
i
,
value
in
enumerate
(
stds
):
validator
.
check_value_type
(
"stds[%d]"
%
i
,
value
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
'wh_ratio_clip'
,
wh_ratio_clip
,
[
float
],
self
.
name
)
validator
.
check_integer
(
"means len"
,
len
(
means
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"stds len"
,
len
(
stds
),
4
,
Rel
.
EQ
,
self
.
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录