diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index a0945e8055625ca4c21ea1c3fa9f27321ca9ba3c..7f3894bb3c1e42a213dd8c8afc174e653f9beed1 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/framework/ddim.h" @@ -44,11 +45,17 @@ class Tensor { typename std::enable_if::value>::type* = nullptr> T* mutable_data(DDim dims, paddle::platform::Place place) { dims_ = dims; + return mutable_data(place); + } + + template ::value>::type* = nullptr> + T* mutable_data(paddle::platform::Place place) { if (holder_ == nullptr || !(holder_->Place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->Size() < product(dims) * sizeof(T) + offset_) { - holder_.reset(new PlaceholderImpl(place, product(dims) * sizeof(T))); + || holder_->Size() < product(dims_) * sizeof(T) + offset_) { + holder_.reset(new PlaceholderImpl(place, product(dims_) * sizeof(T))); offset_ = 0; } return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + @@ -63,6 +70,15 @@ class Tensor { offset_ = src.offset_; } + void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) { + PADDLE_ENFORCE(src.holder_ != nullptr, + "Can not copy from an uninitialized tensor."); + size_t size = product(src.dims()) * src.holder_->TypeSize(); + holder_.reset(src.holder_->Clone(src.offset_, size, dst_place)); + dims_ = src.dims(); + offset_ = 0; + } + Tensor Slice(const int& begin_idx, const int& end_idx) const { PADDLE_ENFORCE(holder_ != nullptr, "The sliced tenosr has not been initialized."); @@ -95,6 +111,8 @@ class Tensor { virtual paddle::platform::Place Place() const = 0; virtual size_t Size() const = 0; virtual size_t TypeSize() const = 0; + virtual Placeholder* Clone(size_t begin, size_t size, + paddle::platform::Place place) const = 0; }; template @@ -122,6 +140,18 @@ class Tensor { virtual size_t Size() const { return size_; } virtual paddle::platform::Place Place() const { return place_; } virtual size_t TypeSize() const { return sizeof(T); } + // TODO: Clone only support CPU now. GPU support is needed. + virtual Placeholder* Clone(size_t begin, size_t size, + paddle::platform::Place place) const { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(place_) && + paddle::platform::is_cpu_place(place), + "PlaceholderImpl::Clone only support CPU now."); + PlaceholderImpl* dst = new PlaceholderImpl(place, size); + void* begin_ptr = + reinterpret_cast(reinterpret_cast(Ptr()) + begin); + memcpy(dst->Ptr(), begin_ptr, size); + return dst; + } std::unique_ptr ptr_; paddle::platform::Place place_; // record the place of ptr_. diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index f4822838cfbd27656232a23b14f716f2fbe510e0..6db0ba8c798ee6ea503e6305213822408807b1e8 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -178,4 +178,29 @@ TEST(Tensor, Slice) { } } +TEST(Tensor, CopyFrom) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor src_tensor; + int* src_ptr = src_tensor.mutable_data(make_ddim({3, 3}), CPUPlace()); + int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + memcpy(src_ptr, arr, 9 * sizeof(int)); + Tensor dst_tensor; + dst_tensor.CopyFrom(src_tensor, CPUPlace()); + const int* dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + + Tensor slice_tensor = src_tensor.Slice(1, 2); + dst_tensor.CopyFrom(slice_tensor, CPUPlace()); + const int* slice_ptr = slice_tensor.data(); + dst_ptr = dst_tensor.data(); + ASSERT_NE(dst_ptr, slice_ptr); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dst_ptr[i], slice_ptr[i]); + } +} */ \ No newline at end of file