From c60f039076eafc678a807088f1872fb0f437c573 Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Thu, 24 Sep 2020 10:24:18 +0800 Subject: [PATCH] modified doc of warpctc ctc_loss and CTCLoss (#2668) * add cn doc of trace API * modified doc of warpctc ctc_loss and CTCLoss, add support float64 of log_probs --- doc/fluid/api_cn/layers_cn/warpctc_cn.rst | 2 +- doc/fluid/api_cn/nn_cn/functional_cn/ctc_loss_cn.rst | 2 +- doc/fluid/api_cn/nn_cn/loss_cn/CTCLoss_cn.rst | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/fluid/api_cn/layers_cn/warpctc_cn.rst b/doc/fluid/api_cn/layers_cn/warpctc_cn.rst index dbb4c4ca8..5295e1586 100644 --- a/doc/fluid/api_cn/layers_cn/warpctc_cn.rst +++ b/doc/fluid/api_cn/layers_cn/warpctc_cn.rst @@ -14,7 +14,7 @@ warpctc 该OP用于计算 `CTC loss `_ 。该OP的底层调用了第三方 `baidu-research::warp-ctc `_ 的实现。 参数: - - **input** (Variable) - 可以是3-D Tensor或2-D LoDTensor。当输入类型是3-D Tensor时,则表示输入是经过padding的定长序列,其 shape 必须是 ``[seq_length, batch_size, num_classes + 1]`` 。当输入类型是2-D LoDTensor时,则表示输入为变长序列,其shape必须为 ``[Lp,num_classes+1]`` , ``Lp`` 是所有输入序列长度之和。以上 shape 中的 ``num_classes`` 是实际类别数,不包括空白标签。该输入不需要经过 softmax 操作,因为该OP的内部对 ``input`` 做了 softmax 操作。数据类型仅支持float32。 + - **input** (Variable) - 可以是3-D Tensor或2-D LoDTensor。当输入类型是3-D Tensor时,则表示输入是经过padding的定长序列,其 shape 必须是 ``[seq_length, batch_size, num_classes + 1]`` 。当输入类型是2-D LoDTensor时,则表示输入为变长序列,其shape必须为 ``[Lp,num_classes+1]`` , ``Lp`` 是所有输入序列长度之和。以上 shape 中的 ``num_classes`` 是实际类别数,不包括空白标签。该输入不需要经过 softmax 操作,因为该OP的内部对 ``input`` 做了 softmax 操作。数据类型支持 float32 和 float64。 - **label** (Variable) - 可以是3-D Tensor或2-D LoDTensor,需要跟 ``input`` 保持一致。当输入类型为3-D Tensor时,表示输入是经过 padding 的定长序列,其 shape 为 ``[batch_size, label_length]`` ,其中, ``label_length`` 是最长的 label 序列的长度。当输入类型是2-D LoDTensor时,则表示输入为变长序列,其shape必须为 ``[Lp, 1]`` , 其中 ``Lp`` 是所有 label 序列的长度和。 ``label`` 中的数值为字符ID。数据类型支持int32。 - **blank** (int,可选) - 空格标记的ID,其取值范围为 ``[0,num_classes+1)`` 。数据类型支持int32。缺省值为0。 - **norm_by_times** (bool,可选) - 是否根据序列长度对梯度进行正则化。数据类型支持 bool 。缺省值为False。 diff --git a/doc/fluid/api_cn/nn_cn/functional_cn/ctc_loss_cn.rst b/doc/fluid/api_cn/nn_cn/functional_cn/ctc_loss_cn.rst index 2722525a6..292b6d82f 100644 --- a/doc/fluid/api_cn/nn_cn/functional_cn/ctc_loss_cn.rst +++ b/doc/fluid/api_cn/nn_cn/functional_cn/ctc_loss_cn.rst @@ -8,7 +8,7 @@ ctc_loss 参数 ::::::::: - - **log_probs** (Tensor): - 经过 padding 的概率序列,其 shape 必须是 [max_logit_length, batch_size, num_classes + 1]。其中 max_logit_length 是最长输入序列的长度。该输入不需要经过 softmax 操作,因为该 OP 的内部对 input 做了 softmax 操作。数据类型仅支持float32。 + - **log_probs** (Tensor): - 经过 padding 的概率序列,其 shape 必须是 [max_logit_length, batch_size, num_classes + 1]。其中 max_logit_length 是最长输入序列的长度。该输入不需要经过 softmax 操作,因为该 OP 的内部对 input 做了 softmax 操作。数据类型支持 float32 和 float64。 - **labels** (Tensor): - 经过 padding 的标签序列,其 shape 为 [batch_size, max_label_length],其中 max_label_length 是最长的 label 序列的长度。数据类型支持int32。 - **input_lengths** (Tensor): - 表示输入 ``log_probs`` 数据中每个序列的长度,shape为 [batch_size] 。数据类型支持int64。 - **label_lengths** (Tensor): - 表示 label 中每个序列的长度,shape为 [batch_size] 。数据类型支持int64。 diff --git a/doc/fluid/api_cn/nn_cn/loss_cn/CTCLoss_cn.rst b/doc/fluid/api_cn/nn_cn/loss_cn/CTCLoss_cn.rst index a2ad0c28a..708f57971 100644 --- a/doc/fluid/api_cn/nn_cn/loss_cn/CTCLoss_cn.rst +++ b/doc/fluid/api_cn/nn_cn/loss_cn/CTCLoss_cn.rst @@ -13,7 +13,7 @@ CTCLoss 形状 ::::::::: - - **log_probs** (Tensor): - 经过 padding 的概率序列,其 shape 必须是 [max_logit_length, batch_size, num_classes + 1]。其中 max_logit_length 是最长输入序列的长度。该输入不需要经过 softmax 操作,因为该 OP 的内部对 input 做了 softmax 操作。数据类型仅支持float32。 + - **log_probs** (Tensor): - 经过 padding 的概率序列,其 shape 必须是 [max_logit_length, batch_size, num_classes + 1]。其中 max_logit_length 是最长输入序列的长度。该输入不需要经过 softmax 操作,因为该 OP 的内部对 input 做了 softmax 操作。数据类型支持 float32 和 float64。 - **labels** (Tensor): - 经过 padding 的标签序列,其 shape 为 [batch_size, max_label_length],其中 max_label_length 是最长的 label 序列的长度。数据类型支持int32。 - **input_lengths** (Tensor): - 表示输入 ``log_probs`` 数据中每个序列的长度,shape为 [batch_size] 。数据类型支持int64。 - **label_lengths** (Tensor): - 表示 label 中每个序列的长度,shape为 [batch_size] 。数据类型支持int64。 -- GitLab