diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index 82e21408088bf1957f068c3eef491145233a594d..fe72aa56efe050aff19ad80e482451a415cb45a0 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -47,8 +47,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f) .AddCustomChecker([](const float& drop_p) { - PADDLE_ENFORCE(drop_p > 0.0f && drop_p < 1.0f, - "'dropout_prob' must be between 0 and 1."); + PADDLE_ENFORCE(drop_p >= 0.0f && drop_p <= 1.0f, + "'dropout_prob' must be between 0.0 and 1.0."); }); AddAttr("is_test", "True if in test phase.").SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index c31d2195e95b116451b0f620f6582f65c0dae706..12e8e989e31919470b473ae3b5f3b99a64f25e9a 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -30,16 +30,15 @@ struct MaskGenerator { __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) : dropout_prob(dropout_prob), seed(seed) {} - __host__ __device__ T operator()(const unsigned int n) const { + inline __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); rng.discard(n); if (dist(rng) < dropout_prob) { return static_cast(0); - } else { - return static_cast(1); } + return static_cast(1); } };