From 17318c1a1d1491ce2ce24e9c21ff1c4774f91f8b Mon Sep 17 00:00:00 2001 From: Huang Jiyi <43315610+huangjiyi@users.noreply.github.com> Date: Thu, 9 Feb 2023 15:40:31 +0800 Subject: [PATCH] [PHI decoupling] move strided_memcpy.h to phi (#50346) * decouple strided_memcpy * move strided_memcpy * move strided_memcpy to phi * fix namespace * update * fix gpu compile bugs --- paddle/fluid/imperative/reducer.cc | 4 +- paddle/fluid/operators/CMakeLists.txt | 1 - paddle/fluid/operators/concat_op.h | 2 +- paddle/fluid/operators/crop_op.h | 2 +- .../detection/collect_fpn_proposals_op.cu | 2 +- paddle/fluid/operators/partial_concat_op.h | 2 +- .../sequence_ops/sequence_slice_op.h | 26 ++--- paddle/fluid/operators/spp_op.h | 41 ++++---- paddle/fluid/operators/unbind_op.h | 2 +- paddle/fluid/pybind/tensor_py.h | 17 ++-- paddle/phi/kernels/cpu/concat_kernel.cc | 17 ++-- .../kernels/funcs}/detail/strided_memcpy.h | 63 +++++++------ .../kernels/funcs}/strided_memcpy.h | 94 ++++++++++--------- paddle/phi/kernels/gpu/concat_kernel.cu | 17 ++-- .../kernels/impl/concat_grad_kernel_impl.h | 4 +- paddle/phi/kernels/impl/split_kernel_impl.h | 5 +- paddle/phi/tests/kernels/CMakeLists.txt | 5 + .../tests/kernels}/strided_memcpy_test.cc | 76 ++++++++------- 18 files changed, 197 insertions(+), 183 deletions(-) rename paddle/{fluid/operators => phi/kernels/funcs}/detail/strided_memcpy.h (66%) rename paddle/{fluid/operators => phi/kernels/funcs}/strided_memcpy.h (68%) rename paddle/{fluid/operators => phi/tests/kernels}/strided_memcpy_test.cc (63%) diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index f89fe234c20..d0d985874fa 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #endif @@ -103,7 +103,7 @@ static void SplitTensorsForAllReduce( } // Sometimes direct copies will be faster if (p_dense_tensors->size() < 10) { - operators::StridedMemcpyWithAxis0(context, *in, shape_refer, &outs); + phi::funcs::StridedMemcpyWithAxis0(context, *in, shape_refer, &outs); } else { operators::math::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a287298ef02..4a64c4411df 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -188,7 +188,6 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function) cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) -cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) if (WITH_GPU) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index de7d5807238..58d978ea9c7 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -20,10 +20,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" #include "paddle/phi/kernels/concat_kernel.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/crop_op.h b/paddle/fluid/operators/crop_op.h index c193eabba37..0c791f01bd9 100644 --- a/paddle/fluid/operators/crop_op.h +++ b/paddle/fluid/operators/crop_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { // Internal diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index 3f9a55225ca..38116fc1216 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -24,11 +24,11 @@ namespace cub = hipcub; #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/mixed_vector.h" #include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/partial_concat_op.h b/paddle/fluid/operators/partial_concat_op.h index 050752f2388..407b57e3a82 100644 --- a/paddle/fluid/operators/partial_concat_op.h +++ b/paddle/fluid/operators/partial_concat_op.h @@ -18,8 +18,8 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.h b/paddle/fluid/operators/sequence_ops/sequence_slice_op.h index 1c418d5e037..205a743605c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { @@ -140,12 +140,12 @@ class SequenceSliceOpKernel : public framework::OpKernel { static_cast(lod[0][i] + offset_data[i]), static_cast(lod[0][i] + offset_data[i] + length_data[i])); - StridedMemcpy(ctx.device_context(), - in_t.data(), - in_stride, - in_t.dims(), - out_stride, - out->data() + out_offset); + phi::funcs::StridedMemcpy(ctx.device_context(), + in_t.data(), + in_stride, + in_t.dims(), + out_stride, + out->data() + out_offset); out_offset += length_data[i] * in_stride[0]; } } @@ -201,12 +201,12 @@ class SequenceSliceGradOpKernel : public framework::OpKernel { static_cast(lod[0][i] + offset_data[i]), static_cast(lod[0][i] + offset_data[i] + length_data[i])); - StridedMemcpy(ctx.device_context(), - out_grad_t.data(), - out_grad_stride, - out_grad_t.dims(), - x_grad_stride, - x_grad_t.data()); + phi::funcs::StridedMemcpy(ctx.device_context(), + out_grad_t.data(), + out_grad_stride, + out_grad_t.dims(), + x_grad_stride, + x_grad_t.data()); } } } diff --git a/paddle/fluid/operators/spp_op.h b/paddle/fluid/operators/spp_op.h index 260d368dd0b..0b5c3f91ae1 100644 --- a/paddle/fluid/operators/spp_op.h +++ b/paddle/fluid/operators/spp_op.h @@ -18,9 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/phi_utils.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { @@ -96,12 +96,13 @@ class SppKernel : public framework::OpKernel { out_level.Resize(output_flatten_shape); // concat auto out_level_stride = phi::stride(out_level.dims()); - StridedMemcpy(context.template device_context(), - out_level.data(), - out_level_stride, - out_level.dims(), - out_stride, - out->data() + output_offset); + phi::funcs::StridedMemcpy( + context.template device_context(), + out_level.data(), + out_level_stride, + out_level.dims(), + out_stride, + out->data() + output_offset); output_offset += out_level.dims()[1] * out_level_stride[1]; } } @@ -150,19 +151,21 @@ class SppGradKernel : public framework::OpKernel { outgrad_level.mutable_data(out_flatten_shape, context.GetPlace()); auto flatten_stride = phi::stride(out_level.dims()); // memcpy - StridedMemcpy(context.template device_context(), - out->data() + out_offset, - out_stride, - out_level.dims(), - flatten_stride, - out_level.data()); + phi::funcs::StridedMemcpy( + context.template device_context(), + out->data() + out_offset, + out_stride, + out_level.dims(), + flatten_stride, + out_level.data()); - StridedMemcpy(context.template device_context(), - out_grad->data() + out_offset, - out_stride, - outgrad_level.dims(), - flatten_stride, - outgrad_level.data()); + phi::funcs::StridedMemcpy( + context.template device_context(), + out_grad->data() + out_offset, + out_stride, + outgrad_level.dims(), + flatten_stride, + outgrad_level.data()); out_offset += out_level.dims()[1] * out_stride[1]; // flatten backward to nchw diff --git a/paddle/fluid/operators/unbind_op.h b/paddle/fluid/operators/unbind_op.h index 082e4584616..51347e45929 100644 --- a/paddle/fluid/operators/unbind_op.h +++ b/paddle/fluid/operators/unbind_op.h @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index f0c038226f8..b21d735b589 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -28,9 +28,9 @@ limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif @@ -724,14 +724,13 @@ void _concatCompute(const std::vector &ins, for (auto &in : ins) { auto in_stride = phi::stride_numel(in.dims()); auto out_stride = phi::stride_numel(out->dims()); - paddle::operators::StridedNumelCopyWithAxis( - ctx, - axis, - out->data() + output_offset, - out_stride, - in.data(), - in_stride, - in_stride[axis]); + phi::funcs::StridedNumelCopyWithAxis(ctx, + axis, + out->data() + output_offset, + out_stride, + in.data(), + in_stride, + in_stride[axis]); output_offset += in_stride[axis]; } } else { diff --git a/paddle/phi/kernels/cpu/concat_kernel.cc b/paddle/phi/kernels/cpu/concat_kernel.cc index 1075cb9f777..da5415d9e49 100644 --- a/paddle/phi/kernels/cpu/concat_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/concat_kernel.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" @@ -24,6 +23,7 @@ #include "paddle/phi/core/lod_utils.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace phi { @@ -86,14 +86,13 @@ void ConcatKernel(const Context& dev_ctx, } auto in_stride = phi::stride_numel(in->dims()); auto out_stride = phi::stride_numel(out->dims()); - paddle::operators::StridedNumelCopyWithAxis( - dev_ctx, - axis, - out->data() + output_offset, - out_stride, - in->data(), - in_stride, - in_stride[axis]); + phi::funcs::StridedNumelCopyWithAxis(dev_ctx, + axis, + out->data() + output_offset, + out_stride, + in->data(), + in_stride, + in_stride[axis]); output_offset += in_stride[axis]; } } else { diff --git a/paddle/fluid/operators/detail/strided_memcpy.h b/paddle/phi/kernels/funcs/detail/strided_memcpy.h similarity index 66% rename from paddle/fluid/operators/detail/strided_memcpy.h rename to paddle/phi/kernels/funcs/detail/strided_memcpy.h index 4c729a65f59..57d9765985f 100644 --- a/paddle/fluid/operators/detail/strided_memcpy.h +++ b/paddle/phi/kernels/funcs/detail/strided_memcpy.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,11 +14,15 @@ limitations under the License. */ #pragma once #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/device_context.h" -namespace paddle { -namespace operators { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif + +namespace phi { +namespace funcs { namespace detail { template @@ -26,25 +30,25 @@ struct StridedMemcpyFunctor; template struct StridedMemcpyFunctor { - void operator()(const platform::DeviceContext& dev_ctx, + void operator()(const phi::DeviceContext& dev_ctx, const T* src, const int64_t* src_stride, const int64_t* dst_dim, const int64_t* dst_stride, T* dst) const { auto place = dev_ctx.GetPlace(); - if (platform::is_cpu_place(place)) { + if (place.GetType() == phi::AllocationType::CPU) { auto& cpu_place = place; - memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T)); + paddle::memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T)); } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_place = place; auto& cuda_ctx = reinterpret_cast(dev_ctx); - memory::Copy( + paddle::memory::Copy( gpu_place, dst, gpu_place, src, sizeof(T), cuda_ctx.stream()); #else PADDLE_THROW( - platform::errors::Unavailable("Paddle is not compiled with GPU.")); + phi::errors::Unavailable("Paddle is not compiled with GPU.")); #endif } } @@ -52,29 +56,30 @@ struct StridedMemcpyFunctor { template struct StridedMemcpyFunctor { - void operator()(const platform::DeviceContext& dev_ctx, + void operator()(const phi::DeviceContext& dev_ctx, const T* src, const int64_t* src_stride, const int64_t* dst_dim, const int64_t* dst_stride, T* dst) const { auto place = dev_ctx.GetPlace(); - if (platform::is_cpu_place(place)) { + if (place.GetType() == phi::AllocationType::CPU) { auto& cpu_place = place; - memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim[0]); + paddle::memory::Copy( + cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim[0]); } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_place = place; auto& cuda_ctx = reinterpret_cast(dev_ctx); - memory::Copy(gpu_place, - dst, - gpu_place, - src, - sizeof(T) * dst_dim[0], - cuda_ctx.stream()); + paddle::memory::Copy(gpu_place, + dst, + gpu_place, + src, + sizeof(T) * dst_dim[0], + cuda_ctx.stream()); #else PADDLE_THROW( - platform::errors::Unavailable("Paddle is not compiled with GPU.")); + phi::errors::Unavailable("Paddle is not compiled with GPU.")); #endif } } @@ -82,7 +87,7 @@ struct StridedMemcpyFunctor { template struct StridedMemcpyFunctor { - void operator()(const platform::DeviceContext& dev_ctx, + void operator()(const phi::DeviceContext& dev_ctx, const T* src, const int64_t* src_stride, const int64_t* dst_dim, @@ -99,10 +104,10 @@ struct StridedMemcpyFunctor { template struct StridedCopyDimVisitor { - StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, + StridedCopyDimVisitor(const phi::DeviceContext& dev_ctx, const T* src, - const framework::DDim& src_stride, - const framework::DDim& dst_stride, + const phi::DDim& src_stride, + const phi::DDim& dst_stride, T* dst) : dev_ctx_(dev_ctx), src_(src), @@ -111,7 +116,7 @@ struct StridedCopyDimVisitor { dst_(dst) {} template - void operator()(const framework::Dim& dst_dim) const { + void operator()(const phi::Dim& dst_dim) const { StridedMemcpyFunctor functor; functor(dev_ctx_, src_, @@ -121,13 +126,13 @@ struct StridedCopyDimVisitor { dst_); } - const platform::DeviceContext& dev_ctx_; + const phi::DeviceContext& dev_ctx_; const T* src_; - const framework::DDim& src_stride_; - const framework::DDim& dst_stride_; + const phi::DDim& src_stride_; + const phi::DDim& dst_stride_; T* dst_; }; } // namespace detail -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/phi/kernels/funcs/strided_memcpy.h similarity index 68% rename from paddle/fluid/operators/strided_memcpy.h rename to paddle/phi/kernels/funcs/strided_memcpy.h index 3a562d2f26e..cac82faf64a 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/phi/kernels/funcs/strided_memcpy.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -12,11 +12,13 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/detail/strided_memcpy.h" +#include "paddle/phi/kernels/funcs/detail/strided_memcpy.h" -namespace paddle { -namespace operators { +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace funcs { // Strided memory copy from src to dst. // @@ -33,13 +35,13 @@ namespace operators { // NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke // `dev_ctx.Wait()`. template -inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, +inline void StridedMemcpy(const phi::DeviceContext& dev_ctx, const T* src, - const framework::DDim& src_stride, - const framework::DDim& dst_dim, - const framework::DDim& dst_stride, + const phi::DDim& src_stride, + const phi::DDim& dst_dim, + const phi::DDim& dst_stride, T* dst) { - paddle::operators::detail::StridedCopyDimVisitor func( + detail::StridedCopyDimVisitor func( dev_ctx, src, src_stride, dst_stride, dst); dst_dim.apply_visitor(func); } @@ -52,12 +54,12 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, // NOTE: The src and dst tensor should have the same elements // except the specified axis. template -inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, +inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx, int64_t axis, T* dst, - const framework::DDim& dst_stride_numel, + const phi::DDim& dst_stride_numel, const T* src, - const framework::DDim& src_stride_numel, + const phi::DDim& src_stride_numel, int64_t size) { int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t src_after = src_stride_numel[axis]; @@ -66,7 +68,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(), - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Source and destination tensor should have the same " "dimension size, but source tensor dimension size is " "%u, destination tensor size is %u.", @@ -78,7 +80,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, PADDLE_ENFORCE_EQ( src_stride_numel[i] / src_stride_numel[axis], dst_stride_numel[i] / dst_stride_numel[axis], - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Source and destination tensor should have the same number of " "elements except the specified axis, but the source elements " "number is %d, destination elements number is %d.", @@ -90,7 +92,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, PADDLE_ENFORCE_EQ( src_stride_numel[i], dst_stride_numel[i], - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Source and destination tensor should have the same number of " "elements except the specified axis, but the source elements " "number is %d, destination elements number is %d.", @@ -100,44 +102,44 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, } for (int64_t i = 0; i < before; ++i) { - if (platform::is_cpu_place(place)) { + if (place.GetType() == phi::AllocationType::CPU) { auto& cpu_place = place; - memory::Copy(cpu_place, - dst + i * dst_after, - cpu_place, - src + i * src_after, - sizeof(T) * size); + paddle::memory::Copy(cpu_place, + dst + i * dst_after, + cpu_place, + src + i * src_after, + sizeof(T) * size); } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_place = place; auto& cuda_ctx = reinterpret_cast(ctx); - memory::Copy(gpu_place, - dst + i * dst_after, - gpu_place, - src + i * src_after, - sizeof(T) * size, - cuda_ctx.stream()); + paddle::memory::Copy(gpu_place, + dst + i * dst_after, + gpu_place, + src + i * src_after, + sizeof(T) * size, + cuda_ctx.stream()); #elif defined(PADDLE_WITH_ASCEND_CL) auto& npu_place = place; auto& npu_ctx = reinterpret_cast(ctx); - memory::Copy(npu_place, - dst + i * dst_after, - npu_place, - src + i * src_after, - sizeof(T) * size, - npu_ctx.stream()); + paddle::memory::Copy(npu_place, + dst + i * dst_after, + npu_place, + src + i * src_after, + sizeof(T) * size, + npu_ctx.stream()); #elif defined(PADDLE_WITH_MLU) auto& mlu_place = place; auto& mlu_ctx = reinterpret_cast(ctx); - memory::Copy(mlu_place, - dst + i * dst_after, - mlu_place, - src + i * src_after, - sizeof(T) * size, - mlu_ctx.stream()); + paddle::memory::Copy(mlu_place, + dst + i * dst_after, + mlu_place, + src + i * src_after, + sizeof(T) * size, + mlu_ctx.stream()); #else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "Paddle is not compiled with GPU.")); + PADDLE_THROW( + phi::errors::PreconditionNotMet("Paddle is not compiled with GPU.")); #endif } } @@ -145,11 +147,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, template inline void StridedMemcpyWithAxis0( - const platform::DeviceContext& dev_ctx, + const phi::DeviceContext& dev_ctx, const phi::DenseTensor& input, const std::vector& shape_refer, std::vector* outputs) { - const framework::DDim in_stride = stride_numel(input.dims()); + const phi::DDim in_stride = stride_numel(input.dims()); const int axis = 0; size_t input_offset = 0; @@ -169,5 +171,5 @@ inline void StridedMemcpyWithAxis0( } } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/concat_kernel.cu b/paddle/phi/kernels/gpu/concat_kernel.cu index 80ff71b2158..ac83cb3f829 100644 --- a/paddle/phi/kernels/gpu/concat_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_kernel.cu @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/concat_kernel.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" @@ -24,6 +23,7 @@ #include "paddle/phi/core/lod_utils.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace phi { @@ -85,14 +85,13 @@ void ConcatKernel(const Context& dev_ctx, } auto in_stride = phi::stride_numel(in->dims()); auto out_stride = phi::stride_numel(out->dims()); - paddle::operators::StridedNumelCopyWithAxis( - dev_ctx, - axis, - out->data() + output_offset, - out_stride, - in->data(), - in_stride, - in_stride[axis]); + phi::funcs::StridedNumelCopyWithAxis(dev_ctx, + axis, + out->data() + output_offset, + out_stride, + in->data(), + in_stride, + in_stride[axis]); output_offset += in_stride[axis]; } } else { diff --git a/paddle/phi/kernels/impl/concat_grad_kernel_impl.h b/paddle/phi/kernels/impl/concat_grad_kernel_impl.h index 6d169354cb4..b0b0e5728d4 100644 --- a/paddle/phi/kernels/impl/concat_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/concat_grad_kernel_impl.h @@ -13,10 +13,10 @@ // limitations under the License. #pragma once -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/phi/kernels/concat_grad_kernel.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace phi { @@ -57,7 +57,7 @@ void ConcatGradKernel(const Context& dev_ctx, if (axis == 0 && outs.size() < 10) { std::vector ref_shape; ref_shape.insert(ref_shape.begin(), x.begin(), x.end()); - paddle::operators::StridedMemcpyWithAxis0( + phi::funcs::StridedMemcpyWithAxis0( dev_ctx, out_grad, ref_shape, &outputs); } else { phi::funcs::SplitFunctor split_functor; diff --git a/paddle/phi/kernels/impl/split_kernel_impl.h b/paddle/phi/kernels/impl/split_kernel_impl.h index 6f43e8ea143..77acf81cf4c 100644 --- a/paddle/phi/kernels/impl/split_kernel_impl.h +++ b/paddle/phi/kernels/impl/split_kernel_impl.h @@ -15,11 +15,11 @@ #pragma once #include "paddle/phi/kernels/split_kernel.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace phi { template @@ -37,8 +37,7 @@ void SplitKernel(const Context& dev_ctx, int axis = axis_scalar.to(); // Sometimes direct copies will be faster, this maybe need deeply analysis. if (axis == 0 && outs.size() < 10) { - paddle::operators::StridedMemcpyWithAxis0( - dev_ctx, x, shape_refer, &outs); + phi::funcs::StridedMemcpyWithAxis0(dev_ctx, x, shape_refer, &outs); } else { phi::funcs::SplitFunctor functor; functor(dev_ctx, x, shape_refer, axis, &outs); diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 646c2b36798..61be79303b0 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -95,3 +95,8 @@ cc_test( test_cache SRCS test_cache.cc DEPS gtest cache) + +cc_test( + strided_memcpy_test + SRCS strided_memcpy_test.cc + DEPS device_context memory) diff --git a/paddle/fluid/operators/strided_memcpy_test.cc b/paddle/phi/tests/kernels/strided_memcpy_test.cc similarity index 63% rename from paddle/fluid/operators/strided_memcpy_test.cc rename to paddle/phi/tests/kernels/strided_memcpy_test.cc index 3d8902a68ac..7ffc83bb31b 100644 --- a/paddle/fluid/operators/strided_memcpy_test.cc +++ b/paddle/phi/tests/kernels/strided_memcpy_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/phi/kernels/funcs/strided_memcpy.h" #include "gtest/gtest.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" -namespace paddle { -namespace operators { +namespace phi { +namespace tests { TEST(StridedMemcpy, CPUCrop) { // clang-format off @@ -29,14 +29,15 @@ TEST(StridedMemcpy, CPUCrop) { }; // clang-format on - framework::DDim src_stride({5, 1}); + phi::DDim src_stride({5, 1}); int dst[4]; - framework::DDim dst_dim({2, 2}); - framework::DDim dst_stride({2, 1}); + phi::DDim dst_dim({2, 2}); + phi::DDim dst_stride({2, 1}); phi::CPUContext ctx; - StridedMemcpy(ctx, src + 1, src_stride, dst_dim, dst_stride, dst); + phi::funcs::StridedMemcpy( + ctx, src + 1, src_stride, dst_dim, dst_stride, dst); ASSERT_EQ(1, dst[0]); ASSERT_EQ(2, dst[1]); @@ -54,13 +55,15 @@ TEST(StridedMemcpy, CPUConcat) { int dst[8]; - framework::DDim src_stride({2, 1}); - framework::DDim dst_dim({2, 2}); - framework::DDim dst_stride({4, 1}); + phi::DDim src_stride({2, 1}); + phi::DDim dst_dim({2, 2}); + phi::DDim dst_stride({4, 1}); phi::CPUContext ctx; - StridedMemcpy(ctx, src, src_stride, dst_dim, dst_stride, dst); - StridedMemcpy(ctx, src, src_stride, dst_dim, dst_stride, dst + 2); + phi::funcs::StridedMemcpy( + ctx, src, src_stride, dst_dim, dst_stride, dst); + phi::funcs::StridedMemcpy( + ctx, src, src_stride, dst_dim, dst_stride, dst + 2); // clang-format off int expect_dst[] = { @@ -83,8 +86,8 @@ TEST(StridedMemcpy, GPUCrop) { }; // clang-format on - platform::CUDAPlace gpu0(0); - platform::CPUPlace cpu; + phi::GPUPlace gpu0(0); + phi::CPUPlace cpu; phi::GPUContext ctx(gpu0); ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() @@ -92,24 +95,24 @@ TEST(StridedMemcpy, GPUCrop) { .get()); ctx.PartialInitWithAllocator(); - auto src_allocation = memory::Alloc(gpu0, sizeof(src)); + auto src_allocation = paddle::memory::Alloc(gpu0, sizeof(src)); int* gpu_src = reinterpret_cast(src_allocation->ptr()); - memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src), ctx.stream()); + paddle::memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src), ctx.stream()); - framework::DDim src_stride({5, 1}); + phi::DDim src_stride({5, 1}); int dst[4]; - auto dst_allocation = memory::Alloc(gpu0, sizeof(dst)); + auto dst_allocation = paddle::memory::Alloc(gpu0, sizeof(dst)); int* gpu_dst = reinterpret_cast(dst_allocation->ptr()); - framework::DDim dst_dim({2, 2}); - framework::DDim dst_stride({2, 1}); + phi::DDim dst_dim({2, 2}); + phi::DDim dst_stride({2, 1}); - StridedMemcpy( + phi::funcs::StridedMemcpy( ctx, gpu_src + 1, src_stride, dst_dim, dst_stride, gpu_dst); - memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream()); + paddle::memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream()); ctx.Wait(); ASSERT_EQ(1, dst[0]); @@ -126,30 +129,31 @@ TEST(StridedMemcpy, GPUConcat) { }; // clang-format on - platform::CUDAPlace gpu0(0); - platform::CPUPlace cpu; + phi::GPUPlace gpu0(0); + phi::CPUPlace cpu; phi::GPUContext ctx(gpu0); ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(gpu0, ctx.stream()) .get()); ctx.PartialInitWithAllocator(); - auto gpu_src_allocation = memory::Alloc(gpu0, sizeof(src)); + auto gpu_src_allocation = paddle::memory::Alloc(gpu0, sizeof(src)); int* gpu_src = reinterpret_cast(gpu_src_allocation->ptr()); - memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src), ctx.stream()); + paddle::memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src), ctx.stream()); int dst[8]; - auto gpu_dst_allocation = memory::Alloc(gpu0, sizeof(dst)); + auto gpu_dst_allocation = paddle::memory::Alloc(gpu0, sizeof(dst)); int* gpu_dst = reinterpret_cast(gpu_dst_allocation->ptr()); - framework::DDim src_stride({2, 1}); - framework::DDim dst_dim({2, 2}); - framework::DDim dst_stride({4, 1}); + phi::DDim src_stride({2, 1}); + phi::DDim dst_dim({2, 2}); + phi::DDim dst_stride({4, 1}); - StridedMemcpy(ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst); - StridedMemcpy( + phi::funcs::StridedMemcpy( + ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst); + phi::funcs::StridedMemcpy( ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst + 2); - memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream()); + paddle::memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream()); ctx.Wait(); // clang-format off @@ -164,5 +168,5 @@ TEST(StridedMemcpy, GPUConcat) { } #endif -} // namespace operators -} // namespace paddle +} // namespace tests +} // namespace phi -- GitLab