From dd85805131f5a48dc09857c52d9c5bbe730385c7 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Sun, 28 Jun 2020 19:53:10 +0800 Subject: [PATCH] Fix beam_search InferShape (#25169) (#25216) * fix beam_search infershape, test=develop * fix beam search op unittest, test=develop --- paddle/fluid/operators/beam_search_op.cc | 4 ++++ python/paddle/fluid/tests/unittests/test_beam_search_op.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index c6715227cf..887d28f587 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel { std::vector({"selected_ids", "selected_scores"})) { OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach"); } + auto id_dims = ctx->GetInputDim("pre_ids"); + ctx->SetOutputDim("selected_scores", ctx->GetInputDim("pre_scores")); + ctx->SetOutputDim("selected_ids", id_dims); + ctx->SetOutputDim("parent_idx", {id_dims[0]}); } protected: diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_op.py index faf085eb6c..346cd1e212 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_op.py @@ -38,9 +38,9 @@ class BeamSearchOpTester(unittest.TestCase): self._create_pre_scores() self._create_scores() self._create_pre_ids() - self.scope.var('selected_ids') - self.scope.var('selected_scores') - self.scope.var('parent_idx') + self.scope.var('selected_ids').get_tensor() + self.scope.var('selected_scores').get_tensor() + self.scope.var('parent_idx').get_tensor() def test_run(self): op = Operator( -- GitLab