未验证 提交 6612068e 编写于 作者: W whs 提交者: GitHub

Merge pull request #8114 from wanghaoshuang/fix_ctc_align

Make CTC align op support for empty output
......@@ -80,6 +80,14 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
// resize output dims
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
if (host_out_lod0.back() == 0) {
output->Resize({1, 1});
output->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(ctx.template device_context<platform::CUDADeviceContext>(),
output, -1);
}
}
};
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <string.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......@@ -65,9 +67,14 @@ class CTCAlignKernel : public framework::OpKernel<T> {
framework::LoD output_lod;
output_lod.push_back(output_lod0);
output->set_lod(output_lod);
// resize output dims
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
// for empty sequence
if (output_lod0.back() == 0) {
output->Resize({1, 1});
output_data = output->mutable_data<T>(ctx.GetPlace());
output_data[0] = -1;
}
}
};
......
......@@ -2525,7 +2525,8 @@ def ctc_greedy_decoder(input, blank, name=None):
interval [0, num_classes + 1).
Returns:
Variable: CTC greedy decode result.
Variable: CTC greedy decode result. If all the sequences in result were
empty, the result LoDTensor will be [-1] with LoD [[0]] and dims [1, 1].
Examples:
.. code-block:: python
......
......@@ -31,6 +31,8 @@ def CTCAlign(input, lod, blank, merge_repeated):
result.append(token)
prev_token = token
result = np.array(result).reshape([len(result), 1]).astype("int32")
if len(result) == 0:
result = np.array([-1])
return result
......@@ -72,5 +74,14 @@ class TestCTCAlignOpCase1(TestCTCAlignOp):
[19, 1]).astype("int32")
class TestCTCAlignOpCase2(TestCTCAlignOp):
def config(self):
self.op_type = "ctc_align"
self.input_lod = [[0, 4]]
self.blank = 0
self.merge_repeated = True
self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册