From a402d2b39257ae58345998ed5edd6b87b09e9a1b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 01:22:13 -0800 Subject: [PATCH] "fix condition" --- paddle/framework/lod_tensor.h | 2 +- paddle/framework/selected_rows.h | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index ab28924161..3465e02c82 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -132,7 +132,7 @@ class LoDTensor : public Tensor { void set_lod(const LoD& lod) { lod_ = lod; if (holder_ != nullptr && - platform::is_same_place(holder_->place(), lod.place())) { + !platform::is_same_place(holder_->place(), lod.place())) { lod_.CopyToPeer(holder_->place()); } } diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 30d3dfc1e8..1132344244 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -42,7 +42,13 @@ class SelectedRows { Vector* mutable_rows() { return &rows_; } - void set_rows(const Vector& rows) { rows_ = rows; } + void set_rows(const Vector& rows) { + rows_ = rows; + if (value_ != nullptr && + !platform::is_same_place(value_->place(), rows.place())) { + rows_.mutable_data(value_->place()); + } + } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); -- GitLab