From 52119d62a7acb39f044bab8e1ea80d216c3cf647 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 25 Dec 2017 23:32:50 +0800 Subject: [PATCH] refine --- paddle/operators/dropout_op.cc | 4 ++-- paddle/operators/dropout_op.cu | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index 82e21408088..fe72aa56efe 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -47,8 +47,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f) .AddCustomChecker([](const float& drop_p) { - PADDLE_ENFORCE(drop_p > 0.0f && drop_p < 1.0f, - "'dropout_prob' must be between 0 and 1."); + PADDLE_ENFORCE(drop_p >= 0.0f && drop_p <= 1.0f, + "'dropout_prob' must be between 0.0 and 1.0."); }); AddAttr("is_test", "True if in test phase.").SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index c31d2195e95..12e8e989e31 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -30,16 +30,15 @@ struct MaskGenerator { __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) : dropout_prob(dropout_prob), seed(seed) {} - __host__ __device__ T operator()(const unsigned int n) const { + inline __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); rng.discard(n); if (dist(rng) < dropout_prob) { return static_cast(0); - } else { - return static_cast(1); } + return static_cast(1); } }; -- GitLab