Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3045af13
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3045af13
编写于
9月 05, 2020
作者:
L
liuxiao93
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix nn.CentralCrop calulation result.
上级
d05c22a1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
6 deletion
+10
-6
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+7
-4
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+3
-2
未找到文件。
mindspore/nn/layer/image.py
浏览文件 @
3045af13
...
...
@@ -386,15 +386,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name):
raise
ValueError
(
f
"
{
func_name
}
{
param_name
}
should be 3d or 4d, but got shape
{
input_shape
}
"
)
@
constexpr
def
_get_bbox
(
rank
,
shape
,
central_fraction
):
def
_get_bbox
(
rank
,
shape
,
size_h
,
size_w
):
"""get bbox start and size for slice"""
if
rank
==
3
:
c
,
h
,
w
=
shape
else
:
n
,
c
,
h
,
w
=
shape
bbox_h_start
=
int
(
np
.
round
((
float
(
h
)
-
float
(
h
)
*
central_fraction
)
/
2
)
)
bbox_w_start
=
int
(
np
.
round
((
float
(
w
)
-
float
(
w
)
*
central_fraction
)
/
2
)
)
bbox_h_start
=
int
(
(
float
(
h
)
-
size_h
)
/
2
)
bbox_w_start
=
int
(
(
float
(
w
)
-
size_w
)
/
2
)
bbox_h_size
=
h
-
bbox_h_start
*
2
bbox_w_size
=
w
-
bbox_w_start
*
2
...
...
@@ -436,12 +436,15 @@ class CentralCrop(Cell):
def
construct
(
self
,
image
):
image_shape
=
F
.
shape
(
image
)
rank
=
len
(
image_shape
)
h
,
w
=
image_shape
[
-
2
],
image_shape
[
-
1
]
if
not
rank
in
(
3
,
4
):
return
_raise_dims_rank_error
(
image_shape
,
"image"
,
self
.
cls_name
)
if
self
.
central_fraction
==
1.0
:
return
image
bbox_begin
,
bbox_size
=
_get_bbox
(
rank
,
image_shape
,
self
.
central_fraction
)
size_h
=
self
.
central_fraction
*
h
size_w
=
self
.
central_fraction
*
w
bbox_begin
,
bbox_size
=
_get_bbox
(
rank
,
image_shape
,
size_h
,
size_w
)
image
=
self
.
slice
(
image
,
bbox_begin
,
bbox_size
)
return
image
mindspore/ops/operations/nn_ops.py
浏览文件 @
3045af13
...
...
@@ -5298,7 +5298,7 @@ class CTCLoss(PrimitiveWithInfer):
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
Data type must be float32 or float64.
Data type must be float
16, float
32 or float64.
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
- **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time.
...
...
@@ -5348,7 +5348,8 @@ class CTCLoss(PrimitiveWithInfer):
return
batch_size
,
inputs
def
infer_dtype
(
self
,
inputs
,
labels_indices
,
labels_values
,
sequence_length
):
validator
.
check_tensor_type_same
({
"inputs_dtype"
:
inputs
},
[
mstype
.
float32
,
mstype
.
double
],
self
.
name
)
valid_dtype
=
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
double
]
validator
.
check_tensor_type_same
({
"inputs_dtype"
:
inputs
},
valid_dtype
,
self
.
name
)
validator
.
check_tensor_type_same
({
"labels_indices_dtype"
:
labels_indices
},
[
mstype
.
int64
],
self
.
name
)
validator
.
check_tensor_type_same
({
"labels_values_dtype"
:
labels_values
},
[
mstype
.
int32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"sequence_length_dtype"
:
sequence_length
},
[
mstype
.
int32
],
self
.
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录