未验证 提交 d5e7d20d 编写于 作者: J jakpiase 提交者: GitHub

minor split optimization (#47314)

上级 0d04bfe1
...@@ -69,10 +69,8 @@ void SplitWithNumKernel(const Context& dev_ctx, ...@@ -69,10 +69,8 @@ void SplitWithNumKernel(const Context& dev_ctx,
std::vector<DenseTensor*> outs) { std::vector<DenseTensor*> outs) {
int axis_value = axis_scalar.to<int>(); int axis_value = axis_scalar.to<int>();
auto input_axis_dim = x.dims().at(axis_value); auto input_axis_dim = x.dims().at(axis_value);
std::vector<int64_t> sections_vec; const std::vector<int64_t> sections_vec(num, input_axis_dim / num);
for (int i = 0; i < num; ++i) {
sections_vec.push_back(input_axis_dim / num);
}
IntArray sections(sections_vec); IntArray sections(sections_vec);
SplitKernel<T, Context>(dev_ctx, x, sections, axis_scalar, outs); SplitKernel<T, Context>(dev_ctx, x, sections, axis_scalar, outs);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册