diff --git a/paddle/fluid/framework/lod_tensor_test.cu b/paddle/fluid/framework/lod_tensor_test.cu index ddda7231887edfc78fa7b1b6adc5cd8324e5b894..006485a698fb3dc93188cd46450ea108e709ff6d 100644 --- a/paddle/fluid/framework/lod_tensor_test.cu +++ b/paddle/fluid/framework/lod_tensor_test.cu @@ -31,15 +31,17 @@ TEST(LoD, data) { lod.push_back(std::vector({0, 1, 6, 8, 10, 11})); auto& v = lod[0]; + paddle::framework::MixVector mix_vector_v(&v); paddle::platform::CUDAPlace gpu(0); #ifdef PADDLE_WITH_HIP - hipLaunchKernelGGL(test, dim3(1), dim3(1), 0, 0, v.CUDAMutableData(gpu), - v.size()); + hipLaunchKernelGGL(test, dim3(1), dim3(1), 0, 0, + mix_vector_v.CUDAMutableData(gpu), v.size()); hipDeviceSynchronize(); #else - test<<<1, 1>>>(v.CUDAMutableData(gpu), v.size()); + test<<<1, 1>>>(mix_vector_v.CUDAMutableData(gpu), v.size()); cudaDeviceSynchronize(); #endif + mix_vector_v.CopyToCPU(); for (size_t i = 0; i < v.size(); ++i) { EXPECT_EQ(v[i], i * 2); } @@ -62,15 +64,17 @@ TEST(LoDTensor, LoDInGPU) { EXPECT_EQ(lod_tensor.lod_element(0, 4).first, 8UL); auto lod = lod_tensor.lod(); + paddle::framework::MixVector mix_vector(&(lod[0])); #ifdef PADDLE_WITH_HIP hipLaunchKernelGGL(test, dim3(1), dim3(8), 0, 0, - lod[0].CUDAMutableData(place), lod[0].size()); + mix_vector.CUDAMutableData(place), lod[0].size()); hipDeviceSynchronize(); #else - test<<<1, 8>>>(lod[0].CUDAMutableData(place), lod[0].size()); + test<<<1, 8>>>(mix_vector.CUDAMutableData(place), lod[0].size()); cudaDeviceSynchronize(); #endif + mix_vector.CopyToCPU(); for (size_t i = 0; i < src_lod[0].size(); ++i) { EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); diff --git a/paddle/fluid/framework/mixed_vector.cc b/paddle/fluid/framework/mixed_vector.cc index b15a66c51c4b6365cb4285894efb1e37a03b7b64..67b2d70f3440c5254abb5ff67995e6758af5c8f1 100644 --- a/paddle/fluid/framework/mixed_vector.cc +++ b/paddle/fluid/framework/mixed_vector.cc @@ -64,19 +64,20 @@ void CopyCPUDataToCUDAHelper(std::vector *cpu_, auto stream = dev_ctx->stream(); paddle::memory::Copy(OptionalCUDAPlace(*gpu_).get(), dst, platform::CPUPlace(), src, *gpu_memory_size_, stream); + dev_ctx->Wait(); #endif } -#define INSTANTIATE_VECTOR_FOR_TYPE(__TYPE__) \ - template <> \ - void Vector<__TYPE__>::VectorData::CopyToCPU() const { \ - CopyToCPUHelper<__TYPE__>(&cpu_, &gpu_, &gpu_memory_size_); \ - } \ - \ - template <> \ - void Vector<__TYPE__>::VectorData::CopyCPUDataToCUDA( \ - const platform::Place &place) const { \ - CopyCPUDataToCUDAHelper<__TYPE__>(&cpu_, &gpu_, &gpu_memory_size_, place); \ +#define INSTANTIATE_VECTOR_FOR_TYPE(__TYPE__) \ + template <> \ + void MixVector<__TYPE__>::VectorData::CopyToCPU() const { \ + CopyToCPUHelper<__TYPE__>(cpu_, &gpu_, &gpu_memory_size_); \ + } \ + \ + template <> \ + void MixVector<__TYPE__>::VectorData::CopyCPUDataToCUDA( \ + const platform::Place &place) const { \ + CopyCPUDataToCUDAHelper<__TYPE__>(cpu_, &gpu_, &gpu_memory_size_, place); \ } INSTANTIATE_VECTOR_FOR_TYPE(size_t) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 0fd67efc177b3d6bd83b1c9d8325d0de81c0d2e5..a589a5b4ea7e15fc24f443e8062635b1e337adfe 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -22,7 +22,6 @@ limitations under the License. */ #include #include "glog/logging.h" -#include "paddle/fluid/framework/details/cow_ptr.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/utils/none.h" #include "paddle/utils/optional.h" @@ -30,6 +29,9 @@ limitations under the License. */ namespace paddle { namespace framework { +template +using Vector = std::vector; + inline paddle::optional OptionalCUDAPlace( const paddle::memory::allocation::AllocationPtr &gpu_) { return gpu_ == nullptr ? paddle::none @@ -39,7 +41,7 @@ inline paddle::optional OptionalCUDAPlace( // Vector implements the std::vector interface, and can get Data or // MutableData from any place. The data will be synced implicitly inside. template -class Vector { +class MixVector { public: using value_type = T; using iterator = typename std::vector::iterator; @@ -49,82 +51,68 @@ class Vector { // The actual class to implement vector logic class VectorData { public: - VectorData() : flag_(kDataInCPU) {} - VectorData(size_t count, const T &value) - : cpu_(count, value), flag_(kDataInCPU) {} - VectorData(std::initializer_list init) : cpu_(init), flag_(kDataInCPU) {} template - explicit VectorData(const std::vector &dat) - : cpu_(dat), flag_(kDataInCPU) {} + explicit VectorData(std::vector *dat) : cpu_(dat), flag_(kDataInCPU) {} ~VectorData() {} - VectorData(const VectorData &o) { - o.ImmutableCPU(); - cpu_ = o.cpu_; - flag_ = kDataInCPU; - } + VectorData(const VectorData &o) = delete; - VectorData &operator=(const VectorData &o) { - o.ImmutableCPU(); - cpu_ = o.cpu_; - flag_ = kDataInCPU; - return *this; - } + VectorData &operator=(const VectorData &o) = delete; T &operator[](size_t i) { MutableCPU(); - return cpu_[i]; + return (*cpu_)[i]; } const T &operator[](size_t i) const { ImmutableCPU(); - return cpu_[i]; + return (*cpu_)[i]; } - size_t size() const { return cpu_.size(); } + size_t size() const { return (*cpu_).size(); } iterator begin() { MutableCPU(); - return cpu_.begin(); + return (*cpu_).begin(); } iterator end() { MutableCPU(); - return cpu_.end(); + return (*cpu_).end(); } T &front() { MutableCPU(); - return cpu_.front(); + return (*cpu_).front(); } T &back() { MutableCPU(); - return cpu_.back(); + return (*cpu_).back(); } const_iterator begin() const { ImmutableCPU(); - return cpu_.begin(); + return (*cpu_).begin(); } const_iterator end() const { ImmutableCPU(); - return cpu_.end(); + return (*cpu_).end(); } const T &back() const { ImmutableCPU(); - return cpu_.back(); + return (*cpu_).back(); } - T *data() { return &(*this)[0]; } + T *data() { return cpu_->data(); } - const T *data() const { return &(*this)[0]; } + const T *data() const { return cpu_->data(); } const T &front() const { ImmutableCPU(); - return cpu_.front(); + return (*cpu_).front(); } // assign this from iterator. @@ -132,14 +120,14 @@ class Vector { template void assign(Iter begin, Iter end) { MutableCPU(); - cpu_.assign(begin, end); + (*cpu_).assign(begin, end); } // push_back. If the previous capacity is not enough, the memory will // double. void push_back(T elem) { MutableCPU(); - cpu_.push_back(elem); + (*cpu_).push_back(elem); } // extend a vector by iterator. @@ -147,14 +135,14 @@ class Vector { template void Extend(It begin, It end) { MutableCPU(); - auto out_it = std::back_inserter>(this->cpu_); + auto out_it = std::back_inserter>(*(this->cpu_)); std::copy(begin, end, out_it); } // resize the vector void resize(size_t size) { MutableCPU(); - cpu_.resize(size); + (*cpu_).resize(size); } // get cuda ptr. immutable @@ -176,26 +164,16 @@ class Vector { // clear void clear() { - cpu_.clear(); + (*cpu_).clear(); flag_ = kDirty | kDataInCPU; } - size_t capacity() const { return cpu_.capacity(); } - - // reserve data - void reserve(size_t size) const { cpu_.reserve(size); } + std::vector *get_vector() { return cpu_; } - // implicit cast operator. Vector can be cast to std::vector implicitly. - operator std::vector() const { - ImmutableCPU(); - return cpu_; - } + size_t capacity() const { return (*cpu_).capacity(); } - bool operator==(const VectorData &other) const { - ImmutableCPU(); - other.ImmutableCPU(); - return cpu_ == other.cpu_; - } + // reserve data + void reserve(size_t size) const { (*cpu_).reserve(size); } std::mutex &Mutex() const { return mtx_; } @@ -203,6 +181,13 @@ class Vector { return OptionalCUDAPlace(gpu_); } + void MutableCPU() { + if (IsInCUDA() && IsDirty()) { + CopyToCPU(); + } + flag_ = kDirty | kDataInCPU; + } + private: enum DataFlag { kDataInCPU = 0x01, @@ -213,13 +198,6 @@ class Vector { void CopyToCPU() const; - void MutableCPU() { - if (IsInCUDA() && IsDirty()) { - CopyToCPU(); - } - flag_ = kDirty | kDataInCPU; - } - void ImmutableCUDA(platform::Place place) const { if (IsDirty()) { if (IsInCPU()) { @@ -269,7 +247,7 @@ class Vector { bool IsInCPU() const { return flag_ & kDataInCPU; } - mutable std::vector cpu_; + std::vector *cpu_; mutable paddle::memory::allocation::AllocationPtr gpu_; mutable size_t gpu_memory_size_{0}; mutable int flag_; @@ -278,89 +256,77 @@ class Vector { }; public: - // Default ctor. Create empty Vector - Vector() : m_(new VectorData()) {} - - // Fill vector with value. The vector size is `count`. - explicit Vector(size_t count, const T &value = T()) - : m_(new VectorData(count, value)) {} - - // Ctor with init_list - Vector(std::initializer_list init) : m_(new VectorData(init)) {} - // implicit cast from std::vector. template - Vector(const std::vector &dat) : m_(new VectorData(dat)) { // NOLINT + MixVector(const std::vector *dat) { // NOLINT + m_.reset(new VectorData(const_cast *>(dat))); } // Copy ctor - Vector(const Vector &other) { m_ = other.m_; } + MixVector(const MixVector &other) = delete; // Copy operator - Vector &operator=(const Vector &other) { - m_ = other.m_; - return *this; - } + MixVector &operator=(const MixVector &other) = delete; // Move ctor - Vector(Vector &&other) { m_ = std::move(other.m_); } + MixVector(MixVector &&other) = delete; // CPU data access method. Mutable. - T &operator[](size_t i) { return (*m_.MutableData())[i]; } + T &operator[](size_t i) { return (*m_)[i]; } // CPU data access method. Immutable. - const T &operator[](size_t i) const { return m_.Data()[i]; } + const T &operator[](size_t i) const { return (*m_)[i]; } // std::vector iterator methods. Based on CPU data access method - size_t size() const { return m_.Data().size(); } + size_t size() const { return m_->size(); } - iterator begin() { return m_.MutableData()->begin(); } + iterator begin() { return m_->begin(); } - iterator end() { return m_.MutableData()->end(); } + iterator end() { return m_->end(); } - T &front() { return m_.MutableData()->front(); } + T &front() { return m_->front(); } - T &back() { return m_.MutableData()->back(); } + T &back() { return m_->back(); } - const_iterator begin() const { return m_.Data().begin(); } + const_iterator begin() const { return m_->begin(); } - const_iterator end() const { return m_.Data().end(); } + const_iterator end() const { return m_->end(); } const_iterator cbegin() const { return begin(); } const_iterator cend() const { return end(); } - const T &back() const { return m_.Data().back(); } + const T &back() const { return m_->back(); } - T *data() { return m_.MutableData()->data(); } + T *data() { return m_->data(); } - const T *data() const { return m_.Data().data(); } + const T *data() const { return m_->data(); } - const T &front() const { return m_.Data().front(); } + const T &front() const { return m_->front(); } // end of std::vector iterator methods // assign this from iterator. // NOTE: the iterator must support `end-begin` template void assign(Iter begin, Iter end) { - m_.MutableData()->assign(begin, end); + m_->assign(begin, end); } // push_back. If the previous capacity is not enough, the memory will // double. - void push_back(T elem) { m_.MutableData()->push_back(elem); } + void push_back(T elem) { m_->push_back(elem); } // extend a vector by iterator. // NOTE: the iterator must support end-begin template void Extend(It begin, It end) { - m_.MutableData()->Extend(begin, end); + m_->Extend(begin, end); } // resize the vector void resize(size_t size) { - if (m_.Data().size() != size) { - m_.MutableData()->resize(size); + if (m_->size() != size) { + m_->resize(size); } } @@ -368,15 +334,15 @@ class Vector { const T *CUDAData(platform::Place place) const { { platform::CUDAPlace p(place.GetDeviceId()); - auto &mtx = m_.Data().Mutex(); + auto &mtx = m_->Mutex(); std::lock_guard guard(mtx); - auto cuda_place = m_.Data().CUDAPlace(); + auto cuda_place = m_->CUDAPlace(); if (cuda_place == paddle::none || cuda_place == p) { - return m_.Data().CUDAData(place); + return m_->CUDAData(place); } } - // If m_ contains CUDAData in a different place. Detach manually. - m_.Detach(); + m_->MutableCPU(); + m_.reset(new VectorData(m_->get_vector())); return CUDAData(place); } @@ -384,25 +350,25 @@ class Vector { T *CUDAMutableData(platform::Place place) { { platform::CUDAPlace p(place.GetDeviceId()); - auto &mtx = m_.Data().Mutex(); + auto &mtx = m_->Mutex(); std::lock_guard guard(mtx); - auto cuda_place = m_.Data().CUDAPlace(); + auto cuda_place = m_->CUDAPlace(); if (cuda_place == paddle::none || cuda_place == p) { - return m_.MutableData()->CUDAMutableData(place); + return m_->CUDAMutableData(place); } } - // If m_ contains CUDAData in a different place. Detach manually. - m_.Detach(); + m_->MutableCPU(); + m_.reset(new VectorData(m_->get_vector())); return CUDAMutableData(place); } // clear - void clear() { m_.MutableData()->clear(); } + void clear() { m_->clear(); } - size_t capacity() const { return m_.Data().capacity(); } + size_t capacity() const { return m_->capacity(); } // reserve data - void reserve(size_t size) { m_.Data().reserve(size); } + void reserve(size_t size) { m_->reserve(size); } // the unify method to access CPU or CUDA data. immutable. const T *Data(platform::Place place) const { @@ -422,26 +388,12 @@ class Vector { } } - // implicit cast operator. Vector can be cast to std::vector implicitly. - operator std::vector() const { return m_.Data(); } - - bool operator==(const Vector &other) const { - if (size() != other.size()) return false; - auto it1 = cbegin(); - auto it2 = other.cbegin(); - for (; it1 < cend(); ++it1, ++it2) { - if (*it1 != *it2) { - return false; - } - } - return true; - } + void CopyToCPU() { m_->MutableCPU(); } - const void *Handle() const { return &m_.Data(); } + const void *Handle() const { return m_.get(); } private: - // Vector is an COW object. - mutable details::COWPtr m_; + mutable std::unique_ptr m_; }; }; // namespace framework diff --git a/paddle/fluid/framework/mixed_vector_test.cu b/paddle/fluid/framework/mixed_vector_test.cu index 011e2729d4adffd49c65f536f2ebb33d9a949e56..4cd9aab2896b6fc5940af38cde35945d007aec64 100644 --- a/paddle/fluid/framework/mixed_vector_test.cu +++ b/paddle/fluid/framework/mixed_vector_test.cu @@ -28,7 +28,7 @@ #include "paddle/fluid/platform/device_context.h" template -using vec = paddle::framework::Vector; +using vec = paddle::framework::MixVector; using gpuStream_t = paddle::gpuStream_t; static __global__ void multiply_10(int* ptr) { @@ -44,10 +44,11 @@ gpuStream_t GetCUDAStream(paddle::platform::CUDAPlace place) { } TEST(mixed_vector, GPU_VECTOR) { - vec tmp; + std::vector x; for (int i = 0; i < 10; ++i) { - tmp.push_back(i); + x.push_back(i); } + vec tmp(&x); ASSERT_EQ(tmp.size(), 10UL); paddle::platform::CUDAPlace gpu(0); @@ -70,10 +71,11 @@ TEST(mixed_vector, MultiGPU) { return; } - vec tmp; + std::vector x; for (int i = 0; i < 10; ++i) { - tmp.push_back(i); + x.push_back(i); } + vec tmp(&x); ASSERT_EQ(tmp.size(), 10UL); paddle::platform::CUDAPlace gpu0(0); paddle::platform::SetDeviceId(0); diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index b9a262105e47479fce8f5ae4f1ab6b852464d745..57eddf782f06bfce1d42c26e68c7789207bcf37f 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/stream.h" +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/phi/core/dense_tensor.h" namespace paddle { diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 1eb5727298c39aba41b4efe832b10d363b6030ea..10eefff093b0e867131c91fb0a8132175a28c6be 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -1455,22 +1455,10 @@ std::ostream& print_tensor>( } std::ostream& operator<<(std::ostream& os, const LoD& lod) { - os << "{"; - for (auto& v : lod) { - os << "{"; - bool is_first = true; - for (auto& i : v) { - if (is_first) { - os << i; - is_first = false; - } else { - os << ", " << i; - } - } - os << "}"; - } - os << "}"; - + // NOTE(xiongkun): + // https://stackoverflow.com/questions/5195512/namespaces-and-operator-resolution + // if we don't redefine, the operator << of pten / framework LoD is not found. + paddle::string::operator<<(os, lod); return os; } @@ -1479,6 +1467,11 @@ std::ostream& operator<<(std::ostream& os, const LoD& lod) { namespace phi { +std::ostream& operator<<(std::ostream& os, const LoD& lod) { + paddle::string::operator<<(os, lod); + return os; +} + std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) { if (t.lod().size() > 0) { os << " - lod: " << t.lod() << "\n"; diff --git a/paddle/fluid/imperative/all_reduce.cc b/paddle/fluid/imperative/all_reduce.cc index 24a8ffbabf526ca779511f620648c64fcbb59cca..436e22f00c303d59652db33a723fe727b63657ef 100644 --- a/paddle/fluid/imperative/all_reduce.cc +++ b/paddle/fluid/imperative/all_reduce.cc @@ -90,6 +90,7 @@ static void AllReduce(const phi::SelectedRows &src, phi::SelectedRows *dst, platform::DeviceContextPool::Instance().Get(place)); bool use_calc_stream = (dev_ctx->stream() == stream); + VLOG(4) << "Is use calculate stream: " << use_calc_stream; // 1. Gather rows number from all workers. Here use ncclAllGather to do this, // but we can use other ways to implement is in the future @@ -97,7 +98,9 @@ static void AllReduce(const phi::SelectedRows &src, phi::SelectedRows *dst, framework::Vector rows_num_vector(strategy.nranks_); rows_num_vector[strategy.local_rank_] = static_cast(src_rows.size()); // CUDAMutableData use CalStream - auto *gpu_rows_num_ptr = rows_num_vector.CUDAMutableData(place); + paddle::framework::MixVector mixv_rows_num_vector(&rows_num_vector); + auto *gpu_rows_num_ptr = mixv_rows_num_vector.CUDAMutableData(place); + VLOG(4) << "start dev_ctx->wait"; if (!use_calc_stream) { dev_ctx->Wait(); } @@ -109,6 +112,7 @@ static void AllReduce(const phi::SelectedRows &src, phi::SelectedRows *dst, platform::GpuStreamSync(stream); } + mixv_rows_num_vector.CopyToCPU(); const auto *cpu_rows_num_ptr = rows_num_vector.data(); auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + strategy.nranks_, @@ -121,8 +125,10 @@ static void AllReduce(const phi::SelectedRows &src, phi::SelectedRows *dst, auto *dst_rows = dst->mutable_rows(); dst_rows->resize(rows_num); - auto *dst_rows_ptr = dst_rows->CUDAMutableData(place); - const auto *src_rows_ptr = src_rows.CUDAData(place); + paddle::framework::MixVector mixv_dst_rows(dst_rows); + auto *dst_rows_ptr = mixv_dst_rows.CUDAMutableData(place); + paddle::framework::MixVector mixv_src_rows(&src_rows); + const auto *src_rows_ptr = mixv_src_rows.CUDAData(place); auto *dst_tensor = dst->mutable_value(); auto dims = src_tensor.dims(); @@ -150,24 +156,28 @@ static void AllReduce(const phi::SelectedRows &src, phi::SelectedRows *dst, PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( src_tensor_ptr, dst_tensor_ptr, value_sendcount, nccl_dtype, comm->comm(), stream)); - return; - } - for (int i = 0; i < strategy.nranks_; ++i) { - if (cpu_rows_num_ptr[i] > 0) { - // 2. Broadcast the rows of SelectedRows - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBroadcast( - src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i], - ncclInt64, i, comm->comm(), stream)); - // 3. Broadcast the tensor data of SelectedRows - auto *dst_tensor_ptr_i = reinterpret_cast(dst_tensor_ptr) + - row_offset * feature_size * sizeof_dtype; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBroadcast( - src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size, - nccl_dtype, i, comm->comm(), stream)); - row_offset += cpu_rows_num_ptr[i]; + } else { + for (int i = 0; i < strategy.nranks_; ++i) { + if (cpu_rows_num_ptr[i] > 0) { + // 2. Broadcast the rows of SelectedRows + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBroadcast( + src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i], + ncclInt64, i, comm->comm(), stream)); + // 3. Broadcast the tensor data of SelectedRows + auto *dst_tensor_ptr_i = reinterpret_cast(dst_tensor_ptr) + + row_offset * feature_size * sizeof_dtype; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBroadcast( + src_tensor_ptr, dst_tensor_ptr_i, + cpu_rows_num_ptr[i] * feature_size, nccl_dtype, i, comm->comm(), + stream)); + row_offset += cpu_rows_num_ptr[i]; + } } } - + if (!use_calc_stream) { + platform::GpuStreamSync(stream); + } + mixv_dst_rows.CopyToCPU(); VLOG(3) << "Original SelectedRows rows: " << string::join_strings(src_rows, ','); VLOG(3) << "Result SelectedRows rows: " diff --git a/paddle/fluid/imperative/gloo_context.cc b/paddle/fluid/imperative/gloo_context.cc index 8997966165769cac1c89ad7c8846cdd13bbc2348..dd34b8b619f80a0e7cb5f122d10850482b1b74ad 100644 --- a/paddle/fluid/imperative/gloo_context.cc +++ b/paddle/fluid/imperative/gloo_context.cc @@ -143,7 +143,7 @@ void GLOOParallelContext::AllReduce(const phi::SelectedRows &src, auto dtype = framework::TransToProtoVarType(src_tensor.dtype()); // 1. Gather rows number from all workers. Here use ncclAllGather to do this, // but we can use other ways to implement is in the future - const auto &src_rows = src.rows(); + auto &src_rows = src.rows(); auto gloo_wrapper = framework::GlooWrapper::GetInstance(); size_t local_row_num = src_rows.size(); std::vector rows_num_vector = @@ -157,8 +157,10 @@ void GLOOParallelContext::AllReduce(const phi::SelectedRows &src, << ", height: " << src.height(); auto *dst_rows = dst->mutable_rows(); dst_rows->resize(rows_num); - auto *dst_rows_ptr = dst_rows->MutableData(place); - const int64_t *src_rows_ptr = src_rows.Data(place); + paddle::framework::MixVector mixv_dst_rows(dst_rows); + auto *dst_rows_ptr = mixv_dst_rows.MutableData(place); + paddle::framework::MixVector mixv_src_rows(&src_rows); + const int64_t *src_rows_ptr = mixv_src_rows.Data(place); auto *dst_tensor = dst->mutable_value(); auto dims = src_tensor.dims(); diff --git a/paddle/fluid/inference/lite/tensor_utils.cc b/paddle/fluid/inference/lite/tensor_utils.cc index 04ae3b9afe32c1762399e987ac5be8bc312d4d59..0e4fb3335f3d76eecea85417ac83c205d63ac9c4 100644 --- a/paddle/fluid/inference/lite/tensor_utils.cc +++ b/paddle/fluid/inference/lite/tensor_utils.cc @@ -38,8 +38,6 @@ void SetLoD(DstLoD* dst, const SrcLoD& src) { dst->emplace_back(v); } } -template void SetLoD( - paddle::lite::LoD* dst, const framework::LoD& src); template void SetLoD( framework::LoD* dst, const paddle::lite::LoD& src); diff --git a/paddle/fluid/operators/ctc_align_op.cu b/paddle/fluid/operators/ctc_align_op.cu index 8a44c1327b9e6fbb1f8767a9ecdf40faf95993eb..b1f2e61ef3930d81aa56794c0d232930452b03d9 100644 --- a/paddle/fluid/operators/ctc_align_op.cu +++ b/paddle/fluid/operators/ctc_align_op.cu @@ -110,10 +110,12 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel { // merge elements and delete blank T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + paddle::framework::MixVector mixv_input_lod(&input_lod[level]); MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( num_tokens, tokens, num_seq, - input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, - merge_repeated, dev_out_lod0_ptr, output_data); + mixv_input_lod.CUDAMutableData(ctx.GetPlace()), blank, merge_repeated, + dev_out_lod0_ptr, output_data); + mixv_input_lod.CopyToCPU(); // set output lod std::vector host_out_lod0(dev_out_lod0.begin(), diff --git a/paddle/fluid/operators/cvm_op.cu b/paddle/fluid/operators/cvm_op.cu index ad96dc24b9206c0e7c6bc172180cec829230dde1..1a3bdee53e9bd31b410093446280a18e2f75d7a2 100644 --- a/paddle/fluid/operators/cvm_op.cu +++ b/paddle/fluid/operators/cvm_op.cu @@ -149,11 +149,12 @@ class CVMGradCUDAKernel : public framework::OpKernel { batch_size, lod[lod.size() - 1], platform::errors::PreconditionNotMet( "Output(X@GRAD)'s dim[0] must be equal to last element of lod")); + paddle::framework::MixVector mixv_lod(&lod); CvmGradComputeKernel<<<(dx_numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( use_cvm, item_size, cvm_data, dout_data, dx_data, true, - lod.CUDAData(context.GetPlace()), lod.size(), dx_numel); + mixv_lod.CUDAData(context.GetPlace()), lod.size(), dx_numel); } } }; diff --git a/paddle/fluid/operators/detection/box_clip_op.cu b/paddle/fluid/operators/detection/box_clip_op.cu index bda22dd0155cce6cec767dfe1c3b282788a5f160..65f2a5590716d42649dbf766575c72571c23eb4d 100644 --- a/paddle/fluid/operators/detection/box_clip_op.cu +++ b/paddle/fluid/operators/detection/box_clip_op.cu @@ -57,9 +57,11 @@ class GPUBoxClipKernel : public framework::OpKernel { auto stream = dev_ctx.stream(); const size_t batch_size = lod.back().size() - 1; T *output_data = output->mutable_data(dev_ctx.GetPlace()); + paddle::framework::MixVector mix_vector(&abs_offset_lod[0]); GPUBoxClip<<>>( - input->data(), abs_offset_lod[0].CUDAMutableData(dev_ctx.GetPlace()), + input->data(), mix_vector.CUDAMutableData(dev_ctx.GetPlace()), bbox_width, im_info->data(), output_data); + mix_vector.CopyToCPU(); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.h b/paddle/fluid/operators/detection/target_assign_op.h index 01b15865e93b6035598b382b506504e9fcc22698..c4506f04e083e0a1e7671605ef6e39a06aa68eed 100644 --- a/paddle/fluid/operators/detection/target_assign_op.h +++ b/paddle/fluid/operators/detection/target_assign_op.h @@ -108,7 +108,8 @@ class TargetAssignKernel : public framework::OpKernel { auto x_lod = x->lod().back(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - size_t* x_lod_data = x_lod.MutableData(ctx.GetPlace()); + paddle::framework::MixVector mixv_x_lod(&x_lod); + size_t* x_lod_data = mixv_x_lod.MutableData(ctx.GetPlace()); #else size_t* x_lod_data = x_lod.data(); #endif @@ -116,6 +117,9 @@ class TargetAssignKernel : public framework::OpKernel { TargetAssignFunctor functor(x_data, match_idx_data, x_lod_data, mismatch_value, n, m, p, k, out_data, out_wt_data); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + mixv_x_lod.CopyToCPU(); +#endif auto& device_ctx = ctx.template device_context(); platform::ForRange for_range(device_ctx, n * m); @@ -130,13 +134,17 @@ class TargetAssignKernel : public framework::OpKernel { const int* neg_idx_data = neg_indices->data(); auto neg_lod = neg_indices->lod().back(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace()); + paddle::framework::MixVector mixv_neg_lod(&neg_lod); + size_t* neg_lod_data = mixv_neg_lod.MutableData(ctx.GetPlace()); #else size_t* neg_lod_data = neg_lod.data(); #endif NegTargetAssignFunctor neg_trg_functor; neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, n, m, k, mismatch_value, out_data, out_wt_data); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + mixv_neg_lod.CopyToCPU(); +#endif } } }; diff --git a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc index 177e8f5bcb7bdd1af907c397bfb75db8dd014d88..0ffc4c91b851c12a5329ae5b27bd3300753896a9 100644 --- a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "dnnl.hpp" +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fused/multi_gru_op.h" #include "paddle/fluid/platform/errors.h" diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index e36c8b1c1b2531f726cc0e9ec1cde6a7aaac6bb5..29079b8b1385dee3a28c42a178a046fab77e6200 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -164,8 +164,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto gpu_place = context.GetPlace(); // TODO(yuyang18): Strange code here. - memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()), + paddle::framework::MixVector mixv_new_rows(&new_rows); + memory::Copy(gpu_place, mixv_new_rows.CUDAMutableData(context.GetPlace()), gpu_place, ids_data, ids_num * sizeof(int64_t), stream); + mixv_new_rows.CopyToCPU(); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu index 42318ca6a8d3e06a8a6560cdf6eef2d67e6116b0..4539f7091b5780a9876adb64c7d8253d51723bdf 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -152,14 +152,16 @@ struct LookupTableV2GradCUDAFunctor { new_rows.resize(ids_num); auto gpu_place = context_.GetPlace(); + paddle::framework::MixVector mixv_new_rows(&new_rows); if (!std::is_same::value) { InputTypeConvert<<>>( - ids_data, ids_num, new_rows.MutableData(gpu_place)); + ids_data, ids_num, mixv_new_rows.MutableData(gpu_place)); } else { - memory::Copy(gpu_place, new_rows.CUDAMutableData(gpu_place), gpu_place, - ids_data, ids_num * sizeof(int64_t), stream); + memory::Copy(gpu_place, mixv_new_rows.CUDAMutableData(gpu_place), + gpu_place, ids_data, ids_num * sizeof(int64_t), stream); } + mixv_new_rows.CopyToCPU(); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); diff --git a/paddle/fluid/operators/math/beam_search.cu b/paddle/fluid/operators/math/beam_search.cu index c954bdf81d30d13abc8383544e17709ee249cc99..486979aa0a8b3009d09f73de54f9b7b3ac8a77ad 100644 --- a/paddle/fluid/operators/math/beam_search.cu +++ b/paddle/fluid/operators/math/beam_search.cu @@ -357,8 +357,9 @@ class BeamSearchFunctor { framework::LoD selected_lod(2); selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end()); selected_lod[1].resize(scores->dims()[0] + 1); - size_t* selected_offsets = - selected_lod[1].CUDAMutableData(context.GetPlace()); + paddle::framework::MixVector mix_vector(&selected_lod[1]); + paddle::framework::MixVector mixv_abs(&abs_lod[level]); + size_t* selected_offsets = mix_vector.CUDAMutableData(context.GetPlace()); if (num_seqs == 1) { const int seq_length = static_cast(abs_lod[level][1]); @@ -377,7 +378,7 @@ class BeamSearchFunctor { is_accumulated, num_used_threads)); } } else if (num_seqs <= 4) { - const size_t* seq_offsets = abs_lod[level].CUDAData(context.GetPlace()); + const size_t* seq_offsets = mixv_abs.CUDAData(context.GetPlace()); // Use only 1 block const int kMaxThreadsPerSeq = 32; const int kMaxSeqs = 4; @@ -400,6 +401,7 @@ class BeamSearchFunctor { } context.Wait(); + mix_vector.CopyToCPU(); if (!framework::CheckLoD(selected_lod)) { PADDLE_THROW(platform::errors::InvalidArgument( "lod %s is not right in" diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 67165ff2219891e3518673845ce224a30b117ff8..fcd5c06a6f310f8a23608a77f2d6b9098e99b33a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -170,7 +170,8 @@ struct SelectedRowsAddTo { auto* in2_value = input2->mutable_value(); // concat rows - in2_rows.Extend(in1_rows.begin(), in1_rows.end()); + paddle::framework::MixVector mixv_in2_rows(&in2_rows); + mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end()); auto in1_place = input1.place(); PADDLE_ENFORCE_EQ(platform::is_cpu_place(in1_place), true, diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index ea0b0bb29548bef0792d00f177d6789daf211ad6..8563d8b05b186c025ecc4c970a400765adeb0c5d 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -161,9 +161,10 @@ struct SelectedRowsAddTensor { const int block_size = 256; dim3 threads(block_size, 1); dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); SelectedRowsAddTensorKernel< T, block_size><<>>( - in1_data, in1_rows.CUDAData(context.GetPlace()), out_data, + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data, in1_row_numel); auto out_eigen = framework::EigenVector::Flatten(*output); @@ -198,8 +199,9 @@ struct SelectedRowsAddTo { auto* in2_value = input2->mutable_value(); // concat rows + paddle::framework::MixVector mixv_in2_rows(&in2_rows); if (in1_rows.size()) { - in2_rows.Extend(in1_rows.begin(), in1_rows.end()); + mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end()); } auto in1_place = input1.place(); @@ -274,9 +276,10 @@ struct SelectedRowsAddToTensor { const int block_size = 256; dim3 threads(block_size, 1); dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); SelectedRowsAddToTensorKernel< T, block_size><<>>( - in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data, + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data, in1_row_numel); } }; @@ -356,10 +359,13 @@ struct MergeAdd { dim3 threads(block_size, 1); dim3 grid1(input_rows.size(), 1); + paddle::framework::MixVector mix_vector_input(&input_rows); + paddle::framework::MixVector mix_vector_out(out.mutable_rows()); MergeAddKernel<<>>( - input_data, input_rows.CUDAData(context.GetPlace()), out_data, - out.mutable_rows()->CUDAMutableData(context.GetPlace()), - out.rows().size(), input_width); + input_data, mix_vector_input.CUDAData(context.GetPlace()), out_data, + mix_vector_out.CUDAMutableData(context.GetPlace()), out.rows().size(), + input_width); + mix_vector_out.CopyToCPU(); } void operator()(const platform::CUDADeviceContext& context, @@ -423,10 +429,13 @@ struct MergeAdd { auto& input_rows = input->rows(); dim3 grid1(input_rows.size(), 1); + paddle::framework::MixVector mix_vector_input(&input_rows); + paddle::framework::MixVector mix_vector_out(out.mutable_rows()); MergeAddKernel<<>>( - input_data, input_rows.CUDAData(context.GetPlace()), out_data, - out.mutable_rows()->CUDAMutableData(context.GetPlace()), - out.rows().size(), input_width); + input_data, mix_vector_input.CUDAData(context.GetPlace()), out_data, + mix_vector_out.CUDAMutableData(context.GetPlace()), out.rows().size(), + input_width); + mix_vector_out.CopyToCPU(); } } }; diff --git a/paddle/fluid/operators/math/sequence2batch.cu b/paddle/fluid/operators/math/sequence2batch.cu index cd1ca572689bc701da801384e5ed08fe6dc10749..f56c5293971bce3b43e86686e828fad4c90639f5 100644 --- a/paddle/fluid/operators/math/sequence2batch.cu +++ b/paddle/fluid/operators/math/sequence2batch.cu @@ -72,8 +72,9 @@ class CopyMatrixRowsFunctor { dim3 threads(128, 8); dim3 grid(8, 1); auto stream = context.stream(); + paddle::framework::MixVector mix_index_lod(&index_lod); CopyMatrixRowsKernel<<>>( - src_data, dst_data, index_lod.CUDAData(context.GetPlace()), height, + src_data, dst_data, mix_index_lod.CUDAData(context.GetPlace()), height, width, is_src_index); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 65bf77f0d152b99059eea2ba98b5d2f0945dc273..01fd2d403c4564ba022e3ab9633fa04d998dd662 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -59,7 +59,7 @@ class PaddingLoDTensorFunctor { int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth) { auto seq_lod = seq_tensor.lod(); - const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; const auto& seq_tensor_dims = seq_tensor.dims(); const auto& pad_tensor_dims = pad_tensor->dims(); int max_seq_len = MaximumSequenceLength(seq_offsets); @@ -104,10 +104,11 @@ class PaddingLoDTensorFunctor { T* pad_data = pad_tensor->data(); const T* pad_value_data = pad_value.data(); + paddle::framework::MixVector mix_vector_seq_offsets(&seq_offsets); SequencePaddingKernel<<>>( pad_data, seq_data, pad_value_data, pad_value.numel() == 1, - seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, - step_width, norm_by_times, layout); + mix_vector_seq_offsets.CUDAData(context.GetPlace()), seq_num, + pad_seq_len, step_width, norm_by_times, layout); } }; @@ -157,9 +158,10 @@ class UnpaddingLoDTensorFunctor { const T* pad_data = pad_tensor.data(); T* seq_data = seq_tensor->data(); + paddle::framework::MixVector mixv_seq_offsets(&seq_offsets); SequencePaddingKernel<<>>( seq_data, pad_data, nullptr, false, - seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + mixv_seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, step_width, norm_by_times, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 1c09acf52fae3f911b3c5e46855c9343a88ffae8..fa7b043153851460c9c8d5586ddce88872b7e3c7 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -168,41 +168,42 @@ class SequencePoolFunctor { const size_t item_dim = output->numel() / output->dims()[0]; dim3 threads(1024, 1); dim3 grid(std::max(static_cast(lod.size()) - 1, 1), 1); + paddle::framework::MixVector mix_vector(&lod); if (pooltype == "MAX") { sequence_pool_kernel< T, MaxPoolFunctor><<>>( MaxPoolFunctor(), input.data(), pad_value, - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_kernel< T, AvgPoolFunctor><<>>( AvgPoolFunctor(), input.data(), pad_value, - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SUM") { sequence_pool_kernel< T, SumPoolFunctor><<>>( SumPoolFunctor(), input.data(), pad_value, - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SQRT") { sequence_pool_kernel< T, SqrtPoolFunctor><<>>( SqrtPoolFunctor(), input.data(), pad_value, - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "LAST") { sequence_pool_kernel< T, LastPoolFunctor><<>>( LastPoolFunctor(), input.data(), pad_value, - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "FIRST") { sequence_pool_kernel< T, FirstPoolFunctor><<>>( FirstPoolFunctor(), input.data(), pad_value, - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -335,41 +336,42 @@ class SequencePoolGradFunctor { const size_t item_dim = in_grad->numel() / in_grad->dims()[0]; dim3 threads(1024, 1); dim3 grid(std::max(static_cast(lod.size()) - 1, 1), 1); + paddle::framework::MixVector mix_vector(&lod); if (pooltype == "MAX") { sequence_pool_grad_kernel< T, MaxPoolGradFunctor><<>>( MaxPoolGradFunctor(), out_grad.data(), - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_grad_kernel< T, AvgPoolGradFunctor><<>>( AvgPoolGradFunctor(), out_grad.data(), - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SUM") { sequence_pool_grad_kernel< T, SumPoolGradFunctor><<>>( SumPoolGradFunctor(), out_grad.data(), - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SQRT") { sequence_pool_grad_kernel< T, SqrtPoolGradFunctor><<>>( SqrtPoolGradFunctor(), out_grad.data(), - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "LAST") { sequence_pool_grad_kernel< T, LastPoolGradFunctor><<>>( LastPoolGradFunctor(), out_grad.data(), - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "FIRST") { sequence_pool_grad_kernel< T, FirstPoolGradFunctor><<>>( FirstPoolGradFunctor(), out_grad.data(), - lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else { diff --git a/paddle/fluid/operators/math/sequence_scale.cu b/paddle/fluid/operators/math/sequence_scale.cu index 1807c77e37ca16967d24c423a1bebac779f59ce5..8e02d1b70ff83b3641d498567a236ffcb41bb988 100644 --- a/paddle/fluid/operators/math/sequence_scale.cu +++ b/paddle/fluid/operators/math/sequence_scale.cu @@ -41,21 +41,23 @@ class ScaleLoDTensorFunctor { auto lod = seq->lod(); const size_t num_seq = lod[level].size() - 1; const size_t seq_width = seq->numel() / seq->dims()[0]; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + auto abs_offset_lod = framework::ToAbsOffset(lod); T* seq_data = seq->mutable_data(context.GetPlace()); + paddle::framework::MixVector mix_vector(&(abs_offset_lod[level])); #ifdef PADDLE_WITH_HIP hipLaunchKernelGGL( HIP_KERNEL_NAME(SequenceScaleKernel), dim3(num_seq), dim3(PADDLE_CUDA_NUM_THREADS), 0, context.stream(), - seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()), - scales, seq_width); + seq_data, mix_vector.CUDAMutableData(context.GetPlace()), scales, + seq_width); #else SequenceScaleKernel<<< num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( - seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()), - scales, seq_width); + seq_data, mix_vector.CUDAMutableData(context.GetPlace()), scales, + seq_width); #endif + mix_vector.CopyToCPU(); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cu b/paddle/fluid/operators/optimizers/adagrad_op.cu index 5bfbc3fd681b8a677e5d512750c69706cc68b2d1..3b8ef9056946a1f84d98621442394dbf3e806576 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/adagrad_op.cu @@ -96,12 +96,14 @@ struct SparseAdagradFunctor { const int block_size = 256; dim3 threads(block_size, 1); dim3 grid2(1, merge_rows.size()); + paddle::framework::MixVector mixv_merge_rows(&merge_rows); SparseAdagradFunctorKernel< T, 256><<(context) .stream()>>>( - grad_merge_data, merge_rows.CUDAMutableData(context.GetPlace()), lr, - param_data, moment_data, grad_width, epsilon); + grad_merge_data, mixv_merge_rows.CUDAMutableData(context.GetPlace()), + lr, param_data, moment_data, grad_width, epsilon); + mixv_merge_rows.CopyToCPU(); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index 668dd41fa257f28ab819dd811c1002b024372fab..c1aa392d8a528d248d07fb9654e45e3006e79139 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -345,7 +345,10 @@ class AdamOpCUDAKernel : public framework::OpKernel { auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); + auto* grad_merge_rows = &grad_merge.rows(); + paddle::framework::MixVector mixv_grad_merge_rows( + grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); if (beta1_pow->place() == platform::CPUPlace() && diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 7a04b0bd75a4950c926e7db21e13c70ea20d2bb1..decab04f1ca261a828dd749cefbdbaf9f5cfac79 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -592,7 +592,10 @@ class AdamOpKernel : public framework::OpKernel { auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); + auto* grad_merge_rows = &grad_merge.rows(); + paddle::framework::MixVector mixv_grad_merge_rows( + grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); SparseAdamFunctor functor( diff --git a/paddle/fluid/operators/optimizers/adamw_op.cu b/paddle/fluid/operators/optimizers/adamw_op.cu index abdc61e7fcb46655e3741c1bd7b37a0ec3fd2c7f..1d61bdec26d581278758f39293e600598624435f 100644 --- a/paddle/fluid/operators/optimizers/adamw_op.cu +++ b/paddle/fluid/operators/optimizers/adamw_op.cu @@ -368,7 +368,10 @@ class AdamWOpCUDAKernel : public framework::OpKernel { auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); + auto* grad_merge_rows = &grad_merge.rows(); + paddle::framework::MixVector mixv_grad_merge_rows( + grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); if (beta1_pow->place() == platform::CPUPlace() && diff --git a/paddle/fluid/operators/optimizers/ftrl_op.h b/paddle/fluid/operators/optimizers/ftrl_op.h index b74009120abc48feb8b4da0256eac96b1e9b1698..596ed05df3ffd740958bc123582139464722ac23 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.h +++ b/paddle/fluid/operators/optimizers/ftrl_op.h @@ -189,7 +189,9 @@ class FTRLOpKernel : public framework::OpKernel { merge_func(ctx.template device_context(), *grad, merged_grad); - const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace()); + auto* merged_rows = merged_grad->mutable_rows(); + paddle::framework::MixVector mixv_merged_rows(merged_rows); + const int64_t* rows = mixv_merged_rows.Data(ctx.GetPlace()); auto row_numel = static_cast(merged_grad->value().dims()[1]); auto row_height = static_cast(merged_grad->rows().size()); diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index a2189d2a7ca0eda833e926604affc9d9075b1e75..45acf2b3e48345c6a17c75f8409744776a03b243 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -594,7 +594,10 @@ class LambOpKernel : public framework::OpKernel { auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); + auto* grad_merge_rows = &grad_merge.rows(); + paddle::framework::MixVector mixv_grad_merge_rows( + grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); if (platform::is_gpu_place(ctx.GetPlace()) && beta1_pow.place() == platform::CPUPlace() && diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index 0561c18580a3f6098ef3471d1cfaa328e5b31026..e271755b740ce33369348ca6f415af958a43616d 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -561,7 +561,10 @@ class MomentumOpKernel : public framework::OpKernel { merge_func(ctx.template device_context(), *grad, merged_grad); - const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace()); + auto* grad_merge_rows = merged_grad->mutable_rows(); + paddle::framework::MixVector mixv_grad_merge_rows( + grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); int64_t row_numel = merged_grad->value().numel() / merged_grad->rows().size(); platform::ForRange for_range( diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.h b/paddle/fluid/operators/optimizers/rmsprop_op.h index 66c16d8015806982a5cf5b321e3ff019fe14831a..71decd27d0d7822c67ba4a2782c1ec2461e67911 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.h +++ b/paddle/fluid/operators/optimizers/rmsprop_op.h @@ -227,7 +227,10 @@ class RmspropOpKernel : public framework::OpKernel { merge_func(dev_ctx, grad, merged_grad); platform::ForRange for_range(dev_ctx, limit); - const int64_t *rows = merged_grad->rows().Data(ctx.GetPlace()); + auto &grad_merge_rows = merged_grad->rows(); + paddle::framework::MixVector mixv_grad_merge_rows( + &grad_merge_rows); + const int64_t *rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); auto &merged_tensor = merged_grad->value(); int64_t row_count = merged_grad->rows().size(); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cu b/paddle/fluid/operators/optimizers/sgd_op.cu index a255f0fed3ce0c7b143de6d75beabe36b08b6d60..3149f5f56ed4964a750f61a354c6cd31a29fc526 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cu +++ b/paddle/fluid/operators/optimizers/sgd_op.cu @@ -148,11 +148,11 @@ class SGDOpKernel int thread_x = kThreadsPerBlock; int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount(); int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - + paddle::framework::MixVector mixv_in_rows(&in_rows); SparseSGDFunctorKernel<<>>( - in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data(), - out_data, in_row_numel, in_rows.size()); + in_data, mixv_in_rows.CUDAData(ctx.GetPlace()), + learning_rate->data(), out_data, in_row_numel, in_rows.size()); } else { PADDLE_ENFORCE_EQ(false, true, diff --git a/paddle/fluid/operators/row_conv_op.cu b/paddle/fluid/operators/row_conv_op.cu index 3def7875232e814b817a7957ab9db65ea611dcf6..c5794948aaec6b47396cbae66a962058812aba11 100644 --- a/paddle/fluid/operators/row_conv_op.cu +++ b/paddle/fluid/operators/row_conv_op.cu @@ -336,7 +336,8 @@ class RowConvKernel int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; - size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); + paddle::framework::MixVector mix_vector(&batch_indices); + size_t *idx = mix_vector.CUDAMutableData(context.GetPlace()); auto stream = context.cuda_device_context().stream(); if (future_context <= 32) { @@ -352,6 +353,7 @@ class RowConvKernel RowConvForward<<>>( in, weight, num_sequence, input_dim, future_context, idx, out); } + mix_vector.CopyToCPU(); } }; @@ -392,7 +394,8 @@ class RowConvGradKernel // int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; - size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); + paddle::framework::MixVector mixv_batch_indices(&batch_indices); + size_t *idx = mixv_batch_indices.CUDAMutableData(context.GetPlace()); auto &device_ctx = context.cuda_device_context(); phi::funcs::SetConstant zero; @@ -444,6 +447,7 @@ class RowConvGradKernel dout, weights, num_sequence, input_dim, future_context, idx, din); } } + mixv_batch_indices.CopyToCPU(); } }; } // namespace operators diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu index 8092a40d19b195828c3742854e9b3656424feee7..9591f3e8b5bbfe70cb059b621eaca0ae1fff993e 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu @@ -71,7 +71,8 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { out->Resize({in_dims[0], win_size}); auto out_data = out->mutable_data(context.GetPlace()); // Copy LoD to GPU - const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace()); + paddle::framework::MixVector mixv_lod0(&lod0); + const size_t* dev_in_lod_ptr = mixv_lod0.CUDAData(context.GetPlace()); // Calc output tensor CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( diff --git a/paddle/fluid/operators/sequence_ops/sequence_erase_op.cu b/paddle/fluid/operators/sequence_ops/sequence_erase_op.cu index bb928cf401c3307b76160387e5108264cd5dbb89..12d3eee65da70edd3f360d448360bb59d2f1069f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_erase_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_erase_op.cu @@ -88,7 +88,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { // Copy LoD to GPU auto last_lod = lod[lod.size() - 1]; auto lod_len = last_lod.size(); - const size_t* dev_in_lod_ptr = last_lod.CUDAData(ctx.GetPlace()); + paddle::framework::MixVector mixv_last_lod(&last_lod); + const size_t* dev_in_lod_ptr = mixv_last_lod.CUDAData(ctx.GetPlace()); // Calc output LoD thrust::device_vector dev_out_lod(lod_len); size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu index f13849fda41769af12aabf93be748e3ce2ad806b..7e1a06b9eca5b9046d2b772edee0efdb1a69437f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu @@ -81,8 +81,9 @@ struct SequenceExpandAsFunctor { dim3 block_size(thread_x); dim3 grid_size(block_x); + paddle::framework::MixVector mixv_ref_lod(&ref_lod); sequence_expand_as_kernel<<>>( - x.data(), ref_lod.CUDAData(context.GetPlace()), height, width, + x.data(), mixv_ref_lod.CUDAData(context.GetPlace()), height, width, out->mutable_data(context.GetPlace())); } }; @@ -107,10 +108,11 @@ struct SequenceExpandAsGradFunctor { dim3 block_size(thread_x); dim3 grid_size(block_x); + paddle::framework::MixVector mixv_ref_lod(&ref_lod); sequence_expand_as_grad_kernel<<>>( - dout.data(), ref_lod.CUDAData(context.GetPlace()), height, width, - dx->mutable_data(context.GetPlace())); + dout.data(), mixv_ref_lod.CUDAData(context.GetPlace()), height, + width, dx->mutable_data(context.GetPlace())); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cu b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cu index cbf5df001707592e03b315b357e3a5d484068011..7b7bc5183bf1f6c98ef386150fcfa4d048e73f01 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cu @@ -157,7 +157,9 @@ struct SequenceExpandFunctor { out_offset[2 * x_lod_size + i] = ref_lod[i]; } - const size_t* out_offset_data = out_offset.CUDAData(context.GetPlace()); + paddle::framework::MixVector mixv_out_offset(&out_offset); + const size_t* out_offset_data = + mixv_out_offset.CUDAData(context.GetPlace()); const size_t* x_lod_data = out_offset_data + x_lod_size; const size_t* ref_lod_data = out_offset_data + 2 * x_lod_size; @@ -193,11 +195,14 @@ struct SequenceExpandGradFunctor { int block_x = static_cast(ref_lod.size()); dim3 block_size(thread_x, thread_y, thread_z); dim3 grid_size(block_x, 1); + paddle::framework::MixVector mixv_ref_lod(&ref_lod); + paddle::framework::MixVector mixv_x_lod(&x_lod); + paddle::framework::MixVector mixv_out_offset(&out_offset); sequence_expand_grad_kernel<<>>( - dout.data(), ref_lod.CUDAData(context.GetPlace()), - x_lod.CUDAData(context.GetPlace()), - out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, - dx->mutable_data(context.GetPlace())); + dout.data(), mixv_ref_lod.CUDAData(context.GetPlace()), + mixv_x_lod.CUDAData(context.GetPlace()), + mixv_out_offset.CUDAData(context.GetPlace()), ref_lod.size(), + x_item_length, dx->mutable_data(context.GetPlace())); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h index c42df836de15f5c51caf32e5d0b7b7d8123ff201..90a17d713cf299a3a61169cfc6f16fce7bb5901c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h @@ -132,7 +132,9 @@ class SequenceReverseOpKernel : public framework::OpKernel { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { - lod = x.lod()[0].CUDAData(ctx.GetPlace()); + auto xlod = x.lod()[0]; + paddle::framework::MixVector mixv_xlod(&xlod); + lod = mixv_xlod.CUDAData(ctx.GetPlace()); } else { #endif lod = x.lod()[0].data(); diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index 220165ac1bd4f6a80a2f3c0b21f5423352982588..c91c59dbfee9993711e777668063bec73a3746d8 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -133,9 +133,10 @@ struct SequenceSoftmaxFunctor { dim3 block_size(thread_x); dim3 grid_size(max_blocks); + paddle::framework::MixVector mixv_ref_lod(&ref_lod); sequence_softmax_kernel< T, kThreadsPerBlock><<>>( - x.data(), ref_lod.CUDAData(context.GetPlace()), height, + x.data(), mixv_ref_lod.CUDAData(context.GetPlace()), height, out->mutable_data(context.GetPlace())); } }; @@ -156,10 +157,12 @@ struct SequenceSoftmaxGradFunctor { dim3 block_size(thread_x); dim3 grid_size(max_blocks); + paddle::framework::MixVector mixv_ref_lod(&ref_lod); sequence_softmax_grad_kernel< T, kThreadsPerBlock><<>>( - dout.data(), out.data(), ref_lod.CUDAData(context.GetPlace()), - height, dx->mutable_data(context.GetPlace())); + dout.data(), out.data(), + mixv_ref_lod.CUDAData(context.GetPlace()), height, + dx->mutable_data(context.GetPlace())); } }; diff --git a/paddle/phi/api/ext/dispatch.h b/paddle/phi/api/ext/dispatch.h index 4e5fa879a2cfc759cea753be8db19e116d91669e..6b6d0ae7fe7230263454d0bf08da40e4a793549b 100644 --- a/paddle/phi/api/ext/dispatch.h +++ b/paddle/phi/api/ext/dispatch.h @@ -292,7 +292,7 @@ namespace paddle { paddle::experimental::complex128, \ __VA_ARGS__) \ default: \ - PADDLE_THROW(paddle::platform::errors::InvalidArgument( \ + PADDLE_THROW(phi::errors::InvalidArgument( \ "Invalid enum data type `%d`.", static_cast(__dtype__))); \ } \ }() diff --git a/paddle/phi/api/lib/utils/storage.cc b/paddle/phi/api/lib/utils/storage.cc index db3f5f0c8f98bcd4831ba7be69537e9db9efbee2..09ff18d10e312f1f1be130bb2411316dca184458 100644 --- a/paddle/phi/api/lib/utils/storage.cc +++ b/paddle/phi/api/lib/utils/storage.cc @@ -19,7 +19,7 @@ namespace experimental { ExternalStorage::ExternalStorage(void* ptr, size_t size, - const paddle::platform::Place& place) + const phi::Place& place) : phi::Storage(std::make_shared(ptr, size, place)), size_(size) {} @@ -29,11 +29,11 @@ ExternalStorage::ExternalStorage(const phi::intrusive_ptr& root, : Storage(std::make_shared( static_cast(root->data()) + delta, size, root->place())), size_(size) { - PADDLE_ENFORCE_LE(static_cast(delta + size), - root->size(), - paddle::platform::errors::InvalidArgument( - "The size of the external storage does " - "not meet the metadata requirements.")); + PADDLE_ENFORCE_LE( + static_cast(delta + size), + root->size(), + phi::errors::InvalidArgument("The size of the external storage does " + "not meet the metadata requirements.")); } } // namespace experimental diff --git a/paddle/phi/api/lib/utils/storage.h b/paddle/phi/api/lib/utils/storage.h index ede5f804836621a88a294d05cbae6a15c9eceb81..c2eedd0fa63f787d7aff6e5f20d807f363bc8b95 100644 --- a/paddle/phi/api/lib/utils/storage.h +++ b/paddle/phi/api/lib/utils/storage.h @@ -30,7 +30,7 @@ class ExternalStorage : public phi::Storage { static const char* name() { return "ExternalStorage"; } void Realloc(size_t n) override { - PADDLE_THROW(paddle::platform::errors::Unavailable( + PADDLE_THROW(phi::errors::Unavailable( "The external shared storage cannot be reallocated.")); } @@ -55,7 +55,7 @@ class ExternalStorage : public phi::Storage { const phi::Place& place() const override { PADDLE_ENFORCE_NOT_NULL( data_, - paddle::platform::errors::Unavailable( + phi::errors::Unavailable( "Unable to visit place as data_ has not been initialized yet.")); return data_->place(); } diff --git a/paddle/phi/backends/dynload/cudnn.cc b/paddle/phi/backends/dynload/cudnn.cc index ff000d27c4f2e185c88259e2353e476b1ff9220b..02d626d5f98f9fc0c260a55c846031634b68e144 100644 --- a/paddle/phi/backends/dynload/cudnn.cc +++ b/paddle/phi/backends/dynload/cudnn.cc @@ -54,7 +54,7 @@ bool HasCUDNN() { void EnforceCUDNNLoaded(const char* fn_name) { PADDLE_ENFORCE_NOT_NULL( cudnn_dso_handle, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Cannot load cudnn shared library. Cannot invoke method %s.", fn_name)); } diff --git a/paddle/phi/backends/dynload/cufft.cc b/paddle/phi/backends/dynload/cufft.cc index 14240af41046c3a735b30392b0ab7685bc3d5806..596a68c1ed6aad96942ddd2b5eee82b8102e2444 100644 --- a/paddle/phi/backends/dynload/cufft.cc +++ b/paddle/phi/backends/dynload/cufft.cc @@ -33,7 +33,7 @@ bool HasCUFFT() { void EnforceCUFFTLoaded(const char* fn_name) { PADDLE_ENFORCE_NOT_NULL( cufft_dso_handle, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Cannot load cufft shared library. Cannot invoke method %s.", fn_name)); } diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index 473c58b33eebc46a62b6b31af10d6b71b0fff53d..2f35e22a18f820cd15325d8516447e3652c132f1 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -24,7 +24,7 @@ limitations under the License. */ #include #endif -// TODO(wilber): The pten computing library requires a component to manage flags +// TODO(wilber): The phi computing library requires a component to manage flags // (maybe not use gflags). #include "gflags/gflags.h" #include "glog/logging.h" @@ -299,8 +299,8 @@ static inline void* GetDsoHandleFromSearchPath( #endif // !_WIN32 if (throw_on_error) { // NOTE: Special error report case, no need to change its format - PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( - error_msg, dso_name, errorno)); + PADDLE_THROW( + phi::errors::PreconditionNotMet(error_msg, dso_name, errorno)); } else { LOG(WARNING) << paddle::string::Sprintf(error_msg, dso_name, errorno); } @@ -547,14 +547,11 @@ void* GetOpDsoHandle(const std::string& dso_name) { void* GetNvtxDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - PADDLE_THROW( - paddle::platform::errors::Unimplemented("Nvtx do not support Apple.")); + PADDLE_THROW(phi::errors::Unimplemented("Nvtx do not support Apple.")); #elif defined(_WIN32) - PADDLE_THROW( - paddle::platform::errors::Unimplemented("Nvtx do not support Windows.")); + PADDLE_THROW(phi::errors::Unimplemented("Nvtx do not support Windows.")); #elif !defined(PADDLE_WITH_CUDA) - PADDLE_THROW(paddle::platform::errors::Unimplemented( - "Nvtx do not support without CUDA.")); + PADDLE_THROW(phi::errors::Unimplemented("Nvtx do not support without CUDA.")); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvToolsExt.so"); #endif diff --git a/paddle/phi/backends/dynload/miopen.cc b/paddle/phi/backends/dynload/miopen.cc index a57574dbab13bc88065cb91b9b175f164799584e..e7916873ccfde7e1e5d0933045c9b44557f2f07a 100644 --- a/paddle/phi/backends/dynload/miopen.cc +++ b/paddle/phi/backends/dynload/miopen.cc @@ -58,7 +58,7 @@ bool HasCUDNN() { void EnforceCUDNNLoaded(const char* fn_name) { PADDLE_ENFORCE_NOT_NULL( miopen_dso_handle, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Cannot load miopen shared library. Cannot invoke method %s.", fn_name)); } diff --git a/paddle/phi/backends/dynload/tensorrt.h b/paddle/phi/backends/dynload/tensorrt.h index 77f25ec0b5aaff99fcaba8cae418d4045dfedf3a..cd8c6457f1b91b938f1ef927119c9ec63a7b6e1b 100644 --- a/paddle/phi/backends/dynload/tensorrt.h +++ b/paddle/phi/backends/dynload/tensorrt.h @@ -54,21 +54,21 @@ extern void* tensorrt_plugin_dso_handle; }; \ extern DynLoad__##__name __name -#define DECLARE_DYNAMIC_LOAD_TENSORRT_NON_POINTER_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - std::call_once(tensorrt_dso_flag, []() { \ - tensorrt_dso_handle = phi::dynload::GetTensorRtHandle(); \ - }); \ - static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ - PADDLE_ENFORCE_NOT_NULL(p_##__name, \ - paddle::platform::errors::Unavailable( \ - "Load tensorrt api %s failed", #__name)); \ - using tensorrt_func = decltype(&::__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ +#define DECLARE_DYNAMIC_LOAD_TENSORRT_NON_POINTER_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + std::call_once(tensorrt_dso_flag, []() { \ + tensorrt_dso_handle = phi::dynload::GetTensorRtHandle(); \ + }); \ + static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ + PADDLE_ENFORCE_NOT_NULL( \ + p_##__name, \ + phi::errors::Unavailable("Load tensorrt api %s failed", #__name)); \ + using tensorrt_func = decltype(&::__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ extern DynLoad__##__name __name #define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \ @@ -80,7 +80,7 @@ extern void* tensorrt_plugin_dso_handle; }); \ static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \ PADDLE_ENFORCE_NOT_NULL(p_##__name, \ - paddle::platform::errors::Unavailable( \ + phi::errors::Unavailable( \ "Load tensorrt plugin %s failed", #__name)); \ using tensorrt_plugin_func = decltype(&::__name); \ return reinterpret_cast(p_##__name)(args...); \ diff --git a/paddle/phi/backends/gpu/cuda/cuda_info.cc b/paddle/phi/backends/gpu/cuda/cuda_info.cc index f8e4ec02bc39e3406437a0503d4cd9622565dbeb..7be21e85f0005b9bfe7849ac6f12561cf108c7e3 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_info.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_info.cc @@ -14,7 +14,7 @@ #include "paddle/phi/backends/gpu/gpu_info.h" -// TODO(pten): remove fluid headers. +// TODO(phi): remove fluid headers. #include "paddle/fluid/platform/enforce.h" static std::once_flag g_device_props_size_init_flag; @@ -74,13 +74,13 @@ int GetGPUDeviceCount() { } int GetGPUComputeCapability(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int major, minor; auto major_error_code = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id); @@ -93,26 +93,26 @@ int GetGPUComputeCapability(int id) { } int GetGPURuntimeVersion(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int runtime_version = 0; PADDLE_ENFORCE_GPU_SUCCESS(cudaRuntimeGetVersion(&runtime_version)); return runtime_version; } int GetGPUDriverVersion(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int driver_version = 0; PADDLE_ENFORCE_GPU_SUCCESS(cudaDriverGetVersion(&driver_version)); return driver_version; @@ -125,13 +125,13 @@ bool TensorCoreAvailable() { } int GetGPUMultiProcessors(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int count; PADDLE_ENFORCE_GPU_SUCCESS( cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id)); @@ -139,13 +139,13 @@ int GetGPUMultiProcessors(int id) { } int GetGPUMaxThreadsPerMultiProcessor(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int count; PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( &count, cudaDevAttrMaxThreadsPerMultiProcessor, id)); @@ -154,13 +154,13 @@ int GetGPUMaxThreadsPerMultiProcessor(int id) { } int GetGPUMaxThreadsPerBlock(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int count; PADDLE_ENFORCE_GPU_SUCCESS( cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id)); @@ -174,13 +174,13 @@ int GetCurrentDeviceId() { } std::array GetGpuMaxGridDimSize(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); std::array ret; int size; auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id); @@ -213,7 +213,7 @@ const gpuDeviceProp &GetDeviceProperties(int id) { } if (id < 0 || id >= static_cast(g_device_props.size())) { - PADDLE_THROW(paddle::platform::errors::OutOfRange( + PADDLE_THROW(phi::errors::OutOfRange( "The device id %d is out of range [0, %d), where %d is the number of " "devices on this machine. Because the device id should be greater than " "or equal to zero and smaller than the number of gpus. Please input " @@ -233,13 +233,13 @@ const gpuDeviceProp &GetDeviceProperties(int id) { void SetDeviceId(int id) { // TODO(qijun): find a better way to cache the cuda device count - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id)); } @@ -294,13 +294,13 @@ gpuError_t GpuGetLastError() { return cudaGetLastError(); } // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-requirements // for more detail about managed memory requirements bool IsGPUManagedMemorySupported(int dev_id) { - PADDLE_ENFORCE_LT(dev_id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - dev_id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + dev_id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); #if defined(__linux__) || defined(_WIN32) int ManagedMemoryAttr; PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( @@ -312,13 +312,13 @@ bool IsGPUManagedMemorySupported(int dev_id) { } bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) { - PADDLE_ENFORCE_LT(dev_id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - dev_id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + dev_id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); #ifdef __linux__ return IsGPUManagedMemorySupported(dev_id) && GetGPUComputeCapability(dev_id) >= 60; diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 21193755044579eb4f19936dca1c2b6b3c5b4bea..5aa569e0197bdcf62d7f178d61ec47cce57cd96c 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -100,12 +100,12 @@ struct GpuLaunchConfig { inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, int64_t numel, int vec_size = 1) { - PADDLE_ENFORCE_GT(numel, - 0, - paddle::platform::errors::InvalidArgument( - "element quantity should be greater than 0," - " but received value is: %d.", - numel)); + PADDLE_ENFORCE_GT( + numel, + 0, + phi::errors::InvalidArgument("element quantity should be greater than 0," + " but received value is: %d.", + numel)); // Get compute_capability const int capability = context.GetComputeCapability(); /* If thread number per block is 64/128/256/512, cuda performs better.*/ @@ -142,18 +142,18 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, int x_dim, int y_dim) { - PADDLE_ENFORCE_GT(x_dim, - 0, - paddle::platform::errors::InvalidArgument( - "x dim number should greater than 0," - " but received value is: %d", - x_dim)); - PADDLE_ENFORCE_GT(y_dim, - 0, - paddle::platform::errors::InvalidArgument( - "y dim number should greater than 0," - " but received value is: %d", - y_dim)); + PADDLE_ENFORCE_GT( + x_dim, + 0, + phi::errors::InvalidArgument("x dim number should greater than 0," + " but received value is: %d", + x_dim)); + PADDLE_ENFORCE_GT( + y_dim, + 0, + phi::errors::InvalidArgument("y dim number should greater than 0," + " but received value is: %d", + y_dim)); const int kThreadsPerBlock = 256; int block_cols = (std::min)(x_dim, kThreadsPerBlock); diff --git a/paddle/phi/backends/gpu/rocm/rocm_info.cc b/paddle/phi/backends/gpu/rocm/rocm_info.cc index c7390cfb6a2198904f081ffbb8f5f4f8532324e2..11dd4f724878266d52fdcbeee031b6ac6a9a9438 100644 --- a/paddle/phi/backends/gpu/rocm/rocm_info.cc +++ b/paddle/phi/backends/gpu/rocm/rocm_info.cc @@ -78,13 +78,13 @@ int GetGPUDeviceCount() { } int GetGPUComputeCapability(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int major, minor; auto major_error_code = hipDeviceGetAttribute( &major, hipDeviceAttributeComputeCapabilityMajor, id); @@ -97,26 +97,26 @@ int GetGPUComputeCapability(int id) { } int GetGPURuntimeVersion(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int runtime_version = 0; PADDLE_ENFORCE_GPU_SUCCESS(hipRuntimeGetVersion(&runtime_version)); return runtime_version; } int GetGPUDriverVersion(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int driver_version = 0; PADDLE_ENFORCE_GPU_SUCCESS(hipDriverGetVersion(&driver_version)); return driver_version; @@ -125,13 +125,13 @@ int GetGPUDriverVersion(int id) { bool TensorCoreAvailable() { return false; } int GetGPUMultiProcessors(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int count; PADDLE_ENFORCE_GPU_SUCCESS( hipDeviceGetAttribute(&count, hipDeviceAttributeMultiprocessorCount, id)); @@ -139,13 +139,13 @@ int GetGPUMultiProcessors(int id) { } int GetGPUMaxThreadsPerMultiProcessor(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int count; PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute( &count, hipDeviceAttributeMaxThreadsPerMultiProcessor, id)); @@ -154,13 +154,13 @@ int GetGPUMaxThreadsPerMultiProcessor(int id) { } int GetGPUMaxThreadsPerBlock(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); int count; PADDLE_ENFORCE_GPU_SUCCESS( hipDeviceGetAttribute(&count, hipDeviceAttributeMaxThreadsPerBlock, id)); @@ -174,13 +174,13 @@ int GetCurrentDeviceId() { } std::array GetGpuMaxGridDimSize(int id) { - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); std::array ret; int size; auto error_code_x = @@ -216,7 +216,7 @@ const gpuDeviceProp &GetDeviceProperties(int id) { } if (id < 0 || id >= static_cast(g_device_props.size())) { - PADDLE_THROW(paddle::platform::errors::OutOfRange( + PADDLE_THROW(phi::errors::OutOfRange( "The device id %d is out of range [0, %d), where %d is the number of " "devices on this machine. Because the device id should be greater than " "or equal to zero and smaller than the number of gpus. Please input " @@ -235,13 +235,13 @@ const gpuDeviceProp &GetDeviceProperties(int id) { void SetDeviceId(int id) { // TODO(qijun): find a better way to cache the cuda device count - PADDLE_ENFORCE_LT(id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + id, + GetGPUDeviceCount())); PADDLE_RETRY_CUDA_SUCCESS(hipSetDevice(id)); } @@ -293,13 +293,13 @@ void GpuDeviceSync() { PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); } gpuError_t GpuGetLastError() { return hipGetLastError(); } bool IsGPUManagedMemorySupported(int dev_id) { - PADDLE_ENFORCE_LT(dev_id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - dev_id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + dev_id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); #if defined(__linux__) || defined(_WIN32) int ManagedMemoryAttr; PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute( @@ -311,13 +311,13 @@ bool IsGPUManagedMemorySupported(int dev_id) { } bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) { - PADDLE_ENFORCE_LT(dev_id, - GetGPUDeviceCount(), - paddle::platform::errors::InvalidArgument( - "Device id must be less than GPU count, " - "but received id is: %d. GPU count is: %d.", - dev_id, - GetGPUDeviceCount())); + PADDLE_ENFORCE_LT( + dev_id, + GetGPUDeviceCount(), + phi::errors::InvalidArgument("Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); #ifdef __linux__ return IsGPUManagedMemorySupported(dev_id) && GetGPUComputeCapability(dev_id) >= 60; diff --git a/paddle/phi/backends/xpu/enforce_xpu.h b/paddle/phi/backends/xpu/enforce_xpu.h index bcfebf6d49fb87b7fa1a0fc29595f6f20ca57f77..29b048ead852dd91788316c2284b438d7dcbd61c 100644 --- a/paddle/phi/backends/xpu/enforce_xpu.h +++ b/paddle/phi/backends/xpu/enforce_xpu.h @@ -173,7 +173,7 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS); ::phi::backends::xpu::details::ExternalApiType< \ __XPU_STATUS_TYPE__>::kSuccess; \ if (UNLIKELY(__cond__ != __success_type__)) { \ - auto __summary__ = paddle::platform::errors::External( \ + auto __summary__ = phi::errors::External( \ ::phi::backends::xpu::build_xpu_error_msg(__cond__)); \ __THROW_ERROR_INTERNAL__(__summary__); \ } \ @@ -183,7 +183,7 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS); do { \ auto __cond__ = (COND); \ if (UNLIKELY(__cond__ != baidu::xpu::api::Error_t::SUCCESS)) { \ - auto __summary__ = paddle::platform::errors::External( \ + auto __summary__ = phi::errors::External( \ ::phi::backends::xpu::build_xpu_xdnn_error_msg(__cond__, MSG)); \ __THROW_ERROR_INTERNAL__(__summary__); \ } \ @@ -192,7 +192,7 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS); #define PADDLE_ENFORCE_XDNN_NOT_NULL(ptr) \ do { \ if (UNLIKELY(ptr == nullptr)) { \ - auto __summary__ = paddle::platform::errors::External( \ + auto __summary__ = phi::errors::External( \ ::phi::backends::xpu::build_xpu_xdnn_error_msg( \ baidu::xpu::api::Error_t::NO_ENOUGH_WORKSPACE, \ "XPU memory is not enough")); \ diff --git a/paddle/phi/backends/xpu/xpu_info.cc b/paddle/phi/backends/xpu/xpu_info.cc index 527e13238082ec154b3ece67ca719425ae40d211..96e95df7a9886f2bb1b5485c822a98d4f42b5f12 100644 --- a/paddle/phi/backends/xpu/xpu_info.cc +++ b/paddle/phi/backends/xpu/xpu_info.cc @@ -100,7 +100,7 @@ void SetXPUDeviceId(int id) { PADDLE_ENFORCE_LT( id, GetXPUDeviceCount(), - paddle::platform::errors::InvalidArgument("id must less than XPU count")); + phi::errors::InvalidArgument("id must less than XPU count")); PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(id)); } diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 32b9b42f74f6219bb2c234080fd3f3ce6d28dda8..80bcc66477cb10fb3672afdbe26591654e664787 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -13,8 +13,8 @@ cc_library(kernel_context SRCS kernel_context.cc DEPS pten_enforce pten_context) cc_library(ddim SRCS ddim.cc DEPS pten_enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc DEPS pten_enforce) -cc_library(tensor_meta SRCS tensor_meta.cc DEPS pten_enforce mixed_vector) -cc_library(lod_utils SRCS lod_utils.cc DEPS pten_enforce mixed_vector) +cc_library(tensor_meta SRCS tensor_meta.cc DEPS pten_enforce) +cc_library(lod_utils SRCS lod_utils.cc DEPS pten_enforce) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base) cc_library(dense_tensor SRCS dense_tensor.cc dense_tensor_impl.cc DEPS fluid_convert_utils tensor_meta tensor_base) @@ -23,7 +23,7 @@ cc_library(sparse_csr_tensor SRCS sparse_csr_tensor.cc DEPS dense_tensor tensor_ cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) -cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor mixed_vector pten_enforce ddim) +cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor pten_enforce ddim) cc_library(pten_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils) diff --git a/paddle/phi/core/ddim.h b/paddle/phi/core/ddim.h index 1d186fe3b43fe00965db2ff32c51d43d6b7a3c11..ce462d8d954023a1ccd2ff4d33e1cf9611b40513 100644 --- a/paddle/phi/core/ddim.h +++ b/paddle/phi/core/ddim.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include +#include #include #include #include diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index a363d3cbaaa340e183dfa3281800db4a9f72b104..44cb63e2b874bd2df9b034ecf9f03053d1888c94 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -73,7 +73,7 @@ void* DenseTensor::AllocateFrom(Allocator* allocator, size_t requested_size) { PADDLE_ENFORCE_NOT_NULL( allocator, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Required allocator shall not be nullptr, but received nullptr.")); if (this->dtype() != dtype) { VLOG(10) << "change data type in mutbale_data, target dtype - " << dtype; @@ -81,13 +81,13 @@ void* DenseTensor::AllocateFrom(Allocator* allocator, } PADDLE_ENFORCE( valid(), - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The meta data must be valid when call the mutable data function.")); size_t bytes = numel() * SizeOf(this->dtype()); if (requested_size) { PADDLE_ENFORCE_GE(requested_size, bytes, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The reserved size %d should be enough to meet the " "volume required by metadata %d.", requested_size, @@ -112,7 +112,7 @@ const T* DenseTensor::data() const { check_memory_size(); PADDLE_ENFORCE( (dtype() == paddle::experimental::CppTypeToDataType::Type()), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The type of data we are trying to retrieve does not match the " "type of data currently contained in the container.")); return static_cast(data()); @@ -123,7 +123,7 @@ T* DenseTensor::data() { check_memory_size(); PADDLE_ENFORCE( (dtype() == paddle::experimental::CppTypeToDataType::Type()), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The type of data we are trying to retrieve does not match the " "type of data currently contained in the container.")); return static_cast(data()); @@ -133,7 +133,7 @@ void* DenseTensor::data() { check_memory_size(); PADDLE_ENFORCE_NOT_NULL( holder_, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The storage must be valid when call the data function.")); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + meta_.offset); @@ -143,7 +143,7 @@ const void* DenseTensor::data() const { check_memory_size(); PADDLE_ENFORCE_NOT_NULL( holder_, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The storage must be valid when call the data function.")); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + meta_.offset); @@ -151,7 +151,7 @@ const void* DenseTensor::data() const { void DenseTensor::set_meta(DenseTensorMeta&& meta) { PADDLE_ENFORCE(!meta_.valid(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Only when the original attribute of Tensor is " "incomplete, can it be reset.")); meta_ = std::move(meta); @@ -160,7 +160,7 @@ void DenseTensor::set_meta(DenseTensorMeta&& meta) { void DenseTensor::set_meta(const DenseTensorMeta& meta) { PADDLE_ENFORCE( meta.valid(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Input meta is invalid, please check the meta attribute.")); meta_.dims = meta.dims; meta_.dtype = meta.dtype; diff --git a/paddle/phi/core/dense_tensor.inl b/paddle/phi/core/dense_tensor.inl index 0547776acad1f3e08752f8ee14d7acf235bdfab4..a422a95346e8b65e91a7404d70c213847e1dcf3e 100644 --- a/paddle/phi/core/dense_tensor.inl +++ b/paddle/phi/core/dense_tensor.inl @@ -54,22 +54,22 @@ DenseTensor(intrusive_ptr storage, DenseTensorMeta&& meta); inline bool IsInitialized() const { return holder_ != nullptr; } template -T* mutable_data(const paddle::platform::Place& place, +T* mutable_data(const phi::Place& place, size_t requested_size = 0); template T* mutable_data(const DDim& dims, - const paddle::platform::Place& place, + const phi::Place& place, size_t requested_size = 0); -void* mutable_data(const paddle::platform::Place& place, +void* mutable_data(const phi::Place& place, paddle::experimental::DataType type, size_t requested_size = 0); -void* mutable_data(const paddle::platform::Place& place, +void* mutable_data(const phi::Place& place, size_t requested_size = 0); -void* mutable_data(const paddle::platform::Place& place, +void* mutable_data(const phi::Place& place, paddle::experimental::DataType type, const phi::Stream& stream); diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index 1b8cfea130d4900b331f24526332b80903f55e19..7cf92e4d933b3674b128bf233350f9530e39d9a9 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/core/macros.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/type_defs.h" +#include "paddle/utils/any.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" diff --git a/paddle/phi/core/kernel_context.cc b/paddle/phi/core/kernel_context.cc index 3c7222f7a5379fe1f9d6c87ffdb38d6e6a8fa48c..a32e0e44f469694c62ff33863971d3b04004ff37 100644 --- a/paddle/phi/core/kernel_context.cc +++ b/paddle/phi/core/kernel_context.cc @@ -69,7 +69,7 @@ void KernelContext::AssignInputRange(std::pair&& range, size_t idx) { } else if (idx == input_range_.size()) { input_range_.emplace_back(range); } else { - PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + PADDLE_THROW(phi::errors::PreconditionNotMet( "Invalid idx when trying to set InputRange, " "index is `%d`, it is greater than the size(%d) of InputRange.", idx, @@ -83,7 +83,7 @@ void KernelContext::AssignOutputRange(std::pair&& range, size_t idx) { } else if (idx == output_range_.size()) { output_range_.emplace_back(range); } else { - PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + PADDLE_THROW(phi::errors::PreconditionNotMet( "Invalid idx when trying to set InputRange, " "index is `%d`, it is greater than the size(%d) of InputRange.", idx, diff --git a/paddle/phi/core/lod_utils.h b/paddle/phi/core/lod_utils.h index a5f73b66fb99b6e50a08bb80698e43efba3ce8fc..147fca4cb576ce1625df83cca95d3701e082e6f6 100644 --- a/paddle/phi/core/lod_utils.h +++ b/paddle/phi/core/lod_utils.h @@ -13,18 +13,11 @@ // limitations under the License. #pragma once - -// See Note [ Why still include the fluid headers? ] -#ifndef PADDLE_WITH_CUSTOM_KERNEL -#include "paddle/fluid/framework/mixed_vector.h" -#endif +#include +#include namespace phi { -#ifndef PADDLE_WITH_CUSTOM_KERNEL -using LoD = std::vector>; -#else -using LoD = std::vector>; -#endif +using LoD = std::vector>; void AppendLoD(LoD* lod, const LoD& lod_length); @@ -40,4 +33,4 @@ void AppendLoD(LoD* lod, const LoD& lod_length); */ LoD ConvertToLengthBasedLoD(const LoD& offset_lod); -} // namespace pten +} // namespace phi diff --git a/paddle/phi/core/selected_rows.h b/paddle/phi/core/selected_rows.h index cd48777b8ea61d58991923ea5919d7555d0a219b..7ee475b4d5d9e03d0931587f2a607f5f4950a426 100644 --- a/paddle/phi/core/selected_rows.h +++ b/paddle/phi/core/selected_rows.h @@ -55,25 +55,17 @@ class SelectedRows : public TensorBase, void set_height(int64_t height) { impl_->set_height(height); } - const paddle::framework::Vector& rows() const { - return impl_->rows(); - } + const std::vector& rows() const { return impl_->rows(); } - paddle::framework::Vector* mutable_rows() { - return impl_->mutable_rows(); - } - - void set_rows(const paddle::framework::Vector& rows) { - impl_->set_rows(rows); - } + std::vector* mutable_rows() { return impl_->mutable_rows(); } + void set_rows(const std::vector& rows) { impl_->set_rows(rows); } /* * @brief Get the index of key in rows * * @return -1 if the key does not exists. */ int64_t Index(int64_t key) const { return impl_->Index(key); } - /* * @brief whether has the specified key in the table. * diff --git a/paddle/phi/core/selected_rows_impl.cc b/paddle/phi/core/selected_rows_impl.cc index 920e9935d5899de82eb2cdd81616f8466916d7e3..7e5fd51343a09aa4ae974ad30f3265169489862c 100644 --- a/paddle/phi/core/selected_rows_impl.cc +++ b/paddle/phi/core/selected_rows_impl.cc @@ -28,7 +28,7 @@ struct ReAllocateVisitor { template void operator()() const { phi::DenseTensor cpu_tensor; - paddle::platform::CPUPlace cpu; + phi::CPUPlace cpu; T* ptr = cpu_tensor.mutable_data(dims_, cpu); const T* old_ptr = tensor_->memory_size() == 0 ? nullptr : tensor_->data(); @@ -57,7 +57,7 @@ struct TensorCopyVisitor { template void apply() const { // TODO(Yancey1989): support other place - paddle::platform::CPUPlace cpu; + phi::CPUPlace cpu; paddle::memory::Copy(cpu, dst_->mutable_data(cpu) + dst_offset_, cpu, @@ -82,7 +82,7 @@ struct TensorFillVisitor { template void apply() const { // TODO(qiao): support other place - paddle::platform::CPUPlace cpu; + phi::CPUPlace cpu; auto* tensor_data = dst_->mutable_data(cpu); auto* start = tensor_data + dst_offset_; auto* end = start + size_; @@ -121,16 +121,16 @@ int64_t SelectedRowsImpl::AutoGrownIndex(int64_t key, auto iter = id_to_index_.find(key); if (iter == id_to_index_.end()) { rwlock_->UNLock(); - PADDLE_ENFORCE_EQ(auto_grown, - true, - paddle::platform::errors::NotFound( - "Input key(%lld) is not found.", key)); + PADDLE_ENFORCE_EQ( + auto_grown, + true, + phi::errors::NotFound("Input key(%lld) is not found.", key)); rwlock_->WRLock(); auto map_size = id_to_index_.size(); auto vector_size = rows_.size(); if (map_size != vector_size) { rwlock_->UNLock(); - PADDLE_THROW(paddle::platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Row map size(%zu) should be equal to rows size(%zu).", map_size, vector_size)); @@ -140,7 +140,7 @@ int64_t SelectedRowsImpl::AutoGrownIndex(int64_t key, int row_num = rows_.size(); if (row_num == value_->dims()[0]) { rwlock_->UNLock(); - PADDLE_THROW(paddle::platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Selected rows is full, then length exceed the length of first " "dimension (%d).", row_num)); @@ -187,7 +187,7 @@ void SelectedRowsImpl::Get(const phi::DenseTensor& ids, PADDLE_ENFORCE_EQ( value_width, value->numel() / value->dims()[0], - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Output tensor should have the same shape with table " "except the first dimmension, excepted value width not counting " "the first dimension is %d, actual value width is %d.", diff --git a/paddle/phi/core/selected_rows_impl.h b/paddle/phi/core/selected_rows_impl.h index 86579e529371ad1289e8c792725b642b3a8e117c..3c54b59a159ddfdac25ad64f083cde97cfdd39f6 100644 --- a/paddle/phi/core/selected_rows_impl.h +++ b/paddle/phi/core/selected_rows_impl.h @@ -27,8 +27,6 @@ limitations under the License. */ #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/rw_lock.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/framework/mixed_vector.h" namespace phi { class SelectedRowsImpl { /* @@ -68,13 +66,11 @@ class SelectedRowsImpl { void set_height(int64_t height) { height_ = height; } - const paddle::framework::Vector& rows() const { return rows_; } + const std::vector& rows() const { return rows_; } - paddle::framework::Vector* mutable_rows() { return &rows_; } + std::vector* mutable_rows() { return &rows_; } - void set_rows(const paddle::framework::Vector& rows) { - rows_ = rows; - } + void set_rows(const std::vector& rows) { rows_ = rows; } /* * @brief Get the index of key in rows @@ -84,7 +80,7 @@ class SelectedRowsImpl { int64_t Index(int64_t key) const { auto it = std::find(rows_.begin(), rows_.end(), key); if (it == rows_.end()) { - PADDLE_THROW(paddle::platform::errors::NotFound( + PADDLE_THROW(phi::errors::NotFound( "Input id (%lld) is not in current rows table.", key)); } return static_cast(std::distance(rows_.begin(), it)); @@ -156,10 +152,7 @@ class SelectedRowsImpl { /// \brief Returns the dims of the tensor. /// \return The dims of the tensor. - const DDim& dims() const noexcept { - return value_->dims(); - // return phi::make_ddim(dims); - } + const DDim& dims() const noexcept { return value_->dims(); } /// \brief Returns the data type of the tensor. /// \return The data type of the tensor. @@ -185,7 +178,7 @@ class SelectedRowsImpl { // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. // SelectedRowsImpl are simply concated when adding together. Until a // SelectedRowsImpl add a Tensor, will the duplicate rows be handled. - paddle::framework::Vector rows_; + std::vector rows_; std::unordered_map id_to_index_; // should not be used when rows_ has duplicate member std::unique_ptr value_{nullptr}; diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index 1659f09248be02a74243a2de071606a9a8d5667c..f2987e36d3db0163c275562562bf5d6bf7aa91af 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -69,17 +69,17 @@ void SparseCooTensor::Resize(const DDim& dense_dims, const int64_t non_zero_num) { PADDLE_ENFORCE_GE(non_zero_num, this->nnz(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "the non_zero_num must be greater than or equal to the " "origin non_zero_num.")); PADDLE_ENFORCE_GE(sparse_dim, 1, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "the sparse_dim must be greater than or equal 1.")); PADDLE_ENFORCE_LE( sparse_dim, dense_dims.size(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "the sparse_dim must be less than or equal dense_dims.")); DDim indices_dims = phi::make_ddim({sparse_dim, non_zero_num}); diff --git a/paddle/phi/core/sparse_csr_tensor.cc b/paddle/phi/core/sparse_csr_tensor.cc index 7f7cd76378cc4932063ecd105147f0bc1a9d07b7..cbf5f941b665d8ae2be58472069d2e04891afe29 100644 --- a/paddle/phi/core/sparse_csr_tensor.cc +++ b/paddle/phi/core/sparse_csr_tensor.cc @@ -20,7 +20,7 @@ inline void check_shape(const DDim& dims) { bool valid = dims.size() == 2 || dims.size() == 3; PADDLE_ENFORCE(valid, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "the SparseCsrTensor only support 2-D Tensor.")); } #define Check(non_zero_crows, non_zero_cols, non_zero_elements, dims) \ @@ -29,12 +29,12 @@ inline void check_shape(const DDim& dims) { PADDLE_ENFORCE_EQ( \ non_zero_cols.place(), \ non_zero_crows.place(), \ - paddle::platform::errors::InvalidArgument( \ + phi::errors::InvalidArgument( \ "non_zero_crows and non_zero_cols must have the same place.")); \ PADDLE_ENFORCE_EQ( \ non_zero_cols.place(), \ non_zero_elements.place(), \ - paddle::platform::errors::InvalidArgument( \ + phi::errors::InvalidArgument( \ "non_zero_cols and non_zero_elements must have the same place.")); \ } @@ -77,7 +77,7 @@ void* SparseCsrTensor::AllocateFrom(Allocator* allocator, void SparseCsrTensor::Resize(const DDim& dense_dims, const int64_t non_zero_num) { PADDLE_ENFORCE(this->initialized(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "the SparseCsrTensor must be initialized when call Resize " "function.")); check_shape(dense_dims); diff --git a/paddle/phi/core/tensor_meta.h b/paddle/phi/core/tensor_meta.h index ede9b43b1f382d1db0aeaee0fa8969e05891b888..3d2da542c74176017492bdb9f567396f81308d6a 100644 --- a/paddle/phi/core/tensor_meta.h +++ b/paddle/phi/core/tensor_meta.h @@ -20,6 +20,8 @@ limitations under the License. */ #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/ddim.h" +#include "paddle/utils/any.h" +#include "paddle/utils/optional.h" // Note: mixed_vector include many header now, LoD will be // used on CUDA device? Can we use small_vector here? @@ -31,11 +33,7 @@ limitations under the License. */ namespace phi { using DDim = phi::DDim; -#ifndef PADDLE_WITH_CUSTOM_KERNEL -using LoD = std::vector>; -#else using LoD = std::vector>; -#endif /// \brief The meta data of dense tensor. Take the structure type /// and use all default operations. /// diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a964788b15e3122aeb2af857a25913543aad1c82..7455f1e6a0896fa25a3b02a03da3f3223f1d087b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -23,7 +23,7 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto x_rank = static_cast(x_dims.size()); PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "ShapeError: The dimensions of input tensor X (%s) " "should be 1 or 2", x_dims.to_str())); @@ -32,7 +32,7 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { PADDLE_ENFORCE_EQ( true, x_rank == static_cast(y_dims.size()), - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "ShapeError: The shape of input tensor Y: %s should match with " "input tenosr X: %s", y_dims.to_str(), @@ -47,7 +47,7 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { PADDLE_ENFORCE_EQ(true, shape_match, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "ShapeError: The shape of input tensor X: %s should " "be exactly the same " "with input tensor Y: %s", @@ -71,12 +71,12 @@ void MatmulInferMeta(const MetaTensor& x, auto ndims_y = dims_y.size(); PADDLE_ENFORCE_GT(ndims_x, 0UL, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The Input(x) dims size must be greater than 0," " but reviced dims size is 0. ")); PADDLE_ENFORCE_GT(ndims_y, 0UL, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The Input(y) dims size must be greater than 0," " but reviced dims size is 0. ")); @@ -150,7 +150,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x, if (x_dims.size() == y_dims.size()) { PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "axis should be -1 or 0 while the dimension of " "tensor X (%s) is equal to the dimension of " "tensor Y (%s), but received axis: %s", @@ -160,7 +160,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x, } PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The axis range must be [%s, %s), but axis is %s. " "Please set the axis again.", -1 * max_dim, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 5e7dd1de69d7d0f3de5ef7e67dc8d1f48373abdb..d72033f95285738f20c75b5d2a678fe4811e8a18 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -24,7 +24,7 @@ void ConcatInferMeta(const std::vector& x, MetaConfig config) { PADDLE_ENFORCE_GE(x.size(), 0UL, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of input meta vector should be greater" "than 0.")); @@ -34,7 +34,7 @@ void ConcatInferMeta(const std::vector& x, PADDLE_ENFORCE_EQ( axis >= -rank && axis < rank, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The axis is expected to be in range of [%d, %d), but got %d", -rank, rank, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index fda395e6d95ec381694610e9028bf48359e4d94c..1fbd6c2b6c2f5f5b3a86917c9ff35031da9b6b93 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -38,11 +38,11 @@ void FlattenInferMeta(const MetaTensor& x, if (stop_axis < 0) { stop_axis = stop_axis + in_dims_size; } - PADDLE_ENFORCE_GE(stop_axis, - start_axis, - paddle::platform::errors::InvalidArgument( - "The stop_axis should be greater" - "than or equal to start_axis.")); + PADDLE_ENFORCE_GE( + stop_axis, + start_axis, + phi::errors::InvalidArgument("The stop_axis should be greater" + "than or equal to start_axis.")); int64_t outer = 1; std::vector out_shape; @@ -113,7 +113,7 @@ static phi::DDim ValidateShape(const std::vector shape, PADDLE_ENFORCE_EQ( unk_dim_idx, -1, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Only one dimension value of 'shape' in ReshapeOp can " "be -1. But received shape = [%s], shape[%d] is also -1.", phi::make_ddim(shape), @@ -123,7 +123,7 @@ static phi::DDim ValidateShape(const std::vector shape, PADDLE_ENFORCE_LT( static_cast(i), in_dims.size(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The index of 0 in `shape` must be less than " "the input tensor X's dimensions. " "But received shape = [%s], shape[%d] = 0, X's shape = [%s], " @@ -136,7 +136,7 @@ static phi::DDim ValidateShape(const std::vector shape, PADDLE_ENFORCE_GT( shape[i], 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Each dimension value of 'shape' in ReshapeOp must not " "be negative except one unknown dimension. " "But received shape = [%s], shape[%d] = %d.", @@ -161,7 +161,7 @@ static phi::DDim ValidateShape(const std::vector shape, PADDLE_ENFORCE_EQ( output_shape[unk_dim_idx] * capacity, -in_size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The 'shape' attribute in ReshapeOp is invalid. " "The input tensor X'size must be divisible by known " "capacity of 'shape'. " @@ -179,7 +179,7 @@ static phi::DDim ValidateShape(const std::vector shape, PADDLE_ENFORCE_EQ( capacity, in_size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The 'shape' in ReshapeOp is invalid. " "The input tensor X'size must be equal to the capacity of " "'shape'. " @@ -199,7 +199,7 @@ static phi::DDim ValidateShape(const std::vector shape, PADDLE_ENFORCE_LE( capacity, in_size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The 'shape' in ReshapeOp is invalid. " "The input tensor X's shape = [%s], X's capacity = %d." "But the target shape of Out is [%s], the " @@ -364,7 +364,7 @@ void SplitInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ( axis_value >= -rank && axis_value < rank, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The axis is expected to be in range of [%d, %d), but got %d", -rank, rank, @@ -383,7 +383,7 @@ void SplitInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ(input_axis_dim % num, 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The input's size along the split dimension " "must be evenly divisible by Attr(num_or_sections). " "But received Attr(num_or_sections) " @@ -416,7 +416,7 @@ void SplitInferMeta(const MetaTensor& x, if (config.is_runtime) { PADDLE_ENFORCE_LE(num_of_unknow, 1, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Only one dimension value of Attr(num_or_sections) " "in SplitOp can be -1. " "But received Attr(num_or_sections) = [%s].", @@ -430,7 +430,7 @@ void SplitInferMeta(const MetaTensor& x, PADDLE_ENFORCE_LT( sum_of_section, input_axis_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Sum of Attr(num_or_sections) other than unknown section " "must be less than the input's " "size " @@ -447,7 +447,7 @@ void SplitInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ( sum_of_section, input_axis_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Sum of Attr(num_or_sections) must be equal to the input's " "size " "along the split dimension. But received Attr(num_or_sections)" diff --git a/paddle/phi/kernels/cpu/concat_kernel.cc b/paddle/phi/kernels/cpu/concat_kernel.cc index 0cae2599f8d13fe807baa71be2692c85201fc5a8..3b74951a5041cd303c85c6a57766f5a06412f71b 100644 --- a/paddle/phi/kernels/cpu/concat_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_kernel.cc @@ -54,7 +54,7 @@ void ConcatKernel(const Context& dev_ctx, PADDLE_ENFORCE_EQ( x[i].lod().size(), lod_size_0, - paddle::platform::errors::Unimplemented( + phi::errors::Unimplemented( "The lod level of all input LoDTensors should be same. " "Maybe different lod level of input LoDTensors can concat," "it is not supported currently. The lod level of %dth input " diff --git a/paddle/phi/kernels/cpu/elementwise.h b/paddle/phi/kernels/cpu/elementwise.h index c692038d24a0a885d21b9c632709b143681a438d..28bf5ab743f6d5d0608fe65c00d5a0de2af3415b 100644 --- a/paddle/phi/kernels/cpu/elementwise.h +++ b/paddle/phi/kernels/cpu/elementwise.h @@ -127,7 +127,7 @@ struct SameDimsDivideFunctor< const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "If use SameDimsDivideFunctor, template args(T) must be floating " "point. "); } @@ -278,12 +278,10 @@ void CommonForwardBroadcastCPU(const DenseTensor& x, std::vector index_array(max_dim, 0); const T* x_data = x.data(); const T* y_data = y.data(); - PADDLE_ENFORCE_NOT_NULL(x_data, - paddle::platform::errors::InvalidArgument( - "The input X should not be empty.")); - PADDLE_ENFORCE_NOT_NULL(y_data, - paddle::platform::errors::InvalidArgument( - "The input Y should not be empty.")); + PADDLE_ENFORCE_NOT_NULL( + x_data, phi::errors::InvalidArgument("The input X should not be empty.")); + PADDLE_ENFORCE_NOT_NULL( + y_data, phi::errors::InvalidArgument("The input Y should not be empty.")); OutType* out_data = ctx.Alloc(z); const int out_size = std::accumulate( @@ -317,12 +315,12 @@ void CommonElementwiseBroadcastForward(const CPUContext& dev_ctx, PADDLE_ENFORCE_GE( axis, 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); @@ -385,12 +383,12 @@ void ElementwiseCompute(const CPUContext& dev_ctx, PADDLE_ENFORCE_GE( axis, 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); @@ -630,12 +628,12 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx, PADDLE_ENFORCE_GE( axis, 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); diff --git a/paddle/phi/kernels/cpu/masked_select_kernel.cc b/paddle/phi/kernels/cpu/masked_select_kernel.cc index 274863a863b799a397840ceec314219fbbf70a39..f377658d507f6086101e1cdb0f0ab1891536e771 100644 --- a/paddle/phi/kernels/cpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/cpu/masked_select_kernel.cc @@ -48,7 +48,7 @@ void MaskedSelectKernel(const Context& dev_ctx, DDim out_dim{out_size}; out->Resize(out_dim); - auto out_data = out->mutable_data(paddle::platform::CPUPlace()); + auto out_data = out->mutable_data(phi::CPUPlace()); int index = 0; for (int i = 0; i < mask_size; i++) { diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index e14241d03c3af09bd1d0201da0f53ffadd2b2c4a..8bd9867f39edd297396d392bd5d286ad5d10056f 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -42,12 +42,12 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, PADDLE_ENFORCE_GE( axis, 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); @@ -72,7 +72,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Broadcast dimension mismatch. Operands could " "not be broadcast together with the shape of X = [%s] and " "the shape of Y = [%s]. Received [%d] in X is not equal to " diff --git a/paddle/phi/kernels/funcs/concat_funcs.h b/paddle/phi/kernels/funcs/concat_funcs.h index 63f0c8058acc16f1665bda7d6a2b91cdc24ef2b0..32237e2cc236657db5a99fdd64392da4ff900562 100644 --- a/paddle/phi/kernels/funcs/concat_funcs.h +++ b/paddle/phi/kernels/funcs/concat_funcs.h @@ -23,7 +23,7 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { PADDLE_ENFORCE_EQ( axis >= -rank && axis < rank, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The axis is expected to be in range of [%d, %d), but got %d", -rank, rank, @@ -42,17 +42,17 @@ static inline phi::DDim ComputeAndCheckShape( auto out_dims = inputs_dims[0]; size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { - PADDLE_ENFORCE_EQ(inputs_dims[i].size(), - out_dims.size(), - paddle::platform::errors::InvalidArgument( - "The shape of input[0] and input[%d] " - "is expected to be equal." - "But received input[0]'s shape = " - "[%s], input[%d]'s shape = [%s].", - i, - inputs_dims[0], - i, - inputs_dims[i])); + PADDLE_ENFORCE_EQ( + inputs_dims[i].size(), + out_dims.size(), + phi::errors::InvalidArgument("The shape of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + i, + inputs_dims[0], + i, + inputs_dims[i])); for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == axis) { if (is_runtime) { @@ -71,7 +71,7 @@ static inline phi::DDim ComputeAndCheckShape( // check all shape in run time PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j], - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The %d-th dimension of input[0] and input[%d] " "is expected to be equal." "But received input[0]'s shape = " @@ -92,4 +92,4 @@ static inline phi::DDim ComputeAndCheckShape( } } // namespace funcs -} // namespace pten +} // namespace phi diff --git a/paddle/phi/kernels/funcs/eigen/common.h b/paddle/phi/kernels/funcs/eigen/common.h index dc64d3b122f1014ddfed081269859d46c26f43ad..d34427df0e499b78fccdfe80660277152560e34d 100644 --- a/paddle/phi/kernels/funcs/eigen/common.h +++ b/paddle/phi/kernels/funcs/eigen/common.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace phi { -// EigenDim converts paddle::platform::DDim into Eigen::DSizes. +// EigenDim converts phi::DDim into Eigen::DSizes. template struct EigenDim { using Type = Eigen::DSizes; @@ -29,7 +29,7 @@ struct EigenDim { static Type From(const DDim& dims) { PADDLE_ENFORCE_EQ(arity(dims), D, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Input dimension size should be equal to %d, but " "received dimension size is %d.", arity(dims), @@ -42,7 +42,7 @@ struct EigenDim { } }; -// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. +// Interpret phi::Tensor as EigenTensor and EigenConstTensor. template { int rank = tensor.dims().size(); PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Input dimension number(num_col_dims) must be " "between 0 and %d, but received number is %d.", rank, @@ -100,7 +100,7 @@ struct EigenMatrix : public EigenTensor { int rank = tensor.dims().size(); PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Input dimension number(num_col_dims) must be " "between 0 and %d, but received number is %d.", rank, diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 9fb2dac6c425f6224da713fb6ada636355b42c26..9a429dfaaf957785ab0108fe19ff63244659df11 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -343,7 +343,7 @@ inline void get_mid_dims(const DDim &x_dims, if (x_dims[i + axis] != y_dims[i]) { PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Broadcast dimension mismatch. Operands " "could not be broadcast together with the shape of " "X = [%s] and the shape of Y = [%s]. Received [%d] " @@ -754,7 +754,7 @@ void ElementwiseKernel(const KPDevice &ctx, const int kArity = Traits::arity; PADDLE_ENFORCE_EQ(ins.size(), kArity, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The number of inputs is expected to be equal to the " "arity of functor. But recieved: the number of inputs " "is %d, the arity of functor is %d.", @@ -762,7 +762,7 @@ void ElementwiseKernel(const KPDevice &ctx, kArity)); PADDLE_ENFORCE_EQ(outs->size(), NumOuts, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Number of outputs shall equal to number of functions, " "but number of outputs is %d, of functions is %d.", outs->size(), @@ -773,7 +773,7 @@ void ElementwiseKernel(const KPDevice &ctx, PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), (*outs)[0]->dims(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The shape of each output tensor shall be identical yet, " "but %dth output tensor`s shape is not.", i)); @@ -796,7 +796,7 @@ void ElementwiseKernel(const KPDevice &ctx, ctx, ins, outs, func); break; default: { - PADDLE_THROW(paddle::platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; } diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 8aed099d9f2433284d83cb9d0c18a70e1415cf8f..4201a75be8ac7ee9f7e633f6def1e002ce4b7e8a 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -184,7 +184,7 @@ struct TensorSetConstantCPU { : tensor_(tensor), value_(value) {} template void apply() const { - auto cpu = paddle::platform::CPUPlace(); + auto cpu = phi::CPUPlace(); auto* begin = tensor_->mutable_data(cpu); std::fill(begin, begin + tensor_->numel(), static_cast(value_)); } @@ -197,8 +197,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - PADDLE_THROW( - paddle::platform::errors::Unimplemented("XPUPlace is not supported")); + PADDLE_THROW(phi::errors::Unimplemented("XPUPlace is not supported")); } template <> @@ -206,8 +205,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - PADDLE_THROW( - paddle::platform::errors::Unimplemented("NPUPlace is not supported")); + PADDLE_THROW(phi::errors::Unimplemented("NPUPlace is not supported")); } template <> @@ -215,8 +213,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - PADDLE_THROW(paddle::platform::errors::Unimplemented( - "NPUPinnedPlace is not supported")); + PADDLE_THROW(phi::errors::Unimplemented("NPUPinnedPlace is not supported")); } template <> @@ -224,8 +221,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - PADDLE_THROW( - paddle::platform::errors::Unimplemented("IPUPlace is not supported")); + PADDLE_THROW(phi::errors::Unimplemented("IPUPlace is not supported")); } template <> @@ -233,12 +229,11 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - PADDLE_THROW( - paddle::platform::errors::Unimplemented("CustomPlace is not supported")); + PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported")); } template <> -void set_constant_with_place( +void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { @@ -250,8 +245,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - PADDLE_THROW( - paddle::platform::errors::Unimplemented("MLUPlace is not supported")); + PADDLE_THROW(phi::errors::Unimplemented("MLUPlace is not supported")); } template <> @@ -286,7 +280,7 @@ void set_constant(const paddle::platform::DeviceContext& context, // tensor->place().apply_visitor(func); paddle::platform::VisitPlace(tensor->place(), func); #else - func(paddle::platform::CPUPlace()); + func(phi::CPUPlace()); #endif } @@ -302,7 +296,7 @@ struct RowwiseAdd { PADDLE_ENFORCE_EQ( vector.numel(), size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The input vector size" " should be equal to the size of each row of input tensor." " Expected vector size=%d, but received %d", @@ -312,7 +306,7 @@ struct RowwiseAdd { const char* out_dims_cstr = out_dims.to_str().c_str(); PADDLE_ENFORCE_EQ(out_dims, in_dims, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The output tensor shape should be same as the input" " tensor shape. Expected output tensor shape: %s," " but received %s", diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index 0b2b53c28c984527a8e4199ed6dc92ab0b50f3f9..ae368a005f057994d9f2c4a91188358aa26e09c2 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -257,7 +257,7 @@ struct RowwiseAdd { PADDLE_ENFORCE_EQ( vector.numel(), size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The input vector size" " should be equal to the size of each row of input tensor." " Expected vector size=%d, but received %d", @@ -268,7 +268,7 @@ struct RowwiseAdd { PADDLE_ENFORCE_EQ( out_dims, in_dims, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The output tensor shape should be same as the input tensor" " shape. Expected output tensor shape: %s," " but received %s", @@ -303,7 +303,7 @@ void ColwiseSum::operator()( auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector->numel(), size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of input vector" " should be equal to the size of input tensor column" " dimension. Expected vector size=%d, but received %d", @@ -339,7 +339,7 @@ void RowwiseSum::operator()( auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0], - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of input vector" " should be equal to the size of input tensor row" " dimension. Expected vector size=%d, but received %d", diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index 7f581c395cc713f702bcf8f512f3e1f1ca764a32..8e1a4cdd1a9688a12d7f0a8b5ba088f6abfc9512 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -115,7 +115,7 @@ struct TensorSetConstantXPU { std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(value_)); paddle::memory::Copy(place_, begin, - paddle::platform::CPUPlace(), + phi::CPUPlace(), static_cast(data_cpu.get()), numel * sizeof(T)); } diff --git a/paddle/phi/kernels/funcs/math_function_impl.h b/paddle/phi/kernels/funcs/math_function_impl.h index b099c6d411602126e72d3b5fdfb3107f92b2bd2f..1638d03e50f95a9338aff0d25bd41dd5d95e9738 100644 --- a/paddle/phi/kernels/funcs/math_function_impl.h +++ b/paddle/phi/kernels/funcs/math_function_impl.h @@ -74,7 +74,7 @@ void ColwiseSum::operator()( auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(out->numel(), size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of output tensor " "should be equal to the size of input tensor column" " dimension. Expected output size=%d, but received %d", @@ -102,7 +102,7 @@ class ColwiseSum { PADDLE_ENFORCE_EQ( out->numel(), size, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of output tensor " "should be equal to the size of input tensor column" " dimension. Expected output size=%d, but received %d", @@ -130,15 +130,14 @@ void RowwiseMean::operator()( const paddle::framework::Tensor& input, paddle::framework::Tensor* out) { auto in_dims = input.dims(); - PADDLE_ENFORCE_EQ( - in_dims.size(), - 2U, - paddle::platform::errors::InvalidArgument("The rank of input tensor " - "should be 2, but received %d", - in_dims.size())); + PADDLE_ENFORCE_EQ(in_dims.size(), + 2U, + phi::errors::InvalidArgument("The rank of input tensor " + "should be 2, but received %d", + in_dims.size())); PADDLE_ENFORCE_EQ(out->numel(), in_dims[0], - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of output tensor " "should be equal to the size of input tensor row" " dimension. Expected output size=%d, but received %d", @@ -161,18 +160,18 @@ class RowwiseMean { const paddle::framework::Tensor& input, paddle::framework::Tensor* out) { auto& in_dims = input.dims(); - PADDLE_ENFORCE_EQ(in_dims.size(), - 2U, - paddle::platform::errors::InvalidArgument( - "The rank of input tensor " - "should be 2, but received %d", - in_dims.size())); + PADDLE_ENFORCE_EQ( + in_dims.size(), + 2U, + phi::errors::InvalidArgument("The rank of input tensor " + "should be 2, but received %d", + in_dims.size())); auto height = in_dims[0]; auto size = in_dims[1]; PADDLE_ENFORCE_EQ( out->numel(), height, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of output tensor " "should be equal to the size of input tensor row" " dimension. Expected output size=%d, but received %d", @@ -198,15 +197,14 @@ void RowwiseSum::operator()( const paddle::framework::Tensor& input, paddle::framework::Tensor* out) { auto in_dims = input.dims(); - PADDLE_ENFORCE_EQ( - in_dims.size(), - 2U, - paddle::platform::errors::InvalidArgument("The rank of input tensor " - "should be 2, but received %d", - in_dims.size())); + PADDLE_ENFORCE_EQ(in_dims.size(), + 2U, + phi::errors::InvalidArgument("The rank of input tensor " + "should be 2, but received %d", + in_dims.size())); PADDLE_ENFORCE_EQ(out->numel(), in_dims[0], - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of output tensor " "should be equal to the size of input tensor row" " dimension. Expected output size=%d, but received %d", @@ -229,18 +227,18 @@ class RowwiseSum { const paddle::framework::Tensor& input, paddle::framework::Tensor* out) { auto& in_dims = input.dims(); - PADDLE_ENFORCE_EQ(in_dims.size(), - 2U, - paddle::platform::errors::InvalidArgument( - "The rank of input tensor " - "should be 2, but received %d", - in_dims.size())); + PADDLE_ENFORCE_EQ( + in_dims.size(), + 2U, + phi::errors::InvalidArgument("The rank of input tensor " + "should be 2, but received %d", + in_dims.size())); auto height = in_dims[0]; auto size = in_dims[1]; PADDLE_ENFORCE_EQ( out->numel(), height, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The size of output tensor " "should be equal to the size of input tensor row" " dimension. Expected output size=%d, but received %d", diff --git a/paddle/phi/kernels/gpu/concat_and_split.h b/paddle/phi/kernels/gpu/concat_and_split.h index 46586012ccc1efc488d815471d0be5c87109ca6c..ced48ece979f06fbf2bd3f9fd8b7e07cc2954fbf 100644 --- a/paddle/phi/kernels/gpu/concat_and_split.h +++ b/paddle/phi/kernels/gpu/concat_and_split.h @@ -16,7 +16,6 @@ #include #include #include "gflags/gflags.h" -#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" @@ -329,7 +328,7 @@ void ConcatImpl(const Context& context, inputs_data, in_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_data->ptr(), - paddle::platform::CPUPlace(), + phi::CPUPlace(), restored, in_num * sizeof(T*), context.stream()); @@ -376,7 +375,7 @@ void ConcatImpl(const Context& context, inputs_col, inputs_col_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), - paddle::platform::CPUPlace(), + phi::CPUPlace(), restored, inputs_col_num * sizeof(int64_t), context.stream()); @@ -488,7 +487,7 @@ void SplitImpl(const Context& context, outputs_data, o_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_outs_data->ptr(), - paddle::platform::CPUPlace(), + phi::CPUPlace(), restored, o_num * sizeof(T*), context.stream()); @@ -535,7 +534,7 @@ void SplitImpl(const Context& context, outputs_cols, outputs_cols_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), - paddle::platform::CPUPlace(), + phi::CPUPlace(), restored, outputs_cols_num * sizeof(int64_t), context.stream()); diff --git a/paddle/phi/kernels/gpu/concat_kernel.cu b/paddle/phi/kernels/gpu/concat_kernel.cu index c80a873127708c244c88eaf83516662f34b40993..b787b80c7e4ed9c10fafb139648e17fd91ca7529 100644 --- a/paddle/phi/kernels/gpu/concat_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_kernel.cu @@ -54,7 +54,7 @@ void ConcatKernel(const Context& dev_ctx, PADDLE_ENFORCE_EQ( x[i].lod().size(), lod_size_0, - paddle::platform::errors::Unimplemented( + phi::errors::Unimplemented( "The lod level of all input LoDTensors should be same. " "Maybe different lod level of input LoDTensors can concat," "it is not supported currently. The lod level of %dth input " diff --git a/paddle/phi/kernels/gpu/copy_kernel.cu b/paddle/phi/kernels/gpu/copy_kernel.cu index e88795b6173706a8b54cd23c64f73b11e08f0fa6..0cbf5525d60f53aa47ce58bc217e8ce75b399c14 100644 --- a/paddle/phi/kernels/gpu/copy_kernel.cu +++ b/paddle/phi/kernels/gpu/copy_kernel.cu @@ -35,7 +35,7 @@ void Copy(const Context& dev_ctx, auto dst_place = dst->place(); if (src_place == dst_place && paddle::platform::is_cpu_place(src_place)) { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "The src and dst tensor are all CPU tensor, you should call copy " "function in CPU mode.")); } @@ -74,13 +74,13 @@ void Copy(const Context& dev_ctx, PADDLE_ENFORCE_EQ( paddle::platform::is_gpu_place(ctx_place), true, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Context place error, excepted GPUPlace, but actually %s.", ctx_place)); auto ctx_gpu_place = ctx_place; PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place, - paddle::platform::errors::Unavailable( + phi::errors::Unavailable( "Source place and context place do not match, source " "place is %s, context place is %s.", src_gpu_place, @@ -98,13 +98,13 @@ void Copy(const Context& dev_ctx, PADDLE_ENFORCE_EQ( paddle::platform::is_gpu_place(ctx_place), true, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Context place error, excepted GPUPlace, but actually %s.", ctx_place)); auto ctx_gpu_place = ctx_place; PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place, - paddle::platform::errors::Unavailable( + phi::errors::Unavailable( "Destination place and context place do not match, " "destination place is %s, context place is %s.", dst_gpu_place, @@ -121,14 +121,14 @@ void Copy(const Context& dev_ctx, auto ctx_place = dev_ctx.GetPlace(); PADDLE_ENFORCE_EQ(paddle::platform::is_gpu_place(ctx_place), true, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Device context place mismatch. When copying Tensor " "data from GPU memory to CUDA Pinned memory, current " "device context place should be GPU.")); auto ctx_gpu_place = ctx_place; PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The source GPU device and current device context do " "not match. The source GPU device number is %d, but " "device context GPU number is %d.", @@ -146,14 +146,14 @@ void Copy(const Context& dev_ctx, auto ctx_place = dev_ctx.GetPlace(); PADDLE_ENFORCE_EQ(paddle::platform::is_gpu_place(ctx_place), true, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Device context place mismatch. When copying Tensor " "data from CUDA Pinned memory to GPU memory, current " "device context place should be GPU.")); auto ctx_gpu_place = ctx_place; PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The target GPU device and current device context do " "not match. The target GPU device number is %d, but " "device context GPU number is %d.", @@ -172,7 +172,7 @@ void Copy(const Context& dev_ctx, PADDLE_ENFORCE_EQ( paddle::platform::is_gpu_place(ctx_place), true, - paddle::platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Context place error, excepted GPUPlace, but actually %s.", ctx_place)); auto stream = @@ -195,12 +195,12 @@ void Copy(const Context& dev_ctx, paddle::memory::Copy( dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } else { - PADDLE_THROW(paddle::platform::errors::Unavailable( + PADDLE_THROW(phi::errors::Unavailable( "Context place dose not match the source and destination place.")); } } } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Place type error. Please check the place of src and dst Tensor.")); } } diff --git a/paddle/phi/kernels/gpu/elementwise.h b/paddle/phi/kernels/gpu/elementwise.h index df66a00a8072569ed2b3e4e01ab27bbda84598d5..a2992702b164af380e961859f236b6b725c898c7 100644 --- a/paddle/phi/kernels/gpu/elementwise.h +++ b/paddle/phi/kernels/gpu/elementwise.h @@ -714,7 +714,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, DX_OP dx_op, DY_OP dy_op) { const auto gplace = ctx.GetPlace(); - auto cplace = paddle::platform::CPUPlace(); + auto cplace = phi::CPUPlace(); const T *x_data = x.data(); const T *y_data = y.data(); const Tout *out_data = out.data(); @@ -1339,12 +1339,12 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, PADDLE_ENFORCE_GE( axis, 0, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LT(axis, max_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index 6db987e22fc6c2ed59f67ac82adee8176ede0c9b..c5eb5220537cdd471402fc12cf8b98cf3a586ebc 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -111,9 +111,9 @@ void HistogramKernel(const Context& dev_ctx, DenseTensor input_min_cpu, input_max_cpu; paddle::framework::TensorCopySync( - input_min_t, paddle::platform::CPUPlace(), &input_min_cpu); + input_min_t, phi::CPUPlace(), &input_min_cpu); paddle::framework::TensorCopySync( - input_max_t, paddle::platform::CPUPlace(), &input_max_cpu); + input_max_t, phi::CPUPlace(), &input_max_cpu); output_min = input_min_cpu.data()[0]; output_max = input_max_cpu.data()[0]; diff --git a/paddle/phi/kernels/impl/full_kernel_impl.h b/paddle/phi/kernels/impl/full_kernel_impl.h index 40675dd175bef8ca6840264b1a3715363c6c3fb4..8cced49906eccdc41ccfb02518dcd06d771d23c9 100644 --- a/paddle/phi/kernels/impl/full_kernel_impl.h +++ b/paddle/phi/kernels/impl/full_kernel_impl.h @@ -59,7 +59,7 @@ void FullLikeKernel(const Context& dev_ctx, (common_type_value <= static_cast(std::numeric_limits::max())), true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " "and %f, but now value is %f.", diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 119bdc2986ea5559ea818af86b4cc6c1e6efe8a5..f6136de5d8d0c3d04c83b0446abc82d0eeb11376 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -38,7 +38,7 @@ static void GetBroadcastFromDims(const int x_ndim, PADDLE_ENFORCE_EQ( x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Input(X) and Input(Y) has error dim." "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," @@ -110,7 +110,7 @@ void MatMulFunction(const Context& dev_ctx, PADDLE_ENFORCE_EQ( M, N, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "X's numbers must be equal to Y's numbers," "when X/Y's dims =1. But received X has [%d] elements," "received Y has [%d] elements", @@ -135,27 +135,27 @@ void MatMulFunction(const Context& dev_ctx, if (x_ndim == 1) { const int N = X.numel(); if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], - N, - paddle::platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], + N, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + N, + y_ndim - 1, + y_dims[y_ndim - 1])); } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], - N, - paddle::platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], + N, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + N, + y_ndim - 2, + y_dims[y_ndim - 2])); } std::vector out_dims(y_ndim - 1); if (trans_y) { @@ -213,27 +213,27 @@ void MatMulFunction(const Context& dev_ctx, if (y_ndim == 1) { const int N = Y.numel(); if (trans_x) { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], - N, - paddle::platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], + N, + phi::errors::InvalidArgument("Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 2, + N, + x_ndim - 2, + x_dims[x_ndim - 2])); } else { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], - N, - paddle::platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 1], + N, + phi::errors::InvalidArgument("Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 1, + N, + x_ndim - 1, + x_dims[x_ndim - 1])); } std::vector out_dims(x_ndim - 1); if (trans_x) { @@ -292,27 +292,27 @@ void MatMulFunction(const Context& dev_ctx, const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], - K, - paddle::platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], + K, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + K, + y_ndim - 1, + y_dims[y_ndim - 1])); } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], - K, - paddle::platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], + K, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + K, + y_ndim - 2, + y_dims[y_ndim - 2])); } const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; const int ndim = (std::max)(x_ndim, y_ndim); @@ -493,16 +493,16 @@ void MatmulKernel(const Context& dev_ctx, bool transpose_x, bool transpose_y, DenseTensor* out) { - PADDLE_ENFORCE_NE(phi::product(x.dims()), - 0, - paddle::platform::errors::InvalidArgument( - "The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE(phi::product(y.dims()), - 0, - paddle::platform::errors::InvalidArgument( - "The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE( + phi::product(x.dims()), + 0, + phi::errors::InvalidArgument("The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE( + phi::product(y.dims()), + 0, + phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); MatMulFunction(dev_ctx, x, y, out, transpose_x, transpose_y); } diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index 4374b5d7f1a1d9992619cffcdafcda8708e4c640..ba89135641e0e67daa84cd526d8b389953ef1862 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -41,7 +41,7 @@ inline int64_t GetNonZeroNum(const DenseTensor& dense, PADDLE_ENFORCE_GE( dims.size(), sparse_dim, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "sparse_dim(%d) should be less than or equal to dense.dim(%d)", sparse_dim, dims.size())); @@ -161,7 +161,7 @@ void SparseCooToCsrKernel(const Context& dev_ctx, bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D matrix")); const int64_t non_zero_num = x.nnz(); if (non_zero_num <= 0) return; diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index b7793e40554455075e98b12192750d862045fa82..1e2c70a9cf39bf0df738a74b301afcc0fcbd8699 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -379,7 +379,7 @@ void SparseCooToCsrKernel(const Context& dev_ctx, bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D matrix")); const int64_t non_zero_num = x.nnz(); if (non_zero_num <= 0) return; diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index 3d7304653e77b3c45b82ccb7426de56457f14b03..b5201e16f548d594af47aa9a4611d35f9cf2ad4f 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -97,7 +97,7 @@ void DenseToSparseCsrKernel(const Context& dev_ctx, bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D Tensor.")); const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3; DenseTensor indices = phi::Empty(dev_ctx); diff --git a/paddle/phi/kernels/xpu/copy_kernel.cc b/paddle/phi/kernels/xpu/copy_kernel.cc index 3bbedbbb346e42e55824c833244774544648ab40..58efbafc88bee0933a364ff9872604c94174305f 100644 --- a/paddle/phi/kernels/xpu/copy_kernel.cc +++ b/paddle/phi/kernels/xpu/copy_kernel.cc @@ -62,7 +62,7 @@ void Copy(const Context& dev_ctx, } paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); } else { - PADDLE_THROW(paddle::platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "Copy from %s to %s is not supported.", src_place, dst_place)); } } diff --git a/paddle/phi/kernels/xpu/scale_kernel.cc b/paddle/phi/kernels/xpu/scale_kernel.cc index e103e5afdcf9bea9206541ee5c94c1c3d7a87e5f..b5a07a7a146c3e8d058f8f8d90e2dbc3cd68e7ab 100644 --- a/paddle/phi/kernels/xpu/scale_kernel.cc +++ b/paddle/phi/kernels/xpu/scale_kernel.cc @@ -32,13 +32,13 @@ void ScaleKernel(const Context& dev_ctx, DenseTensor* out) { out->mutable_data(dev_ctx.GetPlace()); - PADDLE_ENFORCE_EQ(x.dims(), - out->dims(), - paddle::platform::errors::InvalidArgument( - "In and out should have the same dim," - " expected %s, but got %s.", - x.dims().to_str().c_str(), - out->dims().to_str().c_str())); + PADDLE_ENFORCE_EQ( + x.dims(), + out->dims(), + phi::errors::InvalidArgument("In and out should have the same dim," + " expected %s, but got %s.", + x.dims().to_str().c_str(), + out->dims().to_str().c_str())); using XPUType = typename XPUTypeTrait::Type; int r = xpu::scale(dev_ctx.x_context(), reinterpret_cast(x.data()), @@ -50,7 +50,7 @@ void ScaleKernel(const Context& dev_ctx, PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, - paddle::platform::errors::External( + phi::errors::External( "XPU scale kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } diff --git a/paddle/phi/tests/core/allocator.h b/paddle/phi/tests/core/allocator.h index 66e5b4885c8363550561df84685950767c76914b..b92178eba3045365806f2cd94a36d74794be83f3 100644 --- a/paddle/phi/tests/core/allocator.h +++ b/paddle/phi/tests/core/allocator.h @@ -29,8 +29,7 @@ class FancyAllocator : public phi::Allocator { AllocationPtr Allocate(size_t bytes_size) override { void* data = ::operator new(bytes_size); - auto* allocation = - new phi::Allocation(data, bytes_size, paddle::platform::CPUPlace()); + auto* allocation = new phi::Allocation(data, bytes_size, phi::CPUPlace()); return AllocationPtr(allocation, Delete); } }; diff --git a/paddle/phi/tests/core/test_dense_tensor.cc b/paddle/phi/tests/core/test_dense_tensor.cc index 6464ff24d24aa59d2f8eb3d28818af5b552ab658..ddfa184df2c1ede0f953b426f14bd5730ee3a9b0 100644 --- a/paddle/phi/tests/core/test_dense_tensor.cc +++ b/paddle/phi/tests/core/test_dense_tensor.cc @@ -85,7 +85,7 @@ TEST(dense_tensor, ctor) { r = r && (t.dims() == m.dims); r = r && (t.dtype() == m.dtype); r = r && (t.layout() == m.layout); - r = r && (t.place() == paddle::platform::CPUPlace()); + r = r && (t.place() == phi::CPUPlace()); r = r && t.initialized(); r = r && t.IsSharedWith(t); return r; diff --git a/paddle/phi/tests/core/test_sparse_coo_tensor.cc b/paddle/phi/tests/core/test_sparse_coo_tensor.cc index e93f1f0b0ecaffce35c56f939304d0f182d65bfc..5d0e16b0528e7ba73d1b1fea858e8b0529cc9ddf 100644 --- a/paddle/phi/tests/core/test_sparse_coo_tensor.cc +++ b/paddle/phi/tests/core/test_sparse_coo_tensor.cc @@ -53,7 +53,7 @@ TEST(sparse_coo_tensor, construct) { CHECK(sparse.dims() == dense_dims); CHECK(sparse.dtype() == DataType::FLOAT32); CHECK(sparse.layout() == DataLayout::SPARSE_COO); - CHECK(sparse.place() == paddle::platform::CPUPlace()); + CHECK(sparse.place() == phi::CPUPlace()); } TEST(sparse_coo_tensor, other_function) { diff --git a/paddle/utils/string/tinyformat/tinyformat.h b/paddle/utils/string/tinyformat/tinyformat.h index 28a444f87c48fdde7d41aa257fe0e91538c9b7a7..4e46cbc26b6380687639de140333c123662543b5 100644 --- a/paddle/utils/string/tinyformat/tinyformat.h +++ b/paddle/utils/string/tinyformat/tinyformat.h @@ -133,6 +133,8 @@ #include #include +#include "paddle/utils/string/to_string.h" + namespace paddle { namespace string { namespace tinyformat { diff --git a/paddle/utils/string/to_string.h b/paddle/utils/string/to_string.h index 7b3332861e0fa3edbbb8915e3e3f068fed3b412f..3cec88a4571b6bf50ccebb7f9b2c6224f42166e8 100644 --- a/paddle/utils/string/to_string.h +++ b/paddle/utils/string/to_string.h @@ -56,5 +56,26 @@ inline std::string to_string(const char* v) { return std::string(v); } +inline std::ostream& operator<<(std::ostream& os, + const std::vector>& lod) { + os << "{"; + for (auto& v : lod) { + os << "{"; + bool is_first = true; + for (auto& i : v) { + if (is_first) { + os << i; + is_first = false; + } else { + os << ", " << i; + } + } + os << "}"; + } + os << "}"; + + return os; +} + } // namespace string } // namespace paddle