提交 f12b3f36 编写于 作者: Y Yancey1989

use memcpy

上级 13e7194e
...@@ -39,11 +39,10 @@ struct ReAllocateVisitor { ...@@ -39,11 +39,10 @@ struct ReAllocateVisitor {
}; };
struct TensorCopyVisitor { struct TensorCopyVisitor {
TensorCopyVisitor(const platform::Place& place, framework::Tensor* dst, TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset,
int64_t dst_offset, const framework::Tensor src, const framework::Tensor src, int64_t src_offset,
int64_t src_offset, int64_t size) int64_t size)
: place_(place), : dst_(dst),
dst_(dst),
dst_offset_(dst_offset), dst_offset_(dst_offset),
src_(src), src_(src),
src_offset_(src_offset), src_offset_(src_offset),
...@@ -51,12 +50,12 @@ struct TensorCopyVisitor { ...@@ -51,12 +50,12 @@ struct TensorCopyVisitor {
template <typename T> template <typename T>
void operator()() const { void operator()() const {
std::copy(src_.data<T>() + src_offset_, // TODO(Yancey1989): support other place
src_.data<T>() + src_offset_ + size_, platform::CPUPlace cpu;
dst_->mutable_data<T>(place_) + dst_offset_); memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
src_.data<T>() + src_offset_, size_ * sizeof(T));
} }
platform::Place place_;
framework::Tensor* dst_; framework::Tensor* dst_;
int64_t dst_offset_; int64_t dst_offset_;
framework::Tensor src_; framework::Tensor src_;
...@@ -125,16 +124,12 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys, ...@@ -125,16 +124,12 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
framework::Tensor* value) const { framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
std::vector<int64_t> non_keys; std::vector<int64_t> non_keys;
int64_t value_width = value_->numel() / value_->dims()[0]; int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table " "output tensor should have the same shape with table "
"execpt the dims[0]."); "execpt the dims[0].");
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]); int64_t index = Index(keys[i]);
if (index == -1) { if (index == -1) {
...@@ -142,7 +137,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys, ...@@ -142,7 +137,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), framework::ToDataType(value_->type()),
TensorCopyVisitor(cpu, value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
} }
...@@ -159,7 +154,6 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { ...@@ -159,7 +154,6 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1), PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
"The first dim of value should be 1."); "The first dim of value should be 1.");
auto index = Index(key); auto index = Index(key);
platform::Place cpu = platform::CPUPlace();
bool is_new_key = false; bool is_new_key = false;
if (index == -1) { if (index == -1) {
rows_.push_back(key); rows_.push_back(key);
...@@ -176,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { ...@@ -176,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value.type()), framework::ToDataType(value.type()),
TensorCopyVisitor(cpu, value_.get(), TensorCopyVisitor(value_.get(),
index * value_->numel() / value_->dims()[0], value, index * value_->numel() / value_->dims()[0], value,
static_cast<int64_t>(0), value.numel())); static_cast<int64_t>(0), value.numel()));
return is_new_key; return is_new_key;
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/memcpy.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册