未验证 提交 06f10942 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8275 from reyoung/feature/rewrite_vector

Rewrite mixed_vector.h
paddle/operators/check_t.save
paddle/operators/check_tensor.ls
paddle/operators/tensor.save
python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/
python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/
python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/
*.DS_Store *.DS_Store
build/ build/
build_doc/ build_doc/
......
...@@ -181,7 +181,8 @@ elseif(CMAKE_BUILD_TYPE STREQUAL "Release") ...@@ -181,7 +181,8 @@ elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL}) # nvcc 9 does not support -Os. Use Release flags instead
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
endif() endif()
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
......
...@@ -46,29 +46,7 @@ namespace framework { ...@@ -46,29 +46,7 @@ namespace framework {
* 0 2 4 7 * 0 2 4 7
* 0 2 5 7 10 12 15 20 * 0 2 5 7 10 12 15 20
*/ */
struct LoD : public std::vector<Vector<size_t>> { using LoD = std::vector<Vector<size_t>>;
using std::vector<Vector<size_t>>::vector;
platform::Place place() const {
if (this->size() == 0) {
// Not Initialze Yet.
return platform::CPUPlace();
} else {
return this->front().place();
}
}
void CopyFromCUDA() {
for (auto it = this->begin(); it != this->end(); ++it) {
it->CopyFromCUDA();
}
}
void CopyToPeer(platform::Place place) {
for (auto it = this->begin(); it != this->end(); ++it) {
it->CopyToPeer(place);
}
}
};
std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t); std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <paddle/platform/place.h>
__global__ void test(size_t* a, int size) { __global__ void test(size_t* a, int size) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
...@@ -36,10 +37,9 @@ TEST(LoD, data) { ...@@ -36,10 +37,9 @@ TEST(LoD, data) {
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11})); lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));
auto& v = lod[0]; auto& v = lod[0];
test<<<1, 1>>>(v.cuda_data(), v.size()); paddle::platform::CUDAPlace gpu(0);
test<<<1, 1>>>(v.CUDAMutableData(gpu), v.size());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
v.CopyFromCUDA();
for (size_t i = 0; i < v.size(); ++i) { for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], i * 2); EXPECT_EQ(v[i], i * 2);
} }
...@@ -63,9 +63,8 @@ TEST(LoDTensor, LoDInGPU) { ...@@ -63,9 +63,8 @@ TEST(LoDTensor, LoDInGPU) {
auto lod = lod_tensor.lod(); auto lod = lod_tensor.lod();
test<<<1, 8>>>(lod[0].cuda_data(), lod[0].size()); test<<<1, 8>>>(lod[0].CUDAMutableData(place), lod[0].size());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
lod.CopyFromCUDA();
for (size_t i = 0; i < src_lod[0].size(); ++i) { for (size_t i = 0; i < src_lod[0].size(); ++i) {
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
......
...@@ -17,176 +17,347 @@ ...@@ -17,176 +17,347 @@
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include "paddle/memory/memcpy.h" #include "paddle/framework/tensor.h"
#include "paddle/memory/memory.h" #include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "glog/logging.h"
#include "paddle/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/** // Vector<T> implements the std::vector interface, and can get Data or
* @brief Vector support both cpu and gpu. // MutableData from any place. The data will be synced implicitly inside.
* host vector lifetime is same with Vector
* device vector is lazily malloc and modified.
*/
template <typename T> template <typename T>
class Vector : public std::vector<T> { class Vector {
public: public:
using std::vector<T>::vector; using value_type = T;
// Default ctor. Create empty Vector
Vector() { InitEmpty(); }
// Fill vector with value. The vector size is `count`.
explicit Vector(size_t count, const T& value = T()) {
if (count == 0) {
InitEmpty();
} else {
resize(count);
T* ptr = begin();
for (size_t i = 0; i < count; ++i) {
ptr[i] = value;
}
}
}
// Ctor with init_list
Vector(std::initializer_list<T> init) {
if (init.size() == 0) {
InitEmpty();
} else {
InitByIter(init.size(), init.begin(), init.end());
}
}
// implicit cast from std::vector.
template <typename U>
Vector(const std::vector<U>& dat) { // NOLINT
if (dat.size() == 0) {
InitEmpty();
} else {
InitByIter(dat.size(), dat.begin(), dat.end());
}
}
// Copy ctor
Vector(const Vector<T>& other) { this->operator=(other); }
// Copy operator
Vector<T>& operator=(const Vector<T>& other) {
if (other.size() != 0) {
this->InitByIter(other.size(), other.begin(), other.end());
} else {
InitEmpty();
}
return *this;
}
// Move ctor
Vector(Vector<T>&& other) {
this->size_ = other.size_;
this->flag_ = other.flag_;
if (other.cuda_vec_.memory_size()) {
this->cuda_vec_.ShareDataWith(other.cuda_vec_);
}
if (other.cpu_vec_.memory_size()) {
this->cpu_vec_.ShareDataWith(other.cpu_vec_);
}
}
// CPU data access method. Mutable.
T& operator[](size_t i) {
MutableCPU();
return const_cast<T*>(cpu_vec_.data<T>())[i];
}
// CPU data access method. Immutable.
const T& operator[](size_t i) const {
ImmutableCPU();
return cpu_vec_.data<T>()[i];
}
// std::vector iterator methods. Based on CPU data access method
size_t size() const { return size_; }
T* begin() { return &this->operator[](0); }
T* end() { return &this->operator[](size()); }
T& front() { return *begin(); }
T& back() {
auto it = end();
--it;
return *it;
}
const T* begin() const { return &this->operator[](0); }
const T* end() const { return &this->operator[](size()); }
const T& back() const {
auto it = end();
--it;
return *it;
}
T* data() { return begin(); }
const T* data() const { return begin(); }
const T& front() const { return *begin(); }
// end of std::vector iterator methods
// assign this from iterator.
// NOTE: the iterator must support `end-begin`
template <typename Iter>
void assign(Iter begin, Iter end) {
InitByIter(end - begin, begin, end);
}
// push_back. If the previous capacity is not enough, the memory will
// double.
void push_back(T elem) {
if (size_ + 1 > capacity()) {
reserve((size_ + 1) << 1);
}
*end() = elem;
++size_;
}
Vector() {} // extend a vector by iterator.
Vector(const std::vector<T> &v) : std::vector<T>(v) {} // NOLINT // NOTE: the iterator must support end-begin
template <typename It>
void Extend(It begin, It end) {
size_t pre_size = size_;
resize(pre_size + (end - begin));
T* ptr = this->begin() + pre_size;
for (; begin < end; ++begin, ++ptr) {
*ptr = *begin;
}
}
inline platform::Place place() const { return place_; } // resize the vector
void resize(size_t size) {
if (size + 1 < capacity()) {
size_ = size;
} else {
MutableCPU();
Tensor cpu_tensor;
platform::Place cpu = platform::CPUPlace();
T* ptr = cpu_tensor.mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(size)}), cpu);
const T* old_ptr =
cpu_vec_.memory_size() == 0 ? nullptr : cpu_vec_.data<T>();
if (old_ptr != nullptr) {
std::copy(old_ptr, old_ptr + size_, ptr);
}
size_ = size;
cpu_vec_.ShareDataWith(cpu_tensor);
}
}
/*! Return a pointer to constant memory block. */ // get cuda ptr. immutable
inline const T *data(platform::Place place) const; const T* CUDAData(platform::Place place) const {
PADDLE_ENFORCE(platform::is_gpu_place(place),
"CUDA Data must on CUDA place");
ImmutableCUDA(place);
return cuda_vec_.data<T>();
}
/*! Return a pointer to mutable memory block. */ // get cuda ptr. mutable
inline T *mutable_data(platform::Place place); T* CUDAMutableData(platform::Place place) {
const T* ptr = CUDAData(place);
flag_ = kDirty | kDataInCUDA;
return const_cast<T*>(ptr);
}
// TODO(dzhwinter): below interfaces should be removed // clear
/* Get device vector */ void clear() {
T *cuda_data() { size_ = 0;
CopyToCUDA(); flag_ = kDirty | kDataInCPU;
PADDLE_ENFORCE_NOT_NULL(
cuda_ptr_, "No data or Insufficient CUDA memory to allocation");
return static_cast<T *>(cuda_ptr_.get());
} }
/* Get host vector */ size_t capacity() const {
T *data() { return std::vector<T>::data(); } return cpu_vec_.memory_size() / SizeOfType(typeid(T));
const T *data() const { return std::vector<T>::data(); } }
// reserve data
void reserve(size_t size) {
size_t pre_size = size_;
resize(size);
resize(pre_size);
}
T *data(const platform::Place &place) { // the unify method to access CPU or CUDA data. immutable.
if (platform::is_cpu_place(place)) { const T* Data(platform::Place place) const {
if (platform::is_gpu_place(place)) {
return CUDAData(place);
} else {
return data(); return data();
}
}
// the unify method to access CPU or CUDA data. mutable.
T* MutableData(platform::Place place) {
if (platform::is_gpu_place(place)) {
return CUDAMutableData(place);
} else { } else {
return cuda_data(); return data();
} }
} }
/* Synchronize host vector to device vector */ // implicit cast operator. Vector can be cast to std::vector implicitly.
void CopyToCUDA(); operator std::vector<T>() const {
/* Synchronize device vector to host vector */ std::vector<T> result;
void CopyFromCUDA(); result.resize(size());
/* Switch device vector location */ std::copy(begin(), end(), result.begin());
void CopyToPeer(platform::Place); return result;
}
bool operator==(const Vector<T>& other) const {
if (size() != other.size()) return false;
for (auto it1 = begin(), it2 = other.begin(); it1 < end(); ++it1, ++it2) {
if (*it1 != *it2) {
return false;
}
}
return true;
}
private: private:
std::shared_ptr<void> cuda_ptr_; void InitEmpty() {
size_t cuda_size_ = 0; // device vector numel size_ = 0;
platform::CUDAPlace place_; flag_ = kDataInCPU;
}; }
template <typename T> template <typename Iter>
inline const T *Vector<T>::data(platform::Place place) const { void InitByIter(size_t size, Iter begin, Iter end) {
if (platform::is_cpu_place(place)) { platform::Place cpu = platform::CPUPlace();
return std::vector<T>::data(); T* ptr = this->cpu_vec_.template mutable_data<T>(
} else if (platform::is_gpu_place(place)) { framework::make_ddim({static_cast<int64_t>(size)}), cpu);
if (cuda_ptr_ == nullptr) { for (size_t i = 0; i < size; ++i) {
return nullptr; *ptr++ = *begin++;
} }
if (boost::get<platform::CUDAPlace>(place) == place_) { flag_ = kDataInCPU | kDirty;
return static_cast<const T *>(cuda_ptr_.get()); size_ = size;
}
enum DataFlag {
kDataInCPU = 0x01,
kDataInCUDA = 0x02,
// kDirty means the data has been changed in one device.
kDirty = 0x10
};
void CopyToCPU() const {
// COPY GPU Data To CPU
Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_);
WaitPlace(cuda_vec_.place());
}
void MutableCPU() {
if (IsInCUDA() && IsDirty()) {
CopyToCPU();
}
flag_ = kDirty | kDataInCPU;
}
void ImmutableCUDA(platform::Place place) const {
if (IsDirty()) {
if (IsInCPU()) {
Copy(cpu_vec_, boost::get<platform::CUDAPlace>(place), &cuda_vec_);
WaitPlace(place);
UnsetFlag(kDirty);
SetFlag(kDataInCUDA);
} else if (IsInCUDA() && !(place == cuda_vec_.place())) {
framework::Tensor tmp;
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place());
cuda_vec_.ShareDataWith(tmp);
// Still dirty
} else {
// Dirty && DataInCUDA && Device is same
// Do nothing
}
} else { } else {
PADDLE_THROW( if (!IsInCUDA()) {
"Unmatched place. Please use `mutable_data` copy lod to the target " // Even data is not dirty. However, data is not in CUDA. Copy data.
"Place first."); Copy(cpu_vec_, boost::get<platform::CUDAPlace>(place), &cuda_vec_);
WaitPlace(place);
SetFlag(kDataInCUDA);
} else if (!(place == cuda_vec_.place())) {
framework::Tensor tmp;
WaitPlace(cuda_vec_.place());
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place());
WaitPlace(place);
cuda_vec_.ShareDataWith(tmp);
} else {
// Not Dirty && DataInCUDA && Device is same
// Do nothing.
}
} }
} else {
PADDLE_THROW("Unsupport Place.");
} }
}
template <typename T> void ImmutableCPU() const {
inline T *Vector<T>::mutable_data(platform::Place place) { if (IsDirty() &&
if (platform::is_cpu_place(place)) { !IsInCPU()) { // If data has been changed in CUDA, or CPU has no data.
return std::vector<T>::data(); CopyToCPU();
} else if (platform::is_gpu_place(place)) { UnsetFlag(kDirty);
if (boost::get<platform::CUDAPlace>(place) != place_) { }
place_ = boost::get<platform::CUDAPlace>(place); SetFlag(kDataInCPU);
} }
#ifdef PADDLE_WITH_CUDA
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
cuda_ptr_.reset(
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
memory::PlainDeleter<void, platform::CUDAPlace>(place_));
}
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
return static_cast<T *>(cuda_ptr_.get());
#else
return nullptr;
#endif
} else {
PADDLE_THROW("Unsupport Place.");
}
}
template <typename T> void UnsetFlag(int flag) const { flag_ &= ~flag; }
void Vector<T>::CopyToCUDA() { void SetFlag(int flag) const { flag_ |= flag; }
#ifdef PADDLE_WITH_CUDA
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
cuda_ptr_.reset(
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
memory::PlainDeleter<void, platform::CUDAPlace>(place_));
}
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
#endif
}
template <typename T> bool IsDirty() const { return flag_ & kDirty; }
void Vector<T>::CopyFromCUDA() {
#ifdef PADDLE_WITH_CUDA
if (cuda_ptr_ == nullptr) {
LOG(WARNING) << "No uncommitted cuda data.";
return;
}
this->resize(cuda_size_);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(platform::CPUPlace(), static_cast<void *>(this->data()), place_,
static_cast<const void *>(cuda_ptr_.get()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
#endif
}
template <typename T> bool IsInCUDA() const { return flag_ & kDataInCUDA; }
void Vector<T>::CopyToPeer(platform::Place place) {
#ifdef PADDLE_WITH_CUDA bool IsInCPU() const { return flag_ & kDataInCPU; }
if (boost::get<platform::CUDAPlace>(place) != place_) {
place_ = boost::get<platform::CUDAPlace>(place); static void WaitPlace(const platform::Place place) {
} if (platform::is_gpu_place(place)) {
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { platform::DeviceContextPool::Instance()
cuda_ptr_.reset( .Get(boost::get<platform::CUDAPlace>(place))
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)), ->Wait();
memory::PlainDeleter<void, platform::CUDAPlace>(place_)); }
} }
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); mutable int flag_;
auto *ctx = pool.GetByPlace(place_); mutable Tensor cpu_vec_;
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), mutable Tensor cuda_vec_;
static_cast<const void *>(this->data()), size_t size_;
this->size() * sizeof(T), ctx->stream()); };
ctx->Wait();
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -11,62 +11,83 @@ ...@@ -11,62 +11,83 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "gtest/gtest.h"
#include "paddle/framework/init.h" #include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/framework/mixed_vector.h" #include "paddle/framework/mixed_vector.h"
#include "paddle/platform/gpu_info.h"
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::memory;
template <typename T> template <typename T>
__global__ void test(T* data, int size) { using vec = paddle::framework::Vector<T>;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
i += blockDim.x * gridDim.x) { TEST(mixed_vector, CPU_VECTOR) {
data[i] *= 2; vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10);
vec<int> tmp2;
tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]);
}
int cnt = 0;
for (auto& t : tmp2) {
ASSERT_EQ(t, cnt);
++cnt;
} }
} }
TEST(Vector, Normal) { static __global__ void multiply_10(int* ptr) {
// fill the device context pool. for (int i = 0; i < 10; ++i) {
InitDevices(); ptr[i] *= 10;
}
}
cudaStream_t GetCUDAStream(paddle::platform::CUDAPlace place) {
return reinterpret_cast<const paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->stream();
}
Vector<size_t> vec({1, 2, 3}); TEST(mixed_vector, GPU_VECTOR) {
size_t* ptr = vec.data(); vec<int> tmp;
for (size_t i = 0; i < vec.size(); ++i) { for (int i = 0; i < 10; ++i) {
EXPECT_EQ(vec[i], *(ptr + i)); tmp.push_back(i);
} }
ASSERT_EQ(tmp.size(), 10);
paddle::platform::CUDAPlace gpu(0);
vec.clear(); multiply_10<<<1, 1, 0, GetCUDAStream(gpu)>>>(tmp.MutableData(gpu));
vec.CopyFromCUDA();
std::vector<size_t> v = {1, 2, 3}; for (int i = 0; i < 10; ++i) {
for (size_t i = 0; i < v.size(); ++i) { ASSERT_EQ(tmp[i], i * 10);
EXPECT_EQ(v[i], vec[i]);
} }
} }
TEST(Vector, MultipleCopy) { TEST(mixed_vector, MultiGPU) {
InitDevices(); if (paddle::platform::GetCUDADeviceCount() < 2) {
Vector<size_t> vec({1, 2, 3}); LOG(WARNING) << "Skip mixed_vector.MultiGPU since there are not multiple "
CUDAPlace place(0); "GPUs in your machine.";
vec.mutable_data(place); return;
auto vec2 = Vector<size_t>(vec); }
{
const size_t* ptr = vec2.data(CPUPlace()); vec<int> tmp;
for (size_t i = 0; i < vec2.size(); ++i) { for (int i = 0; i < 10; ++i) {
EXPECT_EQ(*(ptr + i), vec[i]); tmp.push_back(i);
}
} }
test<size_t><<<3, 3>>>(vec2.mutable_data(place), vec2.size()); ASSERT_EQ(tmp.size(), 10);
vec2.CopyFromCUDA(); paddle::platform::CUDAPlace gpu0(0);
{ paddle::platform::SetDeviceId(0);
const size_t* ptr = vec2.data(CPUPlace()); multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0));
for (size_t i = 0; i < vec2.size(); ++i) { paddle::platform::CUDAPlace gpu1(1);
EXPECT_EQ(*(ptr + i), vec[i] * 2); auto* gpu1_ptr = tmp.MutableData(gpu1);
} paddle::platform::SetDeviceId(1);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu1)>>>(gpu1_ptr);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp[i], i * 100);
} }
} }
...@@ -120,6 +120,7 @@ class Tensor { ...@@ -120,6 +120,7 @@ class Tensor {
return holder_->type(); return holder_->type();
} }
// memory size returns the holding memory size in byte.
size_t memory_size() const; size_t memory_size() const;
inline void check_memory_size() const; inline void check_memory_size() const;
......
...@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> { ...@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}; };
static inline size_t SizeOfType(std::type_index type) { static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool> functor; SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t> functor;
size_t size = functor(type); size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size; return size;
...@@ -61,15 +61,15 @@ static inline size_t SizeOfType(std::type_index type) { ...@@ -61,15 +61,15 @@ static inline size_t SizeOfType(std::type_index type) {
inline void Tensor::check_memory_size() const { inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_LE(
holder_->size(), memory_size() + offset_, numel() * SizeOfType(type()), memory_size(),
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n" "first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored."); "or maybe the required data-type mismatches the data already stored.");
} }
inline size_t Tensor::memory_size() const { inline size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : numel() * SizeOfType(type()); return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
template <typename T> template <typename T>
......
...@@ -101,9 +101,9 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -101,9 +101,9 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
SparseAdagradFunctorKernel< SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0, T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, merge_rows.cuda_data(), lr, .stream()>>>(
param_data, moment_data, grad_width, grad_merge_data, merge_rows.CUDAMutableData(context.GetPlace()), lr,
epsilon); param_data, moment_data, grad_width, epsilon);
} }
}; };
......
...@@ -201,7 +201,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -201,7 +201,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
int64_t* rows = nullptr; int64_t* rows = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
rows = grad_merge.mutable_rows()->cuda_data(); rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace());
} else { } else {
rows = grad_merge.mutable_rows()->data(); rows = grad_merge.mutable_rows()->data();
} }
......
...@@ -69,8 +69,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> { ...@@ -69,8 +69,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>( MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, input_lod[level].cuda_data(), blank, num_tokens, tokens, num_seq,
merge_repeated, dev_out_lod0_ptr, output_data); input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, merge_repeated,
dev_out_lod0_ptr, output_data);
// set output lod // set output lod
std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end());
......
...@@ -125,7 +125,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -125,7 +125,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
new_rows.resize(ids_dim[0]); new_rows.resize(ids_dim[0]);
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(platform::CPUPlace(), new_rows.cuda_data(), gpu_place, // TODO(yuyang18): Strange code here.
memory::Copy(platform::CPUPlace(),
new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
ids_data, ids_dim[0] * sizeof(int64_t), stream); ids_data, ids_dim[0] * sizeof(int64_t), stream);
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
......
...@@ -128,7 +128,7 @@ struct SelectedRowsAddTo<platform::CPUDeviceContext, T> { ...@@ -128,7 +128,7 @@ struct SelectedRowsAddTo<platform::CPUDeviceContext, T> {
auto* in2_value = input2->mutable_value(); auto* in2_value = input2->mutable_value();
// concat rows // concat rows
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); in2_rows.Extend(in1_rows.begin(), in1_rows.end());
auto in1_place = input1.place(); auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
......
...@@ -126,7 +126,8 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { ...@@ -126,7 +126,8 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
SelectedRowsAddTensorKernel< SelectedRowsAddTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>( T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.cuda_data(), out_data, in1_row_numel); in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
in1_row_numel);
auto out_eigen = framework::EigenVector<T>::Flatten(*output); auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2); auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
...@@ -153,7 +154,9 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { ...@@ -153,7 +154,9 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto* in2_value = input2->mutable_value(); auto* in2_value = input2->mutable_value();
// concat rows // concat rows
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); if (in1_rows.size()) {
in2_rows.Extend(in1_rows.begin(), in1_rows.end());
}
auto in1_place = input1.place(); auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
...@@ -216,7 +219,8 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> { ...@@ -216,7 +219,8 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
SelectedRowsAddToTensorKernel< SelectedRowsAddToTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>( T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.cuda_data(), in2_data, in1_row_numel); in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
in1_row_numel);
} }
}; };
...@@ -283,9 +287,10 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -283,9 +287,10 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
MergeAddKernel< MergeAddKernel<
T, 256><<<grid1, threads, 0, T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input_rows.cuda_data(), out_data, .stream()>>>(
out.mutable_rows()->cuda_data(), input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.rows().size(), input_width); out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width);
return out; return out;
} }
}; };
......
...@@ -45,7 +45,6 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { ...@@ -45,7 +45,6 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& src, const framework::Tensor& src,
framework::Vector<size_t> index_lod, framework::Tensor& dst, framework::Vector<size_t> index_lod, framework::Tensor& dst,
bool is_src_index) { bool is_src_index) {
size_t* index = index_lod.cuda_data();
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst.dims(); auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2, PADDLE_ENFORCE_EQ(src_dims.size(), 2,
...@@ -63,7 +62,8 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { ...@@ -63,7 +62,8 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
dim3 grid(8, 1); dim3 grid(8, 1);
auto stream = context.stream(); auto stream = context.stream();
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>( CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
src_data, dst_data, index, height, width, is_src_index); src_data, dst_data, index_lod.CUDAData(context.GetPlace()), height,
width, is_src_index);
} }
}; };
......
...@@ -121,12 +121,12 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -121,12 +121,12 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (norm_by_times) { if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), padding_data, const_cast<T*>(seq_data),
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} else { } else {
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), padding_data, const_cast<T*>(seq_data),
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} }
} }
...@@ -196,12 +196,12 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -196,12 +196,12 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (norm_by_times) { if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, const_cast<T*>(padding_data), seq_data,
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} else { } else {
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, const_cast<T*>(padding_data), seq_data,
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} }
} }
......
...@@ -73,7 +73,8 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> { ...@@ -73,7 +73,8 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> {
dim3 grid(num_seq, 1); dim3 grid(num_seq, 1);
auto stream = context.stream(); auto stream = context.stream();
KeMaxSequencePool<T><<<grid, threads, 0, stream>>>( KeMaxSequencePool<T><<<grid, threads, 0, stream>>>(
in_data, starts.cuda_data(), out_data, max_index, num_seq, dim); in_data, starts.CUDAData(context.GetPlace()), out_data, max_index,
num_seq, dim);
} }
}; };
......
...@@ -46,7 +46,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -46,7 +46,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<< SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, abs_offset_lod[level].cuda_data(), scales, seq_width); seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()),
scales, seq_width);
} }
}; };
......
...@@ -79,9 +79,6 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -79,9 +79,6 @@ inline void CopyOrShare(const framework::Variable &src,
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod()); dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else { } else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>()); Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
framework::LoD lod(src.Get<LoDTensor>().lod());
lod.CopyToPeer(dst_place);
dst->GetMutable<LoDTensor>()->set_lod(lod);
} }
} else if (src.IsType<SelectedRows>()) { } else if (src.IsType<SelectedRows>()) {
auto &src_sr = src.Get<SelectedRows>(); auto &src_sr = src.Get<SelectedRows>();
...@@ -92,9 +89,6 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -92,9 +89,6 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr->set_rows(src_sr.rows()); dst_sr->set_rows(src_sr.rows());
} else { } else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); Copy(src_sr.value(), dst_place, dst_sr->mutable_value());
framework::Vector<int64_t> lod(src_sr.rows());
lod.CopyToPeer(dst_place);
dst_sr->set_rows(lod);
} }
} else { } else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
...@@ -152,9 +146,6 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -152,9 +146,6 @@ class ParallelDoOp : public framework::OperatorBase {
auto *sub_scope = sub_scopes[i]; auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>(); auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::Copy(src, place, dst); framework::Copy(src, place, dst);
framework::LoD lod(src.lod());
lod.CopyToPeer(place);
dst->set_lod(lod);
} }
} }
WaitOnPlaces(places); WaitOnPlaces(places);
......
...@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T> ...@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T>
int input_dim = X->dims()[1]; int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.cuda_data(); size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
if (future_context <= 32) { if (future_context <= 32) {
...@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T> ...@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T>
int input_dim = X->dims()[1]; int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.cuda_data(); size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
auto &device_ctx = context.cuda_device_context(); auto &device_ctx = context.cuda_device_context();
math::SetConstant<platform::CUDADeviceContext, T> zero; math::SetConstant<platform::CUDADeviceContext, T> zero;
......
...@@ -87,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -87,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
// Copy LoD to GPU // Copy LoD to GPU
auto lod0 = lod[0]; auto lod0 = lod[0];
auto lod_len = lod0.size(); auto lod_len = lod0.size();
thrust::device_vector<size_t> dev_in_lod = lod0; const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace());
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
// Calc output LoD // Calc output LoD
thrust::device_vector<size_t> dev_out_lod(lod_len); thrust::device_vector<size_t> dev_out_lod(lod_len);
......
...@@ -102,8 +102,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> { ...@@ -102,8 +102,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
dim3 grid(1, in_rows.size()); dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel< SparseSGDFunctorKernel<
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>( T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
in_data, in_rows.cuda_data(), learning_rate->data<T>(), out_data, in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
in_row_numel); out_data, in_row_numel);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
...@@ -137,8 +137,8 @@ class TargetAssignKernel : public framework::OpKernel<T> { ...@@ -137,8 +137,8 @@ class TargetAssignKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]); PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]);
} }
size_t* gt_lod_data = gt_lod.data(ctx.GetPlace()); size_t* gt_lod_data = gt_lod.MutableData(ctx.GetPlace());
size_t* neg_lod_data = neg_lod.data(ctx.GetPlace()); size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace());
TargetAssignFunctor<T> functor(box_data, label_data, match_idx_data, TargetAssignFunctor<T> functor(box_data, label_data, match_idx_data,
gt_lod_data, background_label, num, gt_lod_data, background_label, num,
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
std::vector<char*> new_argv; std::vector<char*> new_argv;
std::string gflags_env; std::string gflags_env;
for (int i = 0; i < argc; ++i) { for (int i = 0; i < argc; ++i) {
...@@ -35,7 +36,6 @@ int main(int argc, char** argv) { ...@@ -35,7 +36,6 @@ int main(int argc, char** argv) {
int new_argc = static_cast<int>(new_argv.size()); int new_argc = static_cast<int>(new_argv.size());
char** new_argv_address = new_argv.data(); char** new_argv_address = new_argv.data();
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); google::ParseCommandLineFlags(&new_argc, &new_argv_address, false);
testing::InitGoogleTest(&argc, argv);
paddle::memory::Used(paddle::platform::CPUPlace()); paddle::memory::Used(paddle::platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册