diff --git a/paddle/fluid/operators/distribution_helper.h b/paddle/fluid/operators/distribution_helper.h deleted file mode 100644 index c13bf687af23470d4595def6fb6fabf7385c999f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/distribution_helper.h +++ /dev/null @@ -1,244 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#ifdef __NVCC__ -#include -#endif -#ifdef __HIPCC__ -#include -#endif - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/core/hostdevice.h" - -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/phi/kernels/primitive/kernel_primitives.h" -#endif - -#if !defined(_WIN32) -#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) -#else -// there is no equivalent intrinsics in msvc. -#define UNLIKELY(condition) (condition) -#endif - -namespace paddle { -namespace distribution { - -using Tensor = framework::Tensor; - -/********************* Transformation Function **********************/ -template -struct exponential_transform { - explicit exponential_transform(T lambda) : lambda_(lambda) {} - - HOSTDEVICE inline T operator()(T val) const { -#if defined(__NVCC__) || defined(__HIPCC__) - if (std::is_same::value) { - return static_cast(-1.0) / lambda_ * log(val); - } else { - return static_cast(-1.0) / lambda_ * __logf(val); - } -#else - return static_cast(-1.0) / lambda_ * std::log(static_cast(1.0) - val); -#endif - } - - private: - T lambda_; -}; - -template -struct uniform_transform { - explicit uniform_transform(T min, T max) : range_(max - min), min_(min) {} - - HOSTDEVICE inline T operator()(T val) const { - if (UNLIKELY(val == static_cast(1.0))) { - return min_; - } else { - return val * range_ + min_; - } - } - - private: - T range_; - T min_; -}; - -template -struct normal_transform { - explicit normal_transform(T mean, T std) : mean_(mean), std_(std) {} - - HOSTDEVICE inline T operator()(T val) const { return val * std_ + mean_; } - - private: - T mean_; - T std_; -}; - -#if defined(__NVCC__) || defined(__HIPCC__) - -namespace kps = phi::kps; - -/*********************** Distribution Function *************************/ -template -struct uniform_distribution; - -template -struct normal_distribution; - -#if defined(__NVCC__) -template <> -struct uniform_distribution { - __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { - return curand_uniform4(state); - } - static constexpr int kReturnsCount = 4; -}; - -template <> -struct uniform_distribution { - __device__ inline double2 operator()( - curandStatePhilox4_32_10_t *state) const { - return curand_uniform2_double(state); - } - static constexpr int kReturnsCount = 2; -}; - -template <> -struct normal_distribution { - __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { - return curand_normal4(state); - } - static constexpr int kReturnsCount = 4; -}; - -template <> -struct normal_distribution { - __device__ inline double2 operator()( - curandStatePhilox4_32_10_t *state) const { - return curand_normal2_double(state); - } - static constexpr int kReturnsCount = 2; -}; - -#else -template <> -struct uniform_distribution { - __device__ inline float4 operator()( - hiprandStatePhilox4_32_10_t *state) const { - return hiprand_uniform4(state); - } - static constexpr int kReturnsCount = 4; -}; - -template <> -struct uniform_distribution { - __device__ inline double2 operator()( - hiprandStatePhilox4_32_10_t *state) const { - return hiprand_uniform2_double(state); - } - static constexpr int kReturnsCount = 2; -}; - -template <> -struct normal_distribution { - __device__ inline float4 operator()( - hiprandStatePhilox4_32_10_t *state) const { - return hiprand_normal4(state); - } - static constexpr int kReturnsCount = 4; -}; - -template <> -struct normal_distribution { - __device__ inline double2 operator()( - hiprandStatePhilox4_32_10_t *state) const { - return hiprand_normal2_double(state); - } - static constexpr int kReturnsCount = 2; -}; -#endif - -/******** Launch GPU function of distribution and transformation *********/ -template -__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, - DistOp dist, TransformOp trans, T *out_data, - size_t stride) { - size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); - static constexpr int kCount = DistOp::kReturnsCount; -#if defined(__NVCC__) - curandStatePhilox4_32_10_t state; - curand_init(seed, idx + THREAD_ID_X, offset, &state); - using SType = curandStatePhilox4_32_10_t; -#else - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx + THREAD_ID_X, offset, &state); - using SType = hiprandStatePhilox4_32_10_t; -#endif - size_t total_thread = GRID_NUM_X * BLOCK_NUM_X; - T args[kCount]; - T result[kCount]; - for (size_t i = idx; i < size; i += total_thread * kCount) { - kps::ElementwiseRandom(&args[0], dist, &state); - kps::ElementwiseUnary(&result[0], &args[0], - trans); - kps::WriteData(out_data + i, &result[0], size - i, - 1, stride, 1); - __syncthreads(); - } -} - -template -void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx, - Tensor *out, DistOp dist, TransformOp trans) { - T *out_data = out->mutable_data(dev_ctx.GetPlace()); - auto size = out->numel(); - - int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - - size_t block_size = 256; - size_t expect_grid_size = (size + block_size - 1) / block_size; - const auto &prop = platform::GetDeviceProperties(device_id); - size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) * - prop.multiProcessorCount; - size_t grid_size = - expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size; - - size_t total_thread = block_size * grid_size; - size_t curand4_loop_times = - (size + 4 * total_thread - 1) / (4 * total_thread); - // 'increment' shoulde be multiple of 4 - uint64_t increment = curand4_loop_times * 4; - - auto seed_offset = gen_cuda->IncrementOffset(increment); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; - - DistributionKernel< - T, DistOp, TransformOp><<>>( - size, seed, offset, dist, trans, out_data, total_thread); -} - -#endif - -} // namespace distribution -} // namespace paddle diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index cdcf683fb92c5a5ef56f61da15e5979fd1364945..dcdab033e8f8014214900727d53f329e5a7b4ab4 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -34,8 +34,8 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/functors.h" namespace paddle { @@ -86,8 +86,8 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, bool is_upscale_in_train, uint64_t increment) { using MT = typename details::MPTypeTrait::Type; - using LoadT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; #ifdef PADDLE_WITH_HIP int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; @@ -102,7 +102,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, MT factor = static_cast(1.0f / (1.0f - dropout_prob)); for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { LoadT src_val; - platform::Load(&src[i], &src_val); + phi::Load(&src[i], &src_val); #ifdef PADDLE_WITH_HIP float4 rand = hiprand_uniform4(&state); @@ -126,8 +126,8 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, } } - platform::Store(dst_val, &dst[i]); - platform::Store(mask_val, &mask[i]); + phi::Store(dst_val, &dst[i]); + phi::Store(mask_val, &mask[i]); } } @@ -153,16 +153,16 @@ __global__ void DropoutGradCUDAKernel( const typename details::MPTypeTrait::Type factor, const int64_t size, T* dx) { using MT = typename details::MPTypeTrait::Type; - using LoadT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { LoadT dout_val; - platform::Load(&dout[i], &dout_val); + phi::Load(&dout[i], &dout_val); MaskLoadT mask_val; - platform::Load(&mask[i], &mask_val); + phi::Load(&mask[i], &mask_val); LoadT dx_val; @@ -172,7 +172,7 @@ __global__ void DropoutGradCUDAKernel( static_cast(mask_val[j]) * factor); } - platform::Store(dx_val, &dx[i]); + phi::Store(dx_val, &dx[i]); } } @@ -219,7 +219,7 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, uint64_t increment; // VectorizedRandomGenerator use curand_uniform4, so we only support // vec_size is 4; - int vec_size = (platform::GetVectorizedSize(x_data) == 4) ? 4 : 1; + int vec_size = (phi::GetVectorizedSize(x_data) == 4) ? 4 : 1; auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); auto offset = ((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size; diff --git a/paddle/fluid/operators/exponential_op.cc b/paddle/fluid/operators/exponential_op.cc index ee456dcdafbc51d547e7beacc4e4e79f98738b88..1a48a6767852e138e7725a68ca4ffc56de8234be 100644 --- a/paddle/fluid/operators/exponential_op.cc +++ b/paddle/fluid/operators/exponential_op.cc @@ -76,7 +76,7 @@ class ExponentialKernel auto engine = gen->GetCPUEngine(); std::uniform_real_distribution uniform(0.0, 1.0); - distribution::exponential_transform trans(lambda); + phi::funcs::exponential_transform trans(lambda); for (int64_t i = 0; i < size; ++i) { out_data[i] = trans(uniform(*engine)); } diff --git a/paddle/fluid/operators/exponential_op.cu b/paddle/fluid/operators/exponential_op.cu index 8b989501e4f4248b0c2e3b23e1e75a4865b08588..d5abbf9a26afe6bcbbd8549f59d632fc4e53fec2 100644 --- a/paddle/fluid/operators/exponential_op.cu +++ b/paddle/fluid/operators/exponential_op.cu @@ -26,9 +26,9 @@ class ExponentialKernel auto& dev_cxt = ctx.template device_context(); T lambda = static_cast(ctx.Attr("lambda")); - distribution::uniform_distribution dist; - distribution::exponential_transform trans(lambda); - distribution::distribution_and_transform(dev_cxt, out, dist, trans); + phi::funcs::uniform_distribution dist; + phi::funcs::exponential_transform trans(lambda); + phi::funcs::distribution_and_transform(dev_cxt, out, dist, trans); } }; diff --git a/paddle/fluid/operators/exponential_op.h b/paddle/fluid/operators/exponential_op.h index fbcabc594db0814da1ec50934a0f02514dc208be..7ded174a9f47ede48a49b19b25539867ce344fb0 100644 --- a/paddle/fluid/operators/exponential_op.h +++ b/paddle/fluid/operators/exponential_op.h @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/distribution_helper.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 51cf3bce1cec595ac168e5b2d56c672ec96c27e0..3a2de0c4a093514a1c40321ab7dad61011709204 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -89,9 +89,9 @@ __global__ void BroadcastKernelBinary( template void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, const T* in0, const T* in1, T* out) { - int in_vec_size = std::min(platform::GetVectorizedSize(in0), - platform::GetVectorizedSize(in1)); - int out_vec_size = std::min(4, platform::GetVectorizedSize(out)); + int in_vec_size = + std::min(phi::GetVectorizedSize(in0), phi::GetVectorizedSize(in1)); + int out_vec_size = std::min(4, phi::GetVectorizedSize(out)); int vec_size = std::min(out_vec_size, in_vec_size); int numel = m * n; diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index 994601a2f0608b4fc04966c7549c421f395f3ec7..9f5a1bad047b44b715e11e74d92fdca1982c96f8 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -130,17 +130,17 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, const T factor, const int64_t size, T *dx) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { LoadT dout_vec; LoadT src_vec; MaskLoadT mask_vec; - platform::Load(&dout[i], &dout_vec); - platform::Load(&mask[i], &mask_vec); - platform::Load(&src[i], &src_vec); + phi::Load(&dout[i], &dout_vec); + phi::Load(&mask[i], &mask_vec); + phi::Load(&src[i], &src_vec); StoreT dx_vec; #pragma unroll @@ -148,7 +148,7 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, T tmp = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]); } - platform::Store(dx_vec, &dx[i]); + phi::Store(dx_vec, &dx[i]); } } @@ -167,9 +167,9 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, T *dx, T *dbias) { int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; // calculate the dx and temporary sum if (col_id * VecSize < cols) { @@ -180,10 +180,10 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, LoadT bias_vec; MaskLoadT mask_vec; - platform::Load(&dout[index], &dout_vec); - platform::Load(&src[index], &src_vec); - platform::Load(&mask[index], &mask_vec); - platform::Load(&bias[col_id * VecSize], &bias_vec); + phi::Load(&dout[index], &dout_vec); + phi::Load(&src[index], &src_vec); + phi::Load(&mask[index], &mask_vec); + phi::Load(&bias[col_id * VecSize], &bias_vec); StoreT dx_vec; #pragma unroll @@ -194,7 +194,7 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, dx_vec[i] = val; tmp_sum[i] += val; } - platform::Store(dx_vec, &dx[index]); + phi::Store(dx_vec, &dx[index]); } } diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index f79277e4e8f0d22cedafc9f7b40b56ecd2d6a817..6bf3a7114f4ced3c7c6ecd1f1afeca60ff66528f 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -21,11 +21,11 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/functors.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index ceba3accca7727b5e4f22951d87f9e91034e3403..d53a24a57e3cc1ede127f497a9be9e3b5fa1ab0b 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -42,12 +42,12 @@ __device__ void CalcLayernormY( const LayerNormScaleBiasT *bias, const T *x, T *y, const int row_id, const int col_id, const int cols, const LayerNormParamType mean_val, const LayerNormParamType invvar) { - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using LoadU = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + using LoadU = phi::AlignedVector; using LoadScaleOrBias = - platform::AlignedVector, - VecSize>; + phi::AlignedVector, + VecSize>; for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) { LoadScaleOrBias scale_vec; LoadScaleOrBias bias_vec; @@ -60,15 +60,15 @@ __device__ void CalcLayernormY( static_cast>(0); } // vectorize load data from global - platform::Load(&x[row_id * cols + i], &x_vec); + phi::Load(&x[row_id * cols + i], &x_vec); if (scale != nullptr) { - platform::Load, - VecSize>(&scale[i], &scale_vec); + phi::Load, VecSize>( + &scale[i], &scale_vec); } if (bias != nullptr) { - platform::Load, - VecSize>(&bias[i], &bias_vec); + phi::Load, VecSize>( + &bias[i], &bias_vec); } StoreT y_vec; @@ -78,7 +78,7 @@ __device__ void CalcLayernormY( (static_cast(x_vec[ii]) - mean_val) * invvar + static_cast(bias_vec[ii])); } - platform::Store(y_vec, &y[row_id * cols + i]); + phi::Store(y_vec, &y[row_id * cols + i]); } } @@ -190,9 +190,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr, U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) { - using Vec = platform::AlignedVector; - using Vec_scale = platform::AlignedVector; - using MaskStoreT = platform::AlignedVector; + using Vec = phi::AlignedVector; + using Vec_scale = phi::AlignedVector; + using MaskStoreT = phi::AlignedVector; const int tidx = threadIdx.x; const int bidx = blockIdx.x; @@ -214,8 +214,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( Vec_scale beta[LDGS]; #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Load(gamma_ptr + col * VecSize, &gamma[it]); - platform::Load(beta_ptr + col * VecSize, &beta[it]); + phi::Load(gamma_ptr + col * VecSize, &gamma[it]); + phi::Load(beta_ptr + col * VecSize, &beta[it]); col += THREADS_PER_ROW; } @@ -225,10 +225,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( Vec residual[LDGS]; #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, - &x[it]); - platform::Load( - residual_ptr + row * LN_NUM_COLS + col * VecSize, &residual[it]); + phi::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); + phi::Load(residual_ptr + row * LN_NUM_COLS + col * VecSize, + &residual[it]); col += THREADS_PER_ROW; } @@ -270,9 +269,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( // store dropout_residual_out and mask_out #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Store( + phi::Store( x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize); - platform::Store( + phi::Store( mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize); col += THREADS_PER_ROW; } @@ -333,8 +332,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Store(x[it], - y_ptr + row * LN_NUM_COLS + col * VecSize); + phi::Store(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize); col += THREADS_PER_ROW; } } diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 1b135ad6098e58f457f5d21e73ac6d1a6a7c4074..1d3085a013f81ee9dca21468476df8f621bb26c2 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -32,9 +32,9 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test, typename details::MPTypeTrait::Type *mean_val, typename details::MPTypeTrait::Type *var_val, Functor act_func) { - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using MaskStoreT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + using MaskStoreT = phi::AlignedVector; using U = typename details::MPTypeTrait::Type; LoadT src_vec; @@ -46,14 +46,13 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( residual_vec[ii] = static_cast(0); } // vectorize load data from global - platform::Load(&src[row_id * cols + col_id], &src_vec); + phi::Load(&src[row_id * cols + col_id], &src_vec); if (residual) { - platform::Load(&residual[row_id * cols + col_id], - &residual_vec); + phi::Load(&residual[row_id * cols + col_id], &residual_vec); } if (bias) { - platform::Load(&bias[col_id], &bias_vec); + phi::Load(&bias[col_id], &bias_vec); } MaskStoreT mask_vec; @@ -89,9 +88,9 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( } // store result to global - platform::Store(dest_vec, &dst[row_id * cols + col_id]); + phi::Store(dest_vec, &dst[row_id * cols + col_id]); if (!is_test) { - platform::Store(mask_vec, &mask[row_id * cols + col_id]); + phi::Store(mask_vec, &mask[row_id * cols + col_id]); } } @@ -176,21 +175,21 @@ __global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask, T *dx) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { LoadT dout_vec; MaskLoadT mask_vec; - platform::Load(&dout[i], &dout_vec); - platform::Load(&mask[i], &mask_vec); + phi::Load(&dout[i], &dout_vec); + phi::Load(&mask[i], &mask_vec); StoreT dx_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { dx_vec[ii] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; } - platform::Store(dx_vec, &dx[i]); + phi::Store(dx_vec, &dx[i]); } } @@ -209,9 +208,9 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, T *dbias) { int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; // calculate the dx and temporary sum @@ -221,8 +220,8 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, LoadT out_vec; MaskLoadT mask_vec; StoreT dx_vec; - platform::Load(&dout[index], &out_vec); - platform::Load(&mask[index], &mask_vec); + phi::Load(&dout[index], &out_vec); + phi::Load(&mask[index], &mask_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -230,7 +229,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, tmp_sum[i] += out_vec[i]; } - platform::Store(dx_vec, &dx[index]); + phi::Store(dx_vec, &dx[index]); } } diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index d419bd70e67db27b49d9abccd3dba3227692337a..717ec774414bf892218b6e6df73dbcd57ca3066d 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -19,9 +19,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/distribution_helper.h" #include "paddle/fluid/operators/fill_constant_op.h" -#include "paddle/fluid/operators/index_impl.cu.h" + +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" DECLARE_bool(use_curand); @@ -79,10 +80,10 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { int64_t gen_offset = size * seed_offset.second; auto func = GaussianGenerator(mean, std, seed_offset.first, seed_offset.second); - IndexKernel>(dev_cxt, tensor, func); + phi::IndexKernel>(dev_cxt, tensor, func); } else { auto func = GaussianGenerator(mean, std, seed); - IndexKernel>(dev_cxt, tensor, func); + phi::IndexKernel>(dev_cxt, tensor, func); } } }; diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index 6b778eee4345170a0288bc5741c6c1078615022f..ef836ab72f001a540e081d7e9975ca5ee28758be 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -58,7 +58,7 @@ static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y, static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; for (; offset < n; offset += stride) { - using ArrT = platform::AlignedVector<__half, VecSize>; + using ArrT = phi::AlignedVector<__half, VecSize>; ArrT in_arr = *reinterpret_cast(x + offset); #pragma unroll for (int i = 0; i < VecSize; ++i) { @@ -77,7 +77,7 @@ static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x, static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; for (; offset < n; offset += stride) { - using ArrT = platform::AlignedVector<__half, VecSize>; + using ArrT = phi::AlignedVector<__half, VecSize>; ArrT x_in_arr = *reinterpret_cast(x + offset); ArrT y_g_in_arr = *reinterpret_cast(y_g + offset); #pragma unroll @@ -103,7 +103,7 @@ static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel( #define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math) \ do { \ constexpr auto kAlignment = \ - alignof(platform::AlignedVector<__half, __vec_size>); \ + alignof(phi::AlignedVector<__half, __vec_size>); \ if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \ is_aligned(y, kAlignment)) { \ size_t thread = std::min(512, dev_ctx.GetMaxThreadsPerBlock()); \ @@ -138,7 +138,7 @@ static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel( #define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math) \ do { \ constexpr auto kAlignment = \ - alignof(platform::AlignedVector<__half, __vec_size>); \ + alignof(phi::AlignedVector<__half, __vec_size>); \ if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \ is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) && \ is_aligned(x_g, kAlignment)) { \ diff --git a/paddle/fluid/operators/index_impl.cu.h b/paddle/fluid/operators/index_impl.cu.h index 2e3e6569ef5a88f8dfcb6646974b70bcc6c0c95f..bb26e2f445e7034b8f982594216eacfd3007a24f 100644 --- a/paddle/fluid/operators/index_impl.cu.h +++ b/paddle/fluid/operators/index_impl.cu.h @@ -19,11 +19,11 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/distribution_helper.h" #include "paddle/fluid/operators/fill_constant_op.h" -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" namespace paddle { @@ -58,7 +58,7 @@ void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) { int numel = out->numel(); T *out_data = out->mutable_data(dev_ctx.GetPlace()); if (numel <= 0) return; - int vec_size = paddle::platform::GetVectorizedSize(out_data); + int vec_size = phi::GetVectorizedSize(out_data); #ifdef PADDLE_WITH_XPU_KP int block = 64; int grid = 8; diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 62c21dd2eee401e5f8a526870015c18cf13ee873..412ae3c49b5f3cc9fc2422aa220af324e6d99b69 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -22,10 +22,10 @@ limitations under the License. */ namespace cub = hipcub; #endif -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" namespace paddle { namespace operators { @@ -186,8 +186,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr, U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, T *__restrict__ y_ptr) { - using Vec = platform::AlignedVector; - using Vec_scale = platform::AlignedVector; + using Vec = phi::AlignedVector; + using Vec_scale = phi::AlignedVector; const int tidx = threadIdx.x; const int bidx = blockIdx.x; @@ -203,8 +203,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( Vec_scale beta[LDGS]; #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Load(gamma_ptr + col * VecSize, &gamma[it]); - platform::Load(beta_ptr + col * VecSize, &beta[it]); + phi::Load(gamma_ptr + col * VecSize, &gamma[it]); + phi::Load(beta_ptr + col * VecSize, &beta[it]); col += THREADS_PER_ROW; } @@ -213,8 +213,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( Vec x[LDGS]; #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, - &x[it]); + phi::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); col += THREADS_PER_ROW; } U xf[LDGS * VecSize]; @@ -276,8 +275,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - platform::Store(x[it], - y_ptr + row * LN_NUM_COLS + col * VecSize); + phi::Store(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize); col += THREADS_PER_ROW; } } @@ -401,9 +399,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( U *__restrict__ dgamma_temp_ptr, U *__restrict__ dbeta_temp_ptr, T *__restrict__ dx_ptr, const MaskType *mask_ptr = nullptr, T factor = static_cast(0), T *d_dropout_src_ptr = nullptr) { - using Vec = platform::AlignedVector; - using Vec_scale = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; + using Vec = phi::AlignedVector; + using Vec_scale = phi::AlignedVector; + using MaskLoadT = phi::AlignedVector; const int tidx = threadIdx.x; const int bidx = blockIdx.x; @@ -439,7 +437,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( int col = c; #pragma unroll for (int it = 0; it < LDGS; it++) { - platform::Load(gamma_ptr + col * VecSize, &gamma[it]); + phi::Load(gamma_ptr + col * VecSize, &gamma[it]); col += THREADS_PER_ROW; } @@ -452,12 +450,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( int col = c; #pragma unroll for (int it = 0; it < LDGS; it++) { - platform::Load(dout_ptr + row * LN_NUM_COLS + col * VecSize, - &dout[it]); - platform::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, - &x[it]); + phi::Load(dout_ptr + row * LN_NUM_COLS + col * VecSize, + &dout[it]); + phi::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); if (isFusedDropoutResidualLn) { - platform::Load( + phi::Load( mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]); } @@ -552,10 +549,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( col = c; #pragma unroll for (int it = 0; it < LDGS; it++) { - platform::Store(x[it], - dx_ptr + row * LN_NUM_COLS + col * VecSize); + phi::Store(x[it], dx_ptr + row * LN_NUM_COLS + col * VecSize); if (isFusedDropoutResidualLn) { - platform::Store( + phi::Store( dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize); } col += THREADS_PER_ROW; @@ -641,7 +637,7 @@ template < __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_, ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) { - using Vec = platform::AlignedVector; + using Vec = phi::AlignedVector; static_assert(VEC_COLS == LN_NUM_COLS / VecSize, ""); const int tidx = threadIdx.x; @@ -669,8 +665,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( for (int row = r; row < rows; row += ROWS_PER_CTA) { Vec dg; Vec db; - platform::Load(dg_part_ptr, &dg); - platform::Load(db_part_ptr, &db); + phi::Load(dg_part_ptr, &dg); + phi::Load(db_part_ptr, &db); dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; diff --git a/paddle/fluid/operators/optimizers/cast_with_ptr.h b/paddle/fluid/operators/optimizers/cast_with_ptr.h index ab8b4f2b8f4d37d4be62c5e1dd040a1461d0bdee..a3fbb0e59e24e9be67da5048ebc644f08b385bbf 100644 --- a/paddle/fluid/operators/optimizers/cast_with_ptr.h +++ b/paddle/fluid/operators/optimizers/cast_with_ptr.h @@ -57,8 +57,7 @@ static void LaunchCastKernel(const platform::CUDADeviceContext &ctx, PADDLE_ENFORCE_NE( static_cast(x), static_cast(y), platform::errors::InvalidArgument("Inplace cast is not supported yet.")); - int vec_size = - std::min(platform::GetVectorizedSize(x), platform::GetVectorizedSize(y)); + int vec_size = std::min(phi::GetVectorizedSize(x), phi::GetVectorizedSize(y)); switch (vec_size) { case 4: return details::VecCastKernel(ctx, x, y, n); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index 8bb4606ffff151c6f65606d8dce156f98589a6b4..5b60f65442b55dc89a845859f153048e89704f70 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -19,11 +19,11 @@ #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" #include "paddle/fluid/operators/tensor_to_string.h" -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #ifdef __NVCC__ #include "cub/cub.cuh" @@ -66,8 +66,8 @@ struct L2NormFunctor { int i; for (i = threadIdx.x * VecSize; i + VecSize <= size; i += (BlockDim * VecSize)) { - platform::AlignedVector tmp_vec; - platform::Load(ptr + i, &tmp_vec); + phi::AlignedVector tmp_vec; + phi::Load(ptr + i, &tmp_vec); #pragma unroll for (int j = 0; j < VecSize; ++j) { auto tmp = static_cast(tmp_vec[j]); @@ -111,9 +111,9 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) { constexpr int max_load_bits = 128; int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); auto address = reinterpret_cast(ptr); - constexpr int vec8 = alignof(platform::AlignedVector); - constexpr int vec4 = alignof(platform::AlignedVector); - constexpr int vec2 = alignof(platform::AlignedVector); + constexpr int vec8 = alignof(phi::AlignedVector); + constexpr int vec4 = alignof(phi::AlignedVector); + constexpr int vec2 = alignof(phi::AlignedVector); chunk_size *= sizeof(T); if (address % vec8 == 0 && chunk_size % vec8 == 0) { return std::min(8, valid_vec_size); @@ -316,15 +316,15 @@ static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x, int stride = blockDim.x * gridDim.x * VecSize; for (; i + VecSize <= num; i += stride) { - platform::AlignedVector x_vec; - platform::AlignedVector y_vec; + phi::AlignedVector x_vec; + phi::AlignedVector y_vec; - platform::Load(x + i, &x_vec); + phi::Load(x + i, &x_vec); #pragma unroll for (int j = 0; j < VecSize; ++j) { y_vec[j] = static_cast(static_cast(x_vec[j]) * s); } - platform::Store(y_vec, y + i); + phi::Store(y_vec, y + i); } for (; i < num; ++i) { @@ -410,24 +410,24 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( int stride = blockDim.x * gridDim.x * VecSize; for (; i + VecSize <= num; i += stride) { - platform::AlignedVector param_vec; - platform::AlignedVector grad_vec; - platform::AlignedVector mom1_vec; - platform::AlignedVector mom2_vec; - platform::AlignedVector trust_ratio_div_vec; + phi::AlignedVector param_vec; + phi::AlignedVector grad_vec; + phi::AlignedVector mom1_vec; + phi::AlignedVector mom2_vec; + phi::AlignedVector trust_ratio_div_vec; T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay; if (cur_weight_decay != static_cast(0.0)) { - platform::Load(param_p + i, ¶m_vec); + phi::Load(param_p + i, ¶m_vec); } else { #pragma unroll for (int j = 0; j < VecSize; ++j) { param_vec[j] = static_cast(0); } } - platform::Load(grad_p + i, &grad_vec); - platform::Load(mom1_p + i, &mom1_vec); - platform::Load(mom2_p + i, &mom2_vec); + phi::Load(grad_p + i, &grad_vec); + phi::Load(mom1_p + i, &mom1_vec); + phi::Load(mom2_p + i, &mom2_vec); #define PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(__param, __grad, __mom1, __mom2, \ __trust_ratio_div, __idx) \ @@ -450,9 +450,9 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( mom2_vec, trust_ratio_div_vec, j); } - platform::Store(mom1_vec, mom1_p + i); - platform::Store(mom2_vec, mom2_p + i); - platform::Store(trust_ratio_div_vec, trust_ratio_div_p + i); + phi::Store(mom1_vec, mom1_p + i); + phi::Store(mom2_vec, mom2_p + i); + phi::Store(trust_ratio_div_vec, trust_ratio_div_p + i); } for (; i < num; ++i) { @@ -632,29 +632,29 @@ struct LambUpdateParamAndBetaPowsFunctor { trust_ratio_div += offset; for (i = threadIdx.x * VecSize; i + VecSize <= size; i += stride) { - platform::AlignedVector trust_ratio_div_vec; - platform::Load(trust_ratio_div + i, &trust_ratio_div_vec); + phi::AlignedVector trust_ratio_div_vec; + phi::Load(trust_ratio_div + i, &trust_ratio_div_vec); if (HasMasterParam) { - platform::AlignedVector master_param_vec; - platform::Load(master_param + i, &master_param_vec); - platform::AlignedVector param_vec; + phi::AlignedVector master_param_vec; + phi::Load(master_param + i, &master_param_vec); + phi::AlignedVector param_vec; #pragma unroll for (int j = 0; j < VecSize; ++j) { MT p = master_param_vec[j] - ratio * trust_ratio_div_vec[j]; master_param_vec[j] = p; param_vec[j] = static_cast(p); } - platform::Store(master_param_vec, master_param + i); - platform::Store(param_vec, param + i); + phi::Store(master_param_vec, master_param + i); + phi::Store(param_vec, param + i); } else { - platform::AlignedVector param_vec; - platform::Load(param + i, ¶m_vec); + phi::AlignedVector param_vec; + phi::Load(param + i, ¶m_vec); #pragma unroll for (int j = 0; j < VecSize; ++j) { MT p = static_cast(param_vec[j]) - ratio * trust_ratio_div_vec[j]; param_vec[j] = static_cast(p); } - platform::Store(param_vec, param + i); + phi::Store(param_vec, param + i); } } diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index df5da1b79535cc6f5e4a638e9d32c367ea7cdb9f..fe5cd066864b82c734614e33869dff1734bee6d0 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -88,8 +88,8 @@ __device__ inline void VectorizeLarsUpdate( T* param_out, MT* velocity_out, const MT mu, MT local_lr, const MT lars_weight_decay, const MT rescale_grad, const int tid, const int grid_stride, const int numel, MT* master_param_out = nullptr) { - using VecType = paddle::platform::AlignedVector; - using VecMType = paddle::platform::AlignedVector; + using VecType = phi::AlignedVector; + using VecMType = phi::AlignedVector; int main = numel >> (VecSize >> 1); int tail_offset = main * VecSize; diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index a864c48ad757411861b6d2b3be40361c347601f8..b941dc21c3ab213e5abc2c4c908413b2b6222c41 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -25,8 +25,9 @@ DECLARE_bool(use_curand); #include #include #include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/operators/index_impl.cu.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" #endif namespace paddle { @@ -206,21 +207,21 @@ void UniformRandom(const framework::ExecutionContext& context, if (gen_cuda->GetIsInitPy() && seed_flag) { if (FLAGS_use_curand) { using MT = typename details::MPTypeTrait::Type; - distribution::uniform_distribution dist; - distribution::uniform_transform trans(min, max); - distribution::distribution_and_transform(dev_cxt, tensor, dist, trans); + phi::funcs::uniform_distribution dist; + phi::funcs::uniform_real_transform trans(min, max); + phi::funcs::distribution_and_transform(dev_cxt, tensor, dist, trans); } else { auto seed_offset = gen_cuda->IncrementOffset(1); int64_t gen_offset = size * seed_offset.second; auto func = UniformGeneratorOffset(min, max, seed_offset.first, diag_num, diag_step, diag_val, gen_offset); - IndexKernel>(dev_cxt, tensor, func); + phi::IndexKernel>(dev_cxt, tensor, func); } } else { auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); - IndexKernel>(dev_cxt, tensor, func); + phi::IndexKernel>(dev_cxt, tensor, func); } } #endif diff --git a/paddle/fluid/platform/fast_divmod.h b/paddle/fluid/platform/fast_divmod.h index f26c4fdd17ad7290c71eddf80874f7fa9e115e4f..39eefab774dbe84801bda98c9821d8c801e7fd25 100644 --- a/paddle/fluid/platform/fast_divmod.h +++ b/paddle/fluid/platform/fast_divmod.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #define INT_BITS 32 @@ -25,7 +25,7 @@ namespace platform { struct FastDivMod { // 1st value represents the result of input number divides by recorded divisor // 2nd value represents the result of input number modulo by recorded divisor - using DivModT = AlignedVector; + using DivModT = phi::AlignedVector; FastDivMod() {} HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) { diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index e9fd4cf47b834775c03e9b48ff1e3a5096228fb2..aab31cfbd55b64a957ca75840cc6c0bb41e3f8c0 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -493,16 +493,14 @@ void BroadcastKernelForDifferentVecSize( "%d-th output tensor`s shape is not.", i)); out_vec_size = std::min( - paddle::platform::GetVectorizedSize((*outs)[i]->data()), - out_vec_size); + phi::GetVectorizedSize((*outs)[i]->data()), out_vec_size); } } else { - out_vec_size = - paddle::platform::GetVectorizedSize((*outs)[0]->data()); + out_vec_size = phi::GetVectorizedSize((*outs)[0]->data()); } for (auto *in : ins) { - auto temp_size = paddle::platform::GetVectorizedSize(in->data()); + auto temp_size = phi::GetVectorizedSize(in->data()); in_vec_size = in->dims() == (*outs)[0]->dims() ? std::min(temp_size, in_vec_size) : in_vec_size; diff --git a/paddle/phi/kernels/funcs/distribution_helper.h b/paddle/phi/kernels/funcs/distribution_helper.h index f0793fb9d27db68f22bc2bc27978844072c61616..3ef39dc55d124b0fca30e44c2e07c5ce4c644a30 100644 --- a/paddle/phi/kernels/funcs/distribution_helper.h +++ b/paddle/phi/kernels/funcs/distribution_helper.h @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/phi/core/hostdevice.h" #if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" #endif diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 235dbdd40f6b7db5524251aec80b92cdc22aa819..332ec0b0312da96ca21b2c616440afc57a62edc2 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -23,9 +23,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" #define HOSTDEVICE __host__ __device__ @@ -546,9 +546,8 @@ struct VecSizeGetter { const ArgsT &args, int *vec_size) { using Type = std::tuple_element_t; - *vec_size = std::min( - *vec_size, - paddle::platform::GetVectorizedSize(ins[Index]->data())); + *vec_size = std::min(*vec_size, + phi::GetVectorizedSize(ins[Index]->data())); } }; @@ -563,8 +562,8 @@ int GetVectorizedSizeForTensors(const std::vector &ins, // The Arg VecSize=1 is to match the Unroller template. Unroller::step(ins, arg, &vec_size); for (auto iter = outs.begin(); iter != outs.end(); ++iter) { - vec_size = std::min( - vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); + vec_size = + std::min(vec_size, phi::GetVectorizedSize((*iter)->data())); } return vec_size; } diff --git a/paddle/phi/kernels/gpu/bernoulli_kernel.cu b/paddle/phi/kernels/gpu/bernoulli_kernel.cu index 2b6140d2fde0d3bcef3f15c4414444f1d2099b2e..79d8a7b0f3444b4272d1affd67bd5ac943f2c1cc 100644 --- a/paddle/phi/kernels/gpu/bernoulli_kernel.cu +++ b/paddle/phi/kernels/gpu/bernoulli_kernel.cu @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/bernoulli_kernel.h" + #include #include #ifdef __NVCC__ @@ -28,7 +30,6 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/bernoulli_kernel.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" // See Note [ Why still include the fluid headers? ] diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 569a46f56d5638584262c0d1c8002459fa8ffd70..542234c80b5a1e945aec7c8342d31ef9b676cce8 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -20,11 +20,11 @@ #include "paddle/phi/kernels/funcs/elementwise_base.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" namespace phi {