mixed_vector.h 5.9 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include <initializer_list>
#include <vector>

#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace framework {

/**
 * @brief Vector support both cpu and gpu.
 * host vector lifetime is same with Vector
 * device vector is lazily malloc and modified.
 */

template <typename T>
class Vector : public std::vector<T> {
 public:
  using std::vector<T>::vector;

  Vector() {}
  Vector(const std::vector<T> &v) : std::vector<T>(v) {}  // NOLINT

D
dzhwinter 已提交
43
  inline platform::Place place() const { return place_; }
D
dzhwinter 已提交
44

D
dzhwinter 已提交
45 46 47 48 49 50 51
  /*! Return a pointer to constant memory block. */
  inline const T *data(platform::Place place) const;

  /*! Return a pointer to mutable memory block. */
  inline T *mutable_data(platform::Place place);

  // TODO(dzhwinter): below interfaces should be removed
D
dzhwinter 已提交
52
  /* Get device vector */
D
dzhwinter 已提交
53 54 55 56
  T *cuda_data() {
    CopyToCUDA();
    PADDLE_ENFORCE_NOT_NULL(
        cuda_ptr_, "No data or Insufficient CUDA memory to allocation");
D
dzhwinter 已提交
57
    return static_cast<T *>(cuda_ptr_.get());
D
dzhwinter 已提交
58 59
  }

D
dzhwinter 已提交
60
  /* Get host vector */
D
dzhwinter 已提交
61 62 63
  T *data() { return std::vector<T>::data(); }
  const T *data() const { return std::vector<T>::data(); }

64 65 66 67 68 69 70 71
  T *data(const platform::Place &place) {
    if (platform::is_cpu_place(place)) {
      return data();
    } else {
      return cuda_data();
    }
  }

D
dzhwinter 已提交
72
  /* Synchronize host vector to device vector */
D
dzhwinter 已提交
73
  void CopyToCUDA();
D
dzhwinter 已提交
74
  /* Synchronize device vector to host vector */
D
dzhwinter 已提交
75
  void CopyFromCUDA();
D
dzhwinter 已提交
76
  /* Switch device vector location */
D
dzhwinter 已提交
77 78 79
  void CopyToPeer(platform::Place);

 private:
D
dzhwinter 已提交
80
  std::shared_ptr<void> cuda_ptr_;
D
dzhwinter 已提交
81
  size_t cuda_size_ = 0;  // device vector numel
D
dzhwinter 已提交
82 83 84 85
  platform::CUDAPlace place_;
};

template <typename T>
D
dzhwinter 已提交
86 87 88 89 90 91 92
inline const T *Vector<T>::data(platform::Place place) const {
  if (platform::is_cpu_place(place)) {
    return std::vector<T>::data();
  } else if (platform::is_gpu_place(place)) {
    if (cuda_ptr_ == nullptr) {
      return nullptr;
    }
D
dzhwinter 已提交
93
    if (boost::get<platform::CUDAPlace>(place) == place_) {
D
dzhwinter 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
      return static_cast<const T *>(cuda_ptr_.get());
    } else {
      PADDLE_THROW(
          "Unmatched place. Please use `mutable_data` copy lod to the target "
          "Place first.");
    }
  } else {
    PADDLE_THROW("Unsupport Place.");
  }
}

template <typename T>
inline T *Vector<T>::mutable_data(platform::Place place) {
  if (platform::is_cpu_place(place)) {
    return std::vector<T>::data();
  } else if (platform::is_gpu_place(place)) {
D
dzhwinter 已提交
110
    if (boost::get<platform::CUDAPlace>(place) != place_) {
D
dzhwinter 已提交
111 112
      place_ = boost::get<platform::CUDAPlace>(place);
    }
D
dzhwinter 已提交
113
#ifdef PADDLE_WITH_CUDA
D
dzhwinter 已提交
114 115 116 117
    if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
      cuda_ptr_.reset(
          memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
          memory::PlainDeleter<void, platform::CUDAPlace>(place_));
D
dzhwinter 已提交
118
    }
D
dzhwinter 已提交
119 120 121 122 123 124 125 126
    cuda_size_ = this->size();
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto *ctx = pool.GetByPlace(place_);
    memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
                 static_cast<const void *>(this->data()),
                 this->size() * sizeof(T), ctx->stream());
    ctx->Wait();
    return static_cast<T *>(cuda_ptr_.get());
D
dzhwinter 已提交
127 128
#else
    return nullptr;
D
dzhwinter 已提交
129 130 131 132 133 134 135 136 137 138 139
#endif
  } else {
    PADDLE_THROW("Unsupport Place.");
  }
}

template <typename T>
void Vector<T>::CopyToCUDA() {
#ifdef PADDLE_WITH_CUDA
  if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
    cuda_ptr_.reset(
D
dzhwinter 已提交
140
        memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
D
dzhwinter 已提交
141
        memory::PlainDeleter<void, platform::CUDAPlace>(place_));
D
dzhwinter 已提交
142
  }
D
dzhwinter 已提交
143
  cuda_size_ = this->size();
D
dzhwinter 已提交
144
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
D
dzhwinter 已提交
145
  auto *ctx = pool.GetByPlace(place_);
D
dzhwinter 已提交
146
  memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
D
dzhwinter 已提交
147
               static_cast<const void *>(this->data()),
D
dzhwinter 已提交
148 149
               this->size() * sizeof(T), ctx->stream());
  ctx->Wait();
D
dzhwinter 已提交
150 151 152 153 154 155 156
#endif
}

template <typename T>
void Vector<T>::CopyFromCUDA() {
#ifdef PADDLE_WITH_CUDA
  if (cuda_ptr_ == nullptr) {
D
dzhwinter 已提交
157
    LOG(WARNING) << "No uncommitted cuda data.";
D
dzhwinter 已提交
158 159 160
    return;
  }
  this->resize(cuda_size_);
D
dzhwinter 已提交
161 162
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto *ctx = pool.GetByPlace(place_);
D
dzhwinter 已提交
163
  memory::Copy(platform::CPUPlace(), static_cast<void *>(this->data()), place_,
D
dzhwinter 已提交
164 165
               static_cast<const void *>(cuda_ptr_.get()),
               this->size() * sizeof(T), ctx->stream());
D
dzhwinter 已提交
166
  ctx->Wait();
D
dzhwinter 已提交
167 168 169 170
#endif
}

template <typename T>
D
dzhwinter 已提交
171
void Vector<T>::CopyToPeer(platform::Place place) {
D
dzhwinter 已提交
172
#ifdef PADDLE_WITH_CUDA
D
dzhwinter 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186
  if (boost::get<platform::CUDAPlace>(place) != place_) {
    place_ = boost::get<platform::CUDAPlace>(place);
  }
  if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
    cuda_ptr_.reset(
        memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
        memory::PlainDeleter<void, platform::CUDAPlace>(place_));
  }
  cuda_size_ = this->size();
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto *ctx = pool.GetByPlace(place_);
  memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
               static_cast<const void *>(this->data()),
               this->size() * sizeof(T), ctx->stream());
D
dzhwinter 已提交
187
  ctx->Wait();
D
dzhwinter 已提交
188 189 190 191 192
#endif
}

}  // namespace framework
}  // namespace paddle