未验证 提交 fe332794 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decoupling] move platform/transform to phi (#50498)

* move platform::transform to phi

* fix bugs

* move transform_test to phi

* fix cmake

* update namespace

* fix cmake
上级 b5da73c5
...@@ -448,10 +448,8 @@ add_dependencies(fluid_lib_dist ${platform_lib_deps}) ...@@ -448,10 +448,8 @@ add_dependencies(fluid_lib_dist ${platform_lib_deps})
copy( copy(
fluid_lib_dist fluid_lib_dist
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/dynload/*.h SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/dynload/*.h
${src_dir}/${module}/details/*.h
${PADDLE_BINARY_DIR}/paddle/phi/api/profiler/*.pb.h ${PADDLE_BINARY_DIR}/paddle/phi/api/profiler/*.pb.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/dynload DSTS ${dst_dir}/${module} ${dst_dir}/${module}/dynload ${dst_dir}/${module})
${dst_dir}/${module}/details ${dst_dir}/${module})
set(module "string") set(module "string")
copy( copy(
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
...@@ -94,7 +94,7 @@ struct CastDataType { ...@@ -94,7 +94,7 @@ struct CastDataType {
auto* out_begin = out_->mutable_data<OutType>(in_.place()); auto* out_begin = out_->mutable_data<OutType>(in_.place());
if (platform::is_cpu_place(in_.place())) { if (platform::is_cpu_place(in_.place())) {
platform::Transform<phi::CPUContext> trans; phi::Transform<phi::CPUContext> trans;
auto* context = static_cast<const phi::CPUContext*>(ctx_); auto* context = static_cast<const phi::CPUContext*>(ctx_);
trans(*context, trans(*context,
in_begin, in_begin,
...@@ -103,7 +103,7 @@ struct CastDataType { ...@@ -103,7 +103,7 @@ struct CastDataType {
CastDataTypeFunctor<InType, OutType>()); CastDataTypeFunctor<InType, OutType>());
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
} else if (platform::is_gpu_place(in_.place())) { } else if (platform::is_gpu_place(in_.place())) {
platform::Transform<phi::GPUContext> trans; phi::Transform<phi::GPUContext> trans;
auto* context = static_cast<const phi::GPUContext*>(ctx_); auto* context = static_cast<const phi::GPUContext*>(ctx_);
trans(*context, trans(*context,
in_begin, in_begin,
...@@ -114,7 +114,7 @@ struct CastDataType { ...@@ -114,7 +114,7 @@ struct CastDataType {
#endif #endif
#if defined(PADDLE_WITH_IPU) #if defined(PADDLE_WITH_IPU)
} else if (platform::is_ipu_place(in_.place())) { } else if (platform::is_ipu_place(in_.place())) {
platform::Transform<phi::CPUContext> trans; phi::Transform<phi::CPUContext> trans;
auto* context = static_cast<const phi::CPUContext*>(ctx_); auto* context = static_cast<const phi::CPUContext*>(ctx_);
trans(*context, trans(*context,
in_begin, in_begin,
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
namespace paddle { namespace paddle {
...@@ -44,7 +44,7 @@ struct CastOpFunctor { ...@@ -44,7 +44,7 @@ struct CastOpFunctor {
auto numel = in_->numel(); auto numel = in_->numel();
auto* in_end = in_begin + numel; auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace()); auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::Transform<DeviceContext> trans; phi::Transform<DeviceContext> trans;
trans( trans(
ctx_, in_begin, in_end, out_begin, CastOpTransformFunctor<InT, OutT>()); ctx_, in_begin, in_end, out_begin, CastOpTransformFunctor<InT, OutT>());
} }
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
...@@ -94,7 +94,7 @@ class CenterLossKernel : public framework::OpKernel<T> { ...@@ -94,7 +94,7 @@ class CenterLossKernel : public framework::OpKernel<T> {
T *center_out_index; T *center_out_index;
T *center_loss_diff_index; T *center_loss_diff_index;
T *acc_index; T *acc_index;
platform::Transform<DeviceContext> trans; phi::Transform<DeviceContext> trans;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
tLabel = label_data[i]; tLabel = label_data[i];
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -28,7 +28,8 @@ limitations under the License. */ ...@@ -28,7 +28,8 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/elementwise_grad.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h"
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h"
namespace paddle { namespace paddle {
...@@ -98,7 +98,7 @@ struct ClipAndFakeQuantFunctor<phi::CPUContext, T> { ...@@ -98,7 +98,7 @@ struct ClipAndFakeQuantFunctor<phi::CPUContext, T> {
phi::DenseTensor *out) { phi::DenseTensor *out) {
T s = scale.data<T>()[0]; T s = scale.data<T>()[0];
T inv_s = inverse(s); T inv_s = inverse(s);
platform::Transform<phi::CPUContext> trans; phi::Transform<phi::CPUContext> trans;
if (round_type == 0) { if (round_type == 0) {
trans(ctx, trans(ctx,
in.data<T>(), in.data<T>(),
...@@ -130,7 +130,7 @@ struct ClipAndFakeQuantDequantFunctor<phi::CPUContext, T> { ...@@ -130,7 +130,7 @@ struct ClipAndFakeQuantDequantFunctor<phi::CPUContext, T> {
T s = scale.data<T>()[0]; T s = scale.data<T>()[0];
T inv_s = inverse(s); T inv_s = inverse(s);
platform::Transform<phi::CPUContext> trans; phi::Transform<phi::CPUContext> trans;
if (round_type == 0) { if (round_type == 0) {
trans(ctx, trans(ctx,
in.data<T>(), in.data<T>(),
...@@ -175,7 +175,7 @@ struct ChannelClipAndFakeQuantFunctor<phi::CPUContext, T> { ...@@ -175,7 +175,7 @@ struct ChannelClipAndFakeQuantFunctor<phi::CPUContext, T> {
auto *out_data = out->mutable_data<T>(ctx.GetPlace()); auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims(); auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
platform::Transform<phi::CPUContext> trans; phi::Transform<phi::CPUContext> trans;
if (quant_axis == 0) { if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel; const int64_t channel_size = in.numel() / channel;
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
...@@ -256,7 +256,7 @@ struct ChannelClipFakeQuantDequantFunctor<phi::CPUContext, T> { ...@@ -256,7 +256,7 @@ struct ChannelClipFakeQuantDequantFunctor<phi::CPUContext, T> {
auto *out_data = out->mutable_data<T>(ctx.GetPlace()); auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims(); auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
platform::Transform<phi::CPUContext> trans; phi::Transform<phi::CPUContext> trans;
if (quant_axis == 0) { if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel; const int64_t channel_size = in.numel() / channel;
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/isfinite_kernel.h" #include "paddle/phi/kernels/isfinite_kernel.h"
#include "paddle/phi/kernels/reduce_all_kernel.h" #include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/kernels/reduce_any_kernel.h" #include "paddle/phi/kernels/reduce_any_kernel.h"
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using platform::Transform; using phi::Transform;
template <typename T, template <typename T,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h"
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
......
...@@ -248,10 +248,6 @@ if(WITH_GPU) ...@@ -248,10 +248,6 @@ if(WITH_GPU)
device_context_test_cuda_graph device_context_test_cuda_graph
SRCS device_context_test_cuda_graph.cu SRCS device_context_test_cuda_graph.cu
DEPS device_context gpu_info cuda_graph_with_memory_pool) DEPS device_context gpu_info cuda_graph_with_memory_pool)
nv_test(
transform_test
SRCS transform_test.cu
DEPS memory place device_context)
endif() endif()
if(WITH_ROCM) if(WITH_ROCM)
...@@ -277,10 +273,6 @@ if(WITH_ROCM) ...@@ -277,10 +273,6 @@ if(WITH_ROCM)
device_context_test device_context_test
SRCS device_context_test.cu SRCS device_context_test.cu
DEPS device_context gpu_info) DEPS device_context gpu_info)
hip_test(
transform_test
SRCS transform_test.cu
DEPS memory place device_context)
endif() endif()
cc_library(timer SRCS timer.cc) cc_library(timer SRCS timer.cc)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if !defined(__NVCC__) && !defined(__HIPCC__)
#error device_ptr_cast must be include by .cu file
#endif
#include <type_traits> // For std::remove_pointer and std::is_pointer.
#include "thrust/device_ptr.h"
namespace paddle {
namespace platform {
namespace details {
// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA
// device) pointer into thrust::device_ptr, the other keeps rest types
// un-casted.
template <typename T, bool is_ptr>
struct PointerToThrustDevicePtr;
template <typename T>
struct PointerToThrustDevicePtr<T, true> {
using ELEM = typename std::remove_pointer<T>::type;
using RTYPE = thrust::device_ptr<ELEM>;
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
return thrust::device_pointer_cast(ele);
}
};
template <typename T>
struct PointerToThrustDevicePtr<T, false> {
using RTYPE = T;
inline RTYPE operator()(RTYPE it) const { return it; }
};
// CastToCUDATransformIterator casts a pointer to thrust::device_ptr
// so it could be used as the iterator of thrust::transform. It
// doesn't cast other types.
//
// We need CastToCUDATransformIterator because it is often that we
// want to use device memory pointers as transform iterators, e.g., to
// transform a block of float32 to float16. In this case, we want
// CastToCUDATransformIterator to cast float16/32 pointers to
// thrust::device_ptr, otherwise they cannot work as the iterator
// required by thrust::transform. At the same time, we don't want to
// cast thrust::device_ptr to thrust::device_ptr repeatedly.
template <typename T>
auto CastToCUDATransformIterator(T t) ->
typename PointerToThrustDevicePtr<T, std::is_pointer<T>::value>::RTYPE {
PointerToThrustDevicePtr<T, std::is_pointer<T>::value> cast;
return cast(t);
}
} // namespace details
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -17,20 +17,17 @@ limitations under the License. */ ...@@ -17,20 +17,17 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include "thrust/device_ptr.h"
#include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h"
#endif #endif
namespace paddle { namespace phi {
namespace platform {
// Transform applys a unary or a binary functor on each element in a // Transform applys a unary or a binary functor on each element in a
// range defined by a pair of iterators. // range defined by a pair of iterators.
...@@ -40,16 +37,16 @@ namespace platform { ...@@ -40,16 +37,16 @@ namespace platform {
// //
// NOTE: We need to define InputIter and OutputIter defined as // NOTE: We need to define InputIter and OutputIter defined as
// different types, because the InputIter points op's inputs and // different types, because the InputIter points op's inputs and
// OutputIter pints to op's outputs. // OutputIter points to op's outputs.
// //
// NOTE: We don't assume that InputIter to be const InputType* and // NOTE: We don't assume that InputIter to be const InputType* and
// OutputIter to be OutputType*, because we might use a iterator // OutputIter to be OutputType*, because we might use a iterator
// class, paddle::fluid::operators::RowwiseTRansformIterator. // class, paddle::fluid::operators::RowwiseTRansformIterator.
template <typename DeviceContext> template <typename Context>
struct Transform { struct Transform {
// The unary version. // The unary version.
template <typename InputIter, typename OutputIter, typename UnaryOperation> template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, void operator()(const Context& context,
InputIter first, InputIter first,
InputIter last, InputIter last,
OutputIter result, OutputIter result,
...@@ -60,7 +57,7 @@ struct Transform { ...@@ -60,7 +57,7 @@ struct Transform {
typename InputIter2, typename InputIter2,
typename OutputIter, typename OutputIter,
typename BinaryOperation> typename BinaryOperation>
void operator()(const DeviceContext& context, void operator()(const Context& context,
InputIter1 first1, InputIter1 first1,
InputIter1 last1, InputIter1 last1,
InputIter2 first2, InputIter2 first2,
...@@ -97,6 +94,46 @@ struct Transform<phi::CPUContext> { ...@@ -97,6 +94,46 @@ struct Transform<phi::CPUContext> {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA
// device) pointer into thrust::device_ptr, the other keeps rest types
// un-casted.
template <typename T, bool is_ptr>
struct PointerToThrustDevicePtr;
template <typename T>
struct PointerToThrustDevicePtr<T, true> {
using ELEM = typename std::remove_pointer<T>::type;
using RTYPE = thrust::device_ptr<ELEM>;
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
return thrust::device_pointer_cast(ele);
}
};
template <typename T>
struct PointerToThrustDevicePtr<T, false> {
using RTYPE = T;
inline RTYPE operator()(RTYPE it) const { return it; }
};
// CastToCUDATransformIterator casts a pointer to thrust::device_ptr
// so it could be used as the iterator of thrust::transform. It
// doesn't cast other types.
//
// We need CastToCUDATransformIterator because it is often that we
// want to use device memory pointers as transform iterators, e.g., to
// transform a block of float32 to float16. In this case, we want
// CastToCUDATransformIterator to cast float16/32 pointers to
// thrust::device_ptr, otherwise they cannot work as the iterator
// required by thrust::transform. At the same time, we don't want to
// cast thrust::device_ptr to thrust::device_ptr repeatedly.
template <typename T>
auto CastToCUDATransformIterator(T t) ->
typename PointerToThrustDevicePtr<T, std::is_pointer<T>::value>::RTYPE {
PointerToThrustDevicePtr<T, std::is_pointer<T>::value> cast;
return cast(t);
}
template <> template <>
struct Transform<phi::GPUContext> { struct Transform<phi::GPUContext> {
template <typename InputIter, typename OutputIter, typename UnaryOperation> template <typename InputIter, typename OutputIter, typename UnaryOperation>
...@@ -106,21 +143,21 @@ struct Transform<phi::GPUContext> { ...@@ -106,21 +143,21 @@ struct Transform<phi::GPUContext> {
OutputIter result, OutputIter result,
UnaryOperation op) { UnaryOperation op) {
auto place = context.GetPlace(); auto place = context.GetPlace();
PADDLE_ENFORCE_EQ(is_gpu_place(place), PADDLE_ENFORCE_EQ(place.GetType() == phi::AllocationType::GPU,
true, true,
platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The CUDA Transform must be used in GPU place.")); "The CUDA Transform must be used in GPU place."));
#ifdef __HIPCC__ #ifdef __HIPCC__
thrust::transform(thrust::hip::par.on(context.stream()), thrust::transform(thrust::hip::par.on(context.stream()),
details::CastToCUDATransformIterator(first), CastToCUDATransformIterator(first),
details::CastToCUDATransformIterator(last), CastToCUDATransformIterator(last),
details::CastToCUDATransformIterator(result), CastToCUDATransformIterator(result),
op); op);
#else #else
thrust::transform(thrust::cuda::par.on(context.stream()), thrust::transform(thrust::cuda::par.on(context.stream()),
details::CastToCUDATransformIterator(first), CastToCUDATransformIterator(first),
details::CastToCUDATransformIterator(last), CastToCUDATransformIterator(last),
details::CastToCUDATransformIterator(result), CastToCUDATransformIterator(result),
op); op);
#endif #endif
} }
...@@ -136,28 +173,27 @@ struct Transform<phi::GPUContext> { ...@@ -136,28 +173,27 @@ struct Transform<phi::GPUContext> {
OutputIter result, OutputIter result,
BinaryOperation op) { BinaryOperation op) {
auto place = context.GetPlace(); auto place = context.GetPlace();
PADDLE_ENFORCE_EQ(is_gpu_place(place), PADDLE_ENFORCE_EQ(place.GetType() == phi::AllocationType::GPU,
true, true,
platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The CUDA Transform must be used in GPU place.")); "The CUDA Transform must be used in GPU place."));
#ifdef __HIPCC__ #ifdef __HIPCC__
thrust::transform(thrust::hip::par.on(context.stream()), thrust::transform(thrust::hip::par.on(context.stream()),
details::CastToCUDATransformIterator(first1), CastToCUDATransformIterator(first1),
details::CastToCUDATransformIterator(last1), CastToCUDATransformIterator(last1),
details::CastToCUDATransformIterator(first2), CastToCUDATransformIterator(first2),
details::CastToCUDATransformIterator(result), CastToCUDATransformIterator(result),
op); op);
#else #else
thrust::transform(thrust::cuda::par.on(context.stream()), thrust::transform(thrust::cuda::par.on(context.stream()),
details::CastToCUDATransformIterator(first1), CastToCUDATransformIterator(first1),
details::CastToCUDATransformIterator(last1), CastToCUDATransformIterator(last1),
details::CastToCUDATransformIterator(first2), CastToCUDATransformIterator(first2),
details::CastToCUDATransformIterator(result), CastToCUDATransformIterator(result),
op); op);
#endif #endif
} }
}; };
#endif #endif
} // namespace platform } // namespace phi
} // namespace paddle
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
namespace phi { namespace phi {
...@@ -48,7 +48,7 @@ void BitwiseNotKernel(const Context& dev_ctx, ...@@ -48,7 +48,7 @@ void BitwiseNotKernel(const Context& dev_ctx,
T* out_data = dev_ctx.template Alloc<T>(out); T* out_data = dev_ctx.template Alloc<T>(out);
size_t numel = x.numel(); size_t numel = x.numel();
funcs::BitwiseNotFunctor<T> func; funcs::BitwiseNotFunctor<T> func;
paddle::platform::Transform<Context> trans; phi::Transform<Context> trans;
trans(dev_ctx, x_data, x_data + numel, out_data, func); trans(dev_ctx, x_data, x_data + numel, out_data, func);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
namespace phi { namespace phi {
...@@ -36,7 +36,7 @@ void CastKernelImpl(const CPUContext& dev_ctx, ...@@ -36,7 +36,7 @@ void CastKernelImpl(const CPUContext& dev_ctx,
auto* out_begin = dev_ctx.Alloc<OutT>(out); auto* out_begin = dev_ctx.Alloc<OutT>(out);
paddle::platform::Transform<CPUContext> trans; phi::Transform<CPUContext> trans;
trans(dev_ctx, trans(dev_ctx,
in_begin, in_begin,
in_end, in_end,
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/phi/kernels/hsigmoid_loss_kernel.h" #include "paddle/phi/kernels/hsigmoid_loss_kernel.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
...@@ -82,7 +82,7 @@ void HSigmoidLossKernel(const Context& ctx, ...@@ -82,7 +82,7 @@ void HSigmoidLossKernel(const Context& ctx,
} }
bit_code->Mul(pre_out, w, x); bit_code->Mul(pre_out, w, x);
// clip to [-40, 40] // clip to [-40, 40]
paddle::platform::Transform<Context> trans; phi::Transform<Context> trans;
trans(ctx, trans(ctx,
pre_out_data, pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data + pre_out->numel(),
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/phi/kernels/funcs/logical_functor.h" #include "paddle/phi/kernels/funcs/logical_functor.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
namespace phi { namespace phi {
...@@ -47,7 +47,7 @@ void LogicalNotKernel(const Context& dev_ctx, ...@@ -47,7 +47,7 @@ void LogicalNotKernel(const Context& dev_ctx,
auto* out_ptr = dev_ctx.template Alloc<bool>(out); auto* out_ptr = dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func; funcs::LogicalNotFunctor<T> unary_func;
paddle::platform::Transform<Context> trans; phi::Transform<Context> trans;
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func); trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
} }
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
...@@ -220,12 +220,12 @@ class TransformFunctor { ...@@ -220,12 +220,12 @@ class TransformFunctor {
} }
inline void Run() const { inline void Run() const {
paddle::platform::Transform<DeviceContext> trans; phi::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_); trans(ctx_, x_, x_ + nx_, y_, z_, func_);
} }
inline void RunRowWise(int n, int pre) const { inline void RunRowWise(int n, int pre) const {
paddle::platform::Transform<DeviceContext> trans; phi::Transform<DeviceContext> trans;
if (is_xsize_larger_) { if (is_xsize_larger_) {
trans(ctx_, trans(ctx_,
x_, x_,
...@@ -244,7 +244,7 @@ class TransformFunctor { ...@@ -244,7 +244,7 @@ class TransformFunctor {
} }
inline void RunMidWise(int n, int pre, int post) const { inline void RunMidWise(int n, int pre, int post) const {
paddle::platform::Transform<DeviceContext> trans; phi::Transform<DeviceContext> trans;
if (is_xsize_larger_) { if (is_xsize_larger_) {
trans(ctx_, trans(ctx_,
x_, x_,
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/clip_kernel.h" #include "paddle/phi/kernels/clip_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
...@@ -59,7 +59,7 @@ void ClipGradKernel(const Context& dev_ctx, ...@@ -59,7 +59,7 @@ void ClipGradKernel(const Context& dev_ctx,
auto* d_x_data = dev_ctx.template Alloc<T>(x_grad); auto* d_x_data = dev_ctx.template Alloc<T>(x_grad);
const T* d_out_data = out_grad.data<T>(); const T* d_out_data = out_grad.data<T>();
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
paddle::platform::Transform<Context> trans; phi::Transform<Context> trans;
trans(dev_ctx, trans(dev_ctx,
d_out_data, d_out_data,
d_out_data + numel, d_out_data + numel,
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/clip_kernel.h" #include "paddle/phi/kernels/clip_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
...@@ -67,7 +67,7 @@ void ClipKernel(const Context& dev_ctx, ...@@ -67,7 +67,7 @@ void ClipKernel(const Context& dev_ctx,
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor); phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
#endif #endif
} else { } else {
paddle::platform::Transform<Context> trans; phi::Transform<Context> trans;
trans( trans(
dev_ctx, x_data, x_data + numel, out_data, ClipFunctor<T>(min_, max_)); dev_ctx, x_data, x_data + numel, out_data, ClipFunctor<T>(min_, max_));
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/kernels/isfinite_kernel.h" #include "paddle/phi/kernels/isfinite_kernel.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/transform.h"
namespace phi { namespace phi {
...@@ -28,7 +28,7 @@ namespace phi { ...@@ -28,7 +28,7 @@ namespace phi {
const Context& ctx, const DenseTensor& x, DenseTensor* out) { \ const Context& ctx, const DenseTensor& x, DenseTensor* out) { \
auto* out_ptr = ctx.template Alloc<bool>(out); \ auto* out_ptr = ctx.template Alloc<bool>(out); \
funcs::functor<T> unary_func; \ funcs::functor<T> unary_func; \
paddle::platform::Transform<Context> trans; \ phi::Transform<Context> trans; \
trans(ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func); \ trans(ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func); \
} }
......
...@@ -50,7 +50,7 @@ void ClipSparseKernel(const Context& dev_ctx, ...@@ -50,7 +50,7 @@ void ClipSparseKernel(const Context& dev_ctx,
auto* out_tensor = out->mutable_value(); auto* out_tensor = out->mutable_value();
auto* out_data = out_tensor->data<T>(); auto* out_data = out_tensor->data<T>();
int64_t numel = out_tensor->numel(); int64_t numel = out_tensor->numel();
paddle::platform::Transform<Context> trans; phi::Transform<Context> trans;
trans(dev_ctx, trans(dev_ctx,
out_data, out_data,
out_data + numel, out_data + numel,
......
...@@ -23,10 +23,18 @@ if(WITH_GPU) ...@@ -23,10 +23,18 @@ if(WITH_GPU)
phi_test_scalar phi_test_scalar
SRCS test_scalar.cu SRCS test_scalar.cu
DEPS scalar api_scalar) DEPS scalar api_scalar)
nv_test(
transform_test
SRCS transform_test.cu
DEPS memory place device_context)
endif() endif()
if(WITH_ROCM) if(WITH_ROCM)
hip_test( hip_test(
phi_test_scalar phi_test_scalar
SRCS test_scalar.cu SRCS test_scalar.cu
DEPS scalar api_scalar) DEPS scalar api_scalar)
hip_test(
transform_test
SRCS transform_test.cu
DEPS memory place device_context)
endif() endif()
...@@ -14,10 +14,11 @@ limitations under the License. */ ...@@ -14,10 +14,11 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/phi/common/transform.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
template <typename T> template <typename T>
...@@ -44,7 +45,7 @@ using paddle::platform::CUDAPlace; ...@@ -44,7 +45,7 @@ using paddle::platform::CUDAPlace;
using phi::CPUContext; using phi::CPUContext;
using phi::GPUContext; using phi::GPUContext;
using paddle::platform::Transform; using phi::Transform;
TEST(Transform, CPUUnary) { TEST(Transform, CPUUnary) {
CPUContext ctx; CPUContext ctx;
...@@ -58,19 +59,17 @@ TEST(Transform, CPUUnary) { ...@@ -58,19 +59,17 @@ TEST(Transform, CPUUnary) {
TEST(Transform, GPUUnary) { TEST(Transform, GPUUnary) {
CUDAPlace gpu0(0); CUDAPlace gpu0(0);
phi::GPUContext ctx(gpu0); phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() auto* ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(phi::GPUPlace()));
.GetAllocator(gpu0, ctx.stream())
.get());
ctx.PartialInitWithAllocator();
float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4};
auto gpu_allocation = Alloc(gpu0, sizeof(float) * 4); auto gpu_allocation = Alloc(gpu0, sizeof(float) * 4);
float* gpu_buf = static_cast<float*>(gpu_allocation->ptr()); float* gpu_buf = static_cast<float*>(gpu_allocation->ptr());
Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf), ctx.stream()); Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf), ctx->stream());
Transform<phi::GPUContext> trans; Transform<phi::GPUContext> trans;
trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale<float>(10)); trans(*ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale<float>(10));
ctx.Wait(); ctx->Wait();
Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf), ctx.stream()); Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf), ctx->stream());
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
ASSERT_NEAR(cpu_buf[i], static_cast<float>(i + 1), 1e-5); ASSERT_NEAR(cpu_buf[i], static_cast<float>(i + 1), 1e-5);
} }
...@@ -89,18 +88,16 @@ TEST(Transform, CPUBinary) { ...@@ -89,18 +88,16 @@ TEST(Transform, CPUBinary) {
TEST(Transform, GPUBinary) { TEST(Transform, GPUBinary) {
int buf[4] = {1, 2, 3, 4}; int buf[4] = {1, 2, 3, 4};
CUDAPlace gpu0(0); CUDAPlace gpu0(0);
phi::GPUContext ctx(gpu0); phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() auto* ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(phi::GPUPlace()));
.GetAllocator(gpu0, ctx.stream())
.get());
ctx.PartialInitWithAllocator();
auto gpu_allocation = Alloc(gpu0, sizeof(buf)); auto gpu_allocation = Alloc(gpu0, sizeof(buf));
int* gpu_buf = static_cast<int*>(gpu_allocation->ptr()); int* gpu_buf = static_cast<int*>(gpu_allocation->ptr());
Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf), ctx.stream()); Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf), ctx->stream());
Transform<phi::GPUContext> trans; Transform<phi::GPUContext> trans;
trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply<int>()); trans(*ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply<int>());
ctx.Wait(); ctx->Wait();
Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf), ctx.stream()); Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf), ctx->stream());
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
ASSERT_EQ((i + 1) * (i + 1), buf[i]); ASSERT_EQ((i + 1) * (i + 1), buf[i]);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册