提交 28b240bb 编写于 作者: Y Yang Yang

delete todo in MergeLoDTensor

上级 208f950c
...@@ -262,24 +262,38 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -262,24 +262,38 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
return lods; return lods;
} }
// TODO(tonyyang-svail): make this function support LoD
void LoDTensor::MergeLoDTensor( void LoDTensor::MergeLoDTensor(
const std::vector<const LoDTensor *> &lod_tensors, const std::vector<const LoDTensor *> &lod_tensors,
platform::Place dst_place) { platform::Place dst_place) {
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type(); std::type_index new_type = lod_tensors[0]->type();
auto new_layout = lod_tensors[0]->layout(); framework::DataLayout new_layout = lod_tensors[0]->layout();
for (auto *lod : lod_tensors) { LoD new_lod = lod_tensors[0]->lod();
PADDLE_ENFORCE(new_dim == lod->dims()); for (size_t i = 1; i < lod_tensors.size(); ++i) {
PADDLE_ENFORCE(new_type == lod->type()); auto *t = lod_tensors[i];
PADDLE_ENFORCE(new_layout == lod->layout()); PADDLE_ENFORCE_EQ(new_type.hash_code(), t->type().hash_code());
PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
framework::product(t->dims()) / t->dims()[0]);
new_dim[0] += t->dims()[0];
auto &lod = t->lod();
for (size_t j = 0; j < lod.size(); ++j) {
auto &sub_lod = new_lod[j];
auto &offset = sub_lod.back();
for (size_t k = 1; k < lod[j].size(); ++k) {
sub_lod.push_back(lod[j][k] + offset);
}
}
} }
new_dim[0] *= lod_tensors.size();
Resize(new_dim); Resize(new_dim);
set_layout(new_layout); set_layout(new_layout);
set_lod(new_lod);
mutable_data(dst_place, new_type); mutable_data(dst_place, new_type);
int begin = 0; int begin = 0;
for (auto *src : lod_tensors) { for (auto *src : lod_tensors) {
int end = begin + src->dims()[0]; int end = begin + src->dims()[0];
......
...@@ -159,5 +159,42 @@ TEST(LoD, SplitLoDTensor) { ...@@ -159,5 +159,42 @@ TEST(LoD, SplitLoDTensor) {
EXPECT_EQ(lods[1].lod(), lod1); EXPECT_EQ(lods[1].lod(), lod1);
} }
TEST(LoD, MergeLoDTensor) {
LoD lod;
lod.push_back(std::vector<size_t>({0, 2, 4, 5, 6}));
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 13, 15, 20}));
platform::CPUPlace place;
LoDTensor lod_tensor0;
LoD lod0;
lod0.push_back(std::vector<size_t>({0, 2, 4}));
lod0.push_back(std::vector<size_t>({0, 1, 6, 8, 13}));
lod_tensor0.set_lod(lod0);
lod_tensor0.Resize({13, 1});
float* dst_ptr = lod_tensor0.mutable_data<float>(place);
for (int i = 0; i < lod_tensor0.numel(); ++i) {
dst_ptr[i] = i;
}
LoDTensor lod_tensor1;
LoD lod1;
lod1.push_back(std::vector<size_t>({0, 1, 2}));
lod1.push_back(std::vector<size_t>({0, 2, 7}));
lod_tensor1.set_lod(lod1);
lod_tensor1.Resize({7, 1});
dst_ptr = lod_tensor1.mutable_data<float>(place);
for (int i = 0; i < lod_tensor1.numel(); ++i) {
dst_ptr[i] = i;
}
std::vector<const LoDTensor*> lods{&lod_tensor0, &lod_tensor1};
LoDTensor lod_tensor;
lod_tensor.MergeLoDTensor(lods, place);
EXPECT_EQ(lod_tensor.lod(), lod);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册