importmxnetasmxdef_add_warp_ctc_loss(pred,seq_len,num_label,label):""" Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """label=mx.sym.Reshape(data=label,shape=(-1,))label=mx.sym.Cast(data=label,dtype='int32')returnmx.sym.WarpCTC(data=pred,label=label,label_length=num_label,input_length=seq_len)def_add_mxnet_ctc_loss(pred,seq_len,label):""" Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """pred_ctc=mx.sym.Reshape(data=pred,shape=(-4,seq_len,-1,0))loss=mx.sym.contrib.ctc_loss(data=pred_ctc,label=label)ctc_loss=mx.sym.MakeLoss(loss)softmax_class=mx.symbol.SoftmaxActivation(data=pred)softmax_loss=mx.sym.MakeLoss(softmax_class)softmax_loss=mx.sym.BlockGrad(softmax_loss)returnmx.sym.Group([softmax_loss,ctc_loss])defadd_ctc_loss(pred,seq_len,num_label,loss_type):""" Adds CTC loss on top of pred symbol and returns the resulting symbol """label=mx.sym.Variable('label')ifloss_type=='warpctc':print("Using WarpCTC Loss")sm=_add_warp_ctc_loss(pred,seq_len,num_label,label)else:print("Using MXNet CTC Loss")assertloss_type=='ctc'sm=_add_mxnet_ctc_loss(pred,seq_len,label)returnsm