From e2e0fbd4188fcbcc6bf69d1ef22b3f6f0a927f84 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 26 Jun 2017 10:36:49 -0700 Subject: [PATCH] Add tesnor.h --- paddle/framework/tensor.h | 91 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 paddle/framework/tensor.h diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h new file mode 100644 index 00000000000..a658537430e --- /dev/null +++ b/paddle/framework/tensor.h @@ -0,0 +1,91 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +namespace paddle { +namespace framework { + +class Tensor { + using paddle::platform::Place; + using paddle::platform::get_place; + + public: + explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {} + explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {} + + template + const T* data() const { + PADDLE_ASSERT(holder_ != nullptr); + PADDLE_ASSERT(holder_->Place() == place_); + PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T)); + return static_cast(holder->Ptr()); + } + + template ::value>::type> + T* mutable_data() { + if (holder_ == nullptr || holder_->Place() != place_ || + holder_->Size() < dims_.product() * sizeof(T)) { + holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T))); + } + return static_cast(holder_->Ptr()); + } + + template ::value>::type> + T* mutable_data(DDim dims) { + dims_ = dims; + return mutable_data(); + } + + template ::value>::type> + T* mutable_data(DDim dims, Place place) { + dims_ = dims; + place_ = place; + return mutable_data(); + } + + private: + // Placeholder hides type T, so it doesn't appear as a template + // parameter of Variable. + struct Placeholder { + virtual ~Placeholder() {} + virtual void* Ptr() const = 0; + virtual Place Place() const = 0; + virtual size_t Size() const = 0; + }; + + template + struct PlaceholderImpl : public Placeholder { + PlaceholderImpl(Place pl, size_t size) + : ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)), + place_(pl), + size_(size) {} + + virtual void* Ptr() const { return static_cast(ptr_.get()); } + virtual size_t Size() const { return size_; } + virtual Place Place() const { return place_; } + + std::unique_ptr ptr_; + Place place_; // record the place of ptr_. + size_t size_; // size of the memory block. + }; + + std::unique_ptr holder_; // holds the memory block if allocated. + DDim dims_; // could be smallers than the holder_->Size(). + paddle::platform::Place place_; +}; + +} // namespace framework +} // namespace paddle -- GitLab