未验证 提交 fb1a0dfb 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #7943 from pkuyym/fix-7939

Bug fix for sequence_reshape operator.
...@@ -30,8 +30,13 @@ class SequenceReshapeOp : public framework::OperatorWithKernel { ...@@ -30,8 +30,13 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
auto x_numel = product(x_dims); auto x_numel = product(x_dims);
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
int new_dim = ctx->Attrs().Get<int>("new_dim"); int new_dim = ctx->Attrs().Get<int>("new_dim");
if (ctx->IsRuntime()) {
ctx->SetOutputDim("Out", ctx->SetOutputDim("Out",
{x_numel / new_dim, static_cast<int64_t>(new_dim)}); {x_numel / new_dim, static_cast<int64_t>(new_dim)});
} else {
// when compiling, the batch size is undetermined, just set to -1
ctx->SetOutputDim("Out", {-1, static_cast<int64_t>(new_dim)});
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册