未验证 提交 68e93d8a 编写于 作者: L liu zhengxi 提交者: GitHub

Fix beam_search InferShape (#25169)

* fix beam_search infershape, test=develop

* fix beam search op unittest, test=develop
上级 cd4d9122
...@@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
std::vector<std::string>({"selected_ids", "selected_scores"})) { std::vector<std::string>({"selected_ids", "selected_scores"})) {
OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach"); OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach");
} }
auto id_dims = ctx->GetInputDim("pre_ids");
ctx->SetOutputDim("selected_scores", ctx->GetInputDim("pre_scores"));
ctx->SetOutputDim("selected_ids", id_dims);
ctx->SetOutputDim("parent_idx", {id_dims[0]});
} }
protected: protected:
......
...@@ -38,9 +38,9 @@ class BeamSearchOpTester(unittest.TestCase): ...@@ -38,9 +38,9 @@ class BeamSearchOpTester(unittest.TestCase):
self._create_pre_scores() self._create_pre_scores()
self._create_scores() self._create_scores()
self._create_pre_ids() self._create_pre_ids()
self.scope.var('selected_ids') self.scope.var('selected_ids').get_tensor()
self.scope.var('selected_scores') self.scope.var('selected_scores').get_tensor()
self.scope.var('parent_idx') self.scope.var('parent_idx').get_tensor()
def test_run(self): def test_run(self):
op = Operator( op = Operator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册