From 2723e2698d9c18c4411733373ac6803ced8b6767 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Tue, 14 Jul 2020 10:06:11 +0800 Subject: [PATCH] Fix input validator of Assign. --- mindspore/ops/operations/nn_ops.py | 12 ++++++------ mindspore/ops/operations/other_ops.py | 4 +++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 0d2499c0a..e97c4c91c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -234,7 +234,7 @@ class Softsign(PrimitiveWithInfer): \text{output} = \frac{\text{input_x}}{1 + \abs{\text{input_x}}}, Inputs: - - **input_x** (Tensor) - The input tensor whose data type should be float. + - **input_x** (Tensor) - The input tensor whose data type should be float16 or float32. Outputs: Tensor, with the same type and shape as the `input_x`. @@ -255,7 +255,7 @@ class Softsign(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) + validator.check_tensor_type_same({'input_x': input_x}, [mstype.float16, mstype.float32], self.name) return input_x @@ -4730,19 +4730,19 @@ class CTCLoss(PrimitiveWithInfer): preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation. Default: False. ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged - and are interpreted as individual labels. This is a simplfied version if CTC. + and are interpreted as individual labels. This is a simplfied version of CTC. Default: True. ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored. Default: False. Inputs: - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is - :math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels` - indicates the number of actual labels. Blank labels are reserved. + :math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels` + indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. - **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The - type must be int32. `labels_values[i]` must in the range of `[0, num_class)`. + type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`. - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. The type must be int32. Each value in the tensor should not greater than `max_time`. diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 7221f7790..a58403f88 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -60,7 +60,9 @@ class Assign(PrimitiveWithInfer): return variable def infer_dtype(self, variable, value): - # Add a type validation later when we don't have to assign a value to RefKey. + if variable != mstype.type_refkey: + validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) + validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name) return variable -- GitLab