提交 ef1aba39 编写于 作者: Y Yu Yang

Rewrite mixed_vector.h

上级 b1869f16
paddle/operators/check_t.save
paddle/operators/check_tensor.ls
paddle/operators/tensor.save
python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/
python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/
python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/
*.DS_Store *.DS_Store
build/ build/
build_doc/ build_doc/
......
...@@ -181,7 +181,8 @@ elseif(CMAKE_BUILD_TYPE STREQUAL "Release") ...@@ -181,7 +181,8 @@ elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL}) # nvcc 9 does not support -Os. Use Release flags instead
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
endif() endif()
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
......
...@@ -46,29 +46,7 @@ namespace framework { ...@@ -46,29 +46,7 @@ namespace framework {
* 0 2 4 7 * 0 2 4 7
* 0 2 5 7 10 12 15 20 * 0 2 5 7 10 12 15 20
*/ */
struct LoD : public std::vector<Vector<size_t>> { using LoD = std::vector<Vector<size_t>>;
using std::vector<Vector<size_t>>::vector;
platform::Place place() const {
if (this->size() == 0) {
// Not Initialze Yet.
return platform::CPUPlace();
} else {
return this->front().place();
}
}
void CopyFromCUDA() {
for (auto it = this->begin(); it != this->end(); ++it) {
it->CopyFromCUDA();
}
}
void CopyToPeer(platform::Place place) {
for (auto it = this->begin(); it != this->end(); ++it) {
it->CopyToPeer(place);
}
}
};
std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t); std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <paddle/platform/place.h>
__global__ void test(size_t* a, int size) { __global__ void test(size_t* a, int size) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
...@@ -36,10 +37,9 @@ TEST(LoD, data) { ...@@ -36,10 +37,9 @@ TEST(LoD, data) {
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11})); lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));
auto& v = lod[0]; auto& v = lod[0];
test<<<1, 1>>>(v.cuda_data(), v.size()); paddle::platform::CUDAPlace gpu(0);
test<<<1, 1>>>(v.CUDAMutableData(gpu), v.size());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
v.CopyFromCUDA();
for (size_t i = 0; i < v.size(); ++i) { for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], i * 2); EXPECT_EQ(v[i], i * 2);
} }
...@@ -63,9 +63,8 @@ TEST(LoDTensor, LoDInGPU) { ...@@ -63,9 +63,8 @@ TEST(LoDTensor, LoDInGPU) {
auto lod = lod_tensor.lod(); auto lod = lod_tensor.lod();
test<<<1, 8>>>(lod[0].cuda_data(), lod[0].size()); test<<<1, 8>>>(lod[0].CUDAMutableData(place), lod[0].size());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
lod.CopyFromCUDA();
for (size_t i = 0; i < src_lod[0].size(); ++i) { for (size_t i = 0; i < src_lod[0].size(); ++i) {
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
......
...@@ -17,176 +17,297 @@ ...@@ -17,176 +17,297 @@
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include "paddle/memory/memcpy.h" #include "paddle/framework/tensor.h"
#include "paddle/memory/memory.h" #include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "glog/logging.h"
#include "paddle/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/**
* @brief Vector support both cpu and gpu.
* host vector lifetime is same with Vector
* device vector is lazily malloc and modified.
*/
template <typename T> template <typename T>
class Vector : public std::vector<T> { class Vector {
public: public:
using std::vector<T>::vector; using value_type = T;
Vector() {
size_ = 0;
flag_ = kDataInCPU;
}
explicit Vector(size_t count, const T& value = T()) {
resize(count);
T* ptr = begin();
for (size_t i = 0; i < count; ++i) {
ptr[i] = value;
}
}
Vector(std::initializer_list<T> init) {
InitByIter(init.size(), init.begin(), init.end());
}
template <typename U>
Vector(const std::vector<U>& dat) { // NOLINT
InitByIter(dat.size(), dat.begin(), dat.end());
}
Vector(const Vector<T>& other) { this->operator=(other); }
Vector<T>& operator=(const Vector<T>& other) {
if (other.size() != 0) {
this->InitByIter(other.size(), other.begin(), other.end());
} else {
size_ = 0;
flag_ = kDataInCPU;
}
return *this;
}
Vector(Vector<T>&& other) {
this->size_ = other.size_;
this->flag_ = other.flag_;
if (other.cuda_vec_.capacity()) {
this->cuda_vec_.ShareDataWith(other.cuda_vec_);
}
if (other.cpu_vec_.capacity()) {
this->cpu_vec_.ShareDataWith(other.cpu_vec_);
}
}
Vector() {} T& operator[](size_t i) {
Vector(const std::vector<T> &v) : std::vector<T>(v) {} // NOLINT MutableCPU();
return const_cast<T*>(cpu_vec_.data<T>())[i];
}
const T& operator[](size_t i) const {
ImmutableCPU();
return cpu_vec_.data<T>()[i];
}
size_t size() const { return size_; }
T* begin() { return &this->operator[](0); }
T* end() { return &this->operator[](size()); }
T& front() { return *begin(); }
T& back() {
auto it = end();
--it;
return *it;
}
const T* begin() const { return &this->operator[](0); }
const T* end() const { return &this->operator[](size()); }
inline platform::Place place() const { return place_; } const T& back() const {
auto it = end();
--it;
return *it;
}
const T& front() const { return *begin(); }
template <typename Iter>
void assign(Iter begin, Iter end) {
InitByIter(end - begin, begin, end);
}
T* data() { return begin(); }
/*! Return a pointer to constant memory block. */ const T* data() const { return begin(); }
inline const T *data(platform::Place place) const;
/*! Return a pointer to mutable memory block. */ void push_back(T elem) {
inline T *mutable_data(platform::Place place); if (size_ + 1 > capacity()) {
reserve((size_ + 1) << 1);
}
*end() = elem;
++size_;
}
// TODO(dzhwinter): below interfaces should be removed void resize(size_t size) {
/* Get device vector */ if (size + 1 < capacity()) {
T *cuda_data() { size_ = size;
CopyToCUDA(); } else {
PADDLE_ENFORCE_NOT_NULL( MutableCPU();
cuda_ptr_, "No data or Insufficient CUDA memory to allocation"); Tensor cpu_tensor;
return static_cast<T *>(cuda_ptr_.get()); platform::Place cpu = platform::CPUPlace();
T* ptr = cpu_tensor.mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(size)}), cpu);
const T* old_ptr =
cpu_vec_.capacity() == 0 ? nullptr : cpu_vec_.data<T>();
if (old_ptr != nullptr) {
std::copy(old_ptr, old_ptr + size_, ptr);
}
size_ = size;
cpu_vec_.ShareDataWith(cpu_tensor);
}
} }
/* Get host vector */ const T* CUDAData(platform::Place place) const {
T *data() { return std::vector<T>::data(); } PADDLE_ENFORCE(platform::is_gpu_place(place),
const T *data() const { return std::vector<T>::data(); } "CUDA Data must on CUDA place");
ImmutableCUDA(place);
return cuda_vec_.data<T>();
}
T *data(const platform::Place &place) { T* CUDAMutableData(platform::Place place) {
if (platform::is_cpu_place(place)) { const T* ptr = CUDAData(place);
flag_ = kDirty | kDataInCUDA;
return const_cast<T*>(ptr);
}
template <typename It>
void Extend(It begin, It end) {
size_t pre_size = size_;
resize(pre_size + (end - begin));
T* ptr = this->begin() + pre_size;
for (; begin < end; ++begin, ++ptr) {
*ptr = *begin;
}
}
void clear() {
size_ = 0;
flag_ = kDirty | kDataInCPU;
}
size_t capacity() const {
return cpu_vec_.capacity() / SizeOfType(typeid(T));
}
void reserve(size_t size) {
size_t pre_size = size_;
resize(size);
resize(pre_size);
}
const T* Data(platform::Place place) const {
if (platform::is_gpu_place(place)) {
return CUDAData(place);
} else {
return data(); return data();
}
}
T* MutableData(platform::Place place) {
if (platform::is_gpu_place(place)) {
return CUDAMutableData(place);
} else { } else {
return cuda_data(); return data();
} }
} }
/* Synchronize host vector to device vector */ operator std::vector<T>() const {
void CopyToCUDA(); std::vector<T> result;
/* Synchronize device vector to host vector */ result.resize(size());
void CopyFromCUDA(); std::copy(begin(), end(), result.begin());
/* Switch device vector location */ return result;
void CopyToPeer(platform::Place); }
bool operator==(const Vector<T>& other) const {
if (size() != other.size()) return false;
for (auto it1 = begin(), it2 = other.begin(); it1 < end(); ++it1, ++it2) {
if (*it1 != *it2) {
return false;
}
}
return true;
}
private: private:
std::shared_ptr<void> cuda_ptr_; template <typename Iter>
size_t cuda_size_ = 0; // device vector numel void InitByIter(size_t size, Iter begin, Iter end) {
platform::CUDAPlace place_; platform::Place cpu = platform::CPUPlace();
}; T* ptr = this->cpu_vec_.template mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(size)}), cpu);
for (size_t i = 0; i < size; ++i) {
*ptr++ = *begin++;
}
flag_ = kDataInCPU | kDirty;
size_ = size;
}
template <typename T> enum DataFlag { kDataInCPU = 0x01, kDataInCUDA = 0x02, kDirty = 0x10 };
inline const T *Vector<T>::data(platform::Place place) const {
if (platform::is_cpu_place(place)) { void MutableCPU() {
return std::vector<T>::data(); if (IsInCUDA() && IsDirty()) {
} else if (platform::is_gpu_place(place)) { // COPY GPU Data To CPU
if (cuda_ptr_ == nullptr) { Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_);
return nullptr; WaitPlace(cuda_vec_.place());
} }
if (boost::get<platform::CUDAPlace>(place) == place_) { flag_ = kDirty | kDataInCPU;
return static_cast<const T *>(cuda_ptr_.get()); }
void ImmutableCUDA(platform::Place place) const {
if (IsDirty()) {
if (IsInCPU()) {
Copy(cpu_vec_, boost::get<platform::CUDAPlace>(place), &cuda_vec_);
WaitPlace(place);
UnsetFlag(kDirty);
SetFlag(kDataInCUDA);
} else if (IsInCUDA() && !(place == cuda_vec_.place())) {
framework::Tensor tmp;
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place());
cuda_vec_.ShareDataWith(tmp);
// Still dirty
} else {
// Dirty && DataInCUDA && Device is same
// Do nothing
}
} else { } else {
PADDLE_THROW( if (!IsInCUDA()) {
"Unmatched place. Please use `mutable_data` copy lod to the target " // Even data is not dirty. However, data is not in CUDA. Copy data.
"Place first."); Copy(cpu_vec_, boost::get<platform::CUDAPlace>(place), &cuda_vec_);
WaitPlace(place);
SetFlag(kDataInCUDA);
} else if (!(place == cuda_vec_.place())) {
framework::Tensor tmp;
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place());
cuda_vec_.ShareDataWith(tmp);
} else {
// Not Dirty && DataInCUDA && Device is same
// Do nothing.
}
} }
} else {
PADDLE_THROW("Unsupport Place.");
} }
}
template <typename T> void ImmutableCPU() const {
inline T *Vector<T>::mutable_data(platform::Place place) { if (IsDirty() &&
if (platform::is_cpu_place(place)) { !IsInCPU()) { // If data has been changed in CUDA, or CPU has no data.
return std::vector<T>::data(); Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_);
} else if (platform::is_gpu_place(place)) { WaitPlace(cuda_vec_.place());
if (boost::get<platform::CUDAPlace>(place) != place_) { UnsetFlag(kDirty);
place_ = boost::get<platform::CUDAPlace>(place);
} }
#ifdef PADDLE_WITH_CUDA SetFlag(kDataInCPU);
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { }
cuda_ptr_.reset(
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
memory::PlainDeleter<void, platform::CUDAPlace>(place_));
}
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
return static_cast<T *>(cuda_ptr_.get());
#else
return nullptr;
#endif
} else {
PADDLE_THROW("Unsupport Place.");
}
}
template <typename T> void UnsetFlag(int flag) const { flag_ &= ~flag; }
void Vector<T>::CopyToCUDA() { void SetFlag(int flag) const { flag_ |= flag; }
#ifdef PADDLE_WITH_CUDA
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
cuda_ptr_.reset(
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
memory::PlainDeleter<void, platform::CUDAPlace>(place_));
}
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
#endif
}
template <typename T> bool IsDirty() const { return flag_ & kDirty; }
void Vector<T>::CopyFromCUDA() {
#ifdef PADDLE_WITH_CUDA
if (cuda_ptr_ == nullptr) {
LOG(WARNING) << "No uncommitted cuda data.";
return;
}
this->resize(cuda_size_);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(platform::CPUPlace(), static_cast<void *>(this->data()), place_,
static_cast<const void *>(cuda_ptr_.get()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
#endif
}
template <typename T> bool IsInCUDA() const { return flag_ & kDataInCUDA; }
void Vector<T>::CopyToPeer(platform::Place place) {
#ifdef PADDLE_WITH_CUDA bool IsInCPU() const { return flag_ & kDataInCPU; }
if (boost::get<platform::CUDAPlace>(place) != place_) {
place_ = boost::get<platform::CUDAPlace>(place); static void WaitPlace(const platform::Place place) {
} if (platform::is_gpu_place(place)) {
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { platform::DeviceContextPool::Instance()
cuda_ptr_.reset( .Get(boost::get<platform::CUDAPlace>(place))
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)), ->Wait();
memory::PlainDeleter<void, platform::CUDAPlace>(place_)); }
} }
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); mutable int flag_;
auto *ctx = pool.GetByPlace(place_); mutable Tensor cpu_vec_;
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), mutable Tensor cuda_vec_;
static_cast<const void *>(this->data()), size_t size_;
this->size() * sizeof(T), ctx->stream()); };
ctx->Wait();
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -11,62 +11,3 @@ ...@@ -11,62 +11,3 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda.h>
#include <cuda_runtime.h>
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/mixed_vector.h"
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::memory;
template <typename T>
__global__ void test(T* data, int size) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
i += blockDim.x * gridDim.x) {
data[i] *= 2;
}
}
TEST(Vector, Normal) {
// fill the device context pool.
InitDevices();
Vector<size_t> vec({1, 2, 3});
size_t* ptr = vec.data();
for (size_t i = 0; i < vec.size(); ++i) {
EXPECT_EQ(vec[i], *(ptr + i));
}
vec.clear();
vec.CopyFromCUDA();
std::vector<size_t> v = {1, 2, 3};
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], vec[i]);
}
}
TEST(Vector, MultipleCopy) {
InitDevices();
Vector<size_t> vec({1, 2, 3});
CUDAPlace place(0);
vec.mutable_data(place);
auto vec2 = Vector<size_t>(vec);
{
const size_t* ptr = vec2.data(CPUPlace());
for (size_t i = 0; i < vec2.size(); ++i) {
EXPECT_EQ(*(ptr + i), vec[i]);
}
}
test<size_t><<<3, 3>>>(vec2.mutable_data(place), vec2.size());
vec2.CopyFromCUDA();
{
const size_t* ptr = vec2.data(CPUPlace());
for (size_t i = 0; i < vec2.size(); ++i) {
EXPECT_EQ(*(ptr + i), vec[i] * 2);
}
}
}
...@@ -128,6 +128,10 @@ class Tensor { ...@@ -128,6 +128,10 @@ class Tensor {
inline void set_layout(const DataLayout layout) { layout_ = layout; } inline void set_layout(const DataLayout layout) { layout_ = layout; }
size_t capacity() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
private: private:
friend class LoDTensor; friend class LoDTensor;
......
...@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> { ...@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}; };
static inline size_t SizeOfType(std::type_index type) { static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool> functor; SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t> functor;
size_t size = functor(type); size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size; return size;
......
...@@ -101,9 +101,9 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -101,9 +101,9 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
SparseAdagradFunctorKernel< SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0, T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, merge_rows.cuda_data(), lr, .stream()>>>(
param_data, moment_data, grad_width, grad_merge_data, merge_rows.CUDAMutableData(context.GetPlace()), lr,
epsilon); param_data, moment_data, grad_width, epsilon);
} }
}; };
......
...@@ -201,7 +201,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -201,7 +201,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
int64_t* rows = nullptr; int64_t* rows = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
rows = grad_merge.mutable_rows()->cuda_data(); rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace());
} else { } else {
rows = grad_merge.mutable_rows()->data(); rows = grad_merge.mutable_rows()->data();
} }
......
...@@ -69,8 +69,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> { ...@@ -69,8 +69,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>( MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, input_lod[level].cuda_data(), blank, num_tokens, tokens, num_seq,
merge_repeated, dev_out_lod0_ptr, output_data); input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, merge_repeated,
dev_out_lod0_ptr, output_data);
// set output lod // set output lod
std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end());
......
...@@ -125,7 +125,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -125,7 +125,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
new_rows.resize(ids_dim[0]); new_rows.resize(ids_dim[0]);
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(platform::CPUPlace(), new_rows.cuda_data(), gpu_place, // TODO(yuyang18): Strange code here.
memory::Copy(platform::CPUPlace(),
new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
ids_data, ids_dim[0] * sizeof(int64_t), stream); ids_data, ids_dim[0] * sizeof(int64_t), stream);
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
......
...@@ -128,7 +128,7 @@ struct SelectedRowsAddTo<platform::CPUDeviceContext, T> { ...@@ -128,7 +128,7 @@ struct SelectedRowsAddTo<platform::CPUDeviceContext, T> {
auto* in2_value = input2->mutable_value(); auto* in2_value = input2->mutable_value();
// concat rows // concat rows
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); in2_rows.Extend(in1_rows.begin(), in1_rows.end());
auto in1_place = input1.place(); auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
......
...@@ -126,7 +126,8 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { ...@@ -126,7 +126,8 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
SelectedRowsAddTensorKernel< SelectedRowsAddTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>( T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.cuda_data(), out_data, in1_row_numel); in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
in1_row_numel);
auto out_eigen = framework::EigenVector<T>::Flatten(*output); auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2); auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
...@@ -153,7 +154,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { ...@@ -153,7 +154,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto* in2_value = input2->mutable_value(); auto* in2_value = input2->mutable_value();
// concat rows // concat rows
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); in2_rows.Extend(in1_rows.begin(), in1_rows.end());
auto in1_place = input1.place(); auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
...@@ -216,7 +217,8 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> { ...@@ -216,7 +217,8 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
SelectedRowsAddToTensorKernel< SelectedRowsAddToTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>( T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.cuda_data(), in2_data, in1_row_numel); in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
in1_row_numel);
} }
}; };
...@@ -283,9 +285,10 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -283,9 +285,10 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
MergeAddKernel< MergeAddKernel<
T, 256><<<grid1, threads, 0, T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input_rows.cuda_data(), out_data, .stream()>>>(
out.mutable_rows()->cuda_data(), input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.rows().size(), input_width); out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width);
return out; return out;
} }
}; };
......
...@@ -45,7 +45,6 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { ...@@ -45,7 +45,6 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& src, const framework::Tensor& src,
framework::Vector<size_t> index_lod, framework::Tensor& dst, framework::Vector<size_t> index_lod, framework::Tensor& dst,
bool is_src_index) { bool is_src_index) {
size_t* index = index_lod.cuda_data();
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst.dims(); auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2, PADDLE_ENFORCE_EQ(src_dims.size(), 2,
...@@ -63,7 +62,8 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { ...@@ -63,7 +62,8 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
dim3 grid(8, 1); dim3 grid(8, 1);
auto stream = context.stream(); auto stream = context.stream();
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>( CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
src_data, dst_data, index, height, width, is_src_index); src_data, dst_data, index_lod.CUDAData(context.GetPlace()), height,
width, is_src_index);
} }
}; };
......
...@@ -121,12 +121,12 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -121,12 +121,12 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (norm_by_times) { if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), padding_data, const_cast<T*>(seq_data),
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} else { } else {
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), padding_data, const_cast<T*>(seq_data),
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} }
} }
...@@ -196,12 +196,12 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -196,12 +196,12 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (norm_by_times) { if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, const_cast<T*>(padding_data), seq_data,
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} else { } else {
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, const_cast<T*>(padding_data), seq_data,
abs_offset_lod[level].cuda_data(), sequence_width, abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences); max_sequence_length, num_sequences);
} }
} }
......
...@@ -73,7 +73,8 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> { ...@@ -73,7 +73,8 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> {
dim3 grid(num_seq, 1); dim3 grid(num_seq, 1);
auto stream = context.stream(); auto stream = context.stream();
KeMaxSequencePool<T><<<grid, threads, 0, stream>>>( KeMaxSequencePool<T><<<grid, threads, 0, stream>>>(
in_data, starts.cuda_data(), out_data, max_index, num_seq, dim); in_data, starts.CUDAData(context.GetPlace()), out_data, max_index,
num_seq, dim);
} }
}; };
......
...@@ -46,7 +46,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -46,7 +46,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<< SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, abs_offset_lod[level].cuda_data(), scales, seq_width); seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()),
scales, seq_width);
} }
}; };
......
...@@ -79,9 +79,6 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -79,9 +79,6 @@ inline void CopyOrShare(const framework::Variable &src,
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod()); dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else { } else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>()); Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
framework::LoD lod(src.Get<LoDTensor>().lod());
lod.CopyToPeer(dst_place);
dst->GetMutable<LoDTensor>()->set_lod(lod);
} }
} else if (src.IsType<SelectedRows>()) { } else if (src.IsType<SelectedRows>()) {
auto &src_sr = src.Get<SelectedRows>(); auto &src_sr = src.Get<SelectedRows>();
...@@ -92,9 +89,6 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -92,9 +89,6 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr->set_rows(src_sr.rows()); dst_sr->set_rows(src_sr.rows());
} else { } else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); Copy(src_sr.value(), dst_place, dst_sr->mutable_value());
framework::Vector<int64_t> lod(src_sr.rows());
lod.CopyToPeer(dst_place);
dst_sr->set_rows(lod);
} }
} else { } else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
...@@ -152,9 +146,6 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -152,9 +146,6 @@ class ParallelDoOp : public framework::OperatorBase {
auto *sub_scope = sub_scopes[i]; auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>(); auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::Copy(src, place, dst); framework::Copy(src, place, dst);
framework::LoD lod(src.lod());
lod.CopyToPeer(place);
dst->set_lod(lod);
} }
} }
WaitOnPlaces(places); WaitOnPlaces(places);
......
...@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T> ...@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T>
int input_dim = X->dims()[1]; int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.cuda_data(); size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
if (future_context <= 32) { if (future_context <= 32) {
...@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T> ...@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T>
int input_dim = X->dims()[1]; int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.cuda_data(); size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
auto &device_ctx = context.cuda_device_context(); auto &device_ctx = context.cuda_device_context();
math::SetConstant<platform::CUDADeviceContext, T> zero; math::SetConstant<platform::CUDADeviceContext, T> zero;
......
...@@ -87,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -87,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
// Copy LoD to GPU // Copy LoD to GPU
auto lod0 = lod[0]; auto lod0 = lod[0];
auto lod_len = lod0.size(); auto lod_len = lod0.size();
thrust::device_vector<size_t> dev_in_lod = lod0; const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace());
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
// Calc output LoD // Calc output LoD
thrust::device_vector<size_t> dev_out_lod(lod_len); thrust::device_vector<size_t> dev_out_lod(lod_len);
......
...@@ -102,8 +102,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> { ...@@ -102,8 +102,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
dim3 grid(1, in_rows.size()); dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel< SparseSGDFunctorKernel<
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>( T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
in_data, in_rows.cuda_data(), learning_rate->data<T>(), out_data, in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
in_row_numel); out_data, in_row_numel);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
...@@ -137,8 +137,8 @@ class TargetAssignKernel : public framework::OpKernel<T> { ...@@ -137,8 +137,8 @@ class TargetAssignKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]); PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]);
} }
size_t* gt_lod_data = gt_lod.data(ctx.GetPlace()); size_t* gt_lod_data = gt_lod.MutableData(ctx.GetPlace());
size_t* neg_lod_data = neg_lod.data(ctx.GetPlace()); size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace());
TargetAssignFunctor<T> functor(box_data, label_data, match_idx_data, TargetAssignFunctor<T> functor(box_data, label_data, match_idx_data,
gt_lod_data, background_label, num, gt_lod_data, background_label, num,
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
std::vector<char*> new_argv; std::vector<char*> new_argv;
std::string gflags_env; std::string gflags_env;
for (int i = 0; i < argc; ++i) { for (int i = 0; i < argc; ++i) {
...@@ -35,7 +36,6 @@ int main(int argc, char** argv) { ...@@ -35,7 +36,6 @@ int main(int argc, char** argv) {
int new_argc = static_cast<int>(new_argv.size()); int new_argc = static_cast<int>(new_argv.size());
char** new_argv_address = new_argv.data(); char** new_argv_address = new_argv.data();
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); google::ParseCommandLineFlags(&new_argc, &new_argv_address, false);
testing::InitGoogleTest(&argc, argv);
paddle::memory::Used(paddle::platform::CPUPlace()); paddle::memory::Used(paddle::platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册