提交 b6ca72c8 编写于 作者: M Megvii Engine Team

fix(dnn/dropout): fix the issue of inconsistent prob when calling dropout exec multiple times

GitOrigin-RevId: b29e7b223c5c60729fffa90707fa3bb88b03f148
上级 d510a945
...@@ -60,16 +60,16 @@ void DropoutForwardImpl::exec( ...@@ -60,16 +60,16 @@ void DropoutForwardImpl::exec(
check_exec(inp.layout, oup.layout, mask.layout, workspace.size); check_exec(inp.layout, oup.layout, mask.layout, workspace.size);
size_t length = inp.layout.total_nr_elems(); size_t length = inp.layout.total_nr_elems();
uint64_t seed = param().seed; uint64_t seed = param().seed;
float prob = param().drop_prob;
m_rng.ensure_seed(seed); m_rng.ensure_seed(seed);
#define cb(DType) \ #define cb(DType) \
if (inp.layout.dtype == DType()) { \ if (inp.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \ using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<T>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(forward<T>( \
inp.ptr<T>(), oup.ptr<T>(), mask.raw_ptr(), length, m_rng, \ inp.ptr<T>(), oup.ptr<T>(), mask.raw_ptr(), length, m_rng, prob)); \
param().drop_prob)); \ return; \
return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb #undef cb
...@@ -85,14 +85,14 @@ void DropoutBackwardImpl::exec( ...@@ -85,14 +85,14 @@ void DropoutBackwardImpl::exec(
#if !MGE_BUILD_WITHOUT_NAIVE_EXEC #if !MGE_BUILD_WITHOUT_NAIVE_EXEC
check_exec(doup.layout, mask.layout, dinp.layout, workspace.size); check_exec(doup.layout, mask.layout, dinp.layout, workspace.size);
size_t length = doup.layout.total_nr_elems(); size_t length = doup.layout.total_nr_elems();
float prob = param().drop_prob;
#define cb(DType) \ #define cb(DType) \
if (doup.layout.dtype == DType()) { \ if (doup.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \ using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<T>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(backward<T>( \
doup.ptr<T>(), dinp.ptr<T>(), mask.raw_ptr(), length, \ doup.ptr<T>(), dinp.ptr<T>(), mask.raw_ptr(), length, prob)); \
param().drop_prob)); \ return; \
return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb #undef cb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册