diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index ab289241610bbcc72461b3fb7ea68ff4c58c8000..3465e02c8262ffd626298f2672679732396c1c4f 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -132,7 +132,7 @@ class LoDTensor : public Tensor { void set_lod(const LoD& lod) { lod_ = lod; if (holder_ != nullptr && - platform::is_same_place(holder_->place(), lod.place())) { + !platform::is_same_place(holder_->place(), lod.place())) { lod_.CopyToPeer(holder_->place()); } } diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 30d3dfc1e89f073a8180ceacf77619b36f7079a9..113234424467c97e2069a884a2d2bb43f2c8e657 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -42,7 +42,13 @@ class SelectedRows { Vector* mutable_rows() { return &rows_; } - void set_rows(const Vector& rows) { rows_ = rows; } + void set_rows(const Vector& rows) { + rows_ = rows; + if (value_ != nullptr && + !platform::is_same_place(value_->place(), rows.place())) { + rows_.mutable_data(value_->place()); + } + } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims());