diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 2eefe79588f32658b1f0eab317528a31142d7ce5..794e7f743413b068119afd5df232bfc2bb91a8c7 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -39,11 +39,10 @@ struct ReAllocateVisitor { }; struct TensorCopyVisitor { - TensorCopyVisitor(const platform::Place& place, framework::Tensor* dst, - int64_t dst_offset, const framework::Tensor src, - int64_t src_offset, int64_t size) - : place_(place), - dst_(dst), + TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset, + const framework::Tensor src, int64_t src_offset, + int64_t size) + : dst_(dst), dst_offset_(dst_offset), src_(src), src_offset_(src_offset), @@ -51,12 +50,12 @@ struct TensorCopyVisitor { template void operator()() const { - std::copy(src_.data() + src_offset_, - src_.data() + src_offset_ + size_, - dst_->mutable_data(place_) + dst_offset_); + // TODO(Yancey1989): support other place + platform::CPUPlace cpu; + memory::Copy(cpu, dst_->mutable_data(cpu) + dst_offset_, cpu, + src_.data() + src_offset_, size_ * sizeof(T)); } - platform::Place place_; framework::Tensor* dst_; int64_t dst_offset_; framework::Tensor src_; @@ -125,16 +124,12 @@ std::vector SelectedRows::Get(std::vector keys, framework::Tensor* value) const { PADDLE_ENFORCE(value->IsInitialized(), "The value tensor should be initialized."); - std::vector non_keys; int64_t value_width = value_->numel() / value_->dims()[0]; PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], "output tensor should have the same shape with table " "execpt the dims[0]."); - // TODO(Yancey1989): support other place - platform::CPUPlace cpu; - for (size_t i = 0; i < keys.size(); ++i) { int64_t index = Index(keys[i]); if (index == -1) { @@ -142,7 +137,7 @@ std::vector SelectedRows::Get(std::vector keys, } else { framework::VisitDataType( framework::ToDataType(value_->type()), - TensorCopyVisitor(cpu, value, i * value_width, *value_.get(), + TensorCopyVisitor(value, i * value_width, *value_.get(), index * value_width, value_width)); } } @@ -159,7 +154,6 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), "The first dim of value should be 1."); auto index = Index(key); - platform::Place cpu = platform::CPUPlace(); bool is_new_key = false; if (index == -1) { rows_.push_back(key); @@ -176,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { framework::VisitDataType( framework::ToDataType(value.type()), - TensorCopyVisitor(cpu, value_.get(), + TensorCopyVisitor(value_.get(), index * value_->numel() / value_->dims()[0], value, static_cast(0), value.numel())); return is_new_key; diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index cef3ddab475eee45cb1eec9064ae75cf9ba4ab53..d6c9507b1681855e759a4b1b9d3dddf6fcb2fc13 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/memcpy.h" namespace paddle { namespace framework {