From 2739bd733e81f288b62797f28e53bb8465e2dbc3 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 24 Jun 2022 17:25:24 +0800 Subject: [PATCH] [Phi]Change Copy from Kernel to basic component utils (#43622) * perfect copy * deal with conflict * deal with conflict * fix compile bugs * fix unittest bugs * change code format * deal with conflict * modify code by review * fix ce bugs * fix ce bugs * add lo * perfect code format * deal with conflicts --- .../data_structure_tests/eager_tensor_test.cc | 10 +- .../eager/tests/task_tests/backward_test.cc | 77 +++-- .../fluid/eager/tests/task_tests/grad_test.cc | 104 ++++-- paddle/fluid/jit/layer_test.cc | 2 +- paddle/fluid/operators/transpose_op.cu.h | 318 ++++++++++++------ paddle/phi/api/lib/data_transform.cc | 2 +- paddle/phi/api/lib/tensor_copy.cc | 15 +- paddle/phi/api/lib/tensor_method.cc | 52 ++- paddle/phi/core/CMakeLists.txt | 40 +++ .../copy_kernel.cu => core/tensor_utils.cc} | 129 ++++++- paddle/phi/core/tensor_utils.h | 15 + paddle/phi/kernels/CMakeLists.txt | 3 +- paddle/phi/kernels/assign_kernel.cc | 2 +- paddle/phi/kernels/autotune/auto_tune_test.cu | 2 +- paddle/phi/kernels/copy_kernel.h | 27 -- paddle/phi/kernels/cpu/adam_kernel.cc | 2 +- paddle/phi/kernels/cpu/copy_kernel.cc | 61 ---- .../kernels/cpu/cross_entropy_grad_kernel.cc | 2 +- .../phi/kernels/cpu/cross_entropy_kernel.cc | 2 +- .../cpu/elementwise_divide_grad_kernel.cc | 2 +- .../kernels/cpu/elementwise_grad_kernel.cc | 2 +- .../cpu/elementwise_multiply_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/index_select_impl.h | 2 +- .../phi/kernels/cpu/kthvalue_grad_kernel.cc | 2 +- .../phi/kernels/cpu/layer_norm_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/mode_grad_kernel.cc | 2 +- .../kernels/cpu/put_along_axis_grad_kernel.cc | 2 +- .../phi/kernels/cpu/put_along_axis_kernel.cc | 2 +- paddle/phi/kernels/cpu/rnn_functor.h | 2 +- paddle/phi/kernels/cpu/rnn_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/rnn_kernel.cc | 2 +- paddle/phi/kernels/cpu/scatter_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/scatter_kernel.cc | 2 +- .../kernels/cpu/scatter_nd_add_grad_kernel.cc | 2 +- .../phi/kernels/cpu/scatter_nd_add_kernel.cc | 2 +- .../phi/kernels/cpu/viterbi_decode_kernel.cc | 2 +- paddle/phi/kernels/flatten_grad_kernel.cc | 2 +- paddle/phi/kernels/flatten_kernel.cc | 2 +- paddle/phi/kernels/funcs/mode.h | 2 +- paddle/phi/kernels/funcs/strided_slice.h | 2 +- paddle/phi/kernels/gpu/adam_kernel.cu | 2 +- paddle/phi/kernels/gpu/adamw_kernel.cu | 2 +- paddle/phi/kernels/gpu/arange_kernel.cu | 2 +- .../kernels/gpu/cross_entropy_grad_kernel.cu | 2 +- .../phi/kernels/gpu/cross_entropy_kernel.cu | 2 +- .../gpu/elementwise_add_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/elementwise_grad.h | 2 +- .../kernels/gpu/elementwise_grad_kernel.cu | 2 +- .../gpu/elementwise_subtract_grad_kernel.cu | 2 +- .../kernels/gpu/instance_norm_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/linspace_kernel.cu | 2 +- paddle/phi/kernels/gpu/logspace_kernel.cu | 2 +- .../phi/kernels/gpu/psroi_pool_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/psroi_pool_kernel.cu | 2 +- .../kernels/gpu/put_along_axis_grad_kernel.cu | 2 +- .../phi/kernels/gpu/put_along_axis_kernel.cu | 2 +- paddle/phi/kernels/gpu/scatter_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/scatter_kernel.cu | 2 +- .../kernels/gpu/scatter_nd_add_grad_kernel.cu | 2 +- .../phi/kernels/gpu/scatter_nd_add_kernel.cu | 2 +- .../gpu/sigmoid_cross_entropy_with_logits.h | 2 +- paddle/phi/kernels/gpu/top_k_kernel.cu | 2 +- paddle/phi/kernels/gpu/unique_kernel.cu | 2 +- .../phi/kernels/gpu/viterbi_decode_kernel.cu | 2 +- .../impl/cholesky_solve_grad_kernel_impl.h | 1 - .../kernels/impl/cholesky_solve_kernel_impl.h | 1 - .../impl/determinant_grad_kernel_impl.h | 2 +- .../impl/elementwise_grad_kernel_impl.h | 2 +- .../kernels/impl/expand_as_grad_kernel_impl.h | 2 +- .../kernels/impl/expand_grad_kernel_impl.h | 2 +- .../phi/kernels/impl/meshgrid_kernel_impl.h | 2 +- .../kernels/impl/set_value_grad_kernel_impl.h | 2 +- .../phi/kernels/impl/set_value_kernel_impl.h | 2 +- paddle/phi/kernels/impl/size_kernel_impl.h | 2 +- .../kernels/impl/squeeze_grad_kernel_impl.h | 2 +- paddle/phi/kernels/impl/squeeze_kernel_impl.h | 2 +- .../phi/kernels/impl/tile_grad_kernel_impl.h | 2 +- .../impl/triangular_solve_grad_kernel_impl.h | 2 +- .../kernels/impl/unsqueeze_grad_kernel_impl.h | 2 +- .../phi/kernels/impl/unsqueeze_kernel_impl.h | 2 +- paddle/phi/kernels/impl/warpctc_kernel_impl.h | 2 +- paddle/phi/kernels/reshape_grad_kernel.cc | 2 +- paddle/phi/kernels/reshape_kernel.cc | 2 +- paddle/phi/kernels/reverse_kernel.cc | 2 +- .../phi/kernels/selected_rows/copy_kernel.cc | 49 --- .../phi/kernels/selected_rows/copy_kernel.h | 31 -- .../kernels/selected_rows/cpu/adam_kernel.cc | 2 +- .../kernels/selected_rows/gpu/adam_kernel.cu | 2 +- .../kernels/selected_rows/gpu/adamw_kernel.cu | 2 +- paddle/phi/kernels/sparse/copy_kernel.cc | 2 +- .../sparse/cpu/convolution_grad_kernel.cc | 1 - paddle/phi/kernels/sparse/cpu/full_kernel.cc | 2 +- .../kernels/sparse/cpu/sparse_mask_kernel.cc | 2 +- .../sparse/cpu/sparse_pool_grad_kernel.cc | 2 +- paddle/phi/kernels/sparse/empty_kernel.cc | 2 +- .../phi/kernels/sparse/gpu/convolution.cu.h | 2 +- .../sparse/gpu/convolution_grad_kernel.cu | 2 +- paddle/phi/kernels/sparse/gpu/full_kernel.cu | 2 +- .../kernels/sparse/gpu/matmul_grad_kernel.cu | 2 +- .../phi/kernels/sparse/gpu/matmul_kernel.cu | 2 +- .../kernels/sparse/gpu/sparse_mask_kernel.cu | 2 +- .../sparse/gpu/sparse_pool_grad_kernel.cu | 2 +- .../phi/kernels/sparse/unary_grad_kernel.cc | 2 +- paddle/phi/kernels/sparse/unary_kernel.cc | 2 +- .../strings/gpu/strings_copy_kernel.cu | 2 +- paddle/phi/kernels/xpu/copy_kernel.cc | 80 ----- paddle/phi/tests/api/test_data_transform.cc | 1 - paddle/phi/tests/api/test_fill_api.cc | 1 - paddle/phi/tests/api/test_matmul_api.cc | 2 +- paddle/phi/tests/api/test_pten_tensor.cc | 6 - paddle/phi/tests/api/test_scale_api.cc | 1 - paddle/phi/tests/api/test_to_api.cc | 5 - paddle/phi/tests/common/test_int_array.cc | 2 - paddle/phi/tests/common/test_scalar.cu | 2 - paddle/phi/tests/kernels/test_copy_dev_api.cc | 2 +- .../phi/tests/kernels/test_flatten_dev_api.cc | 10 - .../kernels/test_sparse_conv3d_dev_api.cc | 2 +- .../tests/kernels/test_sparse_pool_dev_api.cc | 2 +- .../kernels/test_sparse_utils_dev_api.cc | 2 +- 119 files changed, 622 insertions(+), 606 deletions(-) rename paddle/phi/{kernels/gpu/copy_kernel.cu => core/tensor_utils.cc} (63%) delete mode 100644 paddle/phi/kernels/copy_kernel.h delete mode 100644 paddle/phi/kernels/cpu/copy_kernel.cc delete mode 100644 paddle/phi/kernels/selected_rows/copy_kernel.cc delete mode 100644 paddle/phi/kernels/selected_rows/copy_kernel.h delete mode 100644 paddle/phi/kernels/xpu/copy_kernel.cc diff --git a/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc b/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc index a82965303af..cd81d3e4829 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc +++ b/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc @@ -21,13 +21,6 @@ #include "paddle/phi/common/layout.h" #include "paddle/phi/core/kernel_registry.h" -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy_sr, CPU, ALL_LAYOUT); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy_sr, GPU, ALL_LAYOUT); -#endif - namespace eager_test { using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta; class AutogradMetaTest : public AbstractAutogradMeta { @@ -212,7 +205,8 @@ TEST(EagerVariable, Constructor) { TEST(EagerVariable, DataLayout) { paddle::experimental::Tensor tensor; phi::DenseTensorMeta meta = - phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim({1, 1, 1, 1}), + phi::DenseTensorMeta(phi::DataType::FLOAT32, + phi::make_ddim({1, 1, 1, 1}), paddle::experimental::DataLayout::UNDEFINED); std::shared_ptr dt = std::make_shared( std::make_unique( diff --git a/paddle/fluid/eager/tests/task_tests/backward_test.cc b/paddle/fluid/eager/tests/task_tests/backward_test.cc index c6d4514fa8e..c91ac93897c 100644 --- a/paddle/fluid/eager/tests/task_tests/backward_test.cc +++ b/paddle/fluid/eager/tests/task_tests/backward_test.cc @@ -30,7 +30,6 @@ #include "paddle/phi/core/tensor_meta.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); namespace egr { @@ -44,9 +43,12 @@ TEST(Backward, SingleNodeEmptyGrad) { // Create Target Tensor paddle::experimental::Tensor target_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); paddle::experimental::Tensor leaf_tensor; { @@ -92,17 +94,24 @@ TEST(Backward, SingleNodeCustomGrad) { paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); // Create Target Tensor - paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); target_tensors.emplace_back(std::move(tensor)); std::vector grad_tensors; // Create Grad Tensor paddle::experimental::Tensor grad_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 10.0 /*value*/, + false /*is_leaf*/); grad_tensors.emplace_back(std::move(grad_tensor)); paddle::experimental::Tensor leaf_tensor; @@ -157,9 +166,13 @@ TEST(Backward, LinearNodes) { paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); // Create Target Tensor - paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); target_tensors.emplace_back(std::move(tensor)); paddle::experimental::Tensor leaf_tensor; @@ -229,25 +242,39 @@ TEST(Backward, WithAccumulation) { // Create Target Tensor std::vector target_tensors; - paddle::experimental::Tensor tensor0 = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); - paddle::experimental::Tensor tensor1 = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor0 = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); + paddle::experimental::Tensor tensor1 = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); target_tensors.emplace_back(std::move(tensor0)); target_tensors.emplace_back(std::move(tensor1)); // Create Grad Tensor std::vector grad_tensors; paddle::experimental::Tensor grad_tensor0 = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 5.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 5.0 /*value*/, + false /*is_leaf*/); paddle::experimental::Tensor grad_tensor1 = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 10.0 /*value*/, + false /*is_leaf*/); grad_tensors.emplace_back(std::move(grad_tensor0)); grad_tensors.emplace_back(std::move(grad_tensor1)); diff --git a/paddle/fluid/eager/tests/task_tests/grad_test.cc b/paddle/fluid/eager/tests/task_tests/grad_test.cc index 8d6c4d7843f..30c0e92511a 100644 --- a/paddle/fluid/eager/tests/task_tests/grad_test.cc +++ b/paddle/fluid/eager/tests/task_tests/grad_test.cc @@ -29,7 +29,6 @@ #include "paddle/phi/core/tensor_meta.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); namespace egr { @@ -43,15 +42,21 @@ TEST(Grad, SingleNodeEmptyGrad) { // Create Target Tensor (output) paddle::experimental::Tensor output_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); // Create input tensor const paddle::experimental::Tensor leaf_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + true /*is_leaf*/); { // Create Scale Node @@ -103,23 +108,33 @@ TEST(Grad, SingleNodeCustomGrad) { paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); // Create Target Tensor - paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); target_tensors.emplace_back(std::move(tensor)); std::vector grad_tensors; // Create Grad Tensor paddle::experimental::Tensor grad_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 10.0 /*value*/, + false /*is_leaf*/); grad_tensors.emplace_back(std::move(grad_tensor)); paddle::experimental::Tensor leaf_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + true /*is_leaf*/); { // Create Scale Node @@ -172,15 +187,22 @@ TEST(Grad, LinearNodes) { paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); // Create Target Tensor - paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); target_tensors.emplace_back(std::move(tensor)); paddle::experimental::Tensor leaf_tensor = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + true /*is_leaf*/); { // Create Node0 auto node0_ptr = std::make_shared(1, 1); @@ -247,25 +269,39 @@ TEST(Grad, WithAccumulation) { // Create Target Tensor std::vector target_tensors; - paddle::experimental::Tensor tensor0 = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); - paddle::experimental::Tensor tensor1 = egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor0 = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); + paddle::experimental::Tensor tensor1 = + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 1.0 /*value*/, + false /*is_leaf*/); target_tensors.emplace_back(std::move(tensor0)); target_tensors.emplace_back(std::move(tensor1)); // Create Grad Tensor std::vector grad_tensors; paddle::experimental::Tensor grad_tensor0 = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 5.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 5.0 /*value*/, + false /*is_leaf*/); paddle::experimental::Tensor grad_tensor1 = - egr_utils_api::CreateTensorWithValue( - ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, - phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/); + egr_utils_api::CreateTensorWithValue(ddim, + paddle::platform::CPUPlace(), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + 10.0 /*value*/, + false /*is_leaf*/); grad_tensors.emplace_back(std::move(grad_tensor0)); grad_tensors.emplace_back(std::move(grad_tensor1)); diff --git a/paddle/fluid/jit/layer_test.cc b/paddle/fluid/jit/layer_test.cc index 881c0602920..ef35d254c57 100644 --- a/paddle/fluid/jit/layer_test.cc +++ b/paddle/fluid/jit/layer_test.cc @@ -21,7 +21,7 @@ #include "paddle/fluid/imperative/tracer.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/fluid/jit/layer.h" diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index f9d91fec4c3..1b90ad2c313 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -20,9 +20,9 @@ limitations under the License. */ #include "paddle/fluid/platform/fast_divmod.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" #include "paddle/phi/kernels/autotune/cache.h" -#include "paddle/phi/kernels/copy_kernel.h" namespace paddle { namespace operators { @@ -41,7 +41,9 @@ struct GreaterThan { // Value can be decided in compile time. template -constexpr bool CheckProperTileSize(int tile_long, int tile_short, int size_T, +constexpr bool CheckProperTileSize(int tile_long, + int tile_short, + int size_T, FUN op) { return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) || (tile_long == 2 * INT_32 && op(tile_short, 4)) || @@ -79,7 +81,8 @@ constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) { // Use SM to do data transfer, load a tile into SM then store out. // All tile read and write are colascing, so can speedup memory copy template -__global__ void TilingSwapDim1And2(const T* __restrict__ input, Dim3 input_dims, +__global__ void TilingSwapDim1And2(const T* __restrict__ input, + Dim3 input_dims, T* __restrict__ output) { assert(blockDim.x == NumThreads); assert(blockDim.y == 1); @@ -218,12 +221,14 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, Dim3 input_dims, template bool SelectProperTileSize(std::vector>* tiles) { PADDLE_ENFORCE_LE( - TSIZE, 16, + TSIZE, + 16, platform::errors::InvalidArgument( "The tile size should smaller than 16, but received is:%d.", TSIZE)); PADDLE_ENFORCE_EQ( - (TSIZE & (TSIZE - 1)), 0, + (TSIZE & (TSIZE - 1)), + 0, platform::errors::InvalidArgument( "Data types should be powers of 2, but reived size is:%d.", TSIZE)); @@ -269,29 +274,37 @@ struct SystemElemType<16> { }; template -void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, +void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, + int tile_size_i, + int tile_size_j, + int total_tiles_count, + const T* input, + const Dim3& input_dims, T* output) { constexpr int NumThreads = tile_long; if (tile_size_i <= tile_long && tile_size_j <= tile_short) { TilingSwapDim1And2 - <<>>(input, input_dims, - output); + <<>>( + input, input_dims, output); } else { TilingSwapDim1And2 - <<>>(input, input_dims, - output); + <<>>( + input, input_dims, output); } } template struct NarrowDims2TransposeDispatch { - static void DoTranspose(const phi::GPUContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, T* output) { + static void DoTranspose(const phi::GPUContext& d, + int tile_size_i, + int tile_size_j, + int total_tiles_count, + const T* input, + const Dim3& input_dims, + T* output) { PADDLE_ENFORCE_EQ( - (tile_long & (tile_long - 1)), 0, + (tile_long & (tile_long - 1)), + 0, platform::errors::InvalidArgument( "The length of the longer side of the tile should be power of 2." " But received value is:%d.", @@ -302,7 +315,12 @@ struct NarrowDims2TransposeDispatch { if (request_satisfied) { LaunchNarrowDims2TransposeKernel( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, output); return; } @@ -312,11 +330,21 @@ struct NarrowDims2TransposeDispatch { if (long_side_request_not_satisfied) { NarrowDims2TransposeDispatch::DoTranspose( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, output); } else { NarrowDims2TransposeDispatch::DoTranspose( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, output); } } @@ -325,14 +353,22 @@ struct NarrowDims2TransposeDispatch { // If Not long tile size, goto this function when compile. template struct NarrowDims2TransposeDispatch< - T, tile_long, tile_short, - typename std::enable_if< - CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> { - static void DoTranspose(const phi::GPUContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, T* output) { + T, + tile_long, + tile_short, + typename std::enable_if::type> { + static void DoTranspose(const phi::GPUContext& d, + int tile_size_i, + int tile_size_j, + int total_tiles_count, + const T* input, + const Dim3& input_dims, + T* output) { PADDLE_ENFORCE_EQ( - (tile_long & (tile_long - 1)), 0, + (tile_long & (tile_long - 1)), + 0, platform::errors::InvalidArgument( "The length of the longer side of the tile should be power of 2." " But received value is:%d.", @@ -343,13 +379,23 @@ struct NarrowDims2TransposeDispatch< if (request_satisfied) { LaunchNarrowDims2TransposeKernel( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, output); return; } NarrowDims2TransposeDispatch::DoTranspose( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, output); } }; @@ -357,34 +403,49 @@ struct NarrowDims2TransposeDispatch< // If long tile size, goto this function when compile. template struct NarrowDims2TransposeDispatch< - T, tile_long, tile_short, + T, + tile_long, + tile_short, typename std::enable_if::type> { - static void DoTranspose(const phi::GPUContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, T* output) { + static void DoTranspose(const phi::GPUContext& d, + int tile_size_i, + int tile_size_j, + int total_tiles_count, + const T* input, + const Dim3& input_dims, + T* output) { PADDLE_ENFORCE_EQ( - (tile_long & (tile_long - 1)), 0, + (tile_long & (tile_long - 1)), + 0, platform::errors::InvalidArgument( "The length of the longer side of the tile should be power of 2," " but received is:%d.", tile_long)); LaunchNarrowDims2TransposeKernel( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, output); } }; template -void SwapDim1And2InNarrow(const phi::GPUContext& d, const T* input, - const Dim3& input_dims, T* output, +void SwapDim1And2InNarrow(const phi::GPUContext& d, + const T* input, + const Dim3& input_dims, + T* output, const int kMinTileSize) { // First get available tile sizes for the data type requested as backups std::vector> tile_sele; auto ret = SelectProperTileSize(&tile_sele); PADDLE_ENFORCE_EQ( - ret, true, + ret, + true, platform::errors::InvalidArgument( "SelectProperTileSize should return true, but return value is:%d.", ret)); @@ -451,16 +512,22 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, const T* input, using ElemType = typename SystemElemType::type; NarrowDims2TransposeDispatch::DoTranspose( - d, select_tile_size_i, select_tile_size_j, total_tiles_count, - reinterpret_cast(input), input_dims, + d, + select_tile_size_i, + select_tile_size_j, + total_tiles_count, + reinterpret_cast(input), + input_dims, reinterpret_cast(output)); } // This is for case that cannot do coalescing read and write. // Or input is too small to split into tiles. template -__global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, - Dim3 input_dims, T* __restrict__ output) { +__global__ void TransposeSimpleKernel(int nthreads, + const T* __restrict__ input, + Dim3 input_dims, + T* __restrict__ output) { Dim3 output_dims; output_dims[pos0] = input_dims[0]; output_dims[pos1] = input_dims[1]; @@ -482,8 +549,10 @@ __global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, // Here suppose convert all tensor to dim3, so just change dim1 and 2. template -void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input, - const Dim3& input_dims, T* output) { +void SendSwapDim1And2InTranspose(const phi::GPUContext& d, + const T* input, + const Dim3& input_dims, + T* output) { // Suppose tile size > 16 static const int kMinTileSize = 16; static const int kMinNarrowTileSize = 96; @@ -508,8 +577,8 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input, input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; TilingSwapDim1And2 - <<>>(input, input_dims, - output); + <<>>( + input, input_dims, output); } else if (narrow_tile) { // If input shape is like Rect, such as 2X100, use Narrow tile size. @@ -529,8 +598,10 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input, template struct SwapDim1And2InTranspose { typedef phi::GPUContext Device; - void operator()(const Device& d, const T* in, - const std::vector& combined_dims, T* out) { + void operator()(const Device& d, + const T* in, + const std::vector& combined_dims, + T* out) { Dim3 input_dims = {static_cast(combined_dims[0]), static_cast(combined_dims[1]), static_cast(combined_dims[2])}; @@ -541,8 +612,10 @@ struct SwapDim1And2InTranspose { template struct SwapDim0And2InTranspose { typedef phi::GPUContext Device; - void operator()(const Device& d, const T* in, - const std::vector& combined_dims, T* out) { + void operator()(const Device& d, + const T* in, + const std::vector& combined_dims, + T* out) { Dim3 input_dims = {static_cast(combined_dims[0]), static_cast(combined_dims[1]), static_cast(combined_dims[2])}; @@ -562,11 +635,13 @@ inline void CombineTransposeDim3(const framework::DDim& shape, const std::vector& perm, std::vector* new_perm, framework::DDim* new_dims) { - PADDLE_ENFORCE_EQ(shape.size(), perm.size(), + PADDLE_ENFORCE_EQ(shape.size(), + perm.size(), platform::errors::InvalidArgument( " shape should have the save dim with perm, but" " received shape size is:%d, perm size is:%d.", - shape.size(), perm.size())); + shape.size(), + perm.size())); std::vector dim_vec; if (shape.size() == 1) { @@ -614,8 +689,10 @@ inline void CombineTransposeDim3(const framework::DDim& shape, template struct TransposeSimple { - static bool run(const phi::GPUContext& ctx, const Tensor& in, - const std::vector perm, Tensor* out) { + static bool run(const phi::GPUContext& ctx, + const Tensor& in, + const std::vector perm, + Tensor* out) { // First reduce the dimensions of the input tensor if possible. std::vector new_perm; framework::DDim new_dims; @@ -805,7 +882,8 @@ __global__ void VectorizedPermuteKernel(PermuteParams params, // A general kernel for normal case, only support vectorized write. template __global__ void GeneralPermuteKernel(PermuteParams params, - const T* __restrict__ src, T* dst, + const T* __restrict__ src, + T* dst, const size_t main_cnt, const size_t tail_cnt, const size_t offset) { @@ -859,10 +937,12 @@ __global__ void GeneralPermuteKernel(PermuteParams params, // A Gerneral permute method that drectly find the dst data // coordinate in the source data. template -inline void LaunchPermuteKernel(const phi::GPUContext& ctx, const IndexT count, +inline void LaunchPermuteKernel(const phi::GPUContext& ctx, + const IndexT count, const PermuteType perm_type, const std::vector& dims, - const std::vector& perm, const T* src, + const std::vector& perm, + const T* src, T* dst) { size_t main_count = count / VecSize; auto params = PermuteParams(dims, perm); @@ -871,15 +951,13 @@ inline void LaunchPermuteKernel(const phi::GPUContext& ctx, const IndexT count, if (perm_type == PermuteType::kNormalPermute) { size_t tail_count = count - main_count * VecSize; size_t offset = count - tail_count; - GeneralPermuteKernel< - T, IndexT, VecSize, - Rank><<>>( - params, src, dst, main_count, tail_count, offset); + GeneralPermuteKernel + <<>>( + params, src, dst, main_count, tail_count, offset); } else { - VectorizedPermuteKernel< - T, IndexT, VecSize, - Rank><<>>( - params, main_count, src, dst); + VectorizedPermuteKernel + <<>>( + params, main_count, src, dst); } } @@ -889,12 +967,13 @@ inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, const PermuteType perm_type, const std::vector& dims, const std::vector& perm, - const T* src, T* dst) { -#define CALL_DISPATCH_RANK(rank) \ - case rank: { \ - LaunchPermuteKernel(ctx, count, perm_type, dims, \ - perm, src, dst); \ - break; \ + const T* src, + T* dst) { +#define CALL_DISPATCH_RANK(rank) \ + case rank: { \ + LaunchPermuteKernel( \ + ctx, count, perm_type, dims, perm, src, dst); \ + break; \ } switch (dims.size()) { @@ -915,7 +994,9 @@ inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, // https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ template __global__ void BatchTransposeKernel(const T* __restrict__ src_data, - T* dst_data, IndexT rows, IndexT cols) { + T* dst_data, + IndexT rows, + IndexT cols) { using VecT = phi::AlignedVector; __shared__ VecT tile[kTileSize][kShareCol]; @@ -961,10 +1042,13 @@ __global__ void BatchTransposeKernel(const T* __restrict__ src_data, // With the byte limitation of shared_memory, the VecSize shall be restricted // for the type whose byte-size is less than 8. -template 8 ? 1 : Size)> inline void LaunchTransposeKernel(const phi::GPUContext& ctx, - const std::vector& dims, const T* src, + const std::vector& dims, + const T* src, T* dst) { auto rank = dims.size(); IndexT num_batches = (rank == 2) ? 1 : dims[0]; @@ -976,9 +1060,8 @@ inline void LaunchTransposeKernel(const phi::GPUContext& ctx, dim3 blocks(num_tile_cols, num_tile_rows, num_batches); dim3 threads(kTileSize, kBlockRows, 1); - BatchTransposeKernel<<>>( - src, dst, rows, cols); + BatchTransposeKernel + <<>>(src, dst, rows, cols); } template @@ -987,16 +1070,18 @@ inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, const PermuteType perm_type, const std::vector& dims, const std::vector& perm, - const T* src, T* dst, IndexT count) { -#define CALL_DISPATCH_VEC_SIZE(vec_size) \ - case vec_size: { \ - if (perm_type == PermuteType::kTranspose) { \ - LaunchTransposeKernel(ctx, dims, src, dst); \ - } else { \ - LaunchPermuteRankDispatch(ctx, count, perm_type, \ - dims, perm, src, dst); \ - } \ - break; \ + const T* src, + T* dst, + IndexT count) { +#define CALL_DISPATCH_VEC_SIZE(vec_size) \ + case vec_size: { \ + if (perm_type == PermuteType::kTranspose) { \ + LaunchTransposeKernel(ctx, dims, src, dst); \ + } else { \ + LaunchPermuteRankDispatch( \ + ctx, count, perm_type, dims, perm, src, dst); \ + } \ + break; \ } switch (vec_size) { @@ -1014,45 +1099,64 @@ inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, template inline void LaunchWithDispatchIndex(const phi::GPUContext& ctx, - const size_t count, const int vec_size, + const size_t count, + const int vec_size, const PermuteType perm_type, const std::vector& dims, - const std::vector& perm, const T* src, + const std::vector& perm, + const T* src, T* dst) { if (count < std::numeric_limits::max()) { - LaunchWithDispatchVecSize(ctx, vec_size, perm_type, dims, perm, - src, dst, + LaunchWithDispatchVecSize(ctx, + vec_size, + perm_type, + dims, + perm, + src, + dst, static_cast(count)); } else { int64_t cnt = static_cast(count); - LaunchWithDispatchVecSize(ctx, vec_size, perm_type, dims, perm, - src, dst, + LaunchWithDispatchVecSize(ctx, + vec_size, + perm_type, + dims, + perm, + src, + dst, static_cast(count)); } } template -inline void SimplifyThenLaunch(const int rank, const DeviceContext& ctx, - const Tensor& in, Tensor* out, +inline void SimplifyThenLaunch(const int rank, + const DeviceContext& ctx, + const Tensor& in, + Tensor* out, const std::vector& perm) { int sm_count = ctx.GetSMCount(); auto src_dims = phi::vectorize(in.dims()); - auto simplifier = DimsSimplifier(sm_count, rank, perm, src_dims, - in.data(), out->data()); + auto simplifier = DimsSimplifier( + sm_count, rank, perm, src_dims, in.data(), out->data()); if (simplifier.GetPermType() == PermuteType::kCopy) { // If perm is [0,1,2,3], then just operate a DtoD copy. phi::Copy(ctx, in, ctx.GetPlace(), false, out); } else { - LaunchWithDispatchIndex( - ctx, simplifier.GetCount(), simplifier.GetVecSize(), - simplifier.GetPermType(), simplifier.GetDims(), simplifier.GetPerm(), - in.data(), out->data()); + LaunchWithDispatchIndex(ctx, + simplifier.GetCount(), + simplifier.GetVecSize(), + simplifier.GetPermType(), + simplifier.GetDims(), + simplifier.GetPerm(), + in.data(), + out->data()); } } template -size_t GetTransposeKey(const int rank, const Tensor& in, +size_t GetTransposeKey(const int rank, + const Tensor& in, const std::vector& perm) { auto in_shape = phi::vectorize(in.dims()); return phi::autotune::GetKey( @@ -1060,15 +1164,19 @@ size_t GetTransposeKey(const int rank, const Tensor& in, } template -void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int rank, +void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, + const int rank, const Tensor& in, - const std::vector& perm, Tensor* out) { + const std::vector& perm, + Tensor* out) { PADDLE_ENFORCE_LT( - rank, phi::DDim::kMaxRank, + rank, + phi::DDim::kMaxRank, platform::errors::OutOfRange( "The maximum dimension rank of " "tensor is expected to be less than %d, but here is %d.", - phi::DDim::kMaxRank, rank)); + phi::DDim::kMaxRank, + rank)); auto ret = TransposeSimple::run(dev_ctx, in, perm, out); if (!ret) { diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 4803616812c..4dafc7a7ee5 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -19,8 +19,8 @@ limitations under the License. */ #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/transfer_layout_kernel.h" #include "paddle/fluid/framework/tensor_util.h" diff --git a/paddle/phi/api/lib/tensor_copy.cc b/paddle/phi/api/lib/tensor_copy.cc index fb18a3b05c7..5b0bb52daae 100644 --- a/paddle/phi/api/lib/tensor_copy.cc +++ b/paddle/phi/api/lib/tensor_copy.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { @@ -31,10 +32,7 @@ void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place)); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - VLOG(6) << "copy API kernel key: " << kernel_key; - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy", kernel_key); - VLOG(6) << "copy API kernel: " << kernel; + VLOG(6) << "start copy. "; auto target_place = phi::TransToPhiPlace(kernel_key.backend()); auto& pool = paddle::experimental::DeviceContextPool::Instance(); @@ -47,14 +45,9 @@ void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { phi::MetaTensor meta_out(kernel_out); phi::UnchangedInferMeta(*dense_x, &meta_out); - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - phi::Place, - bool, - phi::DenseTensor*); + phi::Copy(*dev_ctx, *dense_x, place, blocking, kernel_out); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out); + VLOG(6) << "copy finished. "; } } // namespace experimental diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index fbeeb3332ea..2ead95e11b7 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -19,9 +19,11 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/tensor_base.h" +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/include/sparse_api.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" // clang-format off @@ -113,9 +115,15 @@ void Tensor::copy_(const Tensor &src, // Deep Copy AutoGrad info from src to self. *autograd_meta_ = *(src.autograd_meta_); } - + kernel_key_set.backend_set = + kernel_key_set.backend_set | + BackendSet(phi::TransToPhiBackend(target_place)); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto *dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto place = phi::TransToPhiPlace(kernel_key.backend()); + auto& pool = paddle::experimental::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetMutable( + place.GetType() == target_place.GetType() ? target_place : place); + Backend kernel_backend = Backend::UNDEFINED; DataLayout kernel_layout = DataLayout::UNDEFINED; DataType kernel_data_type = DataType::UNDEFINED; @@ -135,49 +143,29 @@ void Tensor::copy_(const Tensor &src, } if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "copy API kernel key: " << kernel_key; - VLOG(6) << "copy API kernel: " << kernel; - using kernel_signature = void (*)(const platform::DeviceContext &, - const phi::DenseTensor &, - phi::Place, - bool, - phi::DenseTensor *); SetKernelOutput(kernel_backend, this); phi::MetaTensor meta_out(impl_.get()); phi::UnchangedInferMeta( MakeMetaTensor( *(std::static_pointer_cast(src.impl_))), &meta_out); - auto *kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, - (*(std::static_pointer_cast(src.impl_))), - target_place, - blocking, - static_cast(impl_.get())); + phi::Copy(*dev_ctx, + (*(std::static_pointer_cast(src.impl_))), + target_place, + blocking, + static_cast(impl_.get())); } else if (kernel_type == KernelType::SELECTED_ROWS_KENREL) { - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy_sr", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "copy API kernel key: " << kernel_key; - VLOG(6) << "copy API kernel: " << kernel; - using kernel_signature = void (*)(const platform::DeviceContext &, - const phi::SelectedRows &, - phi::Place, - bool, - phi::SelectedRows *); SetSelectedRowsKernelOutput(kernel_backend, this); phi::MetaTensor meta_out(impl_.get()); phi::UnchangedInferMeta( MakeMetaTensor( *(std::static_pointer_cast(src.impl_))), &meta_out); - auto *kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, - (*(std::static_pointer_cast(src.impl_))), - target_place, - blocking, - static_cast(impl_.get())); + phi::Copy(*dev_ctx, + (*(std::static_pointer_cast(src.impl_))), + target_place, + blocking, + static_cast(impl_.get())); } else if (kernel_type == KernelType::SPARSE_COO_KERNEL) { auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( "copy_sparse_coo", {kernel_backend, kernel_layout, kernel_data_type}); diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 8b180a2c2ae..d7ffa1b82f1 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -81,3 +81,43 @@ if(WITH_MKLDNN) add_dependencies(dense_tensor mkldnn) add_dependencies(tensor_base mkldnn) endif() + +if(WITH_GPU) + nv_library( + phi_tensor_utils + SRCS tensor_utils.cc + DEPS cpu_context + gpu_context + dense_tensor + selected_rows + malloc + memcpy + device_context) +elseif(WITH_ROCM) + hip_library( + phi_tensor_utils + SRCS tensor_utils.cc + DEPS cpu_context + gpu_context + dense_tensor + selected_rows + malloc + memcpy + device_context) +elseif(WITH_XPU_KP) + xpu_library( + phi_tensor_utils + SRCS tensor_utils.cc + DEPS cpu_context + xpu_context + dense_tensor + selected_rows + malloc + memcpy + device_context) +else() + cc_library( + phi_tensor_utils + SRCS tensor_utils.cc + DEPS cpu_context dense_tensor selected_rows malloc memcpy device_context) +endif() diff --git a/paddle/phi/kernels/gpu/copy_kernel.cu b/paddle/phi/core/tensor_utils.cc similarity index 63% rename from paddle/phi/kernels/gpu/copy_kernel.cu rename to paddle/phi/core/tensor_utils.cc index 16eff5b26e3..f6743a0c184 100644 --- a/paddle/phi/kernels/gpu/copy_kernel.cu +++ b/paddle/phi/core/tensor_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" @@ -43,30 +43,49 @@ void Copy(const Context& dev_ctx, void* dst_ptr = nullptr; if (paddle::platform::is_cpu_place(dst_place)) { dst_ptr = dev_ctx.HostAlloc(dst, src.dtype()); - } else if (paddle::platform::is_cuda_pinned_place(dst_place)) { - // now we only can use mutable_data to Alloc pinned memory here, - // dev_ctx can not alloc pinned memory now - dst_ptr = dst->mutable_data(dst_place, src.dtype()); - } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + } else if (paddle::platform::is_gpu_place(dst_place) || + paddle::platform::is_cuda_pinned_place(dst_place)) { dst_ptr = dev_ctx.Alloc( dst, src.dtype(), 0, paddle::platform::is_cuda_pinned_place(dst_place)); +#endif + +#ifdef PADDLE_WITH_XPU + } else if (paddle::platform::is_xpu_place(dst_place)) { + dst_ptr = dev_ctx.Alloc(dst, src.dtype()); +#endif + } + + auto size = src.numel() * paddle::experimental::SizeOf(src.dtype()); + if (UNLIKELY(size) == 0) { + return; } + PADDLE_ENFORCE_EQ( + dst->place(), + dst_place, + phi::errors::Unavailable( + "The Dst Tensor's place and dst_place do not match, Tensor's place " + "place is %s, dst_place is %s.", + dst->place(), + dst_place)); + if (src_ptr == dst_ptr && src_place == dst_place) { VLOG(3) << "Skip copy the same data async from " << src_place << " to " << dst_place; return; } VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; - CHECK(dst->layout() == src.layout()); - auto size = src.numel() * paddle::experimental::SizeOf(src.dtype()); - - if ((paddle::platform::is_cpu_place(src_place) || - paddle::platform::is_cuda_pinned_place(src_place)) && // NOLINT - (paddle::platform::is_cpu_place(dst_place) || - paddle::platform::is_cuda_pinned_place(dst_place))) { + if (paddle::platform::is_cpu_place(src_place) && + paddle::platform::is_cpu_place(dst_place)) { + paddle::memory::Copy(src_place, dst_ptr, src_place, src_ptr, size); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + } else if ((paddle::platform::is_cpu_place(src_place) || + paddle::platform::is_cuda_pinned_place(src_place)) && // NOLINT + (paddle::platform::is_cpu_place(dst_place) || + paddle::platform::is_cuda_pinned_place(dst_place))) { paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr); } else if (paddle::platform::is_gpu_place(src_place) && // NOLINT paddle::platform::is_cpu_place(dst_place)) { @@ -176,13 +195,87 @@ void Copy(const Context& dev_ctx, : reinterpret_cast(dev_ctx).stream(); paddle::memory::Copy( dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream); +#endif + } +#ifdef PADDLE_WITH_XPU + else if (paddle::platform::is_xpu_place(src_place) && // NOLINT + paddle::platform::is_cpu_place(dst_place)) { + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } else if (paddle::platform::is_cpu_place(src_place) && + paddle::platform::is_xpu_place(dst_place)) { + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } else if (paddle::platform::is_xpu_place(src_place) && + paddle::platform::is_xpu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } + paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Place type error. Please check the place of src and dst Tensor.")); + PADDLE_THROW(phi::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); } +#endif } -} // namespace phi +template +void Copy(const Context& dev_ctx, + const SelectedRows& src, + Place dst_place, + bool blocking, + SelectedRows* dst) { + if (src.value().Holder() != dst->value().Holder() || + src.value().data() != dst->value().data()) { + dst->set_rows(src.rows()); + dst->set_height(src.height()); + } + Copy( + dev_ctx, src.value(), dst_place, blocking, dst->mutable_value()); +} + +template void Copy(const CPUContext& dev_ctx, + const DenseTensor& src, + Place dst_place, + bool blocking, + DenseTensor* dst); + +template void Copy(const DeviceContext& dev_ctx, + const DenseTensor& src, + Place dst_place, + bool blocking, + DenseTensor* dst); + +template void Copy(const CPUContext& dev_ctx, + const SelectedRows& src, + Place dst_place, + bool blocking, + SelectedRows* dst); +template void Copy(const DeviceContext& dev_ctx, + const SelectedRows& src, + Place dst_place, + bool blocking, + SelectedRows* dst); -PD_REGISTER_GENERAL_KERNEL( - copy, GPU, ALL_LAYOUT, phi::Copy, ALL_DTYPE) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template void Copy(const GPUContext& dev_ctx, + const DenseTensor& src, + Place dst_place, + bool blocking, + DenseTensor* dst); +template void Copy(const GPUContext& dev_ctx, + const SelectedRows& src, + Place dst_place, + bool blocking, + SelectedRows* dst); +#endif + +#ifdef PADDLE_WITH_XPU +template void Copy(const XPUContext& dev_ctx, + const DenseTensor& src, + Place dst_place, + bool blocking, + DenseTensor* dst); +#endif + +} // namespace phi diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index abf8aeff4d3..1c490fd5393 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/tensor_meta.h" namespace phi { @@ -70,4 +71,18 @@ class DenseTensorUtils { } }; +template +void Copy(const Context& dev_ctx, + const DenseTensor& src, + Place dst_place, + bool blocking, + DenseTensor* dst); + +template +void Copy(const Context& dev_ctx, + const SelectedRows& src, + Place dst_place, + bool blocking, + SelectedRows* dst); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index ad71823e3c0..1611c89667c 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -29,7 +29,8 @@ set(COMMON_KERNEL_DEPS arg_map_context convert_utils lod_utils - custom_kernel) + custom_kernel + phi_tensor_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 3d8e4db08bb..16e9bb384b5 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/optional.h" namespace phi { diff --git a/paddle/phi/kernels/autotune/auto_tune_test.cu b/paddle/phi/kernels/autotune/auto_tune_test.cu index 8701a0572fc..d80790dbf2c 100644 --- a/paddle/phi/kernels/autotune/auto_tune_test.cu +++ b/paddle/phi/kernels/autotune/auto_tune_test.cu @@ -20,8 +20,8 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" namespace tune = phi::autotune; diff --git a/paddle/phi/kernels/copy_kernel.h b/paddle/phi/kernels/copy_kernel.h deleted file mode 100644 index 21b59d8d11b..00000000000 --- a/paddle/phi/kernels/copy_kernel.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void Copy(const Context& dev_ctx, - const DenseTensor& src, - Place dst_place, - bool blocking, - DenseTensor* dst); -} // namespace phi diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index 339d690310f..03e2a539640 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/adam_functors.h" DECLARE_int32(inner_op_parallelism); diff --git a/paddle/phi/kernels/cpu/copy_kernel.cc b/paddle/phi/kernels/cpu/copy_kernel.cc deleted file mode 100644 index fa11fd05bf1..00000000000 --- a/paddle/phi/kernels/cpu/copy_kernel.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/copy_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/compat/convert_utils.h" -#include "paddle/phi/core/kernel_registry.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/memory/memcpy.h" - -namespace phi { - -// NOTE(chenweihang): blocking is useless in cpu kernel -template -void Copy(const Context& dev_ctx, - const DenseTensor& src, - Place dst_place, - bool blocking, - DenseTensor* dst) { - auto* src_ptr = src.data(); - const auto& src_place = src.place(); - - VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " - << src_place; - - dst->Resize(src.dims()); - auto* dst_ptr = dev_ctx.HostAlloc(dst, src.dtype()); - - if (src_ptr == dst_ptr) { - VLOG(3) << "Skip copy the same data async from " << src_place << " to " - << src_place; - return; - } - VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; - CHECK(dst->layout() == src.layout()); - - auto size = src.numel() * paddle::experimental::SizeOf(src.dtype()); - - if (paddle::platform::is_cpu_place(src_place)) { - paddle::memory::Copy(src_place, dst_ptr, src_place, src_ptr, size); - } -} - -} // namespace phi - -PD_REGISTER_GENERAL_KERNEL( - copy, CPU, ALL_LAYOUT, phi::Copy, ALL_DTYPE) {} diff --git a/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc index 021fdac2253..305a9accc49 100644 --- a/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc @@ -16,8 +16,8 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/paddle/phi/kernels/cpu/cross_entropy_kernel.cc b/paddle/phi/kernels/cpu/cross_entropy_kernel.cc index bd3eb3eb754..27675fa8b5a 100644 --- a/paddle/phi/kernels/cpu/cross_entropy_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_entropy_kernel.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/softmax_kernel.h" diff --git a/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc index b6541ec0e68..a0e2611f92c 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index ee384cc7519..287d41b5455 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" diff --git a/paddle/phi/kernels/cpu/elementwise_multiply_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_multiply_grad_kernel.cc index 6055541c805..4cef9fef460 100644 --- a/paddle/phi/kernels/cpu/elementwise_multiply_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_multiply_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" diff --git a/paddle/phi/kernels/cpu/index_select_impl.h b/paddle/phi/kernels/cpu/index_select_impl.h index 163174580ff..24b561c0336 100644 --- a/paddle/phi/kernels/cpu/index_select_impl.h +++ b/paddle/phi/kernels/cpu/index_select_impl.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc b/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc index de7dfd167b7..386d41984b0 100644 --- a/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc index 081a32b4f24..58d69cb3454 100644 --- a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -22,7 +22,7 @@ #endif #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" diff --git a/paddle/phi/kernels/cpu/mode_grad_kernel.cc b/paddle/phi/kernels/cpu/mode_grad_kernel.cc index ca813c1757e..05675cf1ab4 100644 --- a/paddle/phi/kernels/cpu/mode_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/mode_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/mode.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc index e94d09e0337..ca92fcee121 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc index 83c9a915ee6..a297843b0c7 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/rnn_functor.h b/paddle/phi/kernels/cpu/rnn_functor.h index 911814647d6..e6139b45272 100644 --- a/paddle/phi/kernels/cpu/rnn_functor.h +++ b/paddle/phi/kernels/cpu/rnn_functor.h @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/operators/utils.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/paddle/phi/kernels/cpu/rnn_grad_kernel.cc b/paddle/phi/kernels/cpu/rnn_grad_kernel.cc index 1cd4add7d50..b4ec6652eb9 100644 --- a/paddle/phi/kernels/cpu/rnn_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/rnn_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/rnn_functor.h" #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/blas/blas.h" diff --git a/paddle/phi/kernels/cpu/rnn_kernel.cc b/paddle/phi/kernels/cpu/rnn_kernel.cc index e2e784b2943..c46bba8c23f 100644 --- a/paddle/phi/kernels/cpu/rnn_kernel.cc +++ b/paddle/phi/kernels/cpu/rnn_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cpu/rnn_functor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" diff --git a/paddle/phi/kernels/cpu/scatter_grad_kernel.cc b/paddle/phi/kernels/cpu/scatter_grad_kernel.cc index f09015f24a1..9fb1136e766 100644 --- a/paddle/phi/kernels/cpu/scatter_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/scatter_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.h" #include "paddle/phi/kernels/funcs/scatter.h" diff --git a/paddle/phi/kernels/cpu/scatter_kernel.cc b/paddle/phi/kernels/cpu/scatter_kernel.cc index 7032c3bb5a3..2c3e8a2f31d 100644 --- a/paddle/phi/kernels/cpu/scatter_kernel.cc +++ b/paddle/phi/kernels/cpu/scatter_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/scatter.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc b/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc index 7c3665c5d2e..844e6370caf 100644 --- a/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc b/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc index 31e2f4c7161..dcdec2343fb 100644 --- a/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc +++ b/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/scatter.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc b/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc index c98a098aa0e..ae6bb5ae4fc 100644 --- a/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc +++ b/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc @@ -21,7 +21,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/compare_functors.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" diff --git a/paddle/phi/kernels/flatten_grad_kernel.cc b/paddle/phi/kernels/flatten_grad_kernel.cc index 73d963f606e..031f4afe98b 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/flatten_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/flatten_kernel.cc b/paddle/phi/kernels/flatten_kernel.cc index 006d3438288..58ba3d70a34 100644 --- a/paddle/phi/kernels/flatten_kernel.cc +++ b/paddle/phi/kernels/flatten_kernel.cc @@ -16,8 +16,8 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/mode.h b/paddle/phi/kernels/funcs/mode.h index 3bd6c19545e..632b0ce7e15 100644 --- a/paddle/phi/kernels/funcs/mode.h +++ b/paddle/phi/kernels/funcs/mode.h @@ -35,7 +35,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/funcs/strided_slice.h b/paddle/phi/kernels/funcs/strided_slice.h index c39a9694e18..4d045bdeb59 100644 --- a/paddle/phi/kernels/funcs/strided_slice.h +++ b/paddle/phi/kernels/funcs/strided_slice.h @@ -20,7 +20,7 @@ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 1322428270d..59aa4cf597e 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -24,7 +24,7 @@ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index cead67fd39a..9ce4d229f10 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -24,7 +24,7 @@ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index 9ea0d7c5393..858191c44ee 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/range_function.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu index 94d91cbcbbd..5d40304c5e0 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu @@ -28,8 +28,8 @@ namespace cub = hipcub; #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 75a4658ee7d..1a4559d5cd6 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -28,8 +28,8 @@ namespace cub = hipcub; #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/elementwise_add_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_add_grad_kernel.cu index 517fbcba158..26ddb68c4b1 100644 --- a/paddle/phi/kernels/gpu/elementwise_add_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_add_grad_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index 9c1ced3c1bd..e8f01be8973 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/common/place.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_grad_base.h" #include "paddle/phi/kernels/funcs/reduce_function.h" diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 3e7430fd84e..4921cf884c4 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" diff --git a/paddle/phi/kernels/gpu/elementwise_subtract_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_subtract_grad_kernel.cu index 2edf7a132ed..376b2ec8424 100644 --- a/paddle/phi/kernels/gpu/elementwise_subtract_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_subtract_grad_kernel.cu @@ -17,7 +17,7 @@ #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/elementwise_grad_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h" diff --git a/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu index d7ea2340afc..35ac4233f37 100644 --- a/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu @@ -17,7 +17,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/norm_utils.h" #include "paddle/phi/kernels/gpu/instance_norm_utils.h" diff --git a/paddle/phi/kernels/gpu/linspace_kernel.cu b/paddle/phi/kernels/gpu/linspace_kernel.cu index 66a3f833d27..9db11381cbc 100644 --- a/paddle/phi/kernels/gpu/linspace_kernel.cu +++ b/paddle/phi/kernels/gpu/linspace_kernel.cu @@ -17,7 +17,7 @@ #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/logspace_kernel.cu b/paddle/phi/kernels/gpu/logspace_kernel.cu index 95a196fe1b2..b5e4904fdcf 100644 --- a/paddle/phi/kernels/gpu/logspace_kernel.cu +++ b/paddle/phi/kernels/gpu/logspace_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/data_type_transform.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu index 8b58340efd5..6ecaaef1870 100644 --- a/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/psroi_pool_kernel.h" diff --git a/paddle/phi/kernels/gpu/psroi_pool_kernel.cu b/paddle/phi/kernels/gpu/psroi_pool_kernel.cu index e0b17a55933..a8fed022f91 100644 --- a/paddle/phi/kernels/gpu/psroi_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/psroi_pool_kernel.cu @@ -20,7 +20,7 @@ #include "paddle/fluid/memory/memory.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index f553da361f1..62c93a989e5 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu index d363c0c2836..b4fde608b1e 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/scatter_grad_kernel.cu b/paddle/phi/kernels/gpu/scatter_grad_kernel.cu index 7f93fd0a905..1750ad2a3ae 100644 --- a/paddle/phi/kernels/gpu/scatter_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_grad_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" diff --git a/paddle/phi/kernels/gpu/scatter_kernel.cu b/paddle/phi/kernels/gpu/scatter_kernel.cu index af8919bec41..a088754381d 100644 --- a/paddle/phi/kernels/gpu/scatter_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu b/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu index 66b373f3b28..135c683bedb 100644 --- a/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.cu.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu b/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu index a7b8bebd38c..563b8868ad3 100644 --- a/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h index c300b6d3f3d..84a24449b3a 100644 --- a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h +++ b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h @@ -22,7 +22,7 @@ #include "paddle/phi/backends/gpu/gpu_helper.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/reduce.h" diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index e5038e0f3be..e0b7bba50d6 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -17,7 +17,7 @@ #include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index 2f24a44c232..3d44c9af03c 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -28,7 +28,7 @@ #include "paddle/fluid/framework/tensor_util.h" // TensorToVector() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/unique_functor.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu index dc04c69ec70..224651326d7 100644 --- a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu +++ b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu @@ -33,7 +33,7 @@ namespace cub = hipcub; #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/compare_functors.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" diff --git a/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h index f68a3e59629..22bb4973ea4 100644 --- a/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h @@ -17,7 +17,6 @@ #include "paddle/phi/kernels/cholesky_solve_grad_kernel.h" #include "paddle/phi/kernels/cholesky_solve_kernel.h" #include "paddle/phi/kernels/complex_kernel.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" diff --git a/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h b/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h index 1cc8acc21f3..562ff25317e 100644 --- a/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h +++ b/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h @@ -16,7 +16,6 @@ #include "paddle/phi/kernels/cholesky_solve_kernel.h" #include "paddle/phi/kernels/complex_kernel.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" diff --git a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h index d9c3333fc24..248305b7fc0 100644 --- a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/determinant_grad_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 4b4a75727a5..da74280b267 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/phi/common/complex.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" diff --git a/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h b/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h index 6ef282d4703..998c54e77fe 100644 --- a/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/impl/expand_as_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/impl/expand_grad_kernel_impl.h b/paddle/phi/kernels/impl/expand_grad_kernel_impl.h index a4fc7157eea..31cb87da25f 100644 --- a/paddle/phi/kernels/impl/expand_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/impl/expand_kernel_impl.h" diff --git a/paddle/phi/kernels/impl/meshgrid_kernel_impl.h b/paddle/phi/kernels/impl/meshgrid_kernel_impl.h index e5e7f785b81..e66632498f6 100644 --- a/paddle/phi/kernels/impl/meshgrid_kernel_impl.h +++ b/paddle/phi/kernels/impl/meshgrid_kernel_impl.h @@ -16,7 +16,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/meshgrid_kernel.h" diff --git a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h index 40543645b01..de930734be6 100644 --- a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h @@ -16,7 +16,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h index 4859a7348e5..a0f594e9d58 100644 --- a/paddle/phi/kernels/impl/set_value_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -17,7 +17,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/paddle/phi/kernels/impl/size_kernel_impl.h b/paddle/phi/kernels/impl/size_kernel_impl.h index 7b781dba3ad..f9757bc4477 100644 --- a/paddle/phi/kernels/impl/size_kernel_impl.h +++ b/paddle/phi/kernels/impl/size_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h b/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h index c74aa5c7243..1e3dfd66ece 100644 --- a/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { template diff --git a/paddle/phi/kernels/impl/squeeze_kernel_impl.h b/paddle/phi/kernels/impl/squeeze_kernel_impl.h index bb1627d4092..b4c94d619cc 100644 --- a/paddle/phi/kernels/impl/squeeze_kernel_impl.h +++ b/paddle/phi/kernels/impl/squeeze_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" namespace phi { diff --git a/paddle/phi/kernels/impl/tile_grad_kernel_impl.h b/paddle/phi/kernels/impl/tile_grad_kernel_impl.h index 9e56e50534d..05f9139b148 100644 --- a/paddle/phi/kernels/impl/tile_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/tile_grad_kernel_impl.h @@ -16,7 +16,7 @@ #include #include -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/tile_grad_kernel.h" diff --git a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h index 3ea75b036a5..8faca812a02 100644 --- a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" diff --git a/paddle/phi/kernels/impl/unsqueeze_grad_kernel_impl.h b/paddle/phi/kernels/impl/unsqueeze_grad_kernel_impl.h index 54b332ea4c8..ff45ec49b7c 100644 --- a/paddle/phi/kernels/impl/unsqueeze_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/unsqueeze_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { template diff --git a/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h b/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h index 02110d631fb..4f81fa6c423 100644 --- a/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h +++ b/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" namespace phi { diff --git a/paddle/phi/kernels/impl/warpctc_kernel_impl.h b/paddle/phi/kernels/impl/warpctc_kernel_impl.h index 6c792507c6f..c8f8d28ce11 100644 --- a/paddle/phi/kernels/impl/warpctc_kernel_impl.h +++ b/paddle/phi/kernels/impl/warpctc_kernel_impl.h @@ -20,7 +20,7 @@ #include "paddle/fluid/operators/math/sequence_scale.h" #include "paddle/phi/backends/dynload/warpctc.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/utils/optional.h" diff --git a/paddle/phi/kernels/reshape_grad_kernel.cc b/paddle/phi/kernels/reshape_grad_kernel.cc index 35f85ba86aa..c4b92c4f760 100644 --- a/paddle/phi/kernels/reshape_grad_kernel.cc +++ b/paddle/phi/kernels/reshape_grad_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/reshape_kernel.cc b/paddle/phi/kernels/reshape_kernel.cc index a723ea19d34..632a63c9ab7 100644 --- a/paddle/phi/kernels/reshape_kernel.cc +++ b/paddle/phi/kernels/reshape_kernel.cc @@ -16,8 +16,8 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { diff --git a/paddle/phi/kernels/reverse_kernel.cc b/paddle/phi/kernels/reverse_kernel.cc index c6c2781a07b..d89e68e7389 100644 --- a/paddle/phi/kernels/reverse_kernel.cc +++ b/paddle/phi/kernels/reverse_kernel.cc @@ -16,7 +16,7 @@ #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { diff --git a/paddle/phi/kernels/selected_rows/copy_kernel.cc b/paddle/phi/kernels/selected_rows/copy_kernel.cc deleted file mode 100644 index cf71ab0583f..00000000000 --- a/paddle/phi/kernels/selected_rows/copy_kernel.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/selected_rows/copy_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" -namespace phi { -namespace sr { - -template -void Copy(const Context& dev_ctx, - const SelectedRows& src, - Place dst_place, - bool blocking, - SelectedRows* dst) { - if (src.value().Holder() != dst->value().Holder() || - src.value().data() != dst->value().data()) { - dst->set_rows(src.rows()); - dst->set_height(src.height()); - } - phi::Copy( - dev_ctx, src.value(), dst_place, blocking, dst->mutable_value()); -} - -} // namespace sr -} // namespace phi - -PD_REGISTER_GENERAL_KERNEL( - copy_sr, CPU, ALL_LAYOUT, phi::sr::Copy, ALL_DTYPE) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_GENERAL_KERNEL( - copy_sr, GPU, ALL_LAYOUT, phi::sr::Copy, ALL_DTYPE) {} -#endif diff --git a/paddle/phi/kernels/selected_rows/copy_kernel.h b/paddle/phi/kernels/selected_rows/copy_kernel.h deleted file mode 100644 index 4aa848bea2a..00000000000 --- a/paddle/phi/kernels/selected_rows/copy_kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/selected_rows.h" -#include "paddle/phi/core/sparse_csr_tensor.h" - -namespace phi { -namespace sr { - -template -void Copy(const Context& dev_ctx, - const SelectedRows& src, - Place dst_place, - bool blocking, - SelectedRows* dst); - -} // namespace sr -} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index d96c707538e..ba5d6feb48f 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/adam_functors.h" namespace phi { diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 842e05fe58e..9aecbb8e99c 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -20,7 +20,7 @@ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 6e0123d2fca..e04784c2620 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -24,7 +24,7 @@ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/paddle/phi/kernels/sparse/copy_kernel.cc b/paddle/phi/kernels/sparse/copy_kernel.cc index 705c19e020c..76726f0ffcc 100644 --- a/paddle/phi/kernels/sparse/copy_kernel.cc +++ b/paddle/phi/kernels/sparse/copy_kernel.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { namespace sparse { diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 5a981fb8df3..a675853ac47 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" diff --git a/paddle/phi/kernels/sparse/cpu/full_kernel.cc b/paddle/phi/kernels/sparse/cpu/full_kernel.cc index 3c8be166262..b848751deb9 100644 --- a/paddle/phi/kernels/sparse/cpu/full_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/full_kernel.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc index 37579ae8564..cf2acd85573 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc @@ -18,8 +18,8 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc index fdf8e5aa7eb..64c843c07a6 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" diff --git a/paddle/phi/kernels/sparse/empty_kernel.cc b/paddle/phi/kernels/sparse/empty_kernel.cc index 4b7a5fe615a..2d04f935214 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.cc +++ b/paddle/phi/kernels/sparse/empty_kernel.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { namespace sparse { diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index 24a7387d4fe..d56575cddbf 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/utils.cu.h" diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 805c417b8db..1f82f2ff93e 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -20,8 +20,8 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" diff --git a/paddle/phi/kernels/sparse/gpu/full_kernel.cu b/paddle/phi/kernels/sparse/gpu/full_kernel.cu index 500217d6edc..a3dc5a9534b 100644 --- a/paddle/phi/kernels/sparse/gpu/full_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/full_kernel.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" namespace phi { diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index 8bc162eaae2..d5c128fea6f 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index df5a4b57520..9357bbd2ad0 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu index 2153d9dfe68..21d6850bdc4 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu index 669ecb017dc..5fe6e68c1e8 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu @@ -18,8 +18,8 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/pooling.h" diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.cc b/paddle/phi/kernels/sparse/unary_grad_kernel.cc index 1fd3ef27112..cd844532e93 100644 --- a/paddle/phi/kernels/sparse/unary_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/unary_grad_kernel.cc @@ -19,8 +19,8 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/activation_grad_kernel.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #define DEFINE_SPARSE_UNARY_GRAD_KERNEL(DenseKernelFunc) \ diff --git a/paddle/phi/kernels/sparse/unary_kernel.cc b/paddle/phi/kernels/sparse/unary_kernel.cc index e02d7757664..2999536b34e 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.cc +++ b/paddle/phi/kernels/sparse/unary_kernel.cc @@ -19,8 +19,8 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/activation_kernel.h" -#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #define DEFINE_SPARSE_UNARY_KERNEL(DenseKernelFunc) \ diff --git a/paddle/phi/kernels/strings/gpu/strings_copy_kernel.cu b/paddle/phi/kernels/strings/gpu/strings_copy_kernel.cu index c49b41e0d3f..fb9d32264b0 100644 --- a/paddle/phi/kernels/strings/gpu/strings_copy_kernel.cu +++ b/paddle/phi/kernels/strings/gpu/strings_copy_kernel.cu @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/pstring.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/strings/gpu/copy_utils.h" diff --git a/paddle/phi/kernels/xpu/copy_kernel.cc b/paddle/phi/kernels/xpu/copy_kernel.cc deleted file mode 100644 index fb931ef18a8..00000000000 --- a/paddle/phi/kernels/xpu/copy_kernel.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/copy_kernel.h" - -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/compat/convert_utils.h" -#include "paddle/phi/core/kernel_registry.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/memory/memcpy.h" - -namespace phi { - -template -void Copy(const Context& dev_ctx, - const DenseTensor& src, - Place dst_place, - bool blocking, - DenseTensor* dst) { - auto* src_ptr = src.data(); - void* dst_ptr = nullptr; - - dst->Resize(src.dims()); - if (paddle::platform::is_cpu_place(dst_place)) { - dst_ptr = dev_ctx.HostAlloc(dst, src.dtype()); - } else { - dst_ptr = dev_ctx.Alloc(dst, src.dtype()); - } - const auto& src_place = src.place(); - - if (src_ptr == dst_ptr && src_place == dst_place) { - VLOG(3) << "Skip copy the same data async from " << src_place << " to " - << dst_place; - return; - } - VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; - - VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " - << dst_place; - - CHECK(dst->layout() == src.layout()); - auto size = src.numel() * paddle::experimental::SizeOf(src.dtype()); - - if (paddle::platform::is_xpu_place(src_place) && // NOLINT - paddle::platform::is_cpu_place(dst_place)) { - paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); - } else if (paddle::platform::is_cpu_place(src_place) && - paddle::platform::is_xpu_place(dst_place)) { - paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); - } else if (paddle::platform::is_xpu_place(src_place) && - paddle::platform::is_xpu_place(dst_place)) { - if (src_ptr == dst_ptr) { - VLOG(3) << "Skip copy the same data async from " << src_place << " to " - << dst_place; - return; - } - paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Copy from %s to %s is not supported.", src_place, dst_place)); - } -} - -} // namespace phi - -PD_REGISTER_GENERAL_KERNEL( - copy, XPU, ALL_LAYOUT, phi::Copy, ALL_DTYPE) {} diff --git a/paddle/phi/tests/api/test_data_transform.cc b/paddle/phi/tests/api/test_data_transform.cc index 7e8204ea6c7..36f4b19e566 100644 --- a/paddle/phi/tests/api/test_data_transform.cc +++ b/paddle/phi/tests/api/test_data_transform.cc @@ -29,7 +29,6 @@ PD_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); #endif namespace paddle { diff --git a/paddle/phi/tests/api/test_fill_api.cc b/paddle/phi/tests/api/test_fill_api.cc index cae56fd6634..58f74321f49 100644 --- a/paddle/phi/tests/api/test_fill_api.cc +++ b/paddle/phi/tests/api/test_fill_api.cc @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); namespace paddle { namespace tests { diff --git a/paddle/phi/tests/api/test_matmul_api.cc b/paddle/phi/tests/api/test_matmul_api.cc index c54c5398280..ff8bd8bfff6 100644 --- a/paddle/phi/tests/api/test_matmul_api.cc +++ b/paddle/phi/tests/api/test_matmul_api.cc @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" diff --git a/paddle/phi/tests/api/test_pten_tensor.cc b/paddle/phi/tests/api/test_pten_tensor.cc index 590717b8d7b..049aa1c355a 100644 --- a/paddle/phi/tests/api/test_pten_tensor.cc +++ b/paddle/phi/tests/api/test_pten_tensor.cc @@ -17,12 +17,6 @@ #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/core/kernel_registry.h" -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); -#endif - namespace paddle { namespace tests { diff --git a/paddle/phi/tests/api/test_scale_api.cc b/paddle/phi/tests/api/test_scale_api.cc index 2795ebcf286..a4999cf0907 100644 --- a/paddle/phi/tests/api/test_scale_api.cc +++ b/paddle/phi/tests/api/test_scale_api.cc @@ -25,7 +25,6 @@ limitations under the License. */ PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(scale_sr, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); namespace paddle { namespace tests { diff --git a/paddle/phi/tests/api/test_to_api.cc b/paddle/phi/tests/api/test_to_api.cc index dcf43348251..1580dd08f7c 100644 --- a/paddle/phi/tests/api/test_to_api.cc +++ b/paddle/phi/tests/api/test_to_api.cc @@ -21,11 +21,6 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); -#endif - namespace paddle { namespace tests { diff --git a/paddle/phi/tests/common/test_int_array.cc b/paddle/phi/tests/common/test_int_array.cc index 30ad7cdd74c..c97eac38b13 100644 --- a/paddle/phi/tests/common/test_int_array.cc +++ b/paddle/phi/tests/common/test_int_array.cc @@ -22,10 +22,8 @@ limitations under the License. */ #include "paddle/phi/kernels/full_kernel.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); #endif namespace phi { diff --git a/paddle/phi/tests/common/test_scalar.cu b/paddle/phi/tests/common/test_scalar.cu index 89b41ef1e58..50b9e198da0 100644 --- a/paddle/phi/tests/common/test_scalar.cu +++ b/paddle/phi/tests/common/test_scalar.cu @@ -25,8 +25,6 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); - namespace phi { namespace tests { diff --git a/paddle/phi/tests/kernels/test_copy_dev_api.cc b/paddle/phi/tests/kernels/test_copy_dev_api.cc index 9eba14ebc81..1c9b17ed613 100644 --- a/paddle/phi/tests/kernels/test_copy_dev_api.cc +++ b/paddle/phi/tests/kernels/test_copy_dev_api.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { namespace tests { diff --git a/paddle/phi/tests/kernels/test_flatten_dev_api.cc b/paddle/phi/tests/kernels/test_flatten_dev_api.cc index 23ee9869c0e..fb1cdee7e5f 100644 --- a/paddle/phi/tests/kernels/test_flatten_dev_api.cc +++ b/paddle/phi/tests/kernels/test_flatten_dev_api.cc @@ -23,16 +23,6 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/flatten_kernel.h" -PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); -#endif - -#ifdef PADDLE_WITH_XPU -PD_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT); -#endif - namespace phi { namespace tests { diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index b7d56cb0d2b..bb84690cd07 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index 5640da399f4..7d7cd1ceaf5 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" diff --git a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc index 0c1a7bbb3d8..d4f1d6efb5d 100644 --- a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" namespace phi { -- GitLab