diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index 6c1ee863737011e7cc9561b18b974d1dc280fb73..83b4e89fe046f4c30528c521577406d12d0b07ec 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -83,6 +83,12 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, const int axis, std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + // TODO(zcd): Add input data validity checking size_t num = outputs->size(); diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index f9cce061383939b0819f2af788072286bdeaf09f..b9481f1c8e40e2bffe75eb7ec6f8e9410b98be02 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -352,6 +352,12 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, int axis, std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + // TODO(zcd): Add input data validity checking int o_num = outputs->size(); int64_t out_row = 1;