Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2723e269
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2723e269
编写于
7月 14, 2020
作者:
L
liuxiao93
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix input validator of Assign.
上级
3bb04abc
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
9 addition
and
7 deletion
+9
-7
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+6
-6
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+3
-1
未找到文件。
mindspore/ops/operations/nn_ops.py
浏览文件 @
2723e269
...
@@ -234,7 +234,7 @@ class Softsign(PrimitiveWithInfer):
...
@@ -234,7 +234,7 @@ class Softsign(PrimitiveWithInfer):
\text{output} = \frac{\text{input_x}}{1 + \abs{\text{input_x}}},
\text{output} = \frac{\text{input_x}}{1 + \abs{\text{input_x}}},
Inputs:
Inputs:
- **input_x** (Tensor) - The input tensor whose data type should be float.
- **input_x** (Tensor) - The input tensor whose data type should be float
16 or float32
.
Outputs:
Outputs:
Tensor, with the same type and shape as the `input_x`.
Tensor, with the same type and shape as the `input_x`.
...
@@ -255,7 +255,7 @@ class Softsign(PrimitiveWithInfer):
...
@@ -255,7 +255,7 @@ class Softsign(PrimitiveWithInfer):
return
input_x
return
input_x
def
infer_dtype
(
self
,
input_x
):
def
infer_dtype
(
self
,
input_x
):
validator
.
check_tensor_type_same
({
'input_x'
:
input_x
},
mstype
.
float_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
'input_x'
:
input_x
},
[
mstype
.
float16
,
mstype
.
float32
]
,
self
.
name
)
return
input_x
return
input_x
...
@@ -4730,19 +4730,19 @@ class CTCLoss(PrimitiveWithInfer):
...
@@ -4730,19 +4730,19 @@ class CTCLoss(PrimitiveWithInfer):
preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation.
preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation.
Default: False.
Default: False.
ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged
ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged
and are interpreted as individual labels. This is a simplfied version
i
f CTC.
and are interpreted as individual labels. This is a simplfied version
o
f CTC.
Default: True.
Default: True.
ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored.
ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored.
Default: False.
Default: False.
Inputs:
Inputs:
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_class
)`. `num_clas
s` should be `num_labels + 1` classes, `num_labels`
:math:`(max_time, batch_size, num_class
es)`. `num_classe
s` should be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved.
indicates the number of actual labels. Blank labels are reserved.
Default blank label is `num_classes - 1`.
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
type must be int32. `labels_values[i]` must in the range of `[0, num_class)`.
type must be int32. `labels_values[i]` must in the range of `[0, num_class
es
)`.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
The type must be int32. Each value in the tensor should not greater than `max_time`.
The type must be int32. Each value in the tensor should not greater than `max_time`.
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
2723e269
...
@@ -60,7 +60,9 @@ class Assign(PrimitiveWithInfer):
...
@@ -60,7 +60,9 @@ class Assign(PrimitiveWithInfer):
return
variable
return
variable
def
infer_dtype
(
self
,
variable
,
value
):
def
infer_dtype
(
self
,
variable
,
value
):
# Add a type validation later when we don't have to assign a value to RefKey.
if
variable
!=
mstype
.
type_refkey
:
validator
.
check_tensor_type_same
({
"variable"
:
variable
},
mstype
.
number_type
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"value"
:
value
},
mstype
.
number_type
,
self
.
name
)
return
variable
return
variable
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录