From 094e3b8cf7bf8cf824955c0b23f6a2d6ba546b5b Mon Sep 17 00:00:00 2001 From: engineer1109 Date: Mon, 30 Jan 2023 16:03:24 +0800 Subject: [PATCH] add phi tensor vector array api from fluid (#49885) replace all TensorFromVector & TensorToVector AssignKernel async copy --- paddle/fluid/framework/tensor_util.h | 2 +- paddle/phi/core/tensor_utils.cc | 445 ++++++++++++++++++ paddle/phi/core/tensor_utils.h | 17 + paddle/phi/kernels/assign_kernel.cc | 10 +- paddle/phi/kernels/cpu/adam_kernel.cc | 3 +- paddle/phi/kernels/cpu/adamw_kernel.cc | 4 +- paddle/phi/kernels/cpu/cross_grad_kernel.cc | 12 +- paddle/phi/kernels/cpu/cross_kernel.cc | 6 +- .../kernels/cpu/index_sample_grad_kernel.cc | 8 +- paddle/phi/kernels/cpu/index_sample_kernel.cc | 8 +- .../cpu/repeat_interleave_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/roll_grad_kernel.cc | 6 +- paddle/phi/kernels/cpu/roll_kernel.cc | 6 +- .../kernels/cpu/unique_consecutive_functor.h | 7 +- .../funcs/repeat_tensor2index_tensor.h | 2 +- paddle/phi/kernels/funcs/unique_functor.h | 6 +- paddle/phi/kernels/gpu/adam_kernel.cu | 3 +- paddle/phi/kernels/gpu/adamw_kernel.cu | 3 +- .../kernels/gpu/class_center_sample_kernel.cu | 4 +- .../phi/kernels/gpu/diagonal_grad_kernel.cu | 8 +- paddle/phi/kernels/gpu/diagonal_kernel.cu | 8 +- .../gpu/margin_cross_entropy_grad_kernel.cu | 5 +- .../gpu/margin_cross_entropy_kernel.cu | 5 +- paddle/phi/kernels/gpu/prior_box_kernel.cu | 10 +- .../kernels/gpu/unique_consecutive_functor.h | 8 +- paddle/phi/kernels/gpu/unique_kernel.cu | 7 +- .../kernels/impl/determinant_kernel_impl.h | 6 +- .../impl/repeat_interleave_grad_kernel_impl.h | 2 +- .../impl/repeat_interleave_kernel_impl.h | 4 +- .../phi/kernels/impl/set_value_kernel_impl.h | 2 +- .../impl/slogdeterminant_kernel_impl.h | 6 +- .../kernels/selected_rows/cpu/adam_kernel.cc | 3 +- .../kernels/selected_rows/cpu/adamw_kernel.cc | 4 +- .../kernels/selected_rows/gpu/adam_kernel.cu | 3 +- .../kernels/selected_rows/gpu/adamw_kernel.cu | 2 +- .../kernels/selected_rows/xpu/adam_kernel.cc | 2 +- paddle/phi/kernels/xpu/adam_kernel.cc | 2 +- paddle/phi/kernels/xpu/adamw_kernel.cc | 4 +- paddle/phi/kernels/xpu/set_value_kernel.cc | 2 +- 39 files changed, 545 insertions(+), 102 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 603239067b..d61c062ac8 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -355,7 +355,7 @@ inline void TensorFromVector(const std::vector& src, reinterpret_cast(ctx).stream()); } #endif -#ifdef PADDLE_WITH_CUSTOM_DEICE +#ifdef PADDLE_WITH_CUSTOM_DEVICE else if (platform::is_custom_place(dst_place)) { // NOLINT auto stream = reinterpret_cast(ctx).stream(); diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 467552032f..e9ed973e0d 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -422,4 +422,449 @@ template void Copy(const OneDNNContext& dev_ctx, bool blocking, DenseTensor* dst); #endif + +template +void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst) { + auto dst_place = ctx.GetPlace(); + auto src_ptr = static_cast(src.data()); + phi::CPUPlace src_place; + dst->Resize({static_cast(src.size())}); + ctx.template Alloc(dst); + auto dst_ptr = static_cast(dst->data()); + auto size = src.size() * sizeof(T); + + if (paddle::platform::is_cpu_place(dst_place)) { + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + else if (paddle::platform::is_gpu_place(dst_place)) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src_place, + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + else if (paddle::platform::is_custom_place(dst_place)) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src_place, + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#ifdef PADDLE_WITH_XPU + else if (paddle::platform::is_xpu_place(dst_place)) { // NOLINT + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#endif + else { // NOLINT + PADDLE_THROW(phi::errors::Unimplemented( + "TensorFromVector on %s is not supported.", dst_place)); + } +} + +template <> +void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst) { + // vector has no data() member, use array instead. + // See details: + // https://stackoverflow.com/questions/46115669/why-does-stdvectorbool-have-no-data/46115714 + bool* array = new bool[src.size()]; + for (unsigned int i = 0; i < src.size(); i++) { + array[i] = static_cast(src[i]); + } + + auto dst_place = ctx.GetPlace(); + auto src_ptr = static_cast(array); + phi::CPUPlace src_place{}; + dst->Resize({static_cast(src.size())}); + auto dst_ptr = ctx.template Alloc(dst); + auto size = src.size() * sizeof(bool); + + if (paddle::platform::is_cpu_place(dst_place)) { + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#ifdef PADDLE_WITH_CUDA + else if (paddle::platform::is_gpu_place(dst_place)) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src_place, + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + else if (paddle::platform::is_custom_place(dst_place)) { // NOLINT + auto stream = reinterpret_cast(ctx).stream(); + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream); + } +#endif +#ifdef PADDLE_WITH_XPU + else if (paddle::platform::is_xpu_place(dst_place)) { // NOLINT + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#endif + else { // NOLINT + PADDLE_THROW(phi::errors::Unimplemented( + "TensorFromVector on %s is not supported.", dst_place)); + } + delete[] array; +} + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector( + const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector( + const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector>( + const std::vector>& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromVector>( + const std::vector>& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template +void TensorFromArray(const T* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst) { + auto dst_place = ctx.GetPlace(); + auto src_ptr = static_cast(src); + phi::CPUPlace src_place; + dst->Resize({static_cast(array_size)}); + ctx.template Alloc(dst); + auto dst_ptr = static_cast(dst->data()); + auto size = array_size * sizeof(T); + + if (paddle::platform::is_cpu_place(dst_place)) { + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + else if (paddle::platform::is_gpu_place(dst_place)) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src_place, + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + else if (paddle::platform::is_custom_place(dst_place)) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src_place, + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#ifdef PADDLE_WITH_XPU + else if (paddle::platform::is_xpu_place(dst_place)) { // NOLINT + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#endif + else { // NOLINT + PADDLE_THROW(phi::errors::Unimplemented( + "TensorFromArray on %s is not supported.", dst_place)); + } +} + +template void TensorFromArray(const bool* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray(const int16_t* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray(const int* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray(const int64_t* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray(const float* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray(const double* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray( + const phi::dtype::bfloat16* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray( + const phi::dtype::float16* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray>( + const phi::dtype::complex* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template void TensorFromArray>( + const phi::dtype::complex* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template +void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst) { + auto src_ptr = static_cast(src.data()); + auto size = src.numel() * sizeof(T); + + phi::CPUPlace dst_place{}; + dst->resize(src.numel()); + auto dst_ptr = static_cast(dst->data()); + + if (paddle::platform::is_cpu_place(src.place())) { + paddle::memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + else if (paddle::platform::is_gpu_place(src.place())) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src.place(), + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#if defined(PADDLE_WITH_XPU) + else if (paddle::platform::is_xpu_place(src.place())) { // NOLINT + paddle::memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + else if (paddle::platform::is_custom_place(src.place())) { // NOLINT + paddle::memory::Copy( + dst_place, dst_ptr, src.place(), src_ptr, size, nullptr); + } +#endif + else { // NOLINT + PADDLE_THROW(phi::errors::Unimplemented( + "TensorToVector on %s is not supported.", src.place())); + } +} + +template <> +void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst) { + auto src_ptr = static_cast(src.data()); + auto size = src.numel() * sizeof(bool); + + bool* array = new bool[src.numel()]; + + phi::CPUPlace dst_place{}; + dst->resize(src.numel()); + auto dst_ptr = static_cast(array); + + if (paddle::platform::is_cpu_place(src.place())) { + paddle::memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + else if (paddle::platform::is_gpu_place(src.place())) { // NOLINT + paddle::memory::Copy( + dst_place, + dst_ptr, + src.place(), + src_ptr, + size, + reinterpret_cast(ctx).stream()); + } +#endif +#if defined(PADDLE_WITH_XPU) + else if (paddle::platform::is_xpu_place(src.place())) { // NOLINT + paddle::memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + else if (paddle::platform::is_custom_place(src.place())) { // NOLINT + paddle::memory::Copy( + dst_place, dst_ptr, src.place(), src_ptr, size, nullptr); + } +#endif + for (unsigned int i = 0; i < src.numel(); i++) { + (*dst)[i] = static_cast(array[i]); + } + delete[] array; +} + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector>* dst); + +template void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector>* dst); + +template +void TensorToVector(const phi::DenseTensor& src, std::vector* dst) { + auto src_ptr = static_cast(src.data()); + auto size = src.numel() * sizeof(T); + + phi::CPUPlace dst_place{}; + dst->resize(src.numel()); + auto dst_ptr = static_cast(dst->data()); + + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(src.place()), + true, + phi::errors::InvalidArgument( + "The input tensor should be CPU device, but actually it is in %s.", + src.place())); + + paddle::memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); +} + +template <> +void TensorToVector(const phi::DenseTensor& src, std::vector* dst) { + auto src_ptr = static_cast(src.data()); + auto size = src.numel() * sizeof(bool); + + bool* array = new bool[src.numel()]; + + paddle::platform::CPUPlace dst_place{}; + dst->resize(src.numel()); + auto dst_ptr = static_cast(array); + + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(src.place()), + true, + phi::errors::InvalidArgument( + "The input tensor should be CPU device, but actually it is in %s.", + src.place())); + + paddle::memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); + + for (unsigned int i = 0; i < src.numel(); i++) { + (*dst)[i] = static_cast(array[i]); + } + delete[] array; +} + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector>* dst); + +template void TensorToVector(const phi::DenseTensor& src, + std::vector>* dst); + } // namespace phi diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index ceb46e2abe..fe0393c791 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -109,4 +110,20 @@ void Copy(const Context& dev_ctx, bool blocking, SparseCsrTensor* dst); +template +void TensorFromVector(const std::vector& src, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template +void TensorFromArray(const T* src, + const size_t& array_size, + const phi::DeviceContext& ctx, + phi::DenseTensor* dst); + +template +void TensorToVector(const phi::DenseTensor& src, + const phi::DeviceContext& ctx, + std::vector* dst); + } // namespace phi diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index e9f2c7547d..2b9dfe66e2 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/assign_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/optional.h" @@ -25,7 +24,7 @@ template void AssignKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { - paddle::framework::TensorCopy(x, x.place(), out); + phi::Copy(dev_ctx, x, x.place(), false, out); } template @@ -65,15 +64,14 @@ typename std::enable_if::value>::type CopyVectorToTensor( for (const auto& val : values) { assign_values.emplace_back(val.to()); } - paddle::framework::TensorFromVector(assign_values, dev_ctx, out); + phi::TensorFromVector(assign_values, dev_ctx, out); // use the array to replace to vector bool* array_ptr = new T[assign_values.size()]; for (unsigned int i = 0; i < assign_values.size(); i++) { array_ptr[i] = static_cast(assign_values[i]); } - paddle::framework::TensorFromArray( - array_ptr, assign_values.size(), dev_ctx, out); + phi::TensorFromArray(array_ptr, assign_values.size(), dev_ctx, out); delete[] array_ptr; } @@ -87,7 +85,7 @@ typename std::enable_if::value>::type CopyVectorToTensor( for (const auto& val : values) { assign_values.emplace_back(val.to()); } - paddle::framework::TensorFromVector(assign_values, dev_ctx, out); + phi::TensorFromVector(assign_values, dev_ctx, out); } template diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index 03a75bd361..c850a0d774 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -16,7 +16,6 @@ #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -61,7 +60,7 @@ void AdamDenseKernel(const Context& dev_ctx, errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call diff --git a/paddle/phi/kernels/cpu/adamw_kernel.cc b/paddle/phi/kernels/cpu/adamw_kernel.cc index 9309213329..f04051c623 100644 --- a/paddle/phi/kernels/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/cpu/adamw_kernel.cc @@ -16,11 +16,11 @@ #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/adam_kernel.h" #include "paddle/phi/kernels/funcs/adam_functors.h" @@ -61,7 +61,7 @@ void AdamwDenseKernel(const Context& dev_ctx, errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } VLOG(3) << "Skip update" << skip_update_; diff --git a/paddle/phi/kernels/cpu/cross_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_grad_kernel.cc index af573cfacf..21c63b28b3 100644 --- a/paddle/phi/kernels/cpu/cross_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_grad_kernel.cc @@ -14,10 +14,10 @@ #include "paddle/phi/kernels/cross_grad_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { @@ -82,9 +82,9 @@ void CrossGradKernel(const Context &dev_ctx, } std::vector input_x_vec, input_y_vec, input_dout_vec; - paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); - paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); - paddle::framework::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec); + phi::TensorToVector(input_x, dev_ctx, &input_x_vec); + phi::TensorToVector(input_y, dev_ctx, &input_y_vec); + phi::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec); std::vector out_dx_vec(output_x_grad->numel()); std::vector out_dy_vec(output_y_grad->numel()); @@ -106,8 +106,8 @@ void CrossGradKernel(const Context &dev_ctx, } } } - paddle::framework::TensorFromVector(out_dx_vec, dev_ctx, output_x_grad); - paddle::framework::TensorFromVector(out_dy_vec, dev_ctx, output_y_grad); + phi::TensorFromVector(out_dx_vec, dev_ctx, output_x_grad); + phi::TensorFromVector(out_dy_vec, dev_ctx, output_y_grad); output_x_grad->Resize(input_x_dims); output_y_grad->Resize(input_x_dims); } diff --git a/paddle/phi/kernels/cpu/cross_kernel.cc b/paddle/phi/kernels/cpu/cross_kernel.cc index a321617dea..a37efa2d3c 100644 --- a/paddle/phi/kernels/cpu/cross_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_kernel.cc @@ -81,8 +81,8 @@ void CrossKernel(const Context& dev_ctx, } std::vector input_x_vec, input_y_vec; - paddle::framework::TensorToVector(input_x, dev_ctx, &input_x_vec); - paddle::framework::TensorToVector(input_y, dev_ctx, &input_y_vec); + phi::TensorToVector(input_x, dev_ctx, &input_x_vec); + phi::TensorToVector(input_y, dev_ctx, &input_y_vec); std::vector out_vec(output->numel()); dev_ctx.template Alloc(output); @@ -100,7 +100,7 @@ void CrossKernel(const Context& dev_ctx, } } } - paddle::framework::TensorFromVector(out_vec, dev_ctx, output); + phi::TensorFromVector(out_vec, dev_ctx, output); output->Resize(input_x_dims); } diff --git a/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc b/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc index 42aef3cc24..808173e892 100644 --- a/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc @@ -14,10 +14,10 @@ #include "paddle/phi/kernels/index_sample_grad_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/utils/data_type.h" namespace phi { template @@ -27,8 +27,8 @@ void IndexSampleGradInner(const Context& context, DenseTensor* x_grad) { std::vector out_grad_vec; std::vector index_vec; - paddle::framework::TensorToVector(out_grad, context, &out_grad_vec); - paddle::framework::TensorToVector(index, context, &index_vec); + phi::TensorToVector(out_grad, context, &out_grad_vec); + phi::TensorToVector(index, context, &index_vec); auto index_dims = index.dims(); auto x_grad_dims = x_grad->dims(); @@ -63,7 +63,7 @@ void IndexSampleGradInner(const Context& context, x_grad_vec[v_i] += out_grad_vec[i]; } context.template Alloc(x_grad); - paddle::framework::TensorFromVector(x_grad_vec, context, x_grad); + phi::TensorFromVector(x_grad_vec, context, x_grad); x_grad->Resize(x_grad_dims); } diff --git a/paddle/phi/kernels/cpu/index_sample_kernel.cc b/paddle/phi/kernels/cpu/index_sample_kernel.cc index e51d06c442..de37392849 100644 --- a/paddle/phi/kernels/cpu/index_sample_kernel.cc +++ b/paddle/phi/kernels/cpu/index_sample_kernel.cc @@ -21,10 +21,10 @@ #include #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/utils/data_type.h" namespace phi { template @@ -42,8 +42,8 @@ void IndexSampleInner(const Context &context, std::vector input_vec; std::vector index_vec; - paddle::framework::TensorToVector(input, context, &input_vec); - paddle::framework::TensorToVector(index, context, &index_vec); + phi::TensorToVector(input, context, &input_vec); + phi::TensorToVector(index, context, &index_vec); std::vector res(index_ids_num); for (int i = 0; i < index_ids_num; i++) { @@ -76,7 +76,7 @@ void IndexSampleInner(const Context &context, auto ddim = phi::make_ddim({batch_size, index_length}); context.template Alloc(output); - paddle::framework::TensorFromVector(res, context, output); + phi::TensorFromVector(res, context, output); output->Resize(ddim); } diff --git a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc index 8f4af6a82c..d37647e72c 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc @@ -94,7 +94,7 @@ void RepeatInterleaveGradKernel(const Context& ctx, std::fill_n(index_vec.begin() + i * repeats, repeats, i); } index.Resize(phi::make_ddim({index_size})); - paddle::framework::TensorFromVector(index_vec, &index); + phi::TensorFromVector(index_vec, ctx, &index); const DenseTensor index_copy = index; IndexSelectGradInner(ctx, out_grad, index_copy, x_grad, dim); } diff --git a/paddle/phi/kernels/cpu/roll_grad_kernel.cc b/paddle/phi/kernels/cpu/roll_grad_kernel.cc index b3bd27fca1..c348bfe300 100644 --- a/paddle/phi/kernels/cpu/roll_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/roll_grad_kernel.cc @@ -14,8 +14,8 @@ #include "paddle/phi/kernels/roll_grad_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/roll_kernel_impl.h" namespace phi { @@ -28,7 +28,7 @@ void RollGradKernel(const Context& dev_ctx, const std::vector& axis, DenseTensor* x_grad) { std::vector out_vec; - paddle::framework::TensorToVector(out_grad, dev_ctx, &out_vec); + phi::TensorToVector(out_grad, dev_ctx, &out_vec); auto shifts_data = shifts.GetData(); size_t nums = shifts_data.size(); @@ -46,7 +46,7 @@ void RollGradKernel(const Context& dev_ctx, } dev_ctx.template Alloc(x_grad); - paddle::framework::TensorFromVector(out_vec, dev_ctx, x_grad); + phi::TensorFromVector(out_vec, dev_ctx, x_grad); x_grad->Resize(out_grad.dims()); } diff --git a/paddle/phi/kernels/cpu/roll_kernel.cc b/paddle/phi/kernels/cpu/roll_kernel.cc index 67eb80304d..a1c4b24117 100644 --- a/paddle/phi/kernels/cpu/roll_kernel.cc +++ b/paddle/phi/kernels/cpu/roll_kernel.cc @@ -14,10 +14,10 @@ #include "paddle/phi/kernels/roll_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/roll_kernel_impl.h" namespace phi { @@ -29,7 +29,7 @@ void RollKernel(const Context& dev_ctx, const std::vector& axis, DenseTensor* out) { std::vector out_vec; - paddle::framework::TensorToVector(x, dev_ctx, &out_vec); + phi::TensorToVector(x, dev_ctx, &out_vec); auto shifts_data = shifts.GetData(); size_t nums = shifts_data.size(); @@ -57,7 +57,7 @@ void RollKernel(const Context& dev_ctx, ShiftAlongDim(out_vec.data(), input_dim, dims[i], shifts_data[i]); } dev_ctx.template Alloc(out); - paddle::framework::TensorFromVector(out_vec, dev_ctx, out); + phi::TensorFromVector(out_vec, dev_ctx, out); out->Resize(x.dims()); } diff --git a/paddle/phi/kernels/cpu/unique_consecutive_functor.h b/paddle/phi/kernels/cpu/unique_consecutive_functor.h index 85081e5806..314c371bf7 100644 --- a/paddle/phi/kernels/cpu/unique_consecutive_functor.h +++ b/paddle/phi/kernels/cpu/unique_consecutive_functor.h @@ -14,9 +14,8 @@ #pragma once -#include "paddle/fluid/framework/tensor_util.h" - #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/unique_functor.h" @@ -210,10 +209,10 @@ static void UniqueConsecutiveDim(const Context& context, phi::funcs::TransCompute( out_trans.dims().size(), context, out_trans, out, permute); if (return_inverse) { - paddle::framework::TensorFromVector(inverse_vec, context, inverse); + phi::TensorFromVector(inverse_vec, context, inverse); } if (return_counts) { - paddle::framework::TensorFromVector(counts_vec, context, count); + phi::TensorFromVector(counts_vec, context, count); } } diff --git a/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h b/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h index 545ecb660f..8b2732bc0f 100644 --- a/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h +++ b/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.h @@ -42,7 +42,7 @@ void RepeatsTensor2IndexTensor(const Context& ctx, } index->Resize(phi::make_ddim({index_size})); - paddle::framework::TensorFromVector(index_vec, ctx, index); + phi::TensorFromVector(index_vec, ctx, index); } } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/unique_functor.h b/paddle/phi/kernels/funcs/unique_functor.h index edd3935ef7..510236e278 100644 --- a/paddle/phi/kernels/funcs/unique_functor.h +++ b/paddle/phi/kernels/funcs/unique_functor.h @@ -312,15 +312,15 @@ static void UniqueDim(const Context& context, out_trans.dims().size(), context, out_trans, out, permute); if (return_inverse) { - paddle::framework::TensorFromVector(inverse_vec, context, index); + phi::TensorFromVector(inverse_vec, context, index); } if (return_counts) { - paddle::framework::TensorFromVector(counts_vec, context, count); + phi::TensorFromVector(counts_vec, context, count); } if (return_index) { - paddle::framework::TensorFromVector(indices_vec, context, indices); + phi::TensorFromVector(indices_vec, context, indices); } } diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index c4c9ff9e06..7da864de8b 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -18,7 +18,6 @@ #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" @@ -162,7 +161,7 @@ void AdamDenseKernel(const Context& dev_ctx, errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 2252deb1da..29e9b984e3 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -18,7 +18,6 @@ #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/bfloat16.h" @@ -181,7 +180,7 @@ void AdamwDenseKernel(const Context& dev_ctx, errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } diff --git a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu b/paddle/phi/kernels/gpu/class_center_sample_kernel.cu index c6fae6ee9e..a98fdfaa8f 100644 --- a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/class_center_sample_kernel.cu @@ -31,6 +31,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/tensor_utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group.h" @@ -344,8 +345,7 @@ void ClassCenterSampleKernel(const Context& dev_ctx, std::vector shard_dim_vec(nranks + 1, 0); shard_dim_vec[rank + 1] = num_classes; DenseTensor num_classes_per_device; - paddle::framework::TensorFromVector( - shard_dim_vec, dev_ctx, &num_classes_per_device); + phi::TensorFromVector(shard_dim_vec, dev_ctx, &num_classes_per_device); T* num_classes_per_device_ptr = num_classes_per_device.data(); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu index a65d9af75f..bac9a297b5 100644 --- a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu @@ -14,9 +14,9 @@ #include "paddle/phi/kernels/diagonal_grad_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/diagonal.h" namespace phi { @@ -38,8 +38,7 @@ void DiagonalGradKernel(const Context& dev_ctx, std::vector res_dout = vectorize(phi::stride(dout->dims())); DenseTensor dout_stride_tensor; - paddle::framework::TensorFromVector( - res_dout, dev_ctx, &dout_stride_tensor); + phi::TensorFromVector(res_dout, dev_ctx, &dout_stride_tensor); int64_t* dout_stride = dout_stride_tensor.data(); auto* dx = in_grad; @@ -49,8 +48,7 @@ void DiagonalGradKernel(const Context& dev_ctx, std::vector res_dx = vectorize(phi::stride(dx->dims())); DenseTensor dx_stride_tensor; - paddle::framework::TensorFromVector( - res_dx, dev_ctx, &dx_stride_tensor); + phi::TensorFromVector(res_dx, dev_ctx, &dx_stride_tensor); int64_t* dx_stride = dx_stride_tensor.data(); const int64_t offset_ = offset; diff --git a/paddle/phi/kernels/gpu/diagonal_kernel.cu b/paddle/phi/kernels/gpu/diagonal_kernel.cu index 74e7db258c..2acc527e9b 100644 --- a/paddle/phi/kernels/gpu/diagonal_kernel.cu +++ b/paddle/phi/kernels/gpu/diagonal_kernel.cu @@ -14,9 +14,9 @@ #include "paddle/phi/kernels/diagonal_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/diagonal.h" namespace phi { @@ -35,8 +35,7 @@ void DiagonalKernel(const Context& dev_ctx, std::vector res_in = vectorize(phi::stride(input->dims())); DenseTensor input_stride_tensor; - paddle::framework::TensorFromVector( - res_in, dev_ctx, &input_stride_tensor); + phi::TensorFromVector(res_in, dev_ctx, &input_stride_tensor); int64_t* input_stride = input_stride_tensor.data(); auto* output = out; @@ -46,8 +45,7 @@ void DiagonalKernel(const Context& dev_ctx, std::vector res_out = vectorize(phi::stride(output->dims())); DenseTensor output_stride_tensor; - paddle::framework::TensorFromVector( - res_out, dev_ctx, &output_stride_tensor); + phi::TensorFromVector(res_out, dev_ctx, &output_stride_tensor); int64_t* output_stride = output_stride_tensor.data(); const int64_t offset_ = offset; diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu index e3db3ee70c..5913031091 100644 --- a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu @@ -62,13 +62,12 @@ void GetClassInterval(const gpuStream_t& stream, std::vector shard_dim_vec(nranks + 1, 0); shard_dim_vec[rank + 1] = D; if (nranks <= 1) { - paddle::framework::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); + phi::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); return; } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DenseTensor num_classes_per_device; - paddle::framework::TensorFromVector( - shard_dim_vec, dev_ctx, &num_classes_per_device); + phi::TensorFromVector(shard_dim_vec, dev_ctx, &num_classes_per_device); int* num_classes_per_device_ptr = num_classes_per_device.data(); auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu index a0674ae5e7..122e4a6b99 100644 --- a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu @@ -57,14 +57,13 @@ void GetClassInterval(const gpuStream_t& stream, std::vector shard_dim_vec(nranks + 1, 0); shard_dim_vec[rank + 1] = D; if (nranks <= 1) { - paddle::framework::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); + phi::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); return; } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DenseTensor num_classes_per_device; - paddle::framework::TensorFromVector( - shard_dim_vec, dev_ctx, &num_classes_per_device); + phi::TensorFromVector(shard_dim_vec, dev_ctx, &num_classes_per_device); int* num_classes_per_device_ptr = num_classes_per_device.data(); auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); diff --git a/paddle/phi/kernels/gpu/prior_box_kernel.cu b/paddle/phi/kernels/gpu/prior_box_kernel.cu index 317f2a3231..b54ec6abbc 100644 --- a/paddle/phi/kernels/gpu/prior_box_kernel.cu +++ b/paddle/phi/kernels/gpu/prior_box_kernel.cu @@ -17,10 +17,10 @@ #include #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { @@ -160,15 +160,15 @@ void PriorBoxKernel(const Context& ctx, ctx.template Alloc(var); DenseTensor r; - paddle::framework::TensorFromVector(new_aspect_ratios, ctx, &r); + phi::TensorFromVector(new_aspect_ratios, ctx, &r); DenseTensor min; - paddle::framework::TensorFromVector(min_sizes, ctx, &min); + phi::TensorFromVector(min_sizes, ctx, &min); T* max_data = nullptr; DenseTensor max; if (max_sizes.size() > 0) { - paddle::framework::TensorFromVector(max_sizes, ctx, &max); + phi::TensorFromVector(max_sizes, ctx, &max); max_data = max.data(); } @@ -189,7 +189,7 @@ void PriorBoxKernel(const Context& ctx, min_max_aspect_ratios_order); DenseTensor v; - paddle::framework::TensorFromVector(variances, ctx, &v); + phi::TensorFromVector(variances, ctx, &v); grid = (box_num * 4 + block - 1) / block; SetVariance<<>>( var->data(), v.data(), variances.size(), box_num * 4); diff --git a/paddle/phi/kernels/gpu/unique_consecutive_functor.h b/paddle/phi/kernels/gpu/unique_consecutive_functor.h index e603f69503..d70813c84a 100644 --- a/paddle/phi/kernels/gpu/unique_consecutive_functor.h +++ b/paddle/phi/kernels/gpu/unique_consecutive_functor.h @@ -24,8 +24,6 @@ #include #include -#include "paddle/fluid/framework/tensor_util.h" - #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" @@ -293,8 +291,8 @@ void IndexSelect(const Context& context, std::vector input_vec; std::vector index_vec; - paddle::framework::TensorToVector(input, context, &input_vec); - paddle::framework::TensorToVector(index, context, &index_vec); + phi::TensorToVector(input, context, &input_vec); + phi::TensorToVector(index, context, &index_vec); std::vector out_vec(output->numel()); for (int i = 0; i < index_size; i++) { @@ -331,7 +329,7 @@ void IndexSelect(const Context& context, } } context.template Alloc(output); - paddle::framework::TensorFromVector(out_vec, context, output); + phi::TensorFromVector(out_vec, context, output); output->Resize(output_dim); } diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index d420c8f438..c073708ed8 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -26,7 +26,6 @@ #include #include -#include "paddle/fluid/framework/tensor_util.h" // TensorToVector() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -127,8 +126,8 @@ void IndexSelect(const Context& context, std::vector input_vec; std::vector index_vec; - paddle::framework::TensorToVector(input, context, &input_vec); - paddle::framework::TensorToVector(index, context, &index_vec); + phi::TensorToVector(input, context, &input_vec); + phi::TensorToVector(index, context, &index_vec); std::vector out_vec(output->numel()); for (int i = 0; i < index_size; i++) { @@ -165,7 +164,7 @@ void IndexSelect(const Context& context, } } context.template Alloc(output); - paddle::framework::TensorFromVector(out_vec, context, output); + phi::TensorFromVector(out_vec, context, output); output->Resize(output_dim); } diff --git a/paddle/phi/kernels/impl/determinant_kernel_impl.h b/paddle/phi/kernels/impl/determinant_kernel_impl.h index 18fb152b28..5c7a16045c 100644 --- a/paddle/phi/kernels/impl/determinant_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_kernel_impl.h @@ -20,8 +20,8 @@ #include #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/determinant_kernel.h" namespace phi { @@ -71,7 +71,7 @@ struct DeterminantFunctor { DenseTensor* output) { std::vector input_vec; std::vector output_vec; - paddle::framework::TensorToVector(input, dev_ctx, &input_vec); + phi::TensorToVector(input, dev_ctx, &input_vec); for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel auto begin_iter = input_vec.begin() + i * rank * rank; auto end_iter = input_vec.begin() + (i + 1) * rank * rank; @@ -85,7 +85,7 @@ struct DeterminantFunctor { } output_vec.push_back(matrix.determinant()); } - paddle::framework::TensorFromVector(output_vec, output); + phi::TensorFromVector(output_vec, dev_ctx, output); } }; diff --git a/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h b/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h index 5e90028527..feb10d08d4 100644 --- a/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/repeat_interleave_grad_kernel_impl.h @@ -205,7 +205,7 @@ void RepeatInterleaveGradKernel(const Context& ctx, std::fill_n(index_vec.begin() + i * repeats, repeats, i); } index.Resize(phi::make_ddim({index_size})); - paddle::framework::TensorFromVector(index_vec, ctx, &index); + phi::TensorFromVector(index_vec, ctx, &index); const int* index_data = index.data(); int64_t index_nums = index.numel(); diff --git a/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h b/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h index 8548785aac..d8a65afaf2 100644 --- a/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h +++ b/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h @@ -75,7 +75,7 @@ void RepeatInterleaveKernel(const Context& ctx, index.Resize(phi::make_ddim({index_size})); if (place == cpu_place) { DenseTensor x_copy = x; - paddle::framework::TensorFromVector(index_vec, &index); + phi::TensorFromVector(index_vec, ctx, &index); auto output_dim = phi::vectorize(x.dims()); output_dim[dim] = index_size; @@ -85,7 +85,7 @@ void RepeatInterleaveKernel(const Context& ctx, } else { auto stride_dim = phi::stride(input_dim); int64_t stride = stride_dim[dim]; - paddle::framework::TensorFromVector(index_vec, ctx, &index); + phi::TensorFromVector(index_vec, ctx, &index); auto stream = ctx.stream(); auto output_dim = phi::vectorize(x.dims()); output_dim[dim] = index_size; diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h index a0f594e9d5..d92fcfd32a 100644 --- a/paddle/phi/kernels/impl/set_value_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -316,7 +316,7 @@ void SetValueKernel(const Context& dev_ctx, assgin_values.push_back(val.to()); } DenseTensor value_tensor = Empty(dev_ctx, shape); - paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor); + phi::TensorFromVector(assgin_values, dev_ctx, &value_tensor); value_tensor.Resize(phi::make_ddim(shape)); SetTensorValueKernel(dev_ctx, diff --git a/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h b/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h index a6e5060385..6f590f6246 100644 --- a/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h +++ b/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h @@ -18,8 +18,8 @@ #include #include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h" #include "paddle/phi/kernels/slogdeterminant_kernel.h" @@ -41,7 +41,7 @@ struct SlogDeterminantFunctor { std::vector sign_vec; std::vector log_vec; std::vector output_vec; - paddle::framework::TensorToVector(input, dev_ctx, &input_vec); + phi::TensorToVector(input, dev_ctx, &input_vec); for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel auto begin_iter = input_vec.begin() + i * rank * rank; auto end_iter = input_vec.begin() + (i + 1) * rank * rank; @@ -65,7 +65,7 @@ struct SlogDeterminantFunctor { // merge sign_vec and log_vec as final output_vec output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); - paddle::framework::TensorFromVector(output_vec, output); + phi::TensorFromVector(output_vec, dev_ctx, output); } }; diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index de8b4eae4f..b58bcd0258 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/selected_rows/adam_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -60,7 +59,7 @@ void AdamDenseParamSparseGradKernel( errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call diff --git a/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc index 6d2fc164d6..d2cd9f0617 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc @@ -14,10 +14,10 @@ #include "paddle/phi/kernels/selected_rows/adamw_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/adam_kernel.h" #include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/selected_rows/adam_kernel.h" @@ -61,7 +61,7 @@ void AdamwDenseParamSparseGradKernel( errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } VLOG(3) << "Skip update" << skip_update_; diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 9ac8244340..a4b3f14306 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/selected_rows/adam_kernel.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" @@ -129,7 +128,7 @@ void AdamDenseParamSparseGradKernel( errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 6dbea3a7ff..90c95492ee 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -146,7 +146,7 @@ void AdamwDenseParamSparseGradKernel( errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index 1f8aa70054..e648350fd4 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -156,7 +156,7 @@ void AdamDenseParamSparseGradKernel( errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index e17af77aba..9389ce5e53 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -124,7 +124,7 @@ void AdamDenseKernel(const Context& dev_ctx, errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index 20f226c850..0e27f686ad 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -20,8 +20,6 @@ #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" -// for TensorToVector -#include "paddle/fluid/framework/tensor_util.h" namespace phi { @@ -61,7 +59,7 @@ void AdamwDenseKernel(const Context& dev_ctx, errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", skip_update->numel())); std::vector skip_update_vec; - paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); skip_update_ = skip_update_vec[0]; } if (skip_update_) { diff --git a/paddle/phi/kernels/xpu/set_value_kernel.cc b/paddle/phi/kernels/xpu/set_value_kernel.cc index fa26f75553..3d37204337 100644 --- a/paddle/phi/kernels/xpu/set_value_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_kernel.cc @@ -365,7 +365,7 @@ void SetValueKernel(const Context& dev_ctx, assgin_values.push_back(val.to()); } DenseTensor value_tensor = Empty(dev_ctx, shape); - paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor); + phi::TensorFromVector(assgin_values, dev_ctx, &value_tensor); value_tensor.Resize(phi::make_ddim(shape)); SetTensorValueKernel(dev_ctx, -- GitLab