未验证 提交 9714878c 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

remove FLAGS_use_curand and change all random op CUDA implementation (#41308)

上级 0d642d3a
...@@ -38,43 +38,9 @@ limitations under the License. */ ...@@ -38,43 +38,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
DECLARE_bool(use_curand);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskGenerator {
const float dropout_prob_;
const bool is_upscale_in_train_;
using MT = typename details::MPTypeTrait<T1>::Type;
MT factor;
HOSTDEVICE inline DstMaskGenerator(const float dropout_prob,
const bool is_upscale_in_train)
: dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) {
factor = static_cast<MT>(1.0f / (1.0f - dropout_prob_));
}
HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
const T2* rand, int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for (int i = 0; i < kCount; i++) {
if (rand[i] < dropout_prob_) {
dst[i] = static_cast<T1>(0);
dst[i + kCount] = dst[i];
} else {
dst[i] = is_upscale_in_train_
? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
: static_cast<T1>(src_val[i]);
dst[i + kCount] = static_cast<T1>(1);
}
}
}
};
template <typename T1, typename T2 = T1, typename OutT = T1> template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskFunctor { struct DstMaskFunctor {
const float retain_prob_; const float retain_prob_;
...@@ -113,7 +79,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, ...@@ -113,7 +79,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const T* src, MaskType* mask, T* dst, const T* src, MaskType* mask, T* dst,
bool is_upscale_in_train, bool is_upscale_in_train,
uint64_t increment, uint64_t increment,
size_t main_offset, bool use_curand) { size_t main_offset) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X); size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount = static constexpr int kCount =
phi::funcs::uniform_distribution<float>::kReturnsCount; phi::funcs::uniform_distribution<float>::kReturnsCount;
...@@ -135,76 +101,41 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, ...@@ -135,76 +101,41 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
int deal_size = BLOCK_NUM_X * kCount; int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount; size_t fix = idx * kCount;
if (use_curand) {
auto dst_functor = auto dst_functor =
DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train); DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
for (; fix < main_offset; fix += stride) { for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size); kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(), kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state); &state);
// dst // dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>( kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], deal_size);
deal_size); // mask
// mask kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>( &mask_result[0], &dst_mask[kCount], Cast());
&mask_result[0], &dst_mask[kCount], Cast()); kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0], deal_size);
deal_size); if (fix > idx * kCount + 1) {
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);
__syncthreads(); __syncthreads();
} }
} else { }
auto dst_functor = int remainder = n - fix;
DstMaskGenerator<T, float>(dropout_prob, is_upscale_in_train); if (remainder > 0) {
for (; fix < main_offset; fix += stride) { kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size); kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(), &state);
&state); // dst
// dst kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>( &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], // mask
deal_size); kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
// mask &mask_result[0], &dst_mask[kCount], Cast());
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>( kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
&mask_result[0], &dst_mask[kCount], Cast()); remainder);
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0], __syncthreads();
deal_size);
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);
}
} }
} }
...@@ -251,13 +182,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, ...@@ -251,13 +182,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t grid_size = gpu_config.GetGridSize(); size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize(); size_t block_size = gpu_config.GetBlockSize();
if (FLAGS_use_curand) { int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); const auto& prop = platform::GetDeviceProperties(device_id);
const auto& prop = platform::GetDeviceProperties(device_id); size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
size_t max_grid_size = prop.maxThreadsPerMultiProcessor * prop.multiProcessorCount / block_size;
prop.multiProcessorCount / block_size; grid_size = std::min(grid_size, max_grid_size);
grid_size = std::min(grid_size, max_grid_size);
}
auto offset = auto offset =
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
...@@ -268,7 +197,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, ...@@ -268,7 +197,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
VectorizedRandomGenerator<T, uint8_t><<<grid_size, block_size, 0, stream>>>( VectorizedRandomGenerator<T, uint8_t><<<grid_size, block_size, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data, size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment, main_offset, FLAGS_use_curand); upscale_in_train, increment, main_offset);
} else { } else {
if (upscale_in_train) { if (upscale_in_train) {
// todo: can y share with data with x directly? // todo: can y share with data with x directly?
......
...@@ -11,21 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,21 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h"
DECLARE_bool(use_curand);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -19,11 +19,7 @@ ...@@ -19,11 +19,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
DECLARE_bool(use_curand);
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
...@@ -146,39 +142,6 @@ struct UniformGenerator { ...@@ -146,39 +142,6 @@ struct UniformGenerator {
} }
}; };
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T> template <typename T>
void UniformRandom(const framework::ExecutionContext& context, void UniformRandom(const framework::ExecutionContext& context,
framework::Tensor* tensor) { framework::Tensor* tensor) {
...@@ -205,19 +168,10 @@ void UniformRandom(const framework::ExecutionContext& context, ...@@ -205,19 +168,10 @@ void UniformRandom(const framework::ExecutionContext& context,
int device_id = context.GetPlace().GetDeviceId(); int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) { if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) { using MT = typename details::MPTypeTrait<T>::Type;
using MT = typename details::MPTypeTrait<T>::Type; phi::funcs::uniform_distribution<MT> dist;
phi::funcs::uniform_distribution<MT> dist; phi::funcs::uniform_real_transform<MT> trans(min, max);
phi::funcs::uniform_real_transform<MT> trans(min, max); phi::funcs::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
phi::funcs::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset);
phi::IndexKernel<T, UniformGeneratorOffset<T>>(dev_cxt, tensor, func);
}
} else { } else {
auto func = auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val); UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
......
...@@ -545,8 +545,6 @@ PADDLE_DEFINE_EXPORTED_double( ...@@ -545,8 +545,6 @@ PADDLE_DEFINE_EXPORTED_double(
*/ */
PADDLE_DEFINE_EXPORTED_bool(use_mkldnn, false, "Use MKLDNN to run"); PADDLE_DEFINE_EXPORTED_bool(use_mkldnn, false, "Use MKLDNN to run");
PADDLE_DEFINE_EXPORTED_bool(use_curand, false, "Random OP use CURAND");
/** /**
* Debug related FLAG * Debug related FLAG
* Name: FLAGS_call_stack_level * Name: FLAGS_call_stack_level
......
...@@ -75,6 +75,7 @@ PD_REGISTER_KERNEL(transpose, ...@@ -75,6 +75,7 @@ PD_REGISTER_KERNEL(transpose,
double, double,
int32_t, int32_t,
int64_t, int64_t,
phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include "paddle/phi/kernels/bernoulli_kernel.h" #include "paddle/phi/kernels/bernoulli_kernel.h"
#include <thrust/random.h>
#include <thrust/transform.h>
#ifdef __NVCC__ #ifdef __NVCC__
#include <curand_kernel.h> #include <curand_kernel.h>
#endif #endif
...@@ -32,35 +30,8 @@ ...@@ -32,35 +30,8 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"
DECLARE_bool(use_curand);
namespace phi { namespace phi {
template <typename T>
struct BernoulliCudaFunctor {
unsigned int seed_;
unsigned int offset_;
__host__ __device__ BernoulliCudaFunctor(unsigned int seed,
unsigned int offset)
: seed_(seed), offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
// lines of error messages if, and it should be refined.
PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
"The probability should be >=0 and <= 1, but got %f",
p);
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
rng.discard(n + offset_);
return static_cast<T>(dist(rng) < p);
}
};
// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time // 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template <typename T> template <typename T>
__global__ void bernoulli_cuda_kernel( __global__ void bernoulli_cuda_kernel(
...@@ -100,30 +71,16 @@ void BernoulliKernel(const Context& ctx, ...@@ -100,30 +71,16 @@ void BernoulliKernel(const Context& ctx,
auto gen_cuda = ctx.GetGenerator(); auto gen_cuda = ctx.GetGenerator();
if (FLAGS_use_curand) { auto seed_offset = gen_cuda->IncrementOffset(12);
auto seed_offset = gen_cuda->IncrementOffset(12); uint64_t seed = seed_offset.first;
uint64_t seed = seed_offset.first; uint64_t offset = seed_offset.second;
uint64_t offset = seed_offset.second;
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4); auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4);
size_t grid_size = gpu_config.GetGridSize(); size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize(); size_t block_size = gpu_config.GetBlockSize();
bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>( bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>(
numel, seed, offset, x_data, out_data); numel, seed, offset, x_data, out_data);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = numel * seed_offset.second;
paddle::platform::Transform<phi::GPUContext> trans;
thrust::counting_iterator<int64_t> index_sequence_begin(0);
trans(ctx,
index_sequence_begin,
index_sequence_begin + numel,
x_data,
out_data,
BernoulliCudaFunctor<T>(static_cast<int64_t>(seed_offset.first),
static_cast<int64_t>(gen_offset)));
}
} }
} // namespace phi } // namespace phi
......
...@@ -14,10 +14,7 @@ ...@@ -14,10 +14,7 @@
#include "paddle/phi/kernels/gaussian_random_kernel.h" #include "paddle/phi/kernels/gaussian_random_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -27,8 +24,6 @@ ...@@ -27,8 +24,6 @@
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
DECLARE_bool(use_curand);
namespace phi { namespace phi {
template <typename T> template <typename T>
...@@ -83,21 +78,11 @@ void GaussianRandomKernel(const Context& dev_ctx, ...@@ -83,21 +78,11 @@ void GaussianRandomKernel(const Context& dev_ctx,
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) { if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) { using MT = typename phi::dtype::MPTypeTrait<T>::Type;
using MT = typename phi::dtype::MPTypeTrait<T>::Type; funcs::normal_distribution<MT> dist;
funcs::normal_distribution<MT> dist; funcs::normal_transform<MT> trans(static_cast<MT>(mean),
funcs::normal_transform<MT> trans(static_cast<MT>(mean), static_cast<MT>(std));
static_cast<MT>(std)); funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func = GaussianGenerator<T>(static_cast<T>(mean),
static_cast<T>(std),
seed_offset.first,
gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
}
} else { } else {
auto func = auto func =
GaussianGenerator<T>(static_cast<T>(mean), static_cast<T>(std), seed); GaussianGenerator<T>(static_cast<T>(mean), static_cast<T>(std), seed);
......
...@@ -18,11 +18,6 @@ limitations under the License. */ ...@@ -18,11 +18,6 @@ limitations under the License. */
#include "paddle/phi/kernels/multinomial_kernel.h" #include "paddle/phi/kernels/multinomial_kernel.h"
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#ifdef __NVCC__ #ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#endif #endif
...@@ -44,12 +39,6 @@ namespace cub = hipcub; ...@@ -44,12 +39,6 @@ namespace cub = hipcub;
#include "paddle/phi/kernels/funcs/multinomial_functor.h" #include "paddle/phi/kernels/funcs/multinomial_functor.h"
#include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool(use_curand);
namespace phi { namespace phi {
template <typename T> template <typename T>
...@@ -74,32 +63,6 @@ __global__ void NormalizeProbability(T* norm_probs, ...@@ -74,32 +63,6 @@ __global__ void NormalizeProbability(T* norm_probs,
} }
} }
template <typename T>
__global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs_data) {
int id = blockIdx.x;
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs_data + id * num_categories);
}
template <typename T>
struct RandomGeneratorCudaFunctor {
unsigned int seed_;
__host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
rng.discard(n);
return dist(rng);
}
};
template <typename T> template <typename T>
__device__ int binarySearchFunctor(T* cumulative_probs_data, __device__ int binarySearchFunctor(T* cumulative_probs_data,
T* norm_probs_data, T* norm_probs_data,
...@@ -130,7 +93,6 @@ __device__ int binarySearchFunctor(T* cumulative_probs_data, ...@@ -130,7 +93,6 @@ __device__ int binarySearchFunctor(T* cumulative_probs_data,
template <typename T> template <typename T>
__global__ void sampleMultinomialWithReplacement( __global__ void sampleMultinomialWithReplacement(
T* rng_data,
const int64_t num_samples, const int64_t num_samples,
int64_t* out_data, int64_t* out_data,
const int64_t num_distributions, const int64_t num_distributions,
...@@ -138,10 +100,9 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -138,10 +100,9 @@ __global__ void sampleMultinomialWithReplacement(
T* cumulative_probs_data, T* cumulative_probs_data,
T* norm_probs_data, T* norm_probs_data,
uint64_t seed, uint64_t seed,
uint64_t offset, uint64_t offset) {
bool use_curand) {
// use binary search to get the selected category sample id. // use binary search to get the selected category sample id.
// let cumulative_probs_data[id-1] < rng_data < cumulative_probs_data[id]. // let cumulative_probs_data[id-1] < rng_number < cumulative_probs_data[id].
size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x + size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
threadIdx.x; threadIdx.x;
...@@ -151,10 +112,7 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -151,10 +112,7 @@ __global__ void sampleMultinomialWithReplacement(
int sample = blockIdx.x * blockDim.x + threadIdx.x; int sample = blockIdx.x * blockDim.x + threadIdx.x;
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
if (sample < num_samples) { if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples]; T rng_number = static_cast<T>(curand_uniform4(&state).x);
if (use_curand) {
rng_number = static_cast<T>(curand_uniform4(&state).x);
}
// Find the bucket that a uniform random number lies in // Find the bucket that a uniform random number lies in
int selected_category = int selected_category =
binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories, binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
...@@ -182,10 +140,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -182,10 +140,7 @@ void MultinomialKernel(const Context& dev_ctx,
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
// If replacement is False, it's not a replaceable sample. Every category // If replacement is False, it's not a replaceable sample. Every category
// can // can be used only once.
// be used only once. So after every sample, probability of the distribution
// will change. The implementation can't be parallelizable. Thus, call CPU
// implementation ``funcs::MultinomialFunctor`` to sample the distribution.
if (!replacement) { if (!replacement) {
int64_t in_data_numel = x.numel(); int64_t in_data_numel = x.numel();
int64_t out_data_numel = out->numel(); int64_t out_data_numel = out->numel();
...@@ -202,76 +157,50 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -202,76 +157,50 @@ void MultinomialKernel(const Context& dev_ctx,
in_data_numel * sizeof(T), in_data_numel * sizeof(T),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
#endif #endif
if (FLAGS_use_curand) { for (size_t i = 0; i < num_distributions; ++i) {
for (size_t i = 0; i < num_distributions; ++i) { int zero_num = 0;
int zero_num = 0; for (size_t j = 0; j < num_categories; ++j) {
for (size_t j = 0; j < num_categories; ++j) { T weight = cpu_in_data[i * num_distributions + j];
T weight = cpu_in_data[i * num_distributions + j]; PADDLE_ENFORCE_GE(
PADDLE_ENFORCE_GE( weight,
weight, 0,
0, errors::InvalidArgument(
errors::InvalidArgument( "Each element of multinomial'input must >= 0, but got %f.",
"Each element of multinomial'input must >= 0, but got %f.", weight));
weight)); if (weight == static_cast<T>(0)) {
if (weight == static_cast<T>(0)) { zero_num++;
zero_num++;
}
} }
int valid_samples = num_categories - zero_num;
PADDLE_ENFORCE_LE(
num_samples,
valid_samples,
errors::InvalidArgument("When replacement=False, 'num_samples' "
"must less than or eaqual to the number of "
"positive item of input"));
} }
int valid_samples = num_categories - zero_num;
// Refer to [gumbel softmax algorithm] PADDLE_ENFORCE_LE(
DenseTensor rand = EmptyLike<T, Context>(dev_ctx, x); num_samples,
T* rand_data = rand.data<T>(); valid_samples,
funcs::uniform_distribution<T> dist; errors::InvalidArgument("When replacement=False, 'num_samples' "
funcs::exponential_transform<T> trans(1.0); "must less than or eaqual to the number of "
funcs::distribution_and_transform<T>(dev_ctx, &rand, dist, trans); "positive item of input"));
funcs::ForRange<Context> for_range(dev_ctx, x.numel());
for_range([rand_data, in_data] __device__(size_t idx) {
rand_data[idx] = in_data[idx] / rand_data[idx];
});
if (num_samples == 1) {
ArgMaxKernel<T, Context>(
dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
} else {
std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
TopkKernel<T, Context>(
dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out);
}
return;
} }
funcs::MultinomialFunctor<T>(dev_ctx, // Refer to [gumbel softmax algorithm]
cpu_out_data, DenseTensor rand = EmptyLike<T, Context>(dev_ctx, x);
cpu_in_data, T* rand_data = rand.data<T>();
num_samples, funcs::uniform_distribution<T> dist;
replacement, funcs::exponential_transform<T> trans(1.0);
num_categories, funcs::distribution_and_transform<T>(dev_ctx, &rand, dist, trans);
num_distributions);
funcs::ForRange<Context> for_range(dev_ctx, x.numel());
#ifdef PADDLE_WITH_HIP for_range([rand_data, in_data] __device__(size_t idx) {
hipMemcpy(out_data, rand_data[idx] = in_data[idx] / rand_data[idx];
cpu_out_data, });
out_data_numel * sizeof(int64_t),
hipMemcpyHostToDevice); if (num_samples == 1) {
#else ArgMaxKernel<T, Context>(
cudaMemcpy(out_data, dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
cpu_out_data, } else {
out_data_numel * sizeof(int64_t), std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
cudaMemcpyHostToDevice); DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
#endif TopkKernel<T, Context>(
dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out);
delete[] cpu_in_data; }
delete[] cpu_out_data;
return; return;
} }
...@@ -322,44 +251,18 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -322,44 +251,18 @@ void MultinomialKernel(const Context& dev_ctx,
auto* cumulative_probs_data = auto* cumulative_probs_data =
dev_ctx.template Alloc<T>(&cumulative_probs_tensor); dev_ctx.template Alloc<T>(&cumulative_probs_tensor);
if (FLAGS_use_curand) { // 'phi::funcs::InclusiveScan' has higher accuracy than
// 'phi::funcs::InclusiveScan' has higher accuracy than // 'thrust::inclusive_scan'
// 'thrust::inclusive_scan' funcs::InclusiveScan<T, std::plus<T>>(
funcs::InclusiveScan<T, std::plus<T>>( /*in*/ norm_probs_data,
/*in*/ norm_probs_data, /*out*/ cumulative_probs_data,
/*out*/ cumulative_probs_data, /*outer_dim*/ static_cast<size_t>(num_distributions),
/*outer_dim*/ static_cast<size_t>(num_distributions), /*mid_dim*/ static_cast<size_t>(num_categories),
/*mid_dim*/ static_cast<size_t>(num_categories), /*inner_dim*/ static_cast<size_t>(1),
/*inner_dim*/ static_cast<size_t>(1), /*init*/ static_cast<T>(0),
/*init*/ static_cast<T>(0), std::plus<T>(),
std::plus<T>(), /*reverse=*/false,
/*reverse=*/false, dev_ctx);
dev_ctx);
} else {
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0, dev_ctx.stream()>>>(
norm_probs_data,
num_distributions,
num_categories,
cumulative_probs_data);
}
// Generate random number for each sample.
std::random_device rd;
auto seed = rd();
DenseTensor rng_data_tensor;
rng_data_tensor.Resize({num_distributions, num_samples});
auto* rng_data = dev_ctx.template Alloc<T>(&rng_data_tensor);
thrust::counting_iterator<int64_t> index_sequence_begin(0);
paddle::platform::Transform<GPUContext> trans;
trans(dev_ctx,
index_sequence_begin,
index_sequence_begin + num_distributions * num_samples,
rng_data,
RandomGeneratorCudaFunctor<T>(seed));
// Sample the multinomial distributions. // Sample the multinomial distributions.
dim3 block(128); dim3 block(128);
...@@ -376,7 +279,6 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -376,7 +279,6 @@ void MultinomialKernel(const Context& dev_ctx,
auto seed_offset = gen_cuda->IncrementOffset(increment); auto seed_offset = gen_cuda->IncrementOffset(increment);
sampleMultinomialWithReplacement<T><<<grid, block, 0, dev_ctx.stream()>>>( sampleMultinomialWithReplacement<T><<<grid, block, 0, dev_ctx.stream()>>>(
rng_data,
num_samples, num_samples,
out_data, out_data,
num_distributions, num_distributions,
...@@ -384,8 +286,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -384,8 +286,7 @@ void MultinomialKernel(const Context& dev_ctx,
cumulative_probs_data, cumulative_probs_data,
norm_probs_data, norm_probs_data,
seed_offset.first, seed_offset.first,
seed_offset.second, seed_offset.second);
FLAGS_use_curand);
} }
} // namespace phi } // namespace phi
......
...@@ -23,8 +23,6 @@ ...@@ -23,8 +23,6 @@
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
DECLARE_bool(use_curand);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
...@@ -37,37 +35,9 @@ void RandintRawKernel(const Context& dev_ctx, ...@@ -37,37 +35,9 @@ void RandintRawKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
out->Resize(phi::make_ddim(shape.GetData())); out->Resize(phi::make_ddim(shape.GetData()));
T* data = dev_ctx.template Alloc<T>(out); T* data = dev_ctx.template Alloc<T>(out);
if (FLAGS_use_curand) { funcs::uniform_distribution<uint32_t> dist;
funcs::uniform_distribution<uint32_t> dist; funcs::uniform_int_transform<T, uint32_t> trans(low, high);
funcs::uniform_int_transform<T, uint32_t> trans(low, high); funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else {
DenseTensor tmp;
tmp.Resize(phi::make_ddim(shape.GetData()));
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
std::uniform_int_distribution<T> dist(low, high - 1);
auto numel = out->numel();
for (int64_t i = 0; i < numel; ++i) {
tmp_data[i] = dist(*engine);
}
paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(),
data,
tmp.place(),
tmp_data,
numel * paddle::experimental::SizeOf(out->dtype()),
0);
}
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -84,91 +84,65 @@ __global__ void SwapRepeatKernel( ...@@ -84,91 +84,65 @@ __global__ void SwapRepeatKernel(
template <typename T, typename Context> template <typename T, typename Context>
void RandpermRawKernel( void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) { const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
if (FLAGS_use_curand) { DenseTensor key;
DenseTensor key; RandintKernel<int, Context>(dev_ctx,
RandintKernel<int, Context>(dev_ctx, std::numeric_limits<int>::min(),
std::numeric_limits<int>::min(), std::numeric_limits<int>::max(),
std::numeric_limits<int>::max(), IntArray({n}),
IntArray({n}), phi::DataType::INT32,
phi::DataType::INT32, &key);
&key); DenseTensor key_out = Empty<int, Context>(dev_ctx, IntArray({n}));
DenseTensor key_out = Empty<int, Context>(dev_ctx, IntArray({n}));
DenseTensor range = Empty<T, Context>(dev_ctx, IntArray({n}));
DenseTensor range = Empty<T, Context>(dev_ctx, IntArray({n})); T* range_data = range.data<T>();
T* range_data = range.data<T>(); funcs::ForRange<Context> for_range(dev_ctx, n);
funcs::ForRange<Context> for_range(dev_ctx, n); for_range([range_data] __device__(size_t idx) {
for_range([range_data] __device__(size_t idx) { range_data[idx] = static_cast<T>(idx);
range_data[idx] = static_cast<T>(idx); });
});
out->Resize(phi::make_ddim({n}));
out->Resize(phi::make_ddim({n})); T* out_data = dev_ctx.template Alloc<T>(out);
T* out_data = dev_ctx.template Alloc<T>(out);
// Refer to [Algorithm of randperm] https://osf.io/af2hy/ to
// Refer to [Algorithm of randperm] https://osf.io/af2hy/ to // improve performance of radix sort.
// improve performance of radix sort. double n_d = static_cast<double>(n);
double n_d = static_cast<double>(n); int begin_bit = 0;
int begin_bit = 0; int end_bit =
int end_bit = std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9))));
std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9))));
size_t temp_storage_bytes = 0;
size_t temp_storage_bytes = 0; cub::DeviceRadixSort::SortPairs<int, T>(nullptr,
cub::DeviceRadixSort::SortPairs<int, T>(nullptr, temp_storage_bytes,
temp_storage_bytes, key.data<int>(),
key.data<int>(), key_out.data<int>(),
key_out.data<int>(), range.data<T>(),
range.data<T>(), out_data,
out_data, n,
n, begin_bit,
begin_bit, end_bit < 32 ? end_bit : 32,
end_bit < 32 ? end_bit : 32, dev_ctx.stream());
dev_ctx.stream());
auto d_temp_storage = paddle::memory::Alloc(dev_ctx, temp_storage_bytes);
auto d_temp_storage = paddle::memory::Alloc(dev_ctx, temp_storage_bytes); cub::DeviceRadixSort::SortPairs<int, T>(d_temp_storage->ptr(),
cub::DeviceRadixSort::SortPairs<int, T>(d_temp_storage->ptr(), temp_storage_bytes,
temp_storage_bytes, key.data<int>(),
key.data<int>(), key_out.data<int>(),
key_out.data<int>(), range.data<T>(),
range.data<T>(), out_data,
out_data, n,
n, begin_bit,
begin_bit, end_bit < 32 ? end_bit : 32,
end_bit < 32 ? end_bit : 32, dev_ctx.stream());
dev_ctx.stream());
auto gen_cuda = dev_ctx.GetGenerator();
auto gen_cuda = dev_ctx.GetGenerator(); auto seed_offset = gen_cuda->IncrementOffset(n);
auto seed_offset = gen_cuda->IncrementOffset(n);
uint64_t seed = seed_offset.first; auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
uint64_t offset = seed_offset.second; SwapRepeatKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); 0,
SwapRepeatKernel<T><<<config.block_per_grid.x, dev_ctx.stream()>>>(
config.thread_per_block.x, key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second);
0,
dev_ctx.stream()>>>(
key_out.data<int>(), out_data, n, seed, offset);
} else {
DenseTensor tmp;
tmp.Resize(phi::make_ddim({n}));
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
for (int i = 0; i < n; ++i) {
tmp_data[i] = static_cast<T>(i);
}
std::shuffle(tmp_data, tmp_data + n, *engine);
T* out_data = dev_ctx.template Alloc<T>(out);
auto size = out->numel() * paddle::experimental::SizeOf(out->dtype());
paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(), out_data, tmp.place(), tmp_data, size, 0);
}
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#include "paddle/phi/kernels/uniform_random_kernel.h" #include "paddle/phi/kernels/uniform_random_kernel.h"
#include <thrust/random.h>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h"
DECLARE_bool(use_curand);
namespace phi { namespace phi {
template <typename T> template <typename T>
...@@ -54,43 +53,6 @@ struct UniformGenerator { ...@@ -54,43 +53,6 @@ struct UniformGenerator {
} }
}; };
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min,
T max,
int seed,
int diag_num,
int diag_step,
T diag_val,
int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T, typename Context> template <typename T, typename Context>
void UniformRandomRawKernel(const Context& dev_ctx, void UniformRandomRawKernel(const Context& dev_ctx,
const IntArray& shape, const IntArray& shape,
...@@ -114,23 +76,10 @@ void UniformRandomRawKernel(const Context& dev_ctx, ...@@ -114,23 +76,10 @@ void UniformRandomRawKernel(const Context& dev_ctx,
auto generator = dev_ctx.GetGenerator(); auto generator = dev_ctx.GetGenerator();
if (generator->GetIsInitPy() && seed_flag) { if (generator->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) { using MT = typename kps::details::MPTypeTrait<T>::Type;
using MT = typename kps::details::MPTypeTrait<T>::Type; funcs::uniform_distribution<MT> dist;
funcs::uniform_distribution<MT> dist; funcs::uniform_real_transform<MT> trans(min, max);
funcs::uniform_real_transform<MT> trans(min, max); funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else {
auto seed_offset = generator->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func = UniformGeneratorOffset<T>(min,
max,
seed_offset.first,
diag_num,
diag_step,
diag_val,
gen_offset);
IndexKernel<T, UniformGeneratorOffset<T>>(dev_ctx, out, func);
}
} else { } else {
auto func = auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val); UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
......
...@@ -657,7 +657,6 @@ for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%# ...@@ -657,7 +657,6 @@ for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%#
set start=%start:~4,10% set start=%start:~4,10%
set FLAGS_call_stack_level=2 set FLAGS_call_stack_level=2
set FLAGS_use_curand=True
dir %THIRD_PARTY_PATH:/=\%\install\openblas\lib dir %THIRD_PARTY_PATH:/=\%\install\openblas\lib
dir %THIRD_PARTY_PATH:/=\%\install\openblas\bin dir %THIRD_PARTY_PATH:/=\%\install\openblas\bin
dir %THIRD_PARTY_PATH:/=\%\install\zlib\bin dir %THIRD_PARTY_PATH:/=\%\install\zlib\bin
......
...@@ -61,8 +61,6 @@ function init() { ...@@ -61,8 +61,6 @@ function init() {
# NOTE(chenweihang): For easy debugging, CI displays the C++ error stacktrace by default # NOTE(chenweihang): For easy debugging, CI displays the C++ error stacktrace by default
export FLAGS_call_stack_level=2 export FLAGS_call_stack_level=2
export FLAGS_use_curand=True
# set CI_SKIP_CPP_TEST if only *.py changed # set CI_SKIP_CPP_TEST if only *.py changed
# In order to avoid using in some CI(such as daily performance), the current # In order to avoid using in some CI(such as daily performance), the current
# branch must not be `${BRANCH}` which is usually develop. # branch must not be `${BRANCH}` which is usually develop.
......
...@@ -561,12 +561,12 @@ class XavierInitializer(Initializer): ...@@ -561,12 +561,12 @@ class XavierInitializer(Initializer):
if framework._non_static_mode(): if framework._non_static_mode():
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in + fan_out)) limit = math.sqrt(6.0 / float(fan_in + fan_out))
out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', out_var = _C_ops.uniform_random('shape', out_var.shape, 'min',
-limit, 'max', limit, 'seed', -limit, 'max', limit, 'seed',
self._seed, 'dtype', out_dtype) self._seed, 'dtype', out_dtype)
else: else:
std = np.sqrt(2.0 / float(fan_in + fan_out)) std = math.sqrt(2.0 / float(fan_in + fan_out))
out_var = _C_ops.gaussian_random( out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0, 'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0,
'std', std, 'seed', self._seed) 'std', std, 'seed', self._seed)
...@@ -581,7 +581,7 @@ class XavierInitializer(Initializer): ...@@ -581,7 +581,7 @@ class XavierInitializer(Initializer):
return None return None
else: else:
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in + fan_out)) limit = math.sqrt(6.0 / float(fan_in + fan_out))
op = block.append_op( op = block.append_op(
type="uniform_random", type="uniform_random",
inputs={}, inputs={},
...@@ -595,7 +595,7 @@ class XavierInitializer(Initializer): ...@@ -595,7 +595,7 @@ class XavierInitializer(Initializer):
}, },
stop_gradient=True) stop_gradient=True)
else: else:
std = np.sqrt(2.0 / float(fan_in + fan_out)) std = math.sqrt(2.0 / float(fan_in + fan_out))
op = block.append_op( op = block.append_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": out_var}, outputs={"Out": out_var},
...@@ -713,13 +713,13 @@ class MSRAInitializer(Initializer): ...@@ -713,13 +713,13 @@ class MSRAInitializer(Initializer):
if framework._non_static_mode(): if framework._non_static_mode():
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in)) limit = math.sqrt(6.0 / float(fan_in))
out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', out_var = _C_ops.uniform_random('shape', out_var.shape, 'min',
-limit, 'max', limit, 'seed', -limit, 'max', limit, 'seed',
self._seed, 'dtype', self._seed, 'dtype',
int(out_dtype)) int(out_dtype))
else: else:
std = np.sqrt(2.0 / float(fan_in)) std = math.sqrt(2.0 / float(fan_in))
out_var = _C_ops.gaussian_random( out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype', 'shape', out_var.shape, 'dtype',
int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed) int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed)
...@@ -734,7 +734,7 @@ class MSRAInitializer(Initializer): ...@@ -734,7 +734,7 @@ class MSRAInitializer(Initializer):
return None return None
else: else:
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in)) limit = math.sqrt(6.0 / float(fan_in))
op = block.append_op( op = block.append_op(
type="uniform_random", type="uniform_random",
inputs={}, inputs={},
...@@ -749,7 +749,7 @@ class MSRAInitializer(Initializer): ...@@ -749,7 +749,7 @@ class MSRAInitializer(Initializer):
stop_gradient=True) stop_gradient=True)
else: else:
std = np.sqrt(2.0 / float(fan_in)) std = math.sqrt(2.0 / float(fan_in))
op = block.append_op( op = block.append_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": out_var}, outputs={"Out": out_var},
......
...@@ -75,9 +75,6 @@ class TestRandomValue(unittest.TestCase): ...@@ -75,9 +75,6 @@ class TestRandomValue(unittest.TestCase):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>") print("Test Fixed Random number on GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
......
...@@ -1034,9 +1034,6 @@ class TestRandomValue(unittest.TestCase): ...@@ -1034,9 +1034,6 @@ class TestRandomValue(unittest.TestCase):
if not "V100" in paddle.device.cuda.get_device_name(): if not "V100" in paddle.device.cuda.get_device_name():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on V100 GPU------>") print("Test Fixed Random number on V100 GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
......
...@@ -100,9 +100,6 @@ class TestExponentialAPI(unittest.TestCase): ...@@ -100,9 +100,6 @@ class TestExponentialAPI(unittest.TestCase):
if not "V100" in paddle.device.cuda.get_device_name(): if not "V100" in paddle.device.cuda.get_device_name():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on V100 GPU------>") print("Test Fixed Random number on V100 GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
......
...@@ -342,9 +342,6 @@ class TestRandomValue(unittest.TestCase): ...@@ -342,9 +342,6 @@ class TestRandomValue(unittest.TestCase):
if not "V100" in paddle.device.cuda.get_device_name(): if not "V100" in paddle.device.cuda.get_device_name():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
def _check_random_value(dtype, expect, expect_mean, expect_std): def _check_random_value(dtype, expect, expect_mean, expect_std):
x = paddle.randn([32, 3, 1024, 1024], dtype=dtype) x = paddle.randn([32, 3, 1024, 1024], dtype=dtype)
actual = x.numpy() actual = x.numpy()
......
...@@ -73,6 +73,22 @@ class LinearTestCase(unittest.TestCase): ...@@ -73,6 +73,22 @@ class LinearTestCase(unittest.TestCase):
np.testing.assert_array_almost_equal(res_f, res_nn) np.testing.assert_array_almost_equal(res_f, res_nn)
np.testing.assert_array_almost_equal(res_nn, res_np) np.testing.assert_array_almost_equal(res_nn, res_np)
def test_weight_init(self):
if not paddle.is_compiled_with_cuda():
return
paddle.seed(100)
linear = paddle.nn.Linear(
2, 3, weight_attr=paddle.nn.initializer.Normal(0, 1.))
paddle.nn.utils._stride_column(linear.weight)
expect = [[1.4349908, -0.8099171, -2.64788],
[-1.4981681, -1.1784115, -0.023253186]]
self.assertTrue(np.allclose(linear.weight.numpy(), expect))
linear = paddle.nn.Linear(2, 3)
expect = [[0.73261100, 0.43836895, 0.07908206],
[0.85075015, -1.04724526, 0.64371765]]
self.assertTrue(np.allclose(linear.weight.numpy(), expect))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -227,9 +227,6 @@ class TestRandomValue(unittest.TestCase): ...@@ -227,9 +227,6 @@ class TestRandomValue(unittest.TestCase):
if not "V100" in paddle.device.cuda.get_device_name(): if not "V100" in paddle.device.cuda.get_device_name():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on V100 GPU------>") print("Test Fixed Random number on V100 GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
......
...@@ -107,9 +107,6 @@ class TestPoissonAPI(unittest.TestCase): ...@@ -107,9 +107,6 @@ class TestPoissonAPI(unittest.TestCase):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>") print("Test Fixed Random number on GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
......
...@@ -198,9 +198,6 @@ class TestRandomValue(unittest.TestCase): ...@@ -198,9 +198,6 @@ class TestRandomValue(unittest.TestCase):
if not "V100" in paddle.device.cuda.get_device_name(): if not "V100" in paddle.device.cuda.get_device_name():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>") print("Test Fixed Random number on GPU------>")
paddle.disable_static() paddle.disable_static()
......
...@@ -155,9 +155,6 @@ class TestRandomValue(unittest.TestCase): ...@@ -155,9 +155,6 @@ class TestRandomValue(unittest.TestCase):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>") print("Test Fixed Random number on GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
......
...@@ -573,37 +573,46 @@ class TestRandomValue(unittest.TestCase): ...@@ -573,37 +573,46 @@ class TestRandomValue(unittest.TestCase):
if not "V100" in paddle.device.cuda.get_device_name(): if not "V100" in paddle.device.cuda.get_device_name():
return return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
def _check_random_value(dtype, expect, expect_mean, expect_std):
x = paddle.rand([32, 3, 1024, 1024], dtype=dtype)
actual = x.numpy()
self.assertTrue(np.allclose(actual[2, 1, 512, 1000:1010], expect))
self.assertEqual(np.mean(actual), expect_mean)
self.assertEqual(np.std(actual), expect_std)
print("Test Fixed Random number on V100 GPU------>") print("Test Fixed Random number on V100 GPU------>")
paddle.disable_static() paddle.disable_static()
paddle.set_device('gpu') paddle.set_device('gpu')
paddle.seed(2021) paddle.seed(2021)
expect_mean = 0.50000454338820143895816272561205551028251647949218750
expect_std = 0.28867379167297479991560749112977646291255950927734375
expect = [ expect = [
0.55298901, 0.65184678, 0.49375412, 0.57943639, 0.16459608, 0.55298901, 0.65184678, 0.49375412, 0.57943639, 0.16459608,
0.67181056, 0.03021481, 0.0238559, 0.07742096, 0.55972187 0.67181056, 0.03021481, 0.0238559, 0.07742096, 0.55972187
] ]
expect_mean = 0.50000454338820143895816272561205551028251647949218750 out = paddle.rand([32, 3, 1024, 1024], dtype='float64').numpy()
expect_std = 0.28867379167297479991560749112977646291255950927734375 self.assertEqual(np.mean(out), expect_mean)
_check_random_value(core.VarDesc.VarType.FP64, expect, expect_mean, self.assertEqual(np.std(out), expect_std)
expect_std) self.assertTrue(np.allclose(out[2, 1, 512, 1000:1010], expect))
expect_mean = 0.50002604722976684570312500
expect_std = 0.2886914908885955810546875
expect = [ expect = [
0.45320973, 0.17582087, 0.725341, 0.30849215, 0.622257, 0.46352342, 0.45320973, 0.17582087, 0.725341, 0.30849215, 0.622257, 0.46352342,
0.97228295, 0.12771158, 0.286525, 0.9810645 0.97228295, 0.12771158, 0.286525, 0.9810645
] ]
expect_mean = 0.50002604722976684570312500 out = paddle.rand([32, 3, 1024, 1024], dtype='float32').numpy()
expect_std = 0.2886914908885955810546875 self.assertEqual(np.mean(out), expect_mean)
_check_random_value(core.VarDesc.VarType.FP32, expect, expect_mean, self.assertEqual(np.std(out), expect_std)
expect_std) self.assertTrue(np.allclose(out[2, 1, 512, 1000:1010], expect))
expect_mean = 25.11843109130859375
expect_std = 43.370647430419921875
expect = [
30.089634, 77.05225, 3.1201615, 68.34072, 59.266724, -25.33281,
12.973292, 27.41127, -17.412298, 27.931019
]
out = paddle.empty(
[16, 16, 16, 16], dtype='float32').uniform_(-50, 100).numpy()
self.assertEqual(np.mean(out), expect_mean)
self.assertEqual(np.std(out), expect_std)
self.assertTrue(np.allclose(out[10, 10, 10, 0:10], expect))
paddle.enable_static() paddle.enable_static()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from .spectral_norm_hook import spectral_norm from .spectral_norm_hook import spectral_norm
from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401 from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401
from .transform_parameters import parameters_to_vector, vector_to_parameters # noqa: F401 from .transform_parameters import parameters_to_vector, vector_to_parameters, _stride_column # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
'weight_norm', 'remove_weight_norm', 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters' 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters'
......
...@@ -36,6 +36,39 @@ def _inplace_reshape_dygraph(x, shape): ...@@ -36,6 +36,39 @@ def _inplace_reshape_dygraph(x, shape):
stop_gradient=True) stop_gradient=True)
@dygraph_only
def _stride_column(param):
"""
A tool function. Permute date of parameter as a 'columns' stride. Now, it only support 2-D parameter.
Args:
param(Tensor]): The param that will be strided according to 'columns'.
Examples:
.. code-block:: python
import paddle
paddle.seed(100)
linear = paddle.nn.Linear(2, 3)
print(linear.weight)
# [[-0.31485492, -1.02896988, 0.45741916],
# [-0.65525872, -1.04643178, 1.07262802]]
paddle.nn.utils.stride_column(linear.weight)
print(linear.weight)
# [[-0.31485492, 0.45741916, -1.04643178],
# [-1.02896988, -0.65525872, 1.07262802]]
"""
assert len(param.shape) == 2
shape = [param.shape[1], param.shape[0]]
with paddle.fluid.dygraph.no_grad():
reshape_var = paddle.reshape(param, shape)
transpose_var = paddle.transpose(reshape_var, [1, 0])
transpose_var._share_underline_tensor_to(param)
@dygraph_only @dygraph_only
def parameters_to_vector(parameters, name=None): def parameters_to_vector(parameters, name=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册