提交 b6ca72c8 编写于 作者: M Megvii Engine Team

fix(dnn/dropout): fix the issue of inconsistent prob when calling dropout exec multiple times

GitOrigin-RevId: b29e7b223c5c60729fffa90707fa3bb88b03f148
上级 d510a945
......@@ -60,16 +60,16 @@ void DropoutForwardImpl::exec(
check_exec(inp.layout, oup.layout, mask.layout, workspace.size);
size_t length = inp.layout.total_nr_elems();
uint64_t seed = param().seed;
float prob = param().drop_prob;
m_rng.ensure_seed(seed);
#define cb(DType) \
if (inp.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<T>( \
inp.ptr<T>(), oup.ptr<T>(), mask.raw_ptr(), length, m_rng, \
param().drop_prob)); \
return; \
#define cb(DType) \
if (inp.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<T>( \
inp.ptr<T>(), oup.ptr<T>(), mask.raw_ptr(), length, m_rng, prob)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
......@@ -85,14 +85,14 @@ void DropoutBackwardImpl::exec(
#if !MGE_BUILD_WITHOUT_NAIVE_EXEC
check_exec(doup.layout, mask.layout, dinp.layout, workspace.size);
size_t length = doup.layout.total_nr_elems();
float prob = param().drop_prob;
#define cb(DType) \
if (doup.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<T>( \
doup.ptr<T>(), dinp.ptr<T>(), mask.raw_ptr(), length, \
param().drop_prob)); \
return; \
#define cb(DType) \
if (doup.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<T>( \
doup.ptr<T>(), dinp.ptr<T>(), mask.raw_ptr(), length, prob)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册