提交 578a2f5d 编写于 作者: L Leo Chen 提交者: Zeng Jinle

fix SplitLodTensor when batch_size = 0, test=develop (#19866)

上级 b125e327
......@@ -283,6 +283,21 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
std::vector<LoDTensor> results;
results.reserve(result_size);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// tensors.
if (result_size == 0) {
for (size_t i = 0; i < places.size(); ++i) {
LoDTensor dst;
dst.Resize(dims());
dst.mutable_data(places[i], type());
if (!lod().empty()) {
dst.set_lod(lod());
}
results.emplace_back(dst);
}
return results;
}
int step_width = static_cast<int>(batch_size / result_size);
for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width);
......
......@@ -155,6 +155,26 @@ TEST(LoD, SplitLoDTensor) {
EXPECT_EQ(lods[1].lod(), lod1);
}
TEST(LoD, SplitLoDTensorWithZeroBatchSize) {
LoD lod;
lod.push_back(std::vector<size_t>({0}));
platform::CPUPlace place;
LoDTensor lod_tensor;
lod_tensor.Resize({0, 5});
lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(lod);
std::vector<platform::Place> places{platform::CPUPlace(),
platform::CPUPlace()};
LoD lod_res;
lod_res.push_back(std::vector<size_t>({0}));
auto lods = lod_tensor.SplitLoDTensor(places);
EXPECT_EQ(lods[0].lod(), lod_res);
EXPECT_EQ(lods[1].lod(), lod_res);
}
TEST(LoD, MergeLoDTensor) {
LoD lod;
lod.push_back(std::vector<size_t>({0, 2, 4, 5, 6}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册