From 898acb1a132540e1a3d4e4954861df6bcc081a1a Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 9 Aug 2021 11:11:33 +0800 Subject: [PATCH] fix split on empty tensor (#34356) --- paddle/fluid/operators/math/concat_and_split.cc | 6 ++++++ paddle/fluid/operators/math/concat_and_split.cu | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index 6c1ee863737..83b4e89fe04 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -83,6 +83,12 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, const int axis, std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + // TODO(zcd): Add input data validity checking size_t num = outputs->size(); diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index f9cce061383..b9481f1c8e4 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -352,6 +352,12 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, int axis, std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + // TODO(zcd): Add input data validity checking int o_num = outputs->size(); int64_t out_row = 1; -- GitLab