diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index c6715227cf90aabe3a78058d6a98a3e332fd40dd..887d28f5875e366503e67ac21f78f846f6e21a1a 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel { std::vector({"selected_ids", "selected_scores"})) { OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach"); } + auto id_dims = ctx->GetInputDim("pre_ids"); + ctx->SetOutputDim("selected_scores", ctx->GetInputDim("pre_scores")); + ctx->SetOutputDim("selected_ids", id_dims); + ctx->SetOutputDim("parent_idx", {id_dims[0]}); } protected: diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_op.py index faf085eb6ca5d5896e9e11360e9f767a316edce3..346cd1e21291887630d79f948b1db28d76a4e2ac 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_op.py @@ -38,9 +38,9 @@ class BeamSearchOpTester(unittest.TestCase): self._create_pre_scores() self._create_scores() self._create_pre_ids() - self.scope.var('selected_ids') - self.scope.var('selected_scores') - self.scope.var('parent_idx') + self.scope.var('selected_ids').get_tensor() + self.scope.var('selected_scores').get_tensor() + self.scope.var('parent_idx').get_tensor() def test_run(self): op = Operator(