提交 2723e269 编写于 作者: L liuxiao93

Fix input validator of Assign.

上级 3bb04abc
...@@ -234,7 +234,7 @@ class Softsign(PrimitiveWithInfer): ...@@ -234,7 +234,7 @@ class Softsign(PrimitiveWithInfer):
\text{output} = \frac{\text{input_x}}{1 + \abs{\text{input_x}}}, \text{output} = \frac{\text{input_x}}{1 + \abs{\text{input_x}}},
Inputs: Inputs:
- **input_x** (Tensor) - The input tensor whose data type should be float. - **input_x** (Tensor) - The input tensor whose data type should be float16 or float32.
Outputs: Outputs:
Tensor, with the same type and shape as the `input_x`. Tensor, with the same type and shape as the `input_x`.
...@@ -255,7 +255,7 @@ class Softsign(PrimitiveWithInfer): ...@@ -255,7 +255,7 @@ class Softsign(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) validator.check_tensor_type_same({'input_x': input_x}, [mstype.float16, mstype.float32], self.name)
return input_x return input_x
...@@ -4730,19 +4730,19 @@ class CTCLoss(PrimitiveWithInfer): ...@@ -4730,19 +4730,19 @@ class CTCLoss(PrimitiveWithInfer):
preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation. preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation.
Default: False. Default: False.
ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged
and are interpreted as individual labels. This is a simplfied version if CTC. and are interpreted as individual labels. This is a simplfied version of CTC.
Default: True. Default: True.
ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored. ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored.
Default: False. Default: False.
Inputs: Inputs:
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels` :math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved. indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The - **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
type must be int32. `labels_values[i]` must in the range of `[0, num_class)`. type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
The type must be int32. Each value in the tensor should not greater than `max_time`. The type must be int32. Each value in the tensor should not greater than `max_time`.
......
...@@ -60,7 +60,9 @@ class Assign(PrimitiveWithInfer): ...@@ -60,7 +60,9 @@ class Assign(PrimitiveWithInfer):
return variable return variable
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
# Add a type validation later when we don't have to assign a value to RefKey. if variable != mstype.type_refkey:
validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name)
validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name)
return variable return variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册