From 9714878cc76b6db1e1fdec2a81dabc4874f25ea6 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Thu, 7 Apr 2022 22:15:52 +0800 Subject: [PATCH] remove FLAGS_use_curand and change all random op CUDA implementation (#41308) --- paddle/fluid/operators/dropout_impl.cu.h | 151 ++++--------- paddle/fluid/operators/gaussian_random_op.cu | 7 - paddle/fluid/operators/uniform_random_op.h | 54 +---- paddle/fluid/platform/flags.cc | 2 - paddle/phi/kernels/cpu/transpose_kernel.cc | 1 + paddle/phi/kernels/gpu/bernoulli_kernel.cu | 59 +---- .../phi/kernels/gpu/gaussian_random_kernel.cu | 25 +- paddle/phi/kernels/gpu/multinomial_kernel.cu | 213 +++++------------- paddle/phi/kernels/gpu/randint_kernel.cu | 36 +-- paddle/phi/kernels/gpu/randperm_kernel.cu | 144 +++++------- .../phi/kernels/gpu/uniform_random_kernel.cu | 61 +---- paddle/scripts/paddle_build.bat | 1 - paddle/scripts/paddle_build.sh | 2 - python/paddle/fluid/initializer.py | 16 +- .../tests/unittests/test_bernoulli_op.py | 3 - .../fluid/tests/unittests/test_dropout_op.py | 3 - .../tests/unittests/test_exponential_op.py | 3 - .../unittests/test_gaussian_random_op.py | 3 - .../fluid/tests/unittests/test_linear.py | 16 ++ .../tests/unittests/test_multinomial_op.py | 3 - .../fluid/tests/unittests/test_poisson_op.py | 3 - .../fluid/tests/unittests/test_randint_op.py | 3 - .../fluid/tests/unittests/test_randperm_op.py | 3 - .../tests/unittests/test_uniform_random_op.py | 45 ++-- python/paddle/nn/utils/__init__.py | 2 +- .../paddle/nn/utils/transform_parameters.py | 33 +++ 26 files changed, 267 insertions(+), 625 deletions(-) diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 83ca9ace20d..6af8c925ff5 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -38,43 +38,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" -DECLARE_bool(use_curand); - namespace paddle { namespace operators { -template -struct DstMaskGenerator { - const float dropout_prob_; - const bool is_upscale_in_train_; - using MT = typename details::MPTypeTrait::Type; - MT factor; - HOSTDEVICE inline DstMaskGenerator(const float dropout_prob, - const bool is_upscale_in_train) - : dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) { - factor = static_cast(1.0f / (1.0f - dropout_prob_)); - } - - HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val, - const T2* rand, int num) const { - static constexpr int kCount = - phi::funcs::uniform_distribution::kReturnsCount; -// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask -#pragma unroll - for (int i = 0; i < kCount; i++) { - if (rand[i] < dropout_prob_) { - dst[i] = static_cast(0); - dst[i + kCount] = dst[i]; - } else { - dst[i] = is_upscale_in_train_ - ? static_cast(static_cast(src_val[i]) * factor) - : static_cast(src_val[i]); - dst[i + kCount] = static_cast(1); - } - } - } -}; - template struct DstMaskFunctor { const float retain_prob_; @@ -113,7 +79,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const T* src, MaskType* mask, T* dst, bool is_upscale_in_train, uint64_t increment, - size_t main_offset, bool use_curand) { + size_t main_offset) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; @@ -135,76 +101,41 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, int deal_size = BLOCK_NUM_X * kCount; size_t fix = idx * kCount; - if (use_curand) { - auto dst_functor = - DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); - for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], - deal_size); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - deal_size); - if (fix > idx * kCount + 1) { - __syncthreads(); - } - } - int remainder = n - fix; - if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], remainder); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - remainder); + + auto dst_functor = + DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], deal_size); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + deal_size); + if (fix > idx * kCount + 1) { __syncthreads(); } - } else { - auto dst_functor = - DstMaskGenerator(dropout_prob, is_upscale_in_train); - for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], - deal_size); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - deal_size); - } - int remainder = n - fix; - if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], remainder); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - remainder); - } + } + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], remainder); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + remainder); + __syncthreads(); } } @@ -251,13 +182,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, size_t grid_size = gpu_config.GetGridSize(); size_t block_size = gpu_config.GetBlockSize(); - if (FLAGS_use_curand) { - int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); - const auto& prop = platform::GetDeviceProperties(device_id); - size_t max_grid_size = prop.maxThreadsPerMultiProcessor * - prop.multiProcessorCount / block_size; - grid_size = std::min(grid_size, max_grid_size); - } + int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); + const auto& prop = platform::GetDeviceProperties(device_id); + size_t max_grid_size = prop.maxThreadsPerMultiProcessor * + prop.multiProcessorCount / block_size; + grid_size = std::min(grid_size, max_grid_size); auto offset = ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; @@ -268,7 +197,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, VectorizedRandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment, main_offset, FLAGS_use_curand); + upscale_in_train, increment, main_offset); } else { if (upscale_in_train) { // todo: can y share with data with x directly? diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 00ce10bfe3b..552649279e9 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -11,21 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include -#include #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/fill_constant_op.h" - -#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" -DECLARE_bool(use_curand); - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index b941dc21c3a..ae846f4cae6 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -19,11 +19,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #if defined(__NVCC__) || defined(__HIPCC__) -DECLARE_bool(use_curand); -#include -#include #include -#include #include "paddle/fluid/framework/generator.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" @@ -146,39 +142,6 @@ struct UniformGenerator { } }; -template -struct UniformGeneratorOffset { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - int offset_; - __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, - int diag_num, int diag_step, - T diag_val, int offset) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val), - offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - template void UniformRandom(const framework::ExecutionContext& context, framework::Tensor* tensor) { @@ -205,19 +168,10 @@ void UniformRandom(const framework::ExecutionContext& context, int device_id = context.GetPlace().GetDeviceId(); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); if (gen_cuda->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename details::MPTypeTrait::Type; - phi::funcs::uniform_distribution dist; - phi::funcs::uniform_real_transform trans(min, max); - phi::funcs::distribution_and_transform(dev_cxt, tensor, dist, trans); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = - UniformGeneratorOffset(min, max, seed_offset.first, diag_num, - diag_step, diag_val, gen_offset); - phi::IndexKernel>(dev_cxt, tensor, func); - } + using MT = typename details::MPTypeTrait::Type; + phi::funcs::uniform_distribution dist; + phi::funcs::uniform_real_transform trans(min, max); + phi::funcs::distribution_and_transform(dev_cxt, tensor, dist, trans); } else { auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 4e47c130c72..c3d3f6a4f68 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -545,8 +545,6 @@ PADDLE_DEFINE_EXPORTED_double( */ PADDLE_DEFINE_EXPORTED_bool(use_mkldnn, false, "Use MKLDNN to run"); -PADDLE_DEFINE_EXPORTED_bool(use_curand, false, "Random OP use CURAND"); - /** * Debug related FLAG * Name: FLAGS_call_stack_level diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc index a80196e7f80..5dc4866e1ef 100644 --- a/paddle/phi/kernels/cpu/transpose_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -75,6 +75,7 @@ PD_REGISTER_KERNEL(transpose, double, int32_t, int64_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/bernoulli_kernel.cu b/paddle/phi/kernels/gpu/bernoulli_kernel.cu index 79d8a7b0f34..edcf29e2d88 100644 --- a/paddle/phi/kernels/gpu/bernoulli_kernel.cu +++ b/paddle/phi/kernels/gpu/bernoulli_kernel.cu @@ -14,8 +14,6 @@ #include "paddle/phi/kernels/bernoulli_kernel.h" -#include -#include #ifdef __NVCC__ #include #endif @@ -32,35 +30,8 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/transform.h" - -DECLARE_bool(use_curand); - namespace phi { -template -struct BernoulliCudaFunctor { - unsigned int seed_; - unsigned int offset_; - __host__ __device__ BernoulliCudaFunctor(unsigned int seed, - unsigned int offset) - : seed_(seed), offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n, const T p) const { - // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several - // lines of error messages if, and it should be refined. - PADDLE_ENFORCE(p >= 0.0 && p <= 1.0, - "The probability should be >=0 and <= 1, but got %f", - p); - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n + offset_); - return static_cast(dist(rng) < p); - } -}; - // 'curand_uniform4/hiprand_uniform4' generate 4 random number each time template __global__ void bernoulli_cuda_kernel( @@ -100,30 +71,16 @@ void BernoulliKernel(const Context& ctx, auto gen_cuda = ctx.GetGenerator(); - if (FLAGS_use_curand) { - auto seed_offset = gen_cuda->IncrementOffset(12); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; + auto seed_offset = gen_cuda->IncrementOffset(12); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; - auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4); - size_t grid_size = gpu_config.GetGridSize(); - size_t block_size = gpu_config.GetBlockSize(); + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4); + size_t grid_size = gpu_config.GetGridSize(); + size_t block_size = gpu_config.GetBlockSize(); - bernoulli_cuda_kernel<<>>( - numel, seed, offset, x_data, out_data); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = numel * seed_offset.second; - paddle::platform::Transform trans; - thrust::counting_iterator index_sequence_begin(0); - trans(ctx, - index_sequence_begin, - index_sequence_begin + numel, - x_data, - out_data, - BernoulliCudaFunctor(static_cast(seed_offset.first), - static_cast(gen_offset))); - } + bernoulli_cuda_kernel<<>>( + numel, seed, offset, x_data, out_data); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu index e159e5916cf..96ebc0353ef 100644 --- a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu @@ -14,10 +14,7 @@ #include "paddle/phi/kernels/gaussian_random_kernel.h" -#include -#include #include -#include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" @@ -27,8 +24,6 @@ #include "paddle/fluid/framework/generator.h" -DECLARE_bool(use_curand); - namespace phi { template @@ -83,21 +78,11 @@ void GaussianRandomKernel(const Context& dev_ctx, auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); if (gen_cuda->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename phi::dtype::MPTypeTrait::Type; - funcs::normal_distribution dist; - funcs::normal_transform trans(static_cast(mean), - static_cast(std)); - funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = GaussianGenerator(static_cast(mean), - static_cast(std), - seed_offset.first, - gen_offset); - IndexKernel>(dev_ctx, tensor, func); - } + using MT = typename phi::dtype::MPTypeTrait::Type; + funcs::normal_distribution dist; + funcs::normal_transform trans(static_cast(mean), + static_cast(std)); + funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); } else { auto func = GaussianGenerator(static_cast(mean), static_cast(std), seed); diff --git a/paddle/phi/kernels/gpu/multinomial_kernel.cu b/paddle/phi/kernels/gpu/multinomial_kernel.cu index ee5f843b18a..ef6cd1323a9 100644 --- a/paddle/phi/kernels/gpu/multinomial_kernel.cu +++ b/paddle/phi/kernels/gpu/multinomial_kernel.cu @@ -18,11 +18,6 @@ limitations under the License. */ #include "paddle/phi/kernels/multinomial_kernel.h" -#include -#include -#include -#include - #ifdef __NVCC__ #include "cub/cub.cuh" #endif @@ -44,12 +39,6 @@ namespace cub = hipcub; #include "paddle/phi/kernels/funcs/multinomial_functor.h" #include "paddle/phi/kernels/top_k_kernel.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/transform.h" - -DECLARE_bool(use_curand); - namespace phi { template @@ -74,32 +63,6 @@ __global__ void NormalizeProbability(T* norm_probs, } } -template -__global__ void GetCumulativeProbs(T* norm_probs_data, - int64_t num_distributions, - int64_t num_categories, - T* cumulative_probs_data) { - int id = blockIdx.x; - thrust::inclusive_scan(thrust::device, - norm_probs_data + id * num_categories, - norm_probs_data + (id + 1) * num_categories, - cumulative_probs_data + id * num_categories); -} - -template -struct RandomGeneratorCudaFunctor { - unsigned int seed_; - __host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n); - return dist(rng); - } -}; - template __device__ int binarySearchFunctor(T* cumulative_probs_data, T* norm_probs_data, @@ -130,7 +93,6 @@ __device__ int binarySearchFunctor(T* cumulative_probs_data, template __global__ void sampleMultinomialWithReplacement( - T* rng_data, const int64_t num_samples, int64_t* out_data, const int64_t num_distributions, @@ -138,10 +100,9 @@ __global__ void sampleMultinomialWithReplacement( T* cumulative_probs_data, T* norm_probs_data, uint64_t seed, - uint64_t offset, - bool use_curand) { + uint64_t offset) { // use binary search to get the selected category sample id. - // let cumulative_probs_data[id-1] < rng_data < cumulative_probs_data[id]. + // let cumulative_probs_data[id-1] < rng_number < cumulative_probs_data[id]. size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x + threadIdx.x; @@ -151,10 +112,7 @@ __global__ void sampleMultinomialWithReplacement( int sample = blockIdx.x * blockDim.x + threadIdx.x; for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { if (sample < num_samples) { - T rng_number = rng_data[sample + dist * num_samples]; - if (use_curand) { - rng_number = static_cast(curand_uniform4(&state).x); - } + T rng_number = static_cast(curand_uniform4(&state).x); // Find the bucket that a uniform random number lies in int selected_category = binarySearchFunctor(cumulative_probs_data + dist * num_categories, @@ -182,10 +140,7 @@ void MultinomialKernel(const Context& dev_ctx, const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; // If replacement is False, it's not a replaceable sample. Every category - // can - // be used only once. So after every sample, probability of the distribution - // will change. The implementation can't be parallelizable. Thus, call CPU - // implementation ``funcs::MultinomialFunctor`` to sample the distribution. + // can be used only once. if (!replacement) { int64_t in_data_numel = x.numel(); int64_t out_data_numel = out->numel(); @@ -202,76 +157,50 @@ void MultinomialKernel(const Context& dev_ctx, in_data_numel * sizeof(T), cudaMemcpyDeviceToHost); #endif - if (FLAGS_use_curand) { - for (size_t i = 0; i < num_distributions; ++i) { - int zero_num = 0; - for (size_t j = 0; j < num_categories; ++j) { - T weight = cpu_in_data[i * num_distributions + j]; - PADDLE_ENFORCE_GE( - weight, - 0, - errors::InvalidArgument( - "Each element of multinomial'input must >= 0, but got %f.", - weight)); - if (weight == static_cast(0)) { - zero_num++; - } + for (size_t i = 0; i < num_distributions; ++i) { + int zero_num = 0; + for (size_t j = 0; j < num_categories; ++j) { + T weight = cpu_in_data[i * num_distributions + j]; + PADDLE_ENFORCE_GE( + weight, + 0, + errors::InvalidArgument( + "Each element of multinomial'input must >= 0, but got %f.", + weight)); + if (weight == static_cast(0)) { + zero_num++; } - int valid_samples = num_categories - zero_num; - PADDLE_ENFORCE_LE( - num_samples, - valid_samples, - errors::InvalidArgument("When replacement=False, 'num_samples' " - "must less than or eaqual to the number of " - "positive item of input")); } - - // Refer to [gumbel softmax algorithm] - DenseTensor rand = EmptyLike(dev_ctx, x); - T* rand_data = rand.data(); - funcs::uniform_distribution dist; - funcs::exponential_transform trans(1.0); - funcs::distribution_and_transform(dev_ctx, &rand, dist, trans); - - funcs::ForRange for_range(dev_ctx, x.numel()); - for_range([rand_data, in_data] __device__(size_t idx) { - rand_data[idx] = in_data[idx] / rand_data[idx]; - }); - - if (num_samples == 1) { - ArgMaxKernel( - dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); - } else { - std::vector out_dim_vec = vectorize(out->dims()); - DenseTensor value = Empty(dev_ctx, IntArray(out_dim_vec)); - TopkKernel( - dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out); - } - return; + int valid_samples = num_categories - zero_num; + PADDLE_ENFORCE_LE( + num_samples, + valid_samples, + errors::InvalidArgument("When replacement=False, 'num_samples' " + "must less than or eaqual to the number of " + "positive item of input")); } - funcs::MultinomialFunctor(dev_ctx, - cpu_out_data, - cpu_in_data, - num_samples, - replacement, - num_categories, - num_distributions); - -#ifdef PADDLE_WITH_HIP - hipMemcpy(out_data, - cpu_out_data, - out_data_numel * sizeof(int64_t), - hipMemcpyHostToDevice); -#else - cudaMemcpy(out_data, - cpu_out_data, - out_data_numel * sizeof(int64_t), - cudaMemcpyHostToDevice); -#endif - - delete[] cpu_in_data; - delete[] cpu_out_data; + // Refer to [gumbel softmax algorithm] + DenseTensor rand = EmptyLike(dev_ctx, x); + T* rand_data = rand.data(); + funcs::uniform_distribution dist; + funcs::exponential_transform trans(1.0); + funcs::distribution_and_transform(dev_ctx, &rand, dist, trans); + + funcs::ForRange for_range(dev_ctx, x.numel()); + for_range([rand_data, in_data] __device__(size_t idx) { + rand_data[idx] = in_data[idx] / rand_data[idx]; + }); + + if (num_samples == 1) { + ArgMaxKernel( + dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); + } else { + std::vector out_dim_vec = vectorize(out->dims()); + DenseTensor value = Empty(dev_ctx, IntArray(out_dim_vec)); + TopkKernel( + dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out); + } return; } @@ -322,44 +251,18 @@ void MultinomialKernel(const Context& dev_ctx, auto* cumulative_probs_data = dev_ctx.template Alloc(&cumulative_probs_tensor); - if (FLAGS_use_curand) { - // 'phi::funcs::InclusiveScan' has higher accuracy than - // 'thrust::inclusive_scan' - funcs::InclusiveScan>( - /*in*/ norm_probs_data, - /*out*/ cumulative_probs_data, - /*outer_dim*/ static_cast(num_distributions), - /*mid_dim*/ static_cast(num_categories), - /*inner_dim*/ static_cast(1), - /*init*/ static_cast(0), - std::plus(), - /*reverse=*/false, - dev_ctx); - } else { - dim3 block_cumsum(1); - dim3 grid_cumsum(num_distributions); - GetCumulativeProbs<<>>( - norm_probs_data, - num_distributions, - num_categories, - cumulative_probs_data); - } - - // Generate random number for each sample. - std::random_device rd; - auto seed = rd(); - - DenseTensor rng_data_tensor; - rng_data_tensor.Resize({num_distributions, num_samples}); - auto* rng_data = dev_ctx.template Alloc(&rng_data_tensor); - - thrust::counting_iterator index_sequence_begin(0); - paddle::platform::Transform trans; - trans(dev_ctx, - index_sequence_begin, - index_sequence_begin + num_distributions * num_samples, - rng_data, - RandomGeneratorCudaFunctor(seed)); + // 'phi::funcs::InclusiveScan' has higher accuracy than + // 'thrust::inclusive_scan' + funcs::InclusiveScan>( + /*in*/ norm_probs_data, + /*out*/ cumulative_probs_data, + /*outer_dim*/ static_cast(num_distributions), + /*mid_dim*/ static_cast(num_categories), + /*inner_dim*/ static_cast(1), + /*init*/ static_cast(0), + std::plus(), + /*reverse=*/false, + dev_ctx); // Sample the multinomial distributions. dim3 block(128); @@ -376,7 +279,6 @@ void MultinomialKernel(const Context& dev_ctx, auto seed_offset = gen_cuda->IncrementOffset(increment); sampleMultinomialWithReplacement<<>>( - rng_data, num_samples, out_data, num_distributions, @@ -384,8 +286,7 @@ void MultinomialKernel(const Context& dev_ctx, cumulative_probs_data, norm_probs_data, seed_offset.first, - seed_offset.second, - FLAGS_use_curand); + seed_offset.second); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/randint_kernel.cu b/paddle/phi/kernels/gpu/randint_kernel.cu index 01885050022..90eaea6a086 100644 --- a/paddle/phi/kernels/gpu/randint_kernel.cu +++ b/paddle/phi/kernels/gpu/randint_kernel.cu @@ -23,8 +23,6 @@ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h" -DECLARE_bool(use_curand); - namespace phi { template @@ -37,37 +35,9 @@ void RandintRawKernel(const Context& dev_ctx, DenseTensor* out) { out->Resize(phi::make_ddim(shape.GetData())); T* data = dev_ctx.template Alloc(out); - if (FLAGS_use_curand) { - funcs::uniform_distribution dist; - funcs::uniform_int_transform trans(low, high); - funcs::distribution_and_transform(dev_ctx, out, dist, trans); - } else { - DenseTensor tmp; - tmp.Resize(phi::make_ddim(shape.GetData())); - T* tmp_data = dev_ctx.template HostAlloc(&tmp); - - std::shared_ptr engine; - if (seed) { - engine = std::make_shared(); - engine->seed(seed); - } else { - engine = dev_ctx.GetHostGenerator()->GetCPUEngine(); - } - - std::uniform_int_distribution dist(low, high - 1); - auto numel = out->numel(); - for (int64_t i = 0; i < numel; ++i) { - tmp_data[i] = dist(*engine); - } - - paddle::memory::Copy( - out->place(), - data, - tmp.place(), - tmp_data, - numel * paddle::experimental::SizeOf(out->dtype()), - 0); - } + funcs::uniform_distribution dist; + funcs::uniform_int_transform trans(low, high); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); } template diff --git a/paddle/phi/kernels/gpu/randperm_kernel.cu b/paddle/phi/kernels/gpu/randperm_kernel.cu index 678b580beca..4e488ed470d 100644 --- a/paddle/phi/kernels/gpu/randperm_kernel.cu +++ b/paddle/phi/kernels/gpu/randperm_kernel.cu @@ -84,91 +84,65 @@ __global__ void SwapRepeatKernel( template void RandpermRawKernel( const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) { - if (FLAGS_use_curand) { - DenseTensor key; - RandintKernel(dev_ctx, - std::numeric_limits::min(), - std::numeric_limits::max(), - IntArray({n}), - phi::DataType::INT32, - &key); - DenseTensor key_out = Empty(dev_ctx, IntArray({n})); - - DenseTensor range = Empty(dev_ctx, IntArray({n})); - T* range_data = range.data(); - funcs::ForRange for_range(dev_ctx, n); - for_range([range_data] __device__(size_t idx) { - range_data[idx] = static_cast(idx); - }); - - out->Resize(phi::make_ddim({n})); - T* out_data = dev_ctx.template Alloc(out); - - // Refer to [Algorithm of randperm] https://osf.io/af2hy/ to - // improve performance of radix sort. - double n_d = static_cast(n); - int begin_bit = 0; - int end_bit = - std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9)))); - - size_t temp_storage_bytes = 0; - cub::DeviceRadixSort::SortPairs(nullptr, - temp_storage_bytes, - key.data(), - key_out.data(), - range.data(), - out_data, - n, - begin_bit, - end_bit < 32 ? end_bit : 32, - dev_ctx.stream()); - - auto d_temp_storage = paddle::memory::Alloc(dev_ctx, temp_storage_bytes); - cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), - temp_storage_bytes, - key.data(), - key_out.data(), - range.data(), - out_data, - n, - begin_bit, - end_bit < 32 ? end_bit : 32, - dev_ctx.stream()); - - auto gen_cuda = dev_ctx.GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(n); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; - - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); - SwapRepeatKernel<<>>( - key_out.data(), out_data, n, seed, offset); - } else { - DenseTensor tmp; - tmp.Resize(phi::make_ddim({n})); - T* tmp_data = dev_ctx.template HostAlloc(&tmp); - - std::shared_ptr engine; - if (seed) { - engine = std::make_shared(); - engine->seed(seed); - } else { - engine = dev_ctx.GetHostGenerator()->GetCPUEngine(); - } - - for (int i = 0; i < n; ++i) { - tmp_data[i] = static_cast(i); - } - std::shuffle(tmp_data, tmp_data + n, *engine); - - T* out_data = dev_ctx.template Alloc(out); - auto size = out->numel() * paddle::experimental::SizeOf(out->dtype()); - paddle::memory::Copy( - out->place(), out_data, tmp.place(), tmp_data, size, 0); - } + DenseTensor key; + RandintKernel(dev_ctx, + std::numeric_limits::min(), + std::numeric_limits::max(), + IntArray({n}), + phi::DataType::INT32, + &key); + DenseTensor key_out = Empty(dev_ctx, IntArray({n})); + + DenseTensor range = Empty(dev_ctx, IntArray({n})); + T* range_data = range.data(); + funcs::ForRange for_range(dev_ctx, n); + for_range([range_data] __device__(size_t idx) { + range_data[idx] = static_cast(idx); + }); + + out->Resize(phi::make_ddim({n})); + T* out_data = dev_ctx.template Alloc(out); + + // Refer to [Algorithm of randperm] https://osf.io/af2hy/ to + // improve performance of radix sort. + double n_d = static_cast(n); + int begin_bit = 0; + int end_bit = + std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9)))); + + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(nullptr, + temp_storage_bytes, + key.data(), + key_out.data(), + range.data(), + out_data, + n, + begin_bit, + end_bit < 32 ? end_bit : 32, + dev_ctx.stream()); + + auto d_temp_storage = paddle::memory::Alloc(dev_ctx, temp_storage_bytes); + cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), + temp_storage_bytes, + key.data(), + key_out.data(), + range.data(), + out_data, + n, + begin_bit, + end_bit < 32 ? end_bit : 32, + dev_ctx.stream()); + + auto gen_cuda = dev_ctx.GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(n); + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); + SwapRepeatKernel<<>>( + key_out.data(), out_data, n, seed_offset.first, seed_offset.second); } template diff --git a/paddle/phi/kernels/gpu/uniform_random_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_kernel.cu index 2cabde0bbf9..a4aea10cfe7 100644 --- a/paddle/phi/kernels/gpu/uniform_random_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_random_kernel.cu @@ -14,14 +14,13 @@ #include "paddle/phi/kernels/uniform_random_kernel.h" +#include #include "gflags/gflags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" -DECLARE_bool(use_curand); - namespace phi { template @@ -54,43 +53,6 @@ struct UniformGenerator { } }; -template -struct UniformGeneratorOffset { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - int offset_; - __host__ __device__ UniformGeneratorOffset(T min, - T max, - int seed, - int diag_num, - int diag_step, - T diag_val, - int offset) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val), - offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - template void UniformRandomRawKernel(const Context& dev_ctx, const IntArray& shape, @@ -114,23 +76,10 @@ void UniformRandomRawKernel(const Context& dev_ctx, auto generator = dev_ctx.GetGenerator(); if (generator->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename kps::details::MPTypeTrait::Type; - funcs::uniform_distribution dist; - funcs::uniform_real_transform trans(min, max); - funcs::distribution_and_transform(dev_ctx, out, dist, trans); - } else { - auto seed_offset = generator->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = UniformGeneratorOffset(min, - max, - seed_offset.first, - diag_num, - diag_step, - diag_val, - gen_offset); - IndexKernel>(dev_ctx, out, func); - } + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(min, max); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); } else { auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index cc55ea82df6..21df60e9721 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -657,7 +657,6 @@ for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%# set start=%start:~4,10% set FLAGS_call_stack_level=2 -set FLAGS_use_curand=True dir %THIRD_PARTY_PATH:/=\%\install\openblas\lib dir %THIRD_PARTY_PATH:/=\%\install\openblas\bin dir %THIRD_PARTY_PATH:/=\%\install\zlib\bin diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d1220e45375..e8bde467e08 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -61,8 +61,6 @@ function init() { # NOTE(chenweihang): For easy debugging, CI displays the C++ error stacktrace by default export FLAGS_call_stack_level=2 - export FLAGS_use_curand=True - # set CI_SKIP_CPP_TEST if only *.py changed # In order to avoid using in some CI(such as daily performance), the current # branch must not be `${BRANCH}` which is usually develop. diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 37eff6d132d..b3baedc4015 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -561,12 +561,12 @@ class XavierInitializer(Initializer): if framework._non_static_mode(): if self._uniform: - limit = np.sqrt(6.0 / float(fan_in + fan_out)) + limit = math.sqrt(6.0 / float(fan_in + fan_out)) out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', -limit, 'max', limit, 'seed', self._seed, 'dtype', out_dtype) else: - std = np.sqrt(2.0 / float(fan_in + fan_out)) + std = math.sqrt(2.0 / float(fan_in + fan_out)) out_var = _C_ops.gaussian_random( 'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0, 'std', std, 'seed', self._seed) @@ -581,7 +581,7 @@ class XavierInitializer(Initializer): return None else: if self._uniform: - limit = np.sqrt(6.0 / float(fan_in + fan_out)) + limit = math.sqrt(6.0 / float(fan_in + fan_out)) op = block.append_op( type="uniform_random", inputs={}, @@ -595,7 +595,7 @@ class XavierInitializer(Initializer): }, stop_gradient=True) else: - std = np.sqrt(2.0 / float(fan_in + fan_out)) + std = math.sqrt(2.0 / float(fan_in + fan_out)) op = block.append_op( type="gaussian_random", outputs={"Out": out_var}, @@ -713,13 +713,13 @@ class MSRAInitializer(Initializer): if framework._non_static_mode(): if self._uniform: - limit = np.sqrt(6.0 / float(fan_in)) + limit = math.sqrt(6.0 / float(fan_in)) out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', -limit, 'max', limit, 'seed', self._seed, 'dtype', int(out_dtype)) else: - std = np.sqrt(2.0 / float(fan_in)) + std = math.sqrt(2.0 / float(fan_in)) out_var = _C_ops.gaussian_random( 'shape', out_var.shape, 'dtype', int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed) @@ -734,7 +734,7 @@ class MSRAInitializer(Initializer): return None else: if self._uniform: - limit = np.sqrt(6.0 / float(fan_in)) + limit = math.sqrt(6.0 / float(fan_in)) op = block.append_op( type="uniform_random", inputs={}, @@ -749,7 +749,7 @@ class MSRAInitializer(Initializer): stop_gradient=True) else: - std = np.sqrt(2.0 / float(fan_in)) + std = math.sqrt(2.0 / float(fan_in)) op = block.append_op( type="gaussian_random", outputs={"Out": out_var}, diff --git a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py index 426d5d463f4..fc4ee13384b 100644 --- a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py +++ b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py @@ -75,9 +75,6 @@ class TestRandomValue(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index d8a4eb8f45f..3aca428ac77 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -1034,9 +1034,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_exponential_op.py b/python/paddle/fluid/tests/unittests/test_exponential_op.py index 7a3ae203be6..c8f4101ea5d 100644 --- a/python/paddle/fluid/tests/unittests/test_exponential_op.py +++ b/python/paddle/fluid/tests/unittests/test_exponential_op.py @@ -100,9 +100,6 @@ class TestExponentialAPI(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 738441a46d3..4fca8b9f2a1 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -342,9 +342,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - def _check_random_value(dtype, expect, expect_mean, expect_std): x = paddle.randn([32, 3, 1024, 1024], dtype=dtype) actual = x.numpy() diff --git a/python/paddle/fluid/tests/unittests/test_linear.py b/python/paddle/fluid/tests/unittests/test_linear.py index 9d07a80da15..6b00a86e3e9 100644 --- a/python/paddle/fluid/tests/unittests/test_linear.py +++ b/python/paddle/fluid/tests/unittests/test_linear.py @@ -73,6 +73,22 @@ class LinearTestCase(unittest.TestCase): np.testing.assert_array_almost_equal(res_f, res_nn) np.testing.assert_array_almost_equal(res_nn, res_np) + def test_weight_init(self): + if not paddle.is_compiled_with_cuda(): + return + paddle.seed(100) + linear = paddle.nn.Linear( + 2, 3, weight_attr=paddle.nn.initializer.Normal(0, 1.)) + paddle.nn.utils._stride_column(linear.weight) + expect = [[1.4349908, -0.8099171, -2.64788], + [-1.4981681, -1.1784115, -0.023253186]] + self.assertTrue(np.allclose(linear.weight.numpy(), expect)) + + linear = paddle.nn.Linear(2, 3) + expect = [[0.73261100, 0.43836895, 0.07908206], + [0.85075015, -1.04724526, 0.64371765]] + self.assertTrue(np.allclose(linear.weight.numpy(), expect)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py index a65a1c7e14c..ecde527523d 100644 --- a/python/paddle/fluid/tests/unittests/test_multinomial_op.py +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -227,9 +227,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index 2123d4e0e7e..f8183bb5f8d 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -107,9 +107,6 @@ class TestPoissonAPI(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_randint_op.py b/python/paddle/fluid/tests/unittests/test_randint_op.py index 1eb99e08bb8..361f4d280f7 100644 --- a/python/paddle/fluid/tests/unittests/test_randint_op.py +++ b/python/paddle/fluid/tests/unittests/test_randint_op.py @@ -198,9 +198,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_randperm_op.py b/python/paddle/fluid/tests/unittests/test_randperm_op.py index 5c9ab36fa34..deb0a9a0821 100644 --- a/python/paddle/fluid/tests/unittests/test_randperm_op.py +++ b/python/paddle/fluid/tests/unittests/test_randperm_op.py @@ -155,9 +155,6 @@ class TestRandomValue(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 41b6ed36d65..683cc2fdf86 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -573,37 +573,46 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - - def _check_random_value(dtype, expect, expect_mean, expect_std): - x = paddle.rand([32, 3, 1024, 1024], dtype=dtype) - actual = x.numpy() - self.assertTrue(np.allclose(actual[2, 1, 512, 1000:1010], expect)) - self.assertEqual(np.mean(actual), expect_mean) - self.assertEqual(np.std(actual), expect_std) - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() + paddle.set_device('gpu') paddle.seed(2021) + + expect_mean = 0.50000454338820143895816272561205551028251647949218750 + expect_std = 0.28867379167297479991560749112977646291255950927734375 expect = [ 0.55298901, 0.65184678, 0.49375412, 0.57943639, 0.16459608, 0.67181056, 0.03021481, 0.0238559, 0.07742096, 0.55972187 ] - expect_mean = 0.50000454338820143895816272561205551028251647949218750 - expect_std = 0.28867379167297479991560749112977646291255950927734375 - _check_random_value(core.VarDesc.VarType.FP64, expect, expect_mean, - expect_std) + out = paddle.rand([32, 3, 1024, 1024], dtype='float64').numpy() + self.assertEqual(np.mean(out), expect_mean) + self.assertEqual(np.std(out), expect_std) + self.assertTrue(np.allclose(out[2, 1, 512, 1000:1010], expect)) + expect_mean = 0.50002604722976684570312500 + expect_std = 0.2886914908885955810546875 expect = [ 0.45320973, 0.17582087, 0.725341, 0.30849215, 0.622257, 0.46352342, 0.97228295, 0.12771158, 0.286525, 0.9810645 ] - expect_mean = 0.50002604722976684570312500 - expect_std = 0.2886914908885955810546875 - _check_random_value(core.VarDesc.VarType.FP32, expect, expect_mean, - expect_std) + out = paddle.rand([32, 3, 1024, 1024], dtype='float32').numpy() + self.assertEqual(np.mean(out), expect_mean) + self.assertEqual(np.std(out), expect_std) + self.assertTrue(np.allclose(out[2, 1, 512, 1000:1010], expect)) + + expect_mean = 25.11843109130859375 + expect_std = 43.370647430419921875 + expect = [ + 30.089634, 77.05225, 3.1201615, 68.34072, 59.266724, -25.33281, + 12.973292, 27.41127, -17.412298, 27.931019 + ] + out = paddle.empty( + [16, 16, 16, 16], dtype='float32').uniform_(-50, 100).numpy() + self.assertEqual(np.mean(out), expect_mean) + self.assertEqual(np.std(out), expect_std) + self.assertTrue(np.allclose(out[10, 10, 10, 0:10], expect)) + paddle.enable_static() diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py index 8f9b55d15ca..8ec4e8cfd60 100644 --- a/python/paddle/nn/utils/__init__.py +++ b/python/paddle/nn/utils/__init__.py @@ -14,7 +14,7 @@ from .spectral_norm_hook import spectral_norm from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401 -from .transform_parameters import parameters_to_vector, vector_to_parameters # noqa: F401 +from .transform_parameters import parameters_to_vector, vector_to_parameters, _stride_column # noqa: F401 __all__ = [ #noqa 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters' diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py index 99870ce29a1..feb70e02d59 100644 --- a/python/paddle/nn/utils/transform_parameters.py +++ b/python/paddle/nn/utils/transform_parameters.py @@ -36,6 +36,39 @@ def _inplace_reshape_dygraph(x, shape): stop_gradient=True) +@dygraph_only +def _stride_column(param): + """ + A tool function. Permute date of parameter as a 'columns' stride. Now, it only support 2-D parameter. + + Args: + param(Tensor]): The param that will be strided according to 'columns'. + + Examples: + .. code-block:: python + + import paddle + paddle.seed(100) + + linear = paddle.nn.Linear(2, 3) + print(linear.weight) + # [[-0.31485492, -1.02896988, 0.45741916], + # [-0.65525872, -1.04643178, 1.07262802]] + + paddle.nn.utils.stride_column(linear.weight) + print(linear.weight) + # [[-0.31485492, 0.45741916, -1.04643178], + # [-1.02896988, -0.65525872, 1.07262802]] + + """ + assert len(param.shape) == 2 + shape = [param.shape[1], param.shape[0]] + with paddle.fluid.dygraph.no_grad(): + reshape_var = paddle.reshape(param, shape) + transpose_var = paddle.transpose(reshape_var, [1, 0]) + transpose_var._share_underline_tensor_to(param) + + @dygraph_only def parameters_to_vector(parameters, name=None): """ -- GitLab