提交 709c157a 编写于 作者: D dzhwinter

"fix ci"

上级 17b1c369
...@@ -129,13 +129,7 @@ class LoDTensor : public Tensor { ...@@ -129,13 +129,7 @@ class LoDTensor : public Tensor {
explicit LoDTensor(const LoD& lod) : lod_(lod) {} explicit LoDTensor(const LoD& lod) : lod_(lod) {}
void set_lod(const LoD& lod) { void set_lod(const LoD& lod) { lod_ = lod; }
lod_ = lod;
if (holder_ != nullptr &&
!platform::is_same_place(holder_->place(), lod.place())) {
lod_.CopyToPeer(holder_->place());
}
}
const LoD& lod() const { return lod_; } const LoD& lod() const { return lod_; }
......
...@@ -42,13 +42,7 @@ class SelectedRows { ...@@ -42,13 +42,7 @@ class SelectedRows {
Vector<int64_t>* mutable_rows() { return &rows_; } Vector<int64_t>* mutable_rows() { return &rows_; }
void set_rows(const Vector<int64_t>& rows) { void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
rows_ = rows;
if (value_ != nullptr &&
!platform::is_same_place(value_->place(), rows.place())) {
rows_.mutable_data(value_->place());
}
}
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims()); std::vector<int64_t> dims = vectorize(value_->dims());
......
...@@ -76,21 +76,26 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -76,21 +76,26 @@ inline void CopyOrShare(const framework::Variable &src,
if (src.IsType<LoDTensor>()) { if (src.IsType<LoDTensor>()) {
if (src.Get<LoDTensor>().place() == dst_place) { if (src.Get<LoDTensor>().place() == dst_place) {
dst->GetMutable<LoDTensor>()->ShareDataWith(src.Get<LoDTensor>()); dst->GetMutable<LoDTensor>()->ShareDataWith(src.Get<LoDTensor>());
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else { } else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>()); Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
LoD lod(src.Get<LoDTensor>().lod());
lod.CopyToPeer(dst_place);
dst->GetMutable<LoDTensor>()->set_lod(lod);
} }
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else if (src.IsType<SelectedRows>()) { } else if (src.IsType<SelectedRows>()) {
auto &src_sr = src.Get<SelectedRows>(); auto &src_sr = src.Get<SelectedRows>();
auto *dst_sr = dst->GetMutable<SelectedRows>(); auto *dst_sr = dst->GetMutable<SelectedRows>();
dst_sr->set_rows(src_sr.rows());
dst_sr->set_height(src_sr.height()); dst_sr->set_height(src_sr.height());
if (src_sr.value().place() == dst_place) { if (src_sr.value().place() == dst_place) {
dst_sr->mutable_value()->ShareDataWith(src_sr.value()); dst_sr->mutable_value()->ShareDataWith(src_sr.value());
dst_sr->set_rows(src_sr.rows());
} else { } else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); Copy(src_sr.value(), dst_place, dst_sr->mutable_value());
LoD lod(src.Get<LoDTensor>().lod());
lod.CopyToPeer(dst_place);
dst_sr->set_rows(lod);
} }
dst_sr->set_rows(src_sr.rows());
} else { } else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册