diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 83ca9ace20d0541577cab52befa9d359c3f89d21..6af8c925ff580292438159e52eff884d7ac10232 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -38,43 +38,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" -DECLARE_bool(use_curand); - namespace paddle { namespace operators { -template -struct DstMaskGenerator { - const float dropout_prob_; - const bool is_upscale_in_train_; - using MT = typename details::MPTypeTrait::Type; - MT factor; - HOSTDEVICE inline DstMaskGenerator(const float dropout_prob, - const bool is_upscale_in_train) - : dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) { - factor = static_cast(1.0f / (1.0f - dropout_prob_)); - } - - HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val, - const T2* rand, int num) const { - static constexpr int kCount = - phi::funcs::uniform_distribution::kReturnsCount; -// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask -#pragma unroll - for (int i = 0; i < kCount; i++) { - if (rand[i] < dropout_prob_) { - dst[i] = static_cast(0); - dst[i + kCount] = dst[i]; - } else { - dst[i] = is_upscale_in_train_ - ? static_cast(static_cast(src_val[i]) * factor) - : static_cast(src_val[i]); - dst[i + kCount] = static_cast(1); - } - } - } -}; - template struct DstMaskFunctor { const float retain_prob_; @@ -113,7 +79,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const T* src, MaskType* mask, T* dst, bool is_upscale_in_train, uint64_t increment, - size_t main_offset, bool use_curand) { + size_t main_offset) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; @@ -135,76 +101,41 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, int deal_size = BLOCK_NUM_X * kCount; size_t fix = idx * kCount; - if (use_curand) { - auto dst_functor = - DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); - for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], - deal_size); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - deal_size); - if (fix > idx * kCount + 1) { - __syncthreads(); - } - } - int remainder = n - fix; - if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], remainder); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - remainder); + + auto dst_functor = + DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], deal_size); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + deal_size); + if (fix > idx * kCount + 1) { __syncthreads(); } - } else { - auto dst_functor = - DstMaskGenerator(dropout_prob, is_upscale_in_train); - for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], - deal_size); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - deal_size); - } - int remainder = n - fix; - if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], remainder); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - remainder); - } + } + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], remainder); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + remainder); + __syncthreads(); } } @@ -251,13 +182,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, size_t grid_size = gpu_config.GetGridSize(); size_t block_size = gpu_config.GetBlockSize(); - if (FLAGS_use_curand) { - int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); - const auto& prop = platform::GetDeviceProperties(device_id); - size_t max_grid_size = prop.maxThreadsPerMultiProcessor * - prop.multiProcessorCount / block_size; - grid_size = std::min(grid_size, max_grid_size); - } + int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); + const auto& prop = platform::GetDeviceProperties(device_id); + size_t max_grid_size = prop.maxThreadsPerMultiProcessor * + prop.multiProcessorCount / block_size; + grid_size = std::min(grid_size, max_grid_size); auto offset = ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; @@ -268,7 +197,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, VectorizedRandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment, main_offset, FLAGS_use_curand); + upscale_in_train, increment, main_offset); } else { if (upscale_in_train) { // todo: can y share with data with x directly? diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 00ce10bfe3bccb404bce9f681ee3c7030e0fa4c4..552649279e9118372faa56b931fe8196c31c03d3 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -11,21 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include -#include #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/fill_constant_op.h" - -#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" -DECLARE_bool(use_curand); - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index b941dc21c3ab213e5abc2c4c908413b2b6222c41..ae846f4cae6fba7314b2c046e01bfc69220349af 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -19,11 +19,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #if defined(__NVCC__) || defined(__HIPCC__) -DECLARE_bool(use_curand); -#include -#include #include -#include #include "paddle/fluid/framework/generator.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" @@ -146,39 +142,6 @@ struct UniformGenerator { } }; -template -struct UniformGeneratorOffset { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - int offset_; - __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, - int diag_num, int diag_step, - T diag_val, int offset) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val), - offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - template void UniformRandom(const framework::ExecutionContext& context, framework::Tensor* tensor) { @@ -205,19 +168,10 @@ void UniformRandom(const framework::ExecutionContext& context, int device_id = context.GetPlace().GetDeviceId(); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); if (gen_cuda->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename details::MPTypeTrait::Type; - phi::funcs::uniform_distribution dist; - phi::funcs::uniform_real_transform trans(min, max); - phi::funcs::distribution_and_transform(dev_cxt, tensor, dist, trans); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = - UniformGeneratorOffset(min, max, seed_offset.first, diag_num, - diag_step, diag_val, gen_offset); - phi::IndexKernel>(dev_cxt, tensor, func); - } + using MT = typename details::MPTypeTrait::Type; + phi::funcs::uniform_distribution dist; + phi::funcs::uniform_real_transform trans(min, max); + phi::funcs::distribution_and_transform(dev_cxt, tensor, dist, trans); } else { auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 4e47c130c7252f6f7f6c8f9e0e993022e99d7686..c3d3f6a4f6893e9bbf49adefe54ea21f9159222f 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -545,8 +545,6 @@ PADDLE_DEFINE_EXPORTED_double( */ PADDLE_DEFINE_EXPORTED_bool(use_mkldnn, false, "Use MKLDNN to run"); -PADDLE_DEFINE_EXPORTED_bool(use_curand, false, "Random OP use CURAND"); - /** * Debug related FLAG * Name: FLAGS_call_stack_level diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc index a80196e7f80e1b68d55265a1d5061887f12ab6bb..5dc4866e1efc33bcd9a680dfe2eb2804e28a7588 100644 --- a/paddle/phi/kernels/cpu/transpose_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -75,6 +75,7 @@ PD_REGISTER_KERNEL(transpose, double, int32_t, int64_t, + phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/bernoulli_kernel.cu b/paddle/phi/kernels/gpu/bernoulli_kernel.cu index 79d8a7b0f3444b4272d1affd67bd5ac943f2c1cc..edcf29e2d88d387b4d1658760fc60d5ceb2954b0 100644 --- a/paddle/phi/kernels/gpu/bernoulli_kernel.cu +++ b/paddle/phi/kernels/gpu/bernoulli_kernel.cu @@ -14,8 +14,6 @@ #include "paddle/phi/kernels/bernoulli_kernel.h" -#include -#include #ifdef __NVCC__ #include #endif @@ -32,35 +30,8 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/transform.h" - -DECLARE_bool(use_curand); - namespace phi { -template -struct BernoulliCudaFunctor { - unsigned int seed_; - unsigned int offset_; - __host__ __device__ BernoulliCudaFunctor(unsigned int seed, - unsigned int offset) - : seed_(seed), offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n, const T p) const { - // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several - // lines of error messages if, and it should be refined. - PADDLE_ENFORCE(p >= 0.0 && p <= 1.0, - "The probability should be >=0 and <= 1, but got %f", - p); - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n + offset_); - return static_cast(dist(rng) < p); - } -}; - // 'curand_uniform4/hiprand_uniform4' generate 4 random number each time template __global__ void bernoulli_cuda_kernel( @@ -100,30 +71,16 @@ void BernoulliKernel(const Context& ctx, auto gen_cuda = ctx.GetGenerator(); - if (FLAGS_use_curand) { - auto seed_offset = gen_cuda->IncrementOffset(12); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; + auto seed_offset = gen_cuda->IncrementOffset(12); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; - auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4); - size_t grid_size = gpu_config.GetGridSize(); - size_t block_size = gpu_config.GetBlockSize(); + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4); + size_t grid_size = gpu_config.GetGridSize(); + size_t block_size = gpu_config.GetBlockSize(); - bernoulli_cuda_kernel<<>>( - numel, seed, offset, x_data, out_data); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = numel * seed_offset.second; - paddle::platform::Transform trans; - thrust::counting_iterator index_sequence_begin(0); - trans(ctx, - index_sequence_begin, - index_sequence_begin + numel, - x_data, - out_data, - BernoulliCudaFunctor(static_cast(seed_offset.first), - static_cast(gen_offset))); - } + bernoulli_cuda_kernel<<>>( + numel, seed, offset, x_data, out_data); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu index e159e5916cff2b602e0638e3fa01d1126598381d..96ebc0353ef2453308d6b9e371b6b640e8ab7b28 100644 --- a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu @@ -14,10 +14,7 @@ #include "paddle/phi/kernels/gaussian_random_kernel.h" -#include -#include #include -#include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" @@ -27,8 +24,6 @@ #include "paddle/fluid/framework/generator.h" -DECLARE_bool(use_curand); - namespace phi { template @@ -83,21 +78,11 @@ void GaussianRandomKernel(const Context& dev_ctx, auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); if (gen_cuda->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename phi::dtype::MPTypeTrait::Type; - funcs::normal_distribution dist; - funcs::normal_transform trans(static_cast(mean), - static_cast(std)); - funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = GaussianGenerator(static_cast(mean), - static_cast(std), - seed_offset.first, - gen_offset); - IndexKernel>(dev_ctx, tensor, func); - } + using MT = typename phi::dtype::MPTypeTrait::Type; + funcs::normal_distribution dist; + funcs::normal_transform trans(static_cast(mean), + static_cast(std)); + funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); } else { auto func = GaussianGenerator(static_cast(mean), static_cast(std), seed); diff --git a/paddle/phi/kernels/gpu/multinomial_kernel.cu b/paddle/phi/kernels/gpu/multinomial_kernel.cu index ee5f843b18a90aa9b2750d0e4beec2fefd462ebf..ef6cd1323a9df832e0051d78d522cab8eb00121c 100644 --- a/paddle/phi/kernels/gpu/multinomial_kernel.cu +++ b/paddle/phi/kernels/gpu/multinomial_kernel.cu @@ -18,11 +18,6 @@ limitations under the License. */ #include "paddle/phi/kernels/multinomial_kernel.h" -#include -#include -#include -#include - #ifdef __NVCC__ #include "cub/cub.cuh" #endif @@ -44,12 +39,6 @@ namespace cub = hipcub; #include "paddle/phi/kernels/funcs/multinomial_functor.h" #include "paddle/phi/kernels/top_k_kernel.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/transform.h" - -DECLARE_bool(use_curand); - namespace phi { template @@ -74,32 +63,6 @@ __global__ void NormalizeProbability(T* norm_probs, } } -template -__global__ void GetCumulativeProbs(T* norm_probs_data, - int64_t num_distributions, - int64_t num_categories, - T* cumulative_probs_data) { - int id = blockIdx.x; - thrust::inclusive_scan(thrust::device, - norm_probs_data + id * num_categories, - norm_probs_data + (id + 1) * num_categories, - cumulative_probs_data + id * num_categories); -} - -template -struct RandomGeneratorCudaFunctor { - unsigned int seed_; - __host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n); - return dist(rng); - } -}; - template __device__ int binarySearchFunctor(T* cumulative_probs_data, T* norm_probs_data, @@ -130,7 +93,6 @@ __device__ int binarySearchFunctor(T* cumulative_probs_data, template __global__ void sampleMultinomialWithReplacement( - T* rng_data, const int64_t num_samples, int64_t* out_data, const int64_t num_distributions, @@ -138,10 +100,9 @@ __global__ void sampleMultinomialWithReplacement( T* cumulative_probs_data, T* norm_probs_data, uint64_t seed, - uint64_t offset, - bool use_curand) { + uint64_t offset) { // use binary search to get the selected category sample id. - // let cumulative_probs_data[id-1] < rng_data < cumulative_probs_data[id]. + // let cumulative_probs_data[id-1] < rng_number < cumulative_probs_data[id]. size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x + threadIdx.x; @@ -151,10 +112,7 @@ __global__ void sampleMultinomialWithReplacement( int sample = blockIdx.x * blockDim.x + threadIdx.x; for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { if (sample < num_samples) { - T rng_number = rng_data[sample + dist * num_samples]; - if (use_curand) { - rng_number = static_cast(curand_uniform4(&state).x); - } + T rng_number = static_cast(curand_uniform4(&state).x); // Find the bucket that a uniform random number lies in int selected_category = binarySearchFunctor(cumulative_probs_data + dist * num_categories, @@ -182,10 +140,7 @@ void MultinomialKernel(const Context& dev_ctx, const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; // If replacement is False, it's not a replaceable sample. Every category - // can - // be used only once. So after every sample, probability of the distribution - // will change. The implementation can't be parallelizable. Thus, call CPU - // implementation ``funcs::MultinomialFunctor`` to sample the distribution. + // can be used only once. if (!replacement) { int64_t in_data_numel = x.numel(); int64_t out_data_numel = out->numel(); @@ -202,76 +157,50 @@ void MultinomialKernel(const Context& dev_ctx, in_data_numel * sizeof(T), cudaMemcpyDeviceToHost); #endif - if (FLAGS_use_curand) { - for (size_t i = 0; i < num_distributions; ++i) { - int zero_num = 0; - for (size_t j = 0; j < num_categories; ++j) { - T weight = cpu_in_data[i * num_distributions + j]; - PADDLE_ENFORCE_GE( - weight, - 0, - errors::InvalidArgument( - "Each element of multinomial'input must >= 0, but got %f.", - weight)); - if (weight == static_cast(0)) { - zero_num++; - } + for (size_t i = 0; i < num_distributions; ++i) { + int zero_num = 0; + for (size_t j = 0; j < num_categories; ++j) { + T weight = cpu_in_data[i * num_distributions + j]; + PADDLE_ENFORCE_GE( + weight, + 0, + errors::InvalidArgument( + "Each element of multinomial'input must >= 0, but got %f.", + weight)); + if (weight == static_cast(0)) { + zero_num++; } - int valid_samples = num_categories - zero_num; - PADDLE_ENFORCE_LE( - num_samples, - valid_samples, - errors::InvalidArgument("When replacement=False, 'num_samples' " - "must less than or eaqual to the number of " - "positive item of input")); } - - // Refer to [gumbel softmax algorithm] - DenseTensor rand = EmptyLike(dev_ctx, x); - T* rand_data = rand.data(); - funcs::uniform_distribution dist; - funcs::exponential_transform trans(1.0); - funcs::distribution_and_transform(dev_ctx, &rand, dist, trans); - - funcs::ForRange for_range(dev_ctx, x.numel()); - for_range([rand_data, in_data] __device__(size_t idx) { - rand_data[idx] = in_data[idx] / rand_data[idx]; - }); - - if (num_samples == 1) { - ArgMaxKernel( - dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); - } else { - std::vector out_dim_vec = vectorize(out->dims()); - DenseTensor value = Empty(dev_ctx, IntArray(out_dim_vec)); - TopkKernel( - dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out); - } - return; + int valid_samples = num_categories - zero_num; + PADDLE_ENFORCE_LE( + num_samples, + valid_samples, + errors::InvalidArgument("When replacement=False, 'num_samples' " + "must less than or eaqual to the number of " + "positive item of input")); } - funcs::MultinomialFunctor(dev_ctx, - cpu_out_data, - cpu_in_data, - num_samples, - replacement, - num_categories, - num_distributions); - -#ifdef PADDLE_WITH_HIP - hipMemcpy(out_data, - cpu_out_data, - out_data_numel * sizeof(int64_t), - hipMemcpyHostToDevice); -#else - cudaMemcpy(out_data, - cpu_out_data, - out_data_numel * sizeof(int64_t), - cudaMemcpyHostToDevice); -#endif - - delete[] cpu_in_data; - delete[] cpu_out_data; + // Refer to [gumbel softmax algorithm] + DenseTensor rand = EmptyLike(dev_ctx, x); + T* rand_data = rand.data(); + funcs::uniform_distribution dist; + funcs::exponential_transform trans(1.0); + funcs::distribution_and_transform(dev_ctx, &rand, dist, trans); + + funcs::ForRange for_range(dev_ctx, x.numel()); + for_range([rand_data, in_data] __device__(size_t idx) { + rand_data[idx] = in_data[idx] / rand_data[idx]; + }); + + if (num_samples == 1) { + ArgMaxKernel( + dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); + } else { + std::vector out_dim_vec = vectorize(out->dims()); + DenseTensor value = Empty(dev_ctx, IntArray(out_dim_vec)); + TopkKernel( + dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out); + } return; } @@ -322,44 +251,18 @@ void MultinomialKernel(const Context& dev_ctx, auto* cumulative_probs_data = dev_ctx.template Alloc(&cumulative_probs_tensor); - if (FLAGS_use_curand) { - // 'phi::funcs::InclusiveScan' has higher accuracy than - // 'thrust::inclusive_scan' - funcs::InclusiveScan>( - /*in*/ norm_probs_data, - /*out*/ cumulative_probs_data, - /*outer_dim*/ static_cast(num_distributions), - /*mid_dim*/ static_cast(num_categories), - /*inner_dim*/ static_cast(1), - /*init*/ static_cast(0), - std::plus(), - /*reverse=*/false, - dev_ctx); - } else { - dim3 block_cumsum(1); - dim3 grid_cumsum(num_distributions); - GetCumulativeProbs<<>>( - norm_probs_data, - num_distributions, - num_categories, - cumulative_probs_data); - } - - // Generate random number for each sample. - std::random_device rd; - auto seed = rd(); - - DenseTensor rng_data_tensor; - rng_data_tensor.Resize({num_distributions, num_samples}); - auto* rng_data = dev_ctx.template Alloc(&rng_data_tensor); - - thrust::counting_iterator index_sequence_begin(0); - paddle::platform::Transform trans; - trans(dev_ctx, - index_sequence_begin, - index_sequence_begin + num_distributions * num_samples, - rng_data, - RandomGeneratorCudaFunctor(seed)); + // 'phi::funcs::InclusiveScan' has higher accuracy than + // 'thrust::inclusive_scan' + funcs::InclusiveScan>( + /*in*/ norm_probs_data, + /*out*/ cumulative_probs_data, + /*outer_dim*/ static_cast(num_distributions), + /*mid_dim*/ static_cast(num_categories), + /*inner_dim*/ static_cast(1), + /*init*/ static_cast(0), + std::plus(), + /*reverse=*/false, + dev_ctx); // Sample the multinomial distributions. dim3 block(128); @@ -376,7 +279,6 @@ void MultinomialKernel(const Context& dev_ctx, auto seed_offset = gen_cuda->IncrementOffset(increment); sampleMultinomialWithReplacement<<>>( - rng_data, num_samples, out_data, num_distributions, @@ -384,8 +286,7 @@ void MultinomialKernel(const Context& dev_ctx, cumulative_probs_data, norm_probs_data, seed_offset.first, - seed_offset.second, - FLAGS_use_curand); + seed_offset.second); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/randint_kernel.cu b/paddle/phi/kernels/gpu/randint_kernel.cu index 018850500226885bb7982341a97ac1793e6fefca..90eaea6a0868c1276687e9835b3133784806bfe1 100644 --- a/paddle/phi/kernels/gpu/randint_kernel.cu +++ b/paddle/phi/kernels/gpu/randint_kernel.cu @@ -23,8 +23,6 @@ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h" -DECLARE_bool(use_curand); - namespace phi { template @@ -37,37 +35,9 @@ void RandintRawKernel(const Context& dev_ctx, DenseTensor* out) { out->Resize(phi::make_ddim(shape.GetData())); T* data = dev_ctx.template Alloc(out); - if (FLAGS_use_curand) { - funcs::uniform_distribution dist; - funcs::uniform_int_transform trans(low, high); - funcs::distribution_and_transform(dev_ctx, out, dist, trans); - } else { - DenseTensor tmp; - tmp.Resize(phi::make_ddim(shape.GetData())); - T* tmp_data = dev_ctx.template HostAlloc(&tmp); - - std::shared_ptr engine; - if (seed) { - engine = std::make_shared(); - engine->seed(seed); - } else { - engine = dev_ctx.GetHostGenerator()->GetCPUEngine(); - } - - std::uniform_int_distribution dist(low, high - 1); - auto numel = out->numel(); - for (int64_t i = 0; i < numel; ++i) { - tmp_data[i] = dist(*engine); - } - - paddle::memory::Copy( - out->place(), - data, - tmp.place(), - tmp_data, - numel * paddle::experimental::SizeOf(out->dtype()), - 0); - } + funcs::uniform_distribution dist; + funcs::uniform_int_transform trans(low, high); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); } template diff --git a/paddle/phi/kernels/gpu/randperm_kernel.cu b/paddle/phi/kernels/gpu/randperm_kernel.cu index 678b580beca2f6515198fd4a4a126cda645d4660..4e488ed470df92a3f92640a4ae37526243ea9173 100644 --- a/paddle/phi/kernels/gpu/randperm_kernel.cu +++ b/paddle/phi/kernels/gpu/randperm_kernel.cu @@ -84,91 +84,65 @@ __global__ void SwapRepeatKernel( template void RandpermRawKernel( const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) { - if (FLAGS_use_curand) { - DenseTensor key; - RandintKernel(dev_ctx, - std::numeric_limits::min(), - std::numeric_limits::max(), - IntArray({n}), - phi::DataType::INT32, - &key); - DenseTensor key_out = Empty(dev_ctx, IntArray({n})); - - DenseTensor range = Empty(dev_ctx, IntArray({n})); - T* range_data = range.data(); - funcs::ForRange for_range(dev_ctx, n); - for_range([range_data] __device__(size_t idx) { - range_data[idx] = static_cast(idx); - }); - - out->Resize(phi::make_ddim({n})); - T* out_data = dev_ctx.template Alloc(out); - - // Refer to [Algorithm of randperm] https://osf.io/af2hy/ to - // improve performance of radix sort. - double n_d = static_cast(n); - int begin_bit = 0; - int end_bit = - std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9)))); - - size_t temp_storage_bytes = 0; - cub::DeviceRadixSort::SortPairs(nullptr, - temp_storage_bytes, - key.data(), - key_out.data(), - range.data(), - out_data, - n, - begin_bit, - end_bit < 32 ? end_bit : 32, - dev_ctx.stream()); - - auto d_temp_storage = paddle::memory::Alloc(dev_ctx, temp_storage_bytes); - cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), - temp_storage_bytes, - key.data(), - key_out.data(), - range.data(), - out_data, - n, - begin_bit, - end_bit < 32 ? end_bit : 32, - dev_ctx.stream()); - - auto gen_cuda = dev_ctx.GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(n); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; - - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); - SwapRepeatKernel<<>>( - key_out.data(), out_data, n, seed, offset); - } else { - DenseTensor tmp; - tmp.Resize(phi::make_ddim({n})); - T* tmp_data = dev_ctx.template HostAlloc(&tmp); - - std::shared_ptr engine; - if (seed) { - engine = std::make_shared(); - engine->seed(seed); - } else { - engine = dev_ctx.GetHostGenerator()->GetCPUEngine(); - } - - for (int i = 0; i < n; ++i) { - tmp_data[i] = static_cast(i); - } - std::shuffle(tmp_data, tmp_data + n, *engine); - - T* out_data = dev_ctx.template Alloc(out); - auto size = out->numel() * paddle::experimental::SizeOf(out->dtype()); - paddle::memory::Copy( - out->place(), out_data, tmp.place(), tmp_data, size, 0); - } + DenseTensor key; + RandintKernel(dev_ctx, + std::numeric_limits::min(), + std::numeric_limits::max(), + IntArray({n}), + phi::DataType::INT32, + &key); + DenseTensor key_out = Empty(dev_ctx, IntArray({n})); + + DenseTensor range = Empty(dev_ctx, IntArray({n})); + T* range_data = range.data(); + funcs::ForRange for_range(dev_ctx, n); + for_range([range_data] __device__(size_t idx) { + range_data[idx] = static_cast(idx); + }); + + out->Resize(phi::make_ddim({n})); + T* out_data = dev_ctx.template Alloc(out); + + // Refer to [Algorithm of randperm] https://osf.io/af2hy/ to + // improve performance of radix sort. + double n_d = static_cast(n); + int begin_bit = 0; + int end_bit = + std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9)))); + + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(nullptr, + temp_storage_bytes, + key.data(), + key_out.data(), + range.data(), + out_data, + n, + begin_bit, + end_bit < 32 ? end_bit : 32, + dev_ctx.stream()); + + auto d_temp_storage = paddle::memory::Alloc(dev_ctx, temp_storage_bytes); + cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), + temp_storage_bytes, + key.data(), + key_out.data(), + range.data(), + out_data, + n, + begin_bit, + end_bit < 32 ? end_bit : 32, + dev_ctx.stream()); + + auto gen_cuda = dev_ctx.GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(n); + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); + SwapRepeatKernel<<>>( + key_out.data(), out_data, n, seed_offset.first, seed_offset.second); } template diff --git a/paddle/phi/kernels/gpu/uniform_random_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_kernel.cu index 2cabde0bbf9425331d36379a86929143aa2094df..a4aea10cfe762f203f326d69888becbf1ee3094e 100644 --- a/paddle/phi/kernels/gpu/uniform_random_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_random_kernel.cu @@ -14,14 +14,13 @@ #include "paddle/phi/kernels/uniform_random_kernel.h" +#include #include "gflags/gflags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" -DECLARE_bool(use_curand); - namespace phi { template @@ -54,43 +53,6 @@ struct UniformGenerator { } }; -template -struct UniformGeneratorOffset { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - int offset_; - __host__ __device__ UniformGeneratorOffset(T min, - T max, - int seed, - int diag_num, - int diag_step, - T diag_val, - int offset) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val), - offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - template void UniformRandomRawKernel(const Context& dev_ctx, const IntArray& shape, @@ -114,23 +76,10 @@ void UniformRandomRawKernel(const Context& dev_ctx, auto generator = dev_ctx.GetGenerator(); if (generator->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename kps::details::MPTypeTrait::Type; - funcs::uniform_distribution dist; - funcs::uniform_real_transform trans(min, max); - funcs::distribution_and_transform(dev_ctx, out, dist, trans); - } else { - auto seed_offset = generator->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = UniformGeneratorOffset(min, - max, - seed_offset.first, - diag_num, - diag_step, - diag_val, - gen_offset); - IndexKernel>(dev_ctx, out, func); - } + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(min, max); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); } else { auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index cc55ea82df608251e25f3af06039fae07dbd4e74..21df60e9721214152c630ab7cb06c59d8b3f7ec4 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -657,7 +657,6 @@ for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%# set start=%start:~4,10% set FLAGS_call_stack_level=2 -set FLAGS_use_curand=True dir %THIRD_PARTY_PATH:/=\%\install\openblas\lib dir %THIRD_PARTY_PATH:/=\%\install\openblas\bin dir %THIRD_PARTY_PATH:/=\%\install\zlib\bin diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d1220e453758258b65fe2cc5e5ea0ff54318db08..e8bde467e085d6e4d3fa075cb48471b2911e8de6 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -61,8 +61,6 @@ function init() { # NOTE(chenweihang): For easy debugging, CI displays the C++ error stacktrace by default export FLAGS_call_stack_level=2 - export FLAGS_use_curand=True - # set CI_SKIP_CPP_TEST if only *.py changed # In order to avoid using in some CI(such as daily performance), the current # branch must not be `${BRANCH}` which is usually develop. diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 37eff6d132d03bc634f9d0ae3fdb62d118d2820e..b3baedc401504f2411e4a660fc9a3b1c5ea53924 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -561,12 +561,12 @@ class XavierInitializer(Initializer): if framework._non_static_mode(): if self._uniform: - limit = np.sqrt(6.0 / float(fan_in + fan_out)) + limit = math.sqrt(6.0 / float(fan_in + fan_out)) out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', -limit, 'max', limit, 'seed', self._seed, 'dtype', out_dtype) else: - std = np.sqrt(2.0 / float(fan_in + fan_out)) + std = math.sqrt(2.0 / float(fan_in + fan_out)) out_var = _C_ops.gaussian_random( 'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0, 'std', std, 'seed', self._seed) @@ -581,7 +581,7 @@ class XavierInitializer(Initializer): return None else: if self._uniform: - limit = np.sqrt(6.0 / float(fan_in + fan_out)) + limit = math.sqrt(6.0 / float(fan_in + fan_out)) op = block.append_op( type="uniform_random", inputs={}, @@ -595,7 +595,7 @@ class XavierInitializer(Initializer): }, stop_gradient=True) else: - std = np.sqrt(2.0 / float(fan_in + fan_out)) + std = math.sqrt(2.0 / float(fan_in + fan_out)) op = block.append_op( type="gaussian_random", outputs={"Out": out_var}, @@ -713,13 +713,13 @@ class MSRAInitializer(Initializer): if framework._non_static_mode(): if self._uniform: - limit = np.sqrt(6.0 / float(fan_in)) + limit = math.sqrt(6.0 / float(fan_in)) out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', -limit, 'max', limit, 'seed', self._seed, 'dtype', int(out_dtype)) else: - std = np.sqrt(2.0 / float(fan_in)) + std = math.sqrt(2.0 / float(fan_in)) out_var = _C_ops.gaussian_random( 'shape', out_var.shape, 'dtype', int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed) @@ -734,7 +734,7 @@ class MSRAInitializer(Initializer): return None else: if self._uniform: - limit = np.sqrt(6.0 / float(fan_in)) + limit = math.sqrt(6.0 / float(fan_in)) op = block.append_op( type="uniform_random", inputs={}, @@ -749,7 +749,7 @@ class MSRAInitializer(Initializer): stop_gradient=True) else: - std = np.sqrt(2.0 / float(fan_in)) + std = math.sqrt(2.0 / float(fan_in)) op = block.append_op( type="gaussian_random", outputs={"Out": out_var}, diff --git a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py index 426d5d463f4530e7662279db83fe29826d51d775..fc4ee13384b2dc435fc179b799dc8119e09ff52c 100644 --- a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py +++ b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py @@ -75,9 +75,6 @@ class TestRandomValue(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index d8a4eb8f45f7d3220f4d5b52927dfcc8d2e0f8c7..3aca428ac77af4c9881a3549c201dcb73ba41253 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -1034,9 +1034,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_exponential_op.py b/python/paddle/fluid/tests/unittests/test_exponential_op.py index 7a3ae203be62d644f076ae9b6bc2bf5b8641ccdf..c8f4101ea5d6ba254d7503625b156b6d7b7f7f37 100644 --- a/python/paddle/fluid/tests/unittests/test_exponential_op.py +++ b/python/paddle/fluid/tests/unittests/test_exponential_op.py @@ -100,9 +100,6 @@ class TestExponentialAPI(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 738441a46d377ec920c911c2712bcdc7bab6dbf0..4fca8b9f2a11827609931e647a9543b71560f06d 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -342,9 +342,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - def _check_random_value(dtype, expect, expect_mean, expect_std): x = paddle.randn([32, 3, 1024, 1024], dtype=dtype) actual = x.numpy() diff --git a/python/paddle/fluid/tests/unittests/test_linear.py b/python/paddle/fluid/tests/unittests/test_linear.py index 9d07a80da15dbfd35ffdedbcb09e82d59a84486e..6b00a86e3e900951e18c0690c3c3e64a78b6b621 100644 --- a/python/paddle/fluid/tests/unittests/test_linear.py +++ b/python/paddle/fluid/tests/unittests/test_linear.py @@ -73,6 +73,22 @@ class LinearTestCase(unittest.TestCase): np.testing.assert_array_almost_equal(res_f, res_nn) np.testing.assert_array_almost_equal(res_nn, res_np) + def test_weight_init(self): + if not paddle.is_compiled_with_cuda(): + return + paddle.seed(100) + linear = paddle.nn.Linear( + 2, 3, weight_attr=paddle.nn.initializer.Normal(0, 1.)) + paddle.nn.utils._stride_column(linear.weight) + expect = [[1.4349908, -0.8099171, -2.64788], + [-1.4981681, -1.1784115, -0.023253186]] + self.assertTrue(np.allclose(linear.weight.numpy(), expect)) + + linear = paddle.nn.Linear(2, 3) + expect = [[0.73261100, 0.43836895, 0.07908206], + [0.85075015, -1.04724526, 0.64371765]] + self.assertTrue(np.allclose(linear.weight.numpy(), expect)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py index a65a1c7e14c2bf14414801d119bb54fd7f6873a6..ecde527523d3dd45eecdbb5aeba468c0456bc5b0 100644 --- a/python/paddle/fluid/tests/unittests/test_multinomial_op.py +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -227,9 +227,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index 2123d4e0e7e35984f01b39633b76cb2c6337bb50..f8183bb5f8db28eb5717b9ae6f5b8b1276637ad7 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -107,9 +107,6 @@ class TestPoissonAPI(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_randint_op.py b/python/paddle/fluid/tests/unittests/test_randint_op.py index 1eb99e08bb8e1b6636a73d81545882d9457bc12c..361f4d280f70fab6883643d53330839512ad5d57 100644 --- a/python/paddle/fluid/tests/unittests/test_randint_op.py +++ b/python/paddle/fluid/tests/unittests/test_randint_op.py @@ -198,9 +198,6 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_randperm_op.py b/python/paddle/fluid/tests/unittests/test_randperm_op.py index 5c9ab36fa34bc334b98c8c69de924aede06b6a57..deb0a9a082140f5702f50f4ef7090fc617ca8613 100644 --- a/python/paddle/fluid/tests/unittests/test_randperm_op.py +++ b/python/paddle/fluid/tests/unittests/test_randperm_op.py @@ -155,9 +155,6 @@ class TestRandomValue(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 41b6ed36d65cccfad093de332ed81b007d77d3ce..683cc2fdf867e086932c08e7ce4853a66e0f869c 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -573,37 +573,46 @@ class TestRandomValue(unittest.TestCase): if not "V100" in paddle.device.cuda.get_device_name(): return - if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): - return - - def _check_random_value(dtype, expect, expect_mean, expect_std): - x = paddle.rand([32, 3, 1024, 1024], dtype=dtype) - actual = x.numpy() - self.assertTrue(np.allclose(actual[2, 1, 512, 1000:1010], expect)) - self.assertEqual(np.mean(actual), expect_mean) - self.assertEqual(np.std(actual), expect_std) - print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() + paddle.set_device('gpu') paddle.seed(2021) + + expect_mean = 0.50000454338820143895816272561205551028251647949218750 + expect_std = 0.28867379167297479991560749112977646291255950927734375 expect = [ 0.55298901, 0.65184678, 0.49375412, 0.57943639, 0.16459608, 0.67181056, 0.03021481, 0.0238559, 0.07742096, 0.55972187 ] - expect_mean = 0.50000454338820143895816272561205551028251647949218750 - expect_std = 0.28867379167297479991560749112977646291255950927734375 - _check_random_value(core.VarDesc.VarType.FP64, expect, expect_mean, - expect_std) + out = paddle.rand([32, 3, 1024, 1024], dtype='float64').numpy() + self.assertEqual(np.mean(out), expect_mean) + self.assertEqual(np.std(out), expect_std) + self.assertTrue(np.allclose(out[2, 1, 512, 1000:1010], expect)) + expect_mean = 0.50002604722976684570312500 + expect_std = 0.2886914908885955810546875 expect = [ 0.45320973, 0.17582087, 0.725341, 0.30849215, 0.622257, 0.46352342, 0.97228295, 0.12771158, 0.286525, 0.9810645 ] - expect_mean = 0.50002604722976684570312500 - expect_std = 0.2886914908885955810546875 - _check_random_value(core.VarDesc.VarType.FP32, expect, expect_mean, - expect_std) + out = paddle.rand([32, 3, 1024, 1024], dtype='float32').numpy() + self.assertEqual(np.mean(out), expect_mean) + self.assertEqual(np.std(out), expect_std) + self.assertTrue(np.allclose(out[2, 1, 512, 1000:1010], expect)) + + expect_mean = 25.11843109130859375 + expect_std = 43.370647430419921875 + expect = [ + 30.089634, 77.05225, 3.1201615, 68.34072, 59.266724, -25.33281, + 12.973292, 27.41127, -17.412298, 27.931019 + ] + out = paddle.empty( + [16, 16, 16, 16], dtype='float32').uniform_(-50, 100).numpy() + self.assertEqual(np.mean(out), expect_mean) + self.assertEqual(np.std(out), expect_std) + self.assertTrue(np.allclose(out[10, 10, 10, 0:10], expect)) + paddle.enable_static() diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py index 8f9b55d15cad0b41eca06d1e92f07ec117c102a3..8ec4e8cfd60b5aa0dacd3444c0ae0da08d543c35 100644 --- a/python/paddle/nn/utils/__init__.py +++ b/python/paddle/nn/utils/__init__.py @@ -14,7 +14,7 @@ from .spectral_norm_hook import spectral_norm from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401 -from .transform_parameters import parameters_to_vector, vector_to_parameters # noqa: F401 +from .transform_parameters import parameters_to_vector, vector_to_parameters, _stride_column # noqa: F401 __all__ = [ #noqa 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters' diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py index 99870ce29a138dab7bd52e76fd1e582a8ffc045e..feb70e02d598815743870bf2d7579d7b39d619f8 100644 --- a/python/paddle/nn/utils/transform_parameters.py +++ b/python/paddle/nn/utils/transform_parameters.py @@ -36,6 +36,39 @@ def _inplace_reshape_dygraph(x, shape): stop_gradient=True) +@dygraph_only +def _stride_column(param): + """ + A tool function. Permute date of parameter as a 'columns' stride. Now, it only support 2-D parameter. + + Args: + param(Tensor]): The param that will be strided according to 'columns'. + + Examples: + .. code-block:: python + + import paddle + paddle.seed(100) + + linear = paddle.nn.Linear(2, 3) + print(linear.weight) + # [[-0.31485492, -1.02896988, 0.45741916], + # [-0.65525872, -1.04643178, 1.07262802]] + + paddle.nn.utils.stride_column(linear.weight) + print(linear.weight) + # [[-0.31485492, 0.45741916, -1.04643178], + # [-1.02896988, -0.65525872, 1.07262802]] + + """ + assert len(param.shape) == 2 + shape = [param.shape[1], param.shape[0]] + with paddle.fluid.dygraph.no_grad(): + reshape_var = paddle.reshape(param, shape) + transpose_var = paddle.transpose(reshape_var, [1, 0]) + transpose_var._share_underline_tensor_to(param) + + @dygraph_only def parameters_to_vector(parameters, name=None): """