未验证 提交 8a7daeea 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #16871 from jacquesqiao/fix-shape

fix split_byref_op infer shape
...@@ -31,14 +31,16 @@ class SplitByrefOp : public framework::OperatorWithKernel { ...@@ -31,14 +31,16 @@ class SplitByrefOp : public framework::OperatorWithKernel {
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out"); auto outs_names = ctx->Outputs("Out");
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num")); size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
std::vector<int> sections = static_cast<std::vector<int>>( auto sections = ctx->Attrs().Get<std::vector<int>>("sections");
ctx->Attrs().Get<std::vector<int>>("sections"));
const size_t outs_number = outs_names.size(); const size_t outs_number = outs_names.size();
std::vector<framework::DDim> outs_dims; std::vector<framework::DDim> outs_dims;
outs_dims.reserve(outs_number); outs_dims.reserve(outs_number);
if (num > 0) { if (num > 0) {
int64_t in_axis_dim = in_dims[0]; int64_t in_axis_dim = 0;
if (ctx->IsRuntime()) {
in_axis_dim = in_dims[0];
}
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
"tensor split does not result" "tensor split does not result"
" in an equal division"); " in an equal division");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册