From 326fa176ea6401f171e9325aa29fb0b5cf6f7a29 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Sun, 4 Feb 2018 22:45:47 +0800 Subject: [PATCH] Fix empty output tensor and add an unitest case --- paddle/operators/ctc_align_op.cu | 8 ++++++++ paddle/operators/ctc_align_op.h | 9 ++++++++- python/paddle/v2/fluid/tests/test_ctc_align.py | 11 +++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/paddle/operators/ctc_align_op.cu b/paddle/operators/ctc_align_op.cu index 2a970cd9fa9..918df83effb 100644 --- a/paddle/operators/ctc_align_op.cu +++ b/paddle/operators/ctc_align_op.cu @@ -80,6 +80,14 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel { // resize output dims output->Resize({static_cast(host_out_lod0.back()), 1}); + + if (host_out_lod0.back() == 0) { + output->Resize({1}); + output->mutable_data(ctx.GetPlace()); + math::SetConstant set_constant; + set_constant(ctx.template device_context(), + output, -1); + } } }; diff --git a/paddle/operators/ctc_align_op.h b/paddle/operators/ctc_align_op.h index fed89aa1e89..7a063870f3c 100644 --- a/paddle/operators/ctc_align_op.h +++ b/paddle/operators/ctc_align_op.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + namespace paddle { namespace operators { @@ -65,9 +67,14 @@ class CTCAlignKernel : public framework::OpKernel { framework::LoD output_lod; output_lod.push_back(output_lod0); output->set_lod(output_lod); - // resize output dims output->Resize({static_cast(output_lod0.back()), 1}); + // for empty sequence + if (output_lod0.back() == 0) { + output->Resize({1}); + output_data = output->mutable_data(ctx.GetPlace()); + output_data[0] = -1; + } } }; diff --git a/python/paddle/v2/fluid/tests/test_ctc_align.py b/python/paddle/v2/fluid/tests/test_ctc_align.py index 773c69d1ad0..cc815d8e9e1 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_align.py +++ b/python/paddle/v2/fluid/tests/test_ctc_align.py @@ -31,6 +31,8 @@ def CTCAlign(input, lod, blank, merge_repeated): result.append(token) prev_token = token result = np.array(result).reshape([len(result), 1]).astype("int32") + if len(result) == 0: + result = np.array([-1]) return result @@ -72,5 +74,14 @@ class TestCTCAlignOpCase1(TestCTCAlignOp): [19, 1]).astype("int32") +class TestCTCAlignOpCase2(TestCTCAlignOp): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[0, 4]] + self.blank = 0 + self.merge_repeated = True + self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32") + + if __name__ == "__main__": unittest.main() -- GitLab