提交 52119d62 编写于 作者: C chengduoZH

refine

上级 a1e1ae3f
...@@ -47,8 +47,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -47,8 +47,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("dropout_prob", "Probability of setting units to zero.") AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f) .SetDefault(.5f)
.AddCustomChecker([](const float& drop_p) { .AddCustomChecker([](const float& drop_p) {
PADDLE_ENFORCE(drop_p > 0.0f && drop_p < 1.0f, PADDLE_ENFORCE(drop_p >= 0.0f && drop_p <= 1.0f,
"'dropout_prob' must be between 0 and 1."); "'dropout_prob' must be between 0.0 and 1.0.");
}); });
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false); AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
......
...@@ -30,16 +30,15 @@ struct MaskGenerator { ...@@ -30,16 +30,15 @@ struct MaskGenerator {
__host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
: dropout_prob(dropout_prob), seed(seed) {} : dropout_prob(dropout_prob), seed(seed) {}
__host__ __device__ T operator()(const unsigned int n) const { inline __host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1); thrust::uniform_real_distribution<AttrType> dist(0, 1);
rng.discard(n); rng.discard(n);
if (dist(rng) < dropout_prob) { if (dist(rng) < dropout_prob) {
return static_cast<T>(0); return static_cast<T>(0);
} else {
return static_cast<T>(1);
} }
return static_cast<T>(1);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册