diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 8af45b3803647bc80d2a0bb4a3504e90b2064854..d0af307ec594b70338280167810d2e1d66458f93 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -448,10 +448,8 @@ add_dependencies(fluid_lib_dist ${platform_lib_deps}) copy( fluid_lib_dist SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/dynload/*.h - ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/phi/api/profiler/*.pb.h - DSTS ${dst_dir}/${module} ${dst_dir}/${module}/dynload - ${dst_dir}/${module}/details ${dst_dir}/${module}) + DSTS ${dst_dir}/${module} ${dst_dir}/${module}/dynload ${dst_dir}/${module}) set(module "string") copy( diff --git a/paddle/fluid/framework/data_transform.h b/paddle/fluid/framework/data_transform.h index 27bc0086c233ded09605fedadaa703672bed6e41..004742e2a44797cdd4cf5976f3e205a4241a062c 100644 --- a/paddle/fluid/framework/data_transform.h +++ b/paddle/fluid/framework/data_transform.h @@ -24,7 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 0f2e244af0ab3cdf91f96a47eccfd654c8f87c38..9d114fcf563963844c394659e995685e93a402ff 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #if defined(PADDLE_WITH_XPU) #include "paddle/fluid/platform/device/device_wrapper.h" @@ -94,7 +94,7 @@ struct CastDataType { auto* out_begin = out_->mutable_data(in_.place()); if (platform::is_cpu_place(in_.place())) { - platform::Transform trans; + phi::Transform trans; auto* context = static_cast(ctx_); trans(*context, in_begin, @@ -103,7 +103,7 @@ struct CastDataType { CastDataTypeFunctor()); #if defined(__NVCC__) || defined(__HIPCC__) } else if (platform::is_gpu_place(in_.place())) { - platform::Transform trans; + phi::Transform trans; auto* context = static_cast(ctx_); trans(*context, in_begin, @@ -114,7 +114,7 @@ struct CastDataType { #endif #if defined(PADDLE_WITH_IPU) } else if (platform::is_ipu_place(in_.place())) { - platform::Transform trans; + phi::Transform trans; auto* context = static_cast(ctx_); trans(*context, in_begin, diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index 2dde7f2ce1741dd7a98b54f46eebb1da1437910e..22bef8fbca57565c5d14ee648d1f244d7576df39 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/phi_utils.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/cast_kernel.h" namespace paddle { @@ -44,7 +44,7 @@ struct CastOpFunctor { auto numel = in_->numel(); auto* in_end = in_begin + numel; auto* out_begin = out_->mutable_data(ctx_.GetPlace()); - platform::Transform trans; + phi::Transform trans; trans( ctx_, in_begin, in_end, out_begin, CastOpTransformFunctor()); } diff --git a/paddle/fluid/operators/center_loss_op.h b/paddle/fluid/operators/center_loss_op.h index 36fe957102bfb3245410bf22d6878e906a0c7d10..7632482c97b3f8d02ca16064cba6eefcad102de3 100644 --- a/paddle/fluid/operators/center_loss_op.h +++ b/paddle/fluid/operators/center_loss_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { @@ -94,7 +94,7 @@ class CenterLossKernel : public framework::OpKernel { T *center_out_index; T *center_loss_diff_index; T *acc_index; - platform::Transform trans; + phi::Transform trans; for (int i = 0; i < batch_size; ++i) { tLabel = label_data[i]; diff --git a/paddle/fluid/operators/clip_by_norm_op.h b/paddle/fluid/operators/clip_by_norm_op.h index f54e323eefb44c1320b7634c91c5a27e1a84084a..3895bc09a08a08dbe3d59bb666744547cfa3d5f5 100644 --- a/paddle/fluid/operators/clip_by_norm_op.h +++ b/paddle/fluid/operators/clip_by_norm_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h" namespace paddle { diff --git a/paddle/fluid/operators/detection/anchor_generator_op.h b/paddle/fluid/operators/detection/anchor_generator_op.h index 6582ddac1411444dfc6100f4fa17666d924393fe..70194a0abcbb27ae945703a46e42595e80d3039b 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.h +++ b/paddle/fluid/operators/detection/anchor_generator_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/detection/prior_box_op.h b/paddle/fluid/operators/detection/prior_box_op.h index 9968a07fcf7f4dd00a88420e39cbe8cc39a390a7..4c5249ec56fce69f3a6659ebcba3a07101fd96f4 100644 --- a/paddle/fluid/operators/detection/prior_box_op.h +++ b/paddle/fluid/operators/detection/prior_box_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 6fd13aad10a49c245a4648687b655f6252afa295..8121ab075e4b3f01e639864b3d2267bccf7fd096 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -28,7 +28,8 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/api/lib/utils/tensor_utils.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h" diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 65e4b28326a9fe55f3dd1657e60a70a4bc6da16c..accd9671868cdea4a9bdfa376c06ab4438648945 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace paddle { @@ -98,7 +98,7 @@ struct ClipAndFakeQuantFunctor { phi::DenseTensor *out) { T s = scale.data()[0]; T inv_s = inverse(s); - platform::Transform trans; + phi::Transform trans; if (round_type == 0) { trans(ctx, in.data(), @@ -130,7 +130,7 @@ struct ClipAndFakeQuantDequantFunctor { T s = scale.data()[0]; T inv_s = inverse(s); - platform::Transform trans; + phi::Transform trans; if (round_type == 0) { trans(ctx, in.data(), @@ -175,7 +175,7 @@ struct ChannelClipAndFakeQuantFunctor { auto *out_data = out->mutable_data(ctx.GetPlace()); auto in_dims = in.dims(); const int64_t channel = in_dims[quant_axis]; - platform::Transform trans; + phi::Transform trans; if (quant_axis == 0) { const int64_t channel_size = in.numel() / channel; for (int64_t i = 0; i < channel; i++) { @@ -256,7 +256,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto *out_data = out->mutable_data(ctx.GetPlace()); auto in_dims = in.dims(); const int64_t channel = in_dims[quant_axis]; - platform::Transform trans; + phi::Transform trans; if (quant_axis == 0) { const int64_t channel_size = in.numel() / channel; for (int i = 0; i < channel; i++) { diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 2129c4b66aa053045e2da0d73d448a23c7c0c1c0..fea05b39af7b76857bcbc0190b9a1e42965c6756 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/blas/blas.h" diff --git a/paddle/fluid/operators/isfinite_op.h b/paddle/fluid/operators/isfinite_op.h index 431d446daa7bbcb99173ff04be0551a3018d82b6..9e63d191e62b06a4d8ca803b4f5f831fa2cf8c41 100644 --- a/paddle/fluid/operators/isfinite_op.h +++ b/paddle/fluid/operators/isfinite_op.h @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/isfinite_kernel.h" #include "paddle/phi/kernels/reduce_all_kernel.h" #include "paddle/phi/kernels/reduce_any_kernel.h" diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index 3272ca84d8a90bd7b33111136b5698e36cc454f1..90e3072dbf97ebbfb6f0bbcbe375d6327b94db3f 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/lstm_compute.h" @@ -29,7 +29,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using platform::Transform; +using phi::Transform; template // For std::remove_pointer and std::is_pointer. - -#include "thrust/device_ptr.h" - -namespace paddle { -namespace platform { -namespace details { - -// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA -// device) pointer into thrust::device_ptr, the other keeps rest types -// un-casted. -template -struct PointerToThrustDevicePtr; - -template -struct PointerToThrustDevicePtr { - using ELEM = typename std::remove_pointer::type; - using RTYPE = thrust::device_ptr; - - inline thrust::device_ptr operator()(ELEM* ele) const { - return thrust::device_pointer_cast(ele); - } -}; - -template -struct PointerToThrustDevicePtr { - using RTYPE = T; - inline RTYPE operator()(RTYPE it) const { return it; } -}; - -// CastToCUDATransformIterator casts a pointer to thrust::device_ptr -// so it could be used as the iterator of thrust::transform. It -// doesn't cast other types. -// -// We need CastToCUDATransformIterator because it is often that we -// want to use device memory pointers as transform iterators, e.g., to -// transform a block of float32 to float16. In this case, we want -// CastToCUDATransformIterator to cast float16/32 pointers to -// thrust::device_ptr, otherwise they cannot work as the iterator -// required by thrust::transform. At the same time, we don't want to -// cast thrust::device_ptr to thrust::device_ptr repeatedly. -template -auto CastToCUDATransformIterator(T t) -> - typename PointerToThrustDevicePtr::value>::RTYPE { - PointerToThrustDevicePtr::value> cast; - return cast(t); -} - -} // namespace details -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/transform.h b/paddle/phi/common/transform.h similarity index 59% rename from paddle/fluid/platform/transform.h rename to paddle/phi/common/transform.h index fc39fa33ffbb75a12abf8d05e9e84db3c36df47c..a56103a6e526774a6cff9cb5147e4426424d9ea0 100644 --- a/paddle/fluid/platform/transform.h +++ b/paddle/phi/common/transform.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,20 +17,17 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/hostdevice.h" #if defined(__NVCC__) || defined(__HIPCC__) #include #include - -#include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h" +#include "thrust/device_ptr.h" #endif -namespace paddle { -namespace platform { +namespace phi { // Transform applys a unary or a binary functor on each element in a // range defined by a pair of iterators. @@ -40,16 +37,16 @@ namespace platform { // // NOTE: We need to define InputIter and OutputIter defined as // different types, because the InputIter points op's inputs and -// OutputIter pints to op's outputs. +// OutputIter points to op's outputs. // // NOTE: We don't assume that InputIter to be const InputType* and // OutputIter to be OutputType*, because we might use a iterator // class, paddle::fluid::operators::RowwiseTRansformIterator. -template +template struct Transform { // The unary version. template - void operator()(const DeviceContext& context, + void operator()(const Context& context, InputIter first, InputIter last, OutputIter result, @@ -60,7 +57,7 @@ struct Transform { typename InputIter2, typename OutputIter, typename BinaryOperation> - void operator()(const DeviceContext& context, + void operator()(const Context& context, InputIter1 first1, InputIter1 last1, InputIter2 first2, @@ -97,6 +94,46 @@ struct Transform { #if defined(__NVCC__) || defined(__HIPCC__) +// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA +// device) pointer into thrust::device_ptr, the other keeps rest types +// un-casted. +template +struct PointerToThrustDevicePtr; + +template +struct PointerToThrustDevicePtr { + using ELEM = typename std::remove_pointer::type; + using RTYPE = thrust::device_ptr; + + inline thrust::device_ptr operator()(ELEM* ele) const { + return thrust::device_pointer_cast(ele); + } +}; + +template +struct PointerToThrustDevicePtr { + using RTYPE = T; + inline RTYPE operator()(RTYPE it) const { return it; } +}; + +// CastToCUDATransformIterator casts a pointer to thrust::device_ptr +// so it could be used as the iterator of thrust::transform. It +// doesn't cast other types. +// +// We need CastToCUDATransformIterator because it is often that we +// want to use device memory pointers as transform iterators, e.g., to +// transform a block of float32 to float16. In this case, we want +// CastToCUDATransformIterator to cast float16/32 pointers to +// thrust::device_ptr, otherwise they cannot work as the iterator +// required by thrust::transform. At the same time, we don't want to +// cast thrust::device_ptr to thrust::device_ptr repeatedly. +template +auto CastToCUDATransformIterator(T t) -> + typename PointerToThrustDevicePtr::value>::RTYPE { + PointerToThrustDevicePtr::value> cast; + return cast(t); +} + template <> struct Transform { template @@ -106,21 +143,21 @@ struct Transform { OutputIter result, UnaryOperation op) { auto place = context.GetPlace(); - PADDLE_ENFORCE_EQ(is_gpu_place(place), + PADDLE_ENFORCE_EQ(place.GetType() == phi::AllocationType::GPU, true, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The CUDA Transform must be used in GPU place.")); #ifdef __HIPCC__ thrust::transform(thrust::hip::par.on(context.stream()), - details::CastToCUDATransformIterator(first), - details::CastToCUDATransformIterator(last), - details::CastToCUDATransformIterator(result), + CastToCUDATransformIterator(first), + CastToCUDATransformIterator(last), + CastToCUDATransformIterator(result), op); #else thrust::transform(thrust::cuda::par.on(context.stream()), - details::CastToCUDATransformIterator(first), - details::CastToCUDATransformIterator(last), - details::CastToCUDATransformIterator(result), + CastToCUDATransformIterator(first), + CastToCUDATransformIterator(last), + CastToCUDATransformIterator(result), op); #endif } @@ -136,28 +173,27 @@ struct Transform { OutputIter result, BinaryOperation op) { auto place = context.GetPlace(); - PADDLE_ENFORCE_EQ(is_gpu_place(place), + PADDLE_ENFORCE_EQ(place.GetType() == phi::AllocationType::GPU, true, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The CUDA Transform must be used in GPU place.")); #ifdef __HIPCC__ thrust::transform(thrust::hip::par.on(context.stream()), - details::CastToCUDATransformIterator(first1), - details::CastToCUDATransformIterator(last1), - details::CastToCUDATransformIterator(first2), - details::CastToCUDATransformIterator(result), + CastToCUDATransformIterator(first1), + CastToCUDATransformIterator(last1), + CastToCUDATransformIterator(first2), + CastToCUDATransformIterator(result), op); #else thrust::transform(thrust::cuda::par.on(context.stream()), - details::CastToCUDATransformIterator(first1), - details::CastToCUDATransformIterator(last1), - details::CastToCUDATransformIterator(first2), - details::CastToCUDATransformIterator(result), + CastToCUDATransformIterator(first1), + CastToCUDATransformIterator(last1), + CastToCUDATransformIterator(first2), + CastToCUDATransformIterator(result), op); #endif } }; #endif -} // namespace platform -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/kernels/cpu/bitwise_kernel.cc b/paddle/phi/kernels/cpu/bitwise_kernel.cc index 69f52790f77969e8bf29fcb50b777afe504215b7..80424ef624f61bb8f28de66cf122053ef1514c1d 100644 --- a/paddle/phi/kernels/cpu/bitwise_kernel.cc +++ b/paddle/phi/kernels/cpu/bitwise_kernel.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_base.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" namespace phi { @@ -48,7 +48,7 @@ void BitwiseNotKernel(const Context& dev_ctx, T* out_data = dev_ctx.template Alloc(out); size_t numel = x.numel(); funcs::BitwiseNotFunctor func; - paddle::platform::Transform trans; + phi::Transform trans; trans(dev_ctx, x_data, x_data + numel, out_data, func); } diff --git a/paddle/phi/kernels/cpu/cast_impl.h b/paddle/phi/kernels/cpu/cast_impl.h index 9648b584243f5b2aa65a5eee7e4fbeb7292f0284..1af4d36dd2c73d147540426634a1e555f57bdfbe 100644 --- a/paddle/phi/kernels/cpu/cast_impl.h +++ b/paddle/phi/kernels/cpu/cast_impl.h @@ -17,7 +17,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" namespace phi { @@ -36,7 +36,7 @@ void CastKernelImpl(const CPUContext& dev_ctx, auto* out_begin = dev_ctx.Alloc(out); - paddle::platform::Transform trans; + phi::Transform trans; trans(dev_ctx, in_begin, in_end, diff --git a/paddle/phi/kernels/cpu/hsigmoid_loss_kernel.cc b/paddle/phi/kernels/cpu/hsigmoid_loss_kernel.cc index 062aa1be24fcca2892ff5f51b1ffec911b5d3b73..c6ee49ef34786a60cf1c78127f80bde5d968d1eb 100644 --- a/paddle/phi/kernels/cpu/hsigmoid_loss_kernel.cc +++ b/paddle/phi/kernels/cpu/hsigmoid_loss_kernel.cc @@ -14,8 +14,8 @@ #include "paddle/phi/kernels/hsigmoid_loss_kernel.h" -#include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -82,7 +82,7 @@ void HSigmoidLossKernel(const Context& ctx, } bit_code->Mul(pre_out, w, x); // clip to [-40, 40] - paddle::platform::Transform trans; + phi::Transform trans; trans(ctx, pre_out_data, pre_out_data + pre_out->numel(), diff --git a/paddle/phi/kernels/cpu/logical_kernel.cc b/paddle/phi/kernels/cpu/logical_kernel.cc index a0747b128e53899b77767298cab4fa37f31e495a..2fa2dd3451121bc7eb4d3b3a92ef6eb918bcbaf2 100644 --- a/paddle/phi/kernels/cpu/logical_kernel.cc +++ b/paddle/phi/kernels/cpu/logical_kernel.cc @@ -20,7 +20,7 @@ #include "paddle/phi/kernels/funcs/logical_functor.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" namespace phi { @@ -47,7 +47,7 @@ void LogicalNotKernel(const Context& dev_ctx, auto* out_ptr = dev_ctx.template Alloc(out); funcs::LogicalNotFunctor unary_func; - paddle::platform::Transform trans; + phi::Transform trans; trans(dev_ctx, x.data(), x.data() + x.numel(), out_ptr, unary_func); } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 1f937425805733b4abd9e5a612d446e459063773..1d40d4d8c2957276ae2b9b4502089d5f585b42e1 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" @@ -220,12 +220,12 @@ class TransformFunctor { } inline void Run() const { - paddle::platform::Transform trans; + phi::Transform trans; trans(ctx_, x_, x_ + nx_, y_, z_, func_); } inline void RunRowWise(int n, int pre) const { - paddle::platform::Transform trans; + phi::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, @@ -244,7 +244,7 @@ class TransformFunctor { } inline void RunMidWise(int n, int pre, int post) const { - paddle::platform::Transform trans; + phi::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, diff --git a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h index 0e6fc199610d2f623bf64d0b4f0d20d9014f32df..98bea06c907a94c0854658499f9941ab795a24f8 100644 --- a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h @@ -15,8 +15,8 @@ #pragma once #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/clip_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) @@ -59,7 +59,7 @@ void ClipGradKernel(const Context& dev_ctx, auto* d_x_data = dev_ctx.template Alloc(x_grad); const T* d_out_data = out_grad.data(); const T* x_data = x.data(); - paddle::platform::Transform trans; + phi::Transform trans; trans(dev_ctx, d_out_data, d_out_data + numel, diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h index 1b63abae525b69647bc9bdc359203a01d89f323e..3b51b09b77c8c2b31323dccb830aab2d133d8053 100644 --- a/paddle/phi/kernels/impl/clip_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/transform.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/clip_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) @@ -67,7 +67,7 @@ void ClipKernel(const Context& dev_ctx, phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); #endif } else { - paddle::platform::Transform trans; + phi::Transform trans; trans( dev_ctx, x_data, x_data + numel, out_data, ClipFunctor(min_, max_)); } diff --git a/paddle/phi/kernels/impl/isfinite_kernel_impl.h b/paddle/phi/kernels/impl/isfinite_kernel_impl.h index d36a7cb915e784d8ae38e9a79393556ab39a5cdb..0f1991838e837507b2a7720a0a7642e4e07cf60b 100644 --- a/paddle/phi/kernels/impl/isfinite_kernel_impl.h +++ b/paddle/phi/kernels/impl/isfinite_kernel_impl.h @@ -18,7 +18,7 @@ #include "paddle/phi/kernels/isfinite_kernel.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/transform.h" namespace phi { @@ -28,7 +28,7 @@ namespace phi { const Context& ctx, const DenseTensor& x, DenseTensor* out) { \ auto* out_ptr = ctx.template Alloc(out); \ funcs::functor unary_func; \ - paddle::platform::Transform trans; \ + phi::Transform trans; \ trans(ctx, x.data(), x.data() + x.numel(), out_ptr, unary_func); \ } diff --git a/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h index 10471f91b9a27aabe3c958be4017b1c50c003d8f..ff8333a92eb4ff783222f640bac868695ac6e046 100644 --- a/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h +++ b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h @@ -50,7 +50,7 @@ void ClipSparseKernel(const Context& dev_ctx, auto* out_tensor = out->mutable_value(); auto* out_data = out_tensor->data(); int64_t numel = out_tensor->numel(); - paddle::platform::Transform trans; + phi::Transform trans; trans(dev_ctx, out_data, out_data + numel, diff --git a/paddle/phi/tests/common/CMakeLists.txt b/paddle/phi/tests/common/CMakeLists.txt index 3499489541d1cdf6abb0f2585d8488d96761197f..aed462cc9cef6de2efaf350ab710cdb2402d7ca0 100644 --- a/paddle/phi/tests/common/CMakeLists.txt +++ b/paddle/phi/tests/common/CMakeLists.txt @@ -23,10 +23,18 @@ if(WITH_GPU) phi_test_scalar SRCS test_scalar.cu DEPS scalar api_scalar) + nv_test( + transform_test + SRCS transform_test.cu + DEPS memory place device_context) endif() if(WITH_ROCM) hip_test( phi_test_scalar SRCS test_scalar.cu DEPS scalar api_scalar) + hip_test( + transform_test + SRCS transform_test.cu + DEPS memory place device_context) endif() diff --git a/paddle/fluid/platform/transform_test.cu b/paddle/phi/tests/common/transform_test.cu similarity index 69% rename from paddle/fluid/platform/transform_test.cu rename to paddle/phi/tests/common/transform_test.cu index ce68452ffbe32df3ae449a949bf945a907f121ea..b2547bbfe0b1c2b1b4b2d8c2833524cfc4a38db3 100644 --- a/paddle/fluid/platform/transform_test.cu +++ b/paddle/phi/tests/common/transform_test.cu @@ -14,10 +14,11 @@ limitations under the License. */ #include -#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/common/transform.h" + #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/hostdevice.h" template @@ -44,7 +45,7 @@ using paddle::platform::CUDAPlace; using phi::CPUContext; using phi::GPUContext; -using paddle::platform::Transform; +using phi::Transform; TEST(Transform, CPUUnary) { CPUContext ctx; @@ -58,19 +59,17 @@ TEST(Transform, CPUUnary) { TEST(Transform, GPUUnary) { CUDAPlace gpu0(0); - phi::GPUContext ctx(gpu0); - ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() - .GetAllocator(gpu0, ctx.stream()) - .get()); - ctx.PartialInitWithAllocator(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* ctx = reinterpret_cast(pool.Get(phi::GPUPlace())); + float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; auto gpu_allocation = Alloc(gpu0, sizeof(float) * 4); float* gpu_buf = static_cast(gpu_allocation->ptr()); - Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf), ctx.stream()); + Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf), ctx->stream()); Transform trans; - trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); - ctx.Wait(); - Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf), ctx.stream()); + trans(*ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + ctx->Wait(); + Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf), ctx->stream()); for (int i = 0; i < 4; ++i) { ASSERT_NEAR(cpu_buf[i], static_cast(i + 1), 1e-5); } @@ -89,18 +88,16 @@ TEST(Transform, CPUBinary) { TEST(Transform, GPUBinary) { int buf[4] = {1, 2, 3, 4}; CUDAPlace gpu0(0); - phi::GPUContext ctx(gpu0); - ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() - .GetAllocator(gpu0, ctx.stream()) - .get()); - ctx.PartialInitWithAllocator(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* ctx = reinterpret_cast(pool.Get(phi::GPUPlace())); + auto gpu_allocation = Alloc(gpu0, sizeof(buf)); int* gpu_buf = static_cast(gpu_allocation->ptr()); - Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf), ctx.stream()); + Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf), ctx->stream()); Transform trans; - trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); - ctx.Wait(); - Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf), ctx.stream()); + trans(*ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + ctx->Wait(); + Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf), ctx->stream()); for (int i = 0; i < 4; ++i) { ASSERT_EQ((i + 1) * (i + 1), buf[i]); }