未验证 提交 c96ee47d 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16797 from phlrain/fix_split

Fix split
...@@ -39,14 +39,22 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -39,14 +39,22 @@ class SplitOp : public framework::OperatorWithKernel {
if (num > 0) { if (num > 0) {
int64_t in_axis_dim = in_dims[axis]; int64_t in_axis_dim = in_dims[axis];
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, if (ctx->IsRuntime() || in_axis_dim > 0) {
"tensor split does not result" PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
" in an equal division"); "tensor split does not result"
size_t out_axis_dim = in_axis_dim / num; " in an equal division");
for (size_t i = 0; i < outs_number; ++i) { size_t out_axis_dim = in_axis_dim / num;
auto dim = in_dims; for (size_t i = 0; i < outs_number; ++i) {
dim[axis] = out_axis_dim; auto dim = in_dims;
outs_dims.push_back(dim); dim[axis] = out_axis_dim;
outs_dims.push_back(dim);
}
} else {
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = -1;
outs_dims.push_back(dim);
}
} }
} else if (sections.size() > 0) { } else if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(sections.size(), outs_number, PADDLE_ENFORCE_EQ(sections.size(), outs_number,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册