diff --git a/.gitignore b/.gitignore index ac56a3320ec85769d2c87c072512f5217eca0c24..59e650bdfe801c7e2ff19b6c0a9d60bed1e1ee10 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ +paddle/operators/check_t.save +paddle/operators/check_tensor.ls +paddle/operators/tensor.save +python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/ +python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/ +python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/ *.DS_Store build/ build_doc/ diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 6bea7cf3022242ce48cc882915f7e71810937283..de94bd5008effef1bf0fd3a125d4aed56e1b7f81 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -181,7 +181,8 @@ elseif(CMAKE_BUILD_TYPE STREQUAL "Release") elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") - list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL}) + # nvcc 9 does not support -Os. Use Release flags instead + list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) endif() mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index be2b301619639106ac7b578e5a79cf33f4379e48..9de454428d9fd733aa70601f5012e77b9ceb2022 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -46,29 +46,7 @@ namespace framework { * 0 2 4 7 * 0 2 5 7 10 12 15 20 */ -struct LoD : public std::vector> { - using std::vector>::vector; - platform::Place place() const { - if (this->size() == 0) { - // Not Initialze Yet. - return platform::CPUPlace(); - } else { - return this->front().place(); - } - } - - void CopyFromCUDA() { - for (auto it = this->begin(); it != this->end(); ++it) { - it->CopyFromCUDA(); - } - } - - void CopyToPeer(platform::Place place) { - for (auto it = this->begin(); it != this->end(); ++it) { - it->CopyToPeer(place); - } - } -}; +using LoD = std::vector>; std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoDTensor& t); diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index adea02e3b3fdcf4873de76ff91116f43ac9fe259..a28b7caf86c689d55808c4e7defecd37a5a03442 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -20,6 +20,7 @@ #include "paddle/platform/assert.h" #include +#include __global__ void test(size_t* a, int size) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; @@ -36,10 +37,9 @@ TEST(LoD, data) { lod.push_back(std::vector({0, 1, 6, 8, 10, 11})); auto& v = lod[0]; - test<<<1, 1>>>(v.cuda_data(), v.size()); + paddle::platform::CUDAPlace gpu(0); + test<<<1, 1>>>(v.CUDAMutableData(gpu), v.size()); cudaDeviceSynchronize(); - - v.CopyFromCUDA(); for (size_t i = 0; i < v.size(); ++i) { EXPECT_EQ(v[i], i * 2); } @@ -63,9 +63,8 @@ TEST(LoDTensor, LoDInGPU) { auto lod = lod_tensor.lod(); - test<<<1, 8>>>(lod[0].cuda_data(), lod[0].size()); + test<<<1, 8>>>(lod[0].CUDAMutableData(place), lod[0].size()); cudaDeviceSynchronize(); - lod.CopyFromCUDA(); for (size_t i = 0; i < src_lod[0].size(); ++i) { EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 5202775515d335ff81bb17e6ce21338c40041ca3..2a80079695f2dd19945bf4bc7ed17d6373fb51f2 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -17,176 +17,297 @@ #include #include -#include "paddle/memory/memcpy.h" -#include "paddle/memory/memory.h" -#include "paddle/platform/device_context.h" -#include "paddle/platform/enforce.h" -#include "paddle/platform/place.h" +#include "paddle/framework/tensor.h" +#include "paddle/framework/tensor_util.h" + +#include "glog/logging.h" namespace paddle { namespace framework { -/** - * @brief Vector support both cpu and gpu. - * host vector lifetime is same with Vector - * device vector is lazily malloc and modified. - */ - template -class Vector : public std::vector { +class Vector { public: - using std::vector::vector; + using value_type = T; + + Vector() { + size_ = 0; + flag_ = kDataInCPU; + } + + explicit Vector(size_t count, const T& value = T()) { + resize(count); + T* ptr = begin(); + for (size_t i = 0; i < count; ++i) { + ptr[i] = value; + } + } + + Vector(std::initializer_list init) { + InitByIter(init.size(), init.begin(), init.end()); + } + + template + Vector(const std::vector& dat) { // NOLINT + InitByIter(dat.size(), dat.begin(), dat.end()); + } + + Vector(const Vector& other) { this->operator=(other); } + + Vector& operator=(const Vector& other) { + if (other.size() != 0) { + this->InitByIter(other.size(), other.begin(), other.end()); + } else { + size_ = 0; + flag_ = kDataInCPU; + } + return *this; + } + + Vector(Vector&& other) { + this->size_ = other.size_; + this->flag_ = other.flag_; + if (other.cuda_vec_.capacity()) { + this->cuda_vec_.ShareDataWith(other.cuda_vec_); + } + if (other.cpu_vec_.capacity()) { + this->cpu_vec_.ShareDataWith(other.cpu_vec_); + } + } - Vector() {} - Vector(const std::vector &v) : std::vector(v) {} // NOLINT + T& operator[](size_t i) { + MutableCPU(); + return const_cast(cpu_vec_.data())[i]; + } + + const T& operator[](size_t i) const { + ImmutableCPU(); + return cpu_vec_.data()[i]; + } + + size_t size() const { return size_; } + + T* begin() { return &this->operator[](0); } + + T* end() { return &this->operator[](size()); } + + T& front() { return *begin(); } + + T& back() { + auto it = end(); + --it; + return *it; + } + + const T* begin() const { return &this->operator[](0); } + const T* end() const { return &this->operator[](size()); } - inline platform::Place place() const { return place_; } + const T& back() const { + auto it = end(); + --it; + return *it; + } + + const T& front() const { return *begin(); } + + template + void assign(Iter begin, Iter end) { + InitByIter(end - begin, begin, end); + } + + T* data() { return begin(); } - /*! Return a pointer to constant memory block. */ - inline const T *data(platform::Place place) const; + const T* data() const { return begin(); } - /*! Return a pointer to mutable memory block. */ - inline T *mutable_data(platform::Place place); + void push_back(T elem) { + if (size_ + 1 > capacity()) { + reserve((size_ + 1) << 1); + } + *end() = elem; + ++size_; + } - // TODO(dzhwinter): below interfaces should be removed - /* Get device vector */ - T *cuda_data() { - CopyToCUDA(); - PADDLE_ENFORCE_NOT_NULL( - cuda_ptr_, "No data or Insufficient CUDA memory to allocation"); - return static_cast(cuda_ptr_.get()); + void resize(size_t size) { + if (size + 1 < capacity()) { + size_ = size; + } else { + MutableCPU(); + Tensor cpu_tensor; + platform::Place cpu = platform::CPUPlace(); + T* ptr = cpu_tensor.mutable_data( + framework::make_ddim({static_cast(size)}), cpu); + const T* old_ptr = + cpu_vec_.capacity() == 0 ? nullptr : cpu_vec_.data(); + if (old_ptr != nullptr) { + std::copy(old_ptr, old_ptr + size_, ptr); + } + size_ = size; + cpu_vec_.ShareDataWith(cpu_tensor); + } } - /* Get host vector */ - T *data() { return std::vector::data(); } - const T *data() const { return std::vector::data(); } + const T* CUDAData(platform::Place place) const { + PADDLE_ENFORCE(platform::is_gpu_place(place), + "CUDA Data must on CUDA place"); + ImmutableCUDA(place); + return cuda_vec_.data(); + } - T *data(const platform::Place &place) { - if (platform::is_cpu_place(place)) { + T* CUDAMutableData(platform::Place place) { + const T* ptr = CUDAData(place); + flag_ = kDirty | kDataInCUDA; + return const_cast(ptr); + } + + template + void Extend(It begin, It end) { + size_t pre_size = size_; + resize(pre_size + (end - begin)); + T* ptr = this->begin() + pre_size; + for (; begin < end; ++begin, ++ptr) { + *ptr = *begin; + } + } + + void clear() { + size_ = 0; + flag_ = kDirty | kDataInCPU; + } + + size_t capacity() const { + return cpu_vec_.capacity() / SizeOfType(typeid(T)); + } + + void reserve(size_t size) { + size_t pre_size = size_; + resize(size); + resize(pre_size); + } + + const T* Data(platform::Place place) const { + if (platform::is_gpu_place(place)) { + return CUDAData(place); + } else { return data(); + } + } + + T* MutableData(platform::Place place) { + if (platform::is_gpu_place(place)) { + return CUDAMutableData(place); } else { - return cuda_data(); + return data(); } } - /* Synchronize host vector to device vector */ - void CopyToCUDA(); - /* Synchronize device vector to host vector */ - void CopyFromCUDA(); - /* Switch device vector location */ - void CopyToPeer(platform::Place); + operator std::vector() const { + std::vector result; + result.resize(size()); + std::copy(begin(), end(), result.begin()); + return result; + } + + bool operator==(const Vector& other) const { + if (size() != other.size()) return false; + for (auto it1 = begin(), it2 = other.begin(); it1 < end(); ++it1, ++it2) { + if (*it1 != *it2) { + return false; + } + } + return true; + } private: - std::shared_ptr cuda_ptr_; - size_t cuda_size_ = 0; // device vector numel - platform::CUDAPlace place_; -}; + template + void InitByIter(size_t size, Iter begin, Iter end) { + platform::Place cpu = platform::CPUPlace(); + T* ptr = this->cpu_vec_.template mutable_data( + framework::make_ddim({static_cast(size)}), cpu); + for (size_t i = 0; i < size; ++i) { + *ptr++ = *begin++; + } + flag_ = kDataInCPU | kDirty; + size_ = size; + } -template -inline const T *Vector::data(platform::Place place) const { - if (platform::is_cpu_place(place)) { - return std::vector::data(); - } else if (platform::is_gpu_place(place)) { - if (cuda_ptr_ == nullptr) { - return nullptr; + enum DataFlag { kDataInCPU = 0x01, kDataInCUDA = 0x02, kDirty = 0x10 }; + + void MutableCPU() { + if (IsInCUDA() && IsDirty()) { + // COPY GPU Data To CPU + Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); + WaitPlace(cuda_vec_.place()); } - if (boost::get(place) == place_) { - return static_cast(cuda_ptr_.get()); + flag_ = kDirty | kDataInCPU; + } + + void ImmutableCUDA(platform::Place place) const { + if (IsDirty()) { + if (IsInCPU()) { + Copy(cpu_vec_, boost::get(place), &cuda_vec_); + WaitPlace(place); + UnsetFlag(kDirty); + SetFlag(kDataInCUDA); + } else if (IsInCUDA() && !(place == cuda_vec_.place())) { + framework::Tensor tmp; + Copy(cuda_vec_, boost::get(place), &tmp); + WaitPlace(cuda_vec_.place()); + cuda_vec_.ShareDataWith(tmp); + // Still dirty + } else { + // Dirty && DataInCUDA && Device is same + // Do nothing + } } else { - PADDLE_THROW( - "Unmatched place. Please use `mutable_data` copy lod to the target " - "Place first."); + if (!IsInCUDA()) { + // Even data is not dirty. However, data is not in CUDA. Copy data. + Copy(cpu_vec_, boost::get(place), &cuda_vec_); + WaitPlace(place); + SetFlag(kDataInCUDA); + } else if (!(place == cuda_vec_.place())) { + framework::Tensor tmp; + Copy(cuda_vec_, boost::get(place), &tmp); + WaitPlace(cuda_vec_.place()); + cuda_vec_.ShareDataWith(tmp); + } else { + // Not Dirty && DataInCUDA && Device is same + // Do nothing. + } } - } else { - PADDLE_THROW("Unsupport Place."); } -} -template -inline T *Vector::mutable_data(platform::Place place) { - if (platform::is_cpu_place(place)) { - return std::vector::data(); - } else if (platform::is_gpu_place(place)) { - if (boost::get(place) != place_) { - place_ = boost::get(place); + void ImmutableCPU() const { + if (IsDirty() && + !IsInCPU()) { // If data has been changed in CUDA, or CPU has no data. + Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); + WaitPlace(cuda_vec_.place()); + UnsetFlag(kDirty); } -#ifdef PADDLE_WITH_CUDA - if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { - cuda_ptr_.reset( - memory::Alloc(place_, this->size() * sizeof(T)), - memory::PlainDeleter(place_)); - } - cuda_size_ = this->size(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *ctx = pool.GetByPlace(place_); - memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), - static_cast(this->data()), - this->size() * sizeof(T), ctx->stream()); - ctx->Wait(); - return static_cast(cuda_ptr_.get()); -#else - return nullptr; -#endif - } else { - PADDLE_THROW("Unsupport Place."); - } -} + SetFlag(kDataInCPU); + } -template -void Vector::CopyToCUDA() { -#ifdef PADDLE_WITH_CUDA - if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { - cuda_ptr_.reset( - memory::Alloc(place_, this->size() * sizeof(T)), - memory::PlainDeleter(place_)); - } - cuda_size_ = this->size(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *ctx = pool.GetByPlace(place_); - memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), - static_cast(this->data()), - this->size() * sizeof(T), ctx->stream()); - ctx->Wait(); -#endif -} + void UnsetFlag(int flag) const { flag_ &= ~flag; } + void SetFlag(int flag) const { flag_ |= flag; } -template -void Vector::CopyFromCUDA() { -#ifdef PADDLE_WITH_CUDA - if (cuda_ptr_ == nullptr) { - LOG(WARNING) << "No uncommitted cuda data."; - return; - } - this->resize(cuda_size_); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *ctx = pool.GetByPlace(place_); - memory::Copy(platform::CPUPlace(), static_cast(this->data()), place_, - static_cast(cuda_ptr_.get()), - this->size() * sizeof(T), ctx->stream()); - ctx->Wait(); -#endif -} + bool IsDirty() const { return flag_ & kDirty; } -template -void Vector::CopyToPeer(platform::Place place) { -#ifdef PADDLE_WITH_CUDA - if (boost::get(place) != place_) { - place_ = boost::get(place); - } - if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { - cuda_ptr_.reset( - memory::Alloc(place_, this->size() * sizeof(T)), - memory::PlainDeleter(place_)); - } - cuda_size_ = this->size(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *ctx = pool.GetByPlace(place_); - memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), - static_cast(this->data()), - this->size() * sizeof(T), ctx->stream()); - ctx->Wait(); -#endif -} + bool IsInCUDA() const { return flag_ & kDataInCUDA; } + + bool IsInCPU() const { return flag_ & kDataInCPU; } + + static void WaitPlace(const platform::Place place) { + if (platform::is_gpu_place(place)) { + platform::DeviceContextPool::Instance() + .Get(boost::get(place)) + ->Wait(); + } + } + + mutable int flag_; + mutable Tensor cpu_vec_; + mutable Tensor cuda_vec_; + size_t size_; +}; } // namespace framework } // namespace paddle diff --git a/paddle/framework/mixed_vector_test.cu b/paddle/framework/mixed_vector_test.cu index 7b571788ad1ade50e05dc9a70cba35b83f8db3ea..6adad6c12c350cf284700579533040eec5cbb095 100644 --- a/paddle/framework/mixed_vector_test.cu +++ b/paddle/framework/mixed_vector_test.cu @@ -11,62 +11,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "gtest/gtest.h" - -#include "paddle/framework/init.h" -#include "paddle/framework/mixed_vector.h" - -using namespace paddle::framework; -using namespace paddle::platform; -using namespace paddle::memory; - -template -__global__ void test(T* data, int size) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; - i += blockDim.x * gridDim.x) { - data[i] *= 2; - } -} - -TEST(Vector, Normal) { - // fill the device context pool. - InitDevices(); - - Vector vec({1, 2, 3}); - size_t* ptr = vec.data(); - for (size_t i = 0; i < vec.size(); ++i) { - EXPECT_EQ(vec[i], *(ptr + i)); - } - - vec.clear(); - vec.CopyFromCUDA(); - - std::vector v = {1, 2, 3}; - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(v[i], vec[i]); - } -} - -TEST(Vector, MultipleCopy) { - InitDevices(); - Vector vec({1, 2, 3}); - CUDAPlace place(0); - vec.mutable_data(place); - auto vec2 = Vector(vec); - { - const size_t* ptr = vec2.data(CPUPlace()); - for (size_t i = 0; i < vec2.size(); ++i) { - EXPECT_EQ(*(ptr + i), vec[i]); - } - } - test<<<3, 3>>>(vec2.mutable_data(place), vec2.size()); - vec2.CopyFromCUDA(); - { - const size_t* ptr = vec2.data(CPUPlace()); - for (size_t i = 0; i < vec2.size(); ++i) { - EXPECT_EQ(*(ptr + i), vec[i] * 2); - } - } -} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index f0ea709a5c37e769e3ffa1b2e9d1e39721979251..a8767a75430b98c6b0aada69ace72be6dd127562 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -128,6 +128,10 @@ class Tensor { inline void set_layout(const DataLayout layout) { layout_ = layout; } + size_t capacity() const { + return holder_ == nullptr ? 0UL : holder_->size() - offset_; + } + private: friend class LoDTensor; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 1340c5e48520ccdd537e694abf452fd79129df99..6dcaa024245f78df5bfba073c2cec5686fee657e 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -52,7 +52,7 @@ struct SizeOfTypeFunctor { }; static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor functor; + SizeOfTypeFunctor functor; size_t size = functor(type); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); return size; diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 00cb6e9cafb4e79ed3d59cd4a6e40ea132e5efda..9a21e00b12bc2795e1bf1591f7db60c0245bacd3 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -101,9 +101,9 @@ struct SparseAdagradFunctor { SparseAdagradFunctorKernel< T, 256><<(context) - .stream()>>>(grad_merge_data, merge_rows.cuda_data(), lr, - param_data, moment_data, grad_width, - epsilon); + .stream()>>>( + grad_merge_data, merge_rows.CUDAMutableData(context.GetPlace()), lr, + param_data, moment_data, grad_width, epsilon); } }; diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index bf536687d398b8342e6ae76a07c11e5fe47483e0..af2c3ecd725ed1c916ff3b8a0291794d35a70e8b 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -201,7 +201,7 @@ class AdamOpKernel : public framework::OpKernel { const T* grad_data = grad_tensor.template data(); int64_t* rows = nullptr; if (platform::is_gpu_place(ctx.GetPlace())) { - rows = grad_merge.mutable_rows()->cuda_data(); + rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace()); } else { rows = grad_merge.mutable_rows()->data(); } diff --git a/paddle/operators/ctc_align_op.cu b/paddle/operators/ctc_align_op.cu index cea595d7c5d461b40198e622abf08248e7ca69e1..6406825d4a5c4538b5e2780efbe5ba86adce5b72 100644 --- a/paddle/operators/ctc_align_op.cu +++ b/paddle/operators/ctc_align_op.cu @@ -69,8 +69,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel { auto stream = ctx.cuda_device_context().stream(); MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( - num_tokens, tokens, num_seq, input_lod[level].cuda_data(), blank, - merge_repeated, dev_out_lod0_ptr, output_data); + num_tokens, tokens, num_seq, + input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, merge_repeated, + dev_out_lod0_ptr, output_data); // set output lod std::vector host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 07372808bbf078bd2e9b0bb5782b95a046253f46..9684b6d4612c8e134ccad658840bd028a8508085 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -125,7 +125,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { new_rows.resize(ids_dim[0]); auto gpu_place = boost::get(context.GetPlace()); - memory::Copy(platform::CPUPlace(), new_rows.cuda_data(), gpu_place, + // TODO(yuyang18): Strange code here. + memory::Copy(platform::CPUPlace(), + new_rows.CUDAMutableData(context.GetPlace()), gpu_place, ids_data, ids_dim[0] * sizeof(int64_t), stream); d_table->set_rows(new_rows); diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index 8a1ebb58c26578f076bf243adfbd51d10c682b99..4e15d01a3071995e1412fed2a451e4ad3f171862 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -128,7 +128,7 @@ struct SelectedRowsAddTo { auto* in2_value = input2->mutable_value(); // concat rows - in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); + in2_rows.Extend(in1_rows.begin(), in1_rows.end()); auto in1_place = input1.place(); PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index acdd87cb3550bc5f3891aed6fefd4301a3395f9f..5c3a53ae1ba92dbd11f3158789f53bd205747149 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -126,7 +126,8 @@ struct SelectedRowsAddTensor { dim3 grid(1, in1_rows.size()); SelectedRowsAddTensorKernel< T, block_size><<>>( - in1_data, in1_rows.cuda_data(), out_data, in1_row_numel); + in1_data, in1_rows.CUDAData(context.GetPlace()), out_data, + in1_row_numel); auto out_eigen = framework::EigenVector::Flatten(*output); auto in2_eigen = framework::EigenVector::Flatten(input2); @@ -153,7 +154,7 @@ struct SelectedRowsAddTo { auto* in2_value = input2->mutable_value(); // concat rows - in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); + in2_rows.Extend(in1_rows.begin(), in1_rows.end()); auto in1_place = input1.place(); PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); @@ -216,7 +217,8 @@ struct SelectedRowsAddToTensor { dim3 grid(1, in1_rows.size()); SelectedRowsAddToTensorKernel< T, block_size><<>>( - in1_data, in1_rows.cuda_data(), in2_data, in1_row_numel); + in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data, + in1_row_numel); } }; @@ -283,9 +285,10 @@ struct MergeAdd { MergeAddKernel< T, 256><<(context) - .stream()>>>(input_data, input_rows.cuda_data(), out_data, - out.mutable_rows()->cuda_data(), - out.rows().size(), input_width); + .stream()>>>( + input_data, input_rows.CUDAData(context.GetPlace()), out_data, + out.mutable_rows()->CUDAMutableData(context.GetPlace()), + out.rows().size(), input_width); return out; } }; diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index f27631271a42b4d64abef00d7f119b85e32edda4..eaed2c30a80c75d56aef329f6e6f67b8bac3520a 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -45,7 +45,6 @@ class CopyMatrixRowsFunctor { const framework::Tensor& src, framework::Vector index_lod, framework::Tensor& dst, bool is_src_index) { - size_t* index = index_lod.cuda_data(); auto src_dims = src.dims(); auto dst_dims = dst.dims(); PADDLE_ENFORCE_EQ(src_dims.size(), 2, @@ -63,7 +62,8 @@ class CopyMatrixRowsFunctor { dim3 grid(8, 1); auto stream = context.stream(); CopyMatrixRowsKernel<<>>( - src_data, dst_data, index, height, width, is_src_index); + src_data, dst_data, index_lod.CUDAData(context.GetPlace()), height, + width, is_src_index); } }; diff --git a/paddle/operators/math/sequence_padding.cu b/paddle/operators/math/sequence_padding.cu index 65c9cfe4a0ec14d220ad237baa71703a783ed0fa..c2bd56448aa363160f6bf621ec67deff9e369c92 100644 --- a/paddle/operators/math/sequence_padding.cu +++ b/paddle/operators/math/sequence_padding.cu @@ -121,12 +121,12 @@ class PaddingLoDTensorFunctor { if (norm_by_times) { SequencePaddingKernel<<>>( padding_data, const_cast(seq_data), - abs_offset_lod[level].cuda_data(), sequence_width, + abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, max_sequence_length, num_sequences); } else { SequencePaddingKernel<<>>( padding_data, const_cast(seq_data), - abs_offset_lod[level].cuda_data(), sequence_width, + abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, max_sequence_length, num_sequences); } } @@ -196,12 +196,12 @@ class UnpaddingLoDTensorFunctor { if (norm_by_times) { SequencePaddingKernel<<>>( const_cast(padding_data), seq_data, - abs_offset_lod[level].cuda_data(), sequence_width, + abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, max_sequence_length, num_sequences); } else { SequencePaddingKernel<<>>( const_cast(padding_data), seq_data, - abs_offset_lod[level].cuda_data(), sequence_width, + abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, max_sequence_length, num_sequences); } } diff --git a/paddle/operators/math/sequence_pooling.cu b/paddle/operators/math/sequence_pooling.cu index f66534a6812a66c737445ea96914a393077d7d65..c69bd3da7e741d3113de63fe08d22c23f772dda4 100644 --- a/paddle/operators/math/sequence_pooling.cu +++ b/paddle/operators/math/sequence_pooling.cu @@ -73,7 +73,8 @@ class MaxSeqPoolFunctor { dim3 grid(num_seq, 1); auto stream = context.stream(); KeMaxSequencePool<<>>( - in_data, starts.cuda_data(), out_data, max_index, num_seq, dim); + in_data, starts.CUDAData(context.GetPlace()), out_data, max_index, + num_seq, dim); } }; diff --git a/paddle/operators/math/sequence_scale.cu b/paddle/operators/math/sequence_scale.cu index fd4e28f6113729cd1fa9dc179bd9b601d29b8a7f..7cb9242db932ba8b2490f528ee08bd3f4b4e8f83 100644 --- a/paddle/operators/math/sequence_scale.cu +++ b/paddle/operators/math/sequence_scale.cu @@ -46,7 +46,8 @@ class ScaleLoDTensorFunctor { SequenceScaleKernel<<< num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( - seq_data, abs_offset_lod[level].cuda_data(), scales, seq_width); + seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()), + scales, seq_width); } }; diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 89045923f9ff2f33bc112b199c493047440e15c4..edb9de82509f9ee7619f5e90f49022de977a2ea4 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -79,9 +79,6 @@ inline void CopyOrShare(const framework::Variable &src, dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); - framework::LoD lod(src.Get().lod()); - lod.CopyToPeer(dst_place); - dst->GetMutable()->set_lod(lod); } } else if (src.IsType()) { auto &src_sr = src.Get(); @@ -92,9 +89,6 @@ inline void CopyOrShare(const framework::Variable &src, dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); - framework::Vector lod(src_sr.rows()); - lod.CopyToPeer(dst_place); - dst_sr->set_rows(lod); } } else { PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); @@ -152,9 +146,6 @@ class ParallelDoOp : public framework::OperatorBase { auto *sub_scope = sub_scopes[i]; auto *dst = sub_scope->Var(param)->GetMutable(); framework::Copy(src, place, dst); - framework::LoD lod(src.lod()); - lod.CopyToPeer(place); - dst->set_lod(lod); } } WaitOnPlaces(places); diff --git a/paddle/operators/row_conv_op.cu b/paddle/operators/row_conv_op.cu index b3825212e1ac41b13a2f4cad2c128da39c5f6e71..d1a6d119d3da605b1d455d38f38a8808234b8ad1 100644 --- a/paddle/operators/row_conv_op.cu +++ b/paddle/operators/row_conv_op.cu @@ -307,7 +307,7 @@ class RowConvKernel int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; - size_t *idx = batch_indices.cuda_data(); + size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); auto stream = context.cuda_device_context().stream(); if (future_context <= 32) { @@ -345,7 +345,7 @@ class RowConvGradKernel int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; - size_t *idx = batch_indices.cuda_data(); + size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); auto &device_ctx = context.cuda_device_context(); math::SetConstant zero; diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index a5311f15f0c607c880a6f12c0bef10b2dd8c8a79..4a7217cfd656f4f6b46d5a80a9c8e165c839df1d 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -87,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { // Copy LoD to GPU auto lod0 = lod[0]; auto lod_len = lod0.size(); - thrust::device_vector dev_in_lod = lod0; - size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); + const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace()); // Calc output LoD thrust::device_vector dev_out_lod(lod_len); diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index 29f5aa3542c26c76a1b80da61ec6752019216131..d27befe4460550f7b7b30aa93a23c8e51aa52da9 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -102,8 +102,8 @@ class SGDOpCUDAKernel : public framework::OpKernel { dim3 grid(1, in_rows.size()); SparseSGDFunctorKernel< T, 256><<>>( - in_data, in_rows.cuda_data(), learning_rate->data(), out_data, - in_row_numel); + in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data(), + out_data, in_row_numel); } else { PADDLE_THROW("Unsupported Variable Type of Grad"); diff --git a/paddle/operators/target_assign_op.h b/paddle/operators/target_assign_op.h index 82fca5724c0bd9fbfb60a98b91944700bfab9cdf..574919e1ef8d28c2a27b73b97a91d29e89896a6b 100644 --- a/paddle/operators/target_assign_op.h +++ b/paddle/operators/target_assign_op.h @@ -137,8 +137,8 @@ class TargetAssignKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]); } - size_t* gt_lod_data = gt_lod.data(ctx.GetPlace()); - size_t* neg_lod_data = neg_lod.data(ctx.GetPlace()); + size_t* gt_lod_data = gt_lod.MutableData(ctx.GetPlace()); + size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace()); TargetAssignFunctor functor(box_data, label_data, match_idx_data, gt_lod_data, background_label, num, diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index fd8c4a69da897cc39f31f435036e32c41285fb59..ab84f1c292b97f55d88165e7ef0e32b93d542802 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/memory/memory.h" int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); std::vector new_argv; std::string gflags_env; for (int i = 0; i < argc; ++i) { @@ -35,7 +36,6 @@ int main(int argc, char** argv) { int new_argc = static_cast(new_argv.size()); char** new_argv_address = new_argv.data(); google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); - testing::InitGoogleTest(&argc, argv); paddle::memory::Used(paddle::platform::CPUPlace()); #ifdef PADDLE_WITH_CUDA