提交 10abb684 编写于 作者: L liuxiao

add ops CTCLoss

上级 fdad9135
......@@ -196,6 +196,8 @@ const char kNameBatchToSpace[] = "BatchToSpace";
const char kNameAtan2[] = "Atan2";
const char kNameApplyRMSProp[] = "ApplyRMSProp";
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
const char kNameL2Loss[] = "L2Loss";
const char kNameCTCLoss[] = "CTCLoss";
// -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
......@@ -391,7 +393,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
{string(kNameAtan2), ADPT_DESC(Atan2)},
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)},
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)}};
#ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
......
......@@ -1227,6 +1227,22 @@ INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)},
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
// L2Loss
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP;
OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}};
// CTCLoss
INPUT_MAP(CTCLoss) = {{1, INPUT_DESC(inputs)},
{2, INPUT_DESC(labels_indices)},
{3, INPUT_DESC(labels_values)},
{4, INPUT_DESC(sequence_length)}};
ATTR_MAP(CTCLoss) = {
{"preprocess_collapse_repeated", ATTR_DESC(preprocess_collapse_repeated, AnyTraits<bool>())},
{"ctc_merge_repeated", ATTR_DESC(ctc_merge_repeated, AnyTraits<bool>())},
{"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits<bool>())}};
OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}};
#ifdef ENABLE_GE
// Print
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
......
......@@ -465,6 +465,10 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
DECLARE_OP_ADAPTER(L2Loss)
DECLARE_OP_USE_OUTPUT(L2Loss)
DECLARE_OP_ADAPTER(CTCLoss)
DECLARE_OP_USE_OUTPUT(CTCLoss)
#ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print)
......
......@@ -668,3 +668,16 @@ def get_bprop_dropout(self):
return (dx,)
return bprop
@bprop_getters.register(P.CTCLoss)
def get_bprop_ctc_loss(self):
"""Grad definition for `CTCLoss` operation"""
expand = P.ExpandDims()
def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
grad_loss = out[1]
grad = grad_loss * expand(dout[0], -1)
return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
return bprop
......@@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
DropoutDoMask, DropoutGrad, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss,
LogSoftmax,
MaxPool,
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
......@@ -172,6 +172,7 @@ __all__ = [
'Reciprocal',
'SmoothL1Loss',
'L2Loss',
'CTCLoss',
'ReduceAll',
'ScalarToArray',
'ScalarToTensor',
......
......@@ -1564,7 +1564,7 @@ class L2Loss(PrimitiveWithInfer):
def infer_dtype(self, x_type):
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
valid_types = [mstype.float16, mstype.float32, mstype.double]
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name)
return x_type
......@@ -2871,3 +2871,78 @@ class DropoutGrad(PrimitiveWithInfer):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
return dy_dtype
class CTCLoss(PrimitiveWithInfer):
"""
Calculates the CTC(Connectionist Temporal Classification) loss. Also calculates the gradient.
Args:
preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation.
Default: False.
ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged
and are interpreted as individual labels. This is a simplfied version if CTC.
Default: True.
ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored.
Default: False.
Inputs:
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved.
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
type must be int32. `labels_values[i]` must in the range of `[0, num_class)`.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
The type must be int32. Each value in the tensor should not greater than `max_time`.
Outputs:
- **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. Has the same
type with `inputs`.
- **gradient** (Tensor) - The gradient of `loss`. Has the same type and shape with `inputs`.
Examples:
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
>>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64)
>>> labels_values = Tensor(np.array([2, 2]), mindspore.int32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> ctc_loss = P.CTCloss()
>>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
"""
@prim_attr_register
def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=False,
ignore_longer_outputs_than_inputs=False):
self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"],
outputs=["loss", "gradient"])
validator.check_value_type("preprocess_collapse_repeated", preprocess_collapse_repeated, [bool], self.name)
self.preprocess_collapse_repeated_ = preprocess_collapse_repeated
self.ctc_merge_repeated_ = validator.check_value_type("ctc_merge_repeated", ctc_merge_repeated,
[bool], self.name)
validator.check_value_type("ignore_longer_outputs_than_inputs",
ignore_longer_outputs_than_inputs, [bool], self.name)
self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs
def infer_shape(self, inputs, labels_indices, labels_values, sequence_length):
validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name)
validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name)
validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name)
validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name)
validator.check('labels_indices size', labels_indices[0], 'labels_values size',
labels_values[0], Rel.EQ, self.name)
validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size',
sequence_length[0], Rel.EQ, self.name)
batch_size = []
batch_size.append(inputs[1])
return batch_size, inputs
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name)
validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name)
validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name)
validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name)
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)
return inputs, inputs
......@@ -909,6 +909,13 @@ test_case_nn_ops = [
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': [3, 3],
'skip': ['backward']}),
('CTCLoss', {
'block': P.CTCLoss(),
'desc_inputs': [Tensor(np.ones([6, 4, 6]).astype(np.float32)),
Tensor(np.array([[0, 1], [1, 0], [2, 3], [3, 2]]).astype(np.int64)),
Tensor(np.array([1, 2, 3, 4]).astype(np.int32)),
Tensor(np.array([6, 6, 6, 6]).astype(np.int32))],
'desc_bprop': [[4], [6, 4, 6]]}),
('L2Loss_1', {
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册