未验证 提交 898acb1a 编写于 作者: L Leo Chen 提交者: GitHub

fix split on empty tensor (#34356)

上级 0dff82c2
......@@ -83,6 +83,12 @@ class SplitFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if (input.numel() == 0) {
return;
}
// TODO(zcd): Add input data validity checking
size_t num = outputs->size();
......
......@@ -352,6 +352,12 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
int axis, std::vector<framework::Tensor*>* outputs) {
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if (input.numel() == 0) {
return;
}
// TODO(zcd): Add input data validity checking
int o_num = outputs->size();
int64_t out_row = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册