未验证 提交 dd858051 编写于 作者: L liu zhengxi 提交者: GitHub

Fix beam_search InferShape (#25169) (#25216)

* fix beam_search infershape, test=develop

* fix beam search op unittest, test=develop
上级 772746c0
......@@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
std::vector<std::string>({"selected_ids", "selected_scores"})) {
OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach");
}
auto id_dims = ctx->GetInputDim("pre_ids");
ctx->SetOutputDim("selected_scores", ctx->GetInputDim("pre_scores"));
ctx->SetOutputDim("selected_ids", id_dims);
ctx->SetOutputDim("parent_idx", {id_dims[0]});
}
protected:
......
......@@ -38,9 +38,9 @@ class BeamSearchOpTester(unittest.TestCase):
self._create_pre_scores()
self._create_scores()
self._create_pre_ids()
self.scope.var('selected_ids')
self.scope.var('selected_scores')
self.scope.var('parent_idx')
self.scope.var('selected_ids').get_tensor()
self.scope.var('selected_scores').get_tensor()
self.scope.var('parent_idx').get_tensor()
def test_run(self):
op = Operator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册