From 1999307015a5035d3d3f0c34bfcf5cb3d7bfac02 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 1 Jan 2022 20:19:15 +0800 Subject: [PATCH] feat(mgb/opr): add dropout kernel GitOrigin-RevId: d248bd20050c030196cf9febeefed8c6e7a5b06b --- dnn/include/megdnn/oprs/nn.h | 44 ++++++ dnn/scripts/opr_param_defs.py | 5 + dnn/src/common/dropout.cpp | 74 ++++++++++ dnn/src/common/handle_impl.h | 5 +- dnn/src/common/opr_trait.h | 2 + dnn/src/cuda/dropout/opr_impl.cpp | 118 ++++++++++++++++ dnn/src/cuda/dropout/opr_impl.h | 116 +++++++++++++++ dnn/src/cuda/handle_create.cpp | 1 + dnn/src/naive/dropout/opr_impl.cpp | 110 +++++++++++++++ dnn/src/naive/dropout/opr_impl.h | 49 +++++++ dnn/src/naive/handle.cpp | 1 + dnn/test/cuda/rng.cpp | 72 ++++++++++ dnn/test/naive/rng.cpp | 69 +++++++++ imperative/python/megengine/functional/nn.py | 12 +- imperative/python/src/ops.cpp | 10 +- .../test/unit/functional/test_functional.py | 49 +++++-- imperative/src/impl/ops/rng.cpp | 55 ++++++++ src/core/include/megbrain/ir/ops.td | 15 ++ src/opr/impl/rand.cpp | 132 ++++++++++++++++++ src/opr/impl/rand.sereg.h | 15 ++ src/opr/include/megbrain/opr/rand.h | 23 +++ src/opr/test/rand.cpp | 38 +++++ src/serialization/impl/schema.fbs | 1 + 23 files changed, 998 insertions(+), 18 deletions(-) create mode 100644 dnn/src/common/dropout.cpp create mode 100644 dnn/src/cuda/dropout/opr_impl.cpp create mode 100644 dnn/src/cuda/dropout/opr_impl.h create mode 100644 dnn/src/naive/dropout/opr_impl.cpp create mode 100644 dnn/src/naive/dropout/opr_impl.h diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index cfac433f..7188d941 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -2005,6 +2005,50 @@ protected: size_t workspace_in_bytes); }; +class DropoutBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(DropoutBase, OperatorBase); + DEF_OPR_PARAM(Dropout); +}; + +class DropoutForward : public DropoutBase { + DEF_OPR_IMPL(DropoutForward, DropoutBase, 1, 2); + +public: + void deduce_layout(const TensorLayout& inp, TensorLayout& oup, TensorLayout& mask); + virtual void exec( + _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& inp, const TensorLayout& oup, + const TensorLayout& mask) = 0; + virtual size_t get_mask_size_in_bytes(const TensorLayout& inp) = 0; + +protected: + void check_exec( + const TensorLayout& inp, const TensorLayout& oup, const TensorLayout& mask, + size_t workspace_in_bytes); +}; +using Dropout = DropoutForward; + +class DropoutBackward : public DropoutBase { + DEF_OPR_IMPL(DropoutBackward, DropoutBase, 2, 1); + +public: + void deduce_layout( + const TensorLayout& doup, const TensorLayout& mask, TensorLayout& dinp); + virtual void exec( + _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& doup, const TensorLayout& mask, + const TensorLayout& dinp) = 0; + +protected: + void check_exec( + const TensorLayout& doup, const TensorLayout& mask, + const TensorLayout& dinp, size_t workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 08828e8e..9da6bbe9 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1218,4 +1218,9 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), .add_fields('float32', 'eps', '1e-5f') .add_fields('uint64', 'normalized_dim', '1') .add_fields('uint64', 'normalized_size', '1') +) + +(pdef('Dropout') + .add_fields('float32', 'drop_prob', '0') + .add_fields('uint64', 'seed', '0') ) diff --git a/dnn/src/common/dropout.cpp b/dnn/src/common/dropout.cpp new file mode 100644 index 00000000..7327ca99 --- /dev/null +++ b/dnn/src/common/dropout.cpp @@ -0,0 +1,74 @@ +/** + * \file dnn/src/common/dropout.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void DropoutForward::deduce_layout( + const TensorLayout& inp, TensorLayout& oup, TensorLayout& mask) { + oup = inp; + size_t mask_size = get_mask_size_in_bytes(inp); + mask = TensorLayout(TensorShape({mask_size}), dtype::Byte()); +} + +void DropoutForward::check_exec( + const TensorLayout& inp, const TensorLayout& oup, const TensorLayout& mask, + size_t workspace_in_bytes) { + auto errmsg = [&]() { + return megdnn_layout_msg(inp) + ", " + megdnn_layout_msg(oup) + ", " + + megdnn_layout_msg(mask); + }; + MEGDNN_MARK_USED_VAR(errmsg); + + megdnn_assert_contiguous(inp); + megdnn_assert_contiguous(oup); + megdnn_assert_contiguous(mask); + megdnn_assert(inp.eq_layout(oup), "%s", errmsg().c_str()); + megdnn_assert(inp.dtype.category() == DTypeCategory::FLOAT); + + auto required_workspace_in_bytes = get_workspace_in_bytes(inp, oup, mask); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + auto required_mask_size_in_bytes = get_mask_size_in_bytes(inp); + megdnn_assert(mask.total_nr_elems() >= required_mask_size_in_bytes); + megdnn_assert(mask.dtype == dtype::Byte()); +} + +void DropoutBackward::deduce_layout( + const TensorLayout& doup, const TensorLayout&, TensorLayout& dinp) { + dinp = doup; +} + +void DropoutBackward::check_exec( + const TensorLayout& doup, const TensorLayout& mask, const TensorLayout& dinp, + size_t workspace_in_bytes) { + auto errmsg = [&]() { + return megdnn_layout_msg(doup) + ", " + megdnn_layout_msg(mask) + ", " + + megdnn_layout_msg(dinp); + }; + MEGDNN_MARK_USED_VAR(errmsg); + + megdnn_assert_contiguous(doup); + megdnn_assert_contiguous(mask); + megdnn_assert_contiguous(dinp); + megdnn_assert(doup.eq_layout(dinp), "%s", errmsg().c_str()); + + auto required_workspace_in_bytes = get_workspace_in_bytes(doup, mask, dinp); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + megdnn_assert(doup.dtype.category() == DTypeCategory::FLOAT); + megdnn_assert(mask.dtype == dtype::Byte()); + megdnn_assert(mask.ndim == 1); +} + +} // namespace megdnn diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 2b5206f9..ff030f25 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -211,8 +211,9 @@ private: cb(PaddingForward) \ cb(PaddingBackward) \ cb(LayerNormForward) \ - cb(LayerNormBackward) - + cb(LayerNormBackward) \ + cb(DropoutForward) \ + cb(DropoutBackward) // clang-format on /*! diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 851b5d8e..8b6145a4 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -137,6 +137,8 @@ DEF(LSQBackward, 7, true, false); DEF(Fill, 1, true, false); DEF(LayerNormForward, 6, true, true); DEF(LayerNormBackward, 8, true, true); +DEF(DropoutForward, 3, true, true); +DEF(DropoutBackward, 3, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/dropout/opr_impl.cpp b/dnn/src/cuda/dropout/opr_impl.cpp new file mode 100644 index 00000000..d7349114 --- /dev/null +++ b/dnn/src/cuda/dropout/opr_impl.cpp @@ -0,0 +1,118 @@ +/** + * \file dnn/src/cuda/dropout/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/dropout/opr_impl.h" + +namespace megdnn { +namespace cuda { + +using Param = megdnn::Dropout::Param; + +struct DropoutTensorDesc : public TensorDesc { +public: + DropoutTensorDesc(const TensorLayout& layout) : TensorDesc() { + set_dropout_desc(layout); + } + void set_dropout_desc(const TensorLayout& layout) { + cudnnDataType_t cudnn_dtype; + switch (layout.dtype.enumv()) { + case DTypeEnum::Float32: + cudnn_dtype = CUDNN_DATA_FLOAT; + break; + case DTypeEnum::Float16: + cudnn_dtype = CUDNN_DATA_HALF; + break; + default: + megdnn_throw("dtype must be float16/float32"); + } + cudnn_check(cudnnSetTensor4dDescriptor( + desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, 1, 1, + layout.total_nr_elems())); + } +}; + +size_t DropoutForwardImpl::get_mask_size_in_bytes(const TensorLayout& inp) { + size_t reserve_space_size_in_bytes = 0; + DropoutTensorDesc ddesc(inp); + cudnn_check( + cudnnDropoutGetReserveSpaceSize(ddesc.desc, &reserve_space_size_in_bytes)); + return reserve_space_size_in_bytes; +} + +void DropoutForwardImpl::exec( + _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask, + _megdnn_workspace workspace) { + check_exec(inp.layout, oup.layout, mask.layout, workspace.size); + uint64_t seed = param().seed; + float drop_prob = param().drop_prob; + + if (!dropout_status.initialized()) { + dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob); + } + if (dropout_status.drop_prob != drop_prob) { + dropout_status.drop_prob = drop_prob; + dropout_status.restore_desc(cudnn_handle(this->handle())); + } + megdnn_assert(dropout_status.seed == seed); + + DropoutTensorDesc inp_desc(inp.layout), oup_desc(oup.layout); + auto&& op_desc = dropout_status.desc; + + cudnn_check(cudnnDropoutForward( + cudnn_handle(this->handle()), op_desc.desc, inp_desc.desc, inp.raw_ptr(), + oup_desc.desc, oup.raw_ptr(), mask.raw_ptr(), + mask.layout.total_nr_elems())); +} + +void DropoutBackwardImpl::exec( + _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp, + _megdnn_workspace workspace) { + check_exec(doup.layout, mask.layout, dinp.layout, workspace.size); + +#if CUDNN_VERSION >= 7000 + size_t status_size_in_bytes = 0; + cudnn_check(cudnnDropoutGetStatesSize( + cudnn_handle(this->handle()), &status_size_in_bytes)); + + DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout); + op_desc.restore( + cudnn_handle(this->handle()), param().drop_prob, nullptr, + status_size_in_bytes, 0); + cudnn_check(cudnnDropoutBackward( + cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(), + dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(), + mask.layout.total_nr_elems())); +#else + uint64_t seed = param().seed; + float drop_prob = param().drop_prob; + + if (!dropout_status.initialized()) { + dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob); + } + if (dropout_status.drop_prob != drop_prob) { + dropout_status.drop_prob = drop_prob; + dropout_status.restore_desc(cudnn_handle(this->handle())); + } + + auto&& op_desc = dropout_status.desc; + DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout); + + cudnn_check(cudnnDropoutBackward( + cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(), + dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(), + mask.layout.total_nr_elems())); +#endif +} + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/dropout/opr_impl.h b/dnn/src/cuda/dropout/opr_impl.h new file mode 100644 index 00000000..4db5df47 --- /dev/null +++ b/dnn/src/cuda/dropout/opr_impl.h @@ -0,0 +1,116 @@ +/** + * \file dnn/src/cuda/dropout/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +class DropoutDesc { +public: + DropoutDesc() { cudnn_check(cudnnCreateDropoutDescriptor(&desc)); } + ~DropoutDesc() { cudnn_check(cudnnDestroyDropoutDescriptor(desc)); } + void set( + cudnnHandle_t handle, void* status, size_t states_size_in_bytes, + uint64_t seed, float drop_prob) { + cudnn_check(cudnnSetDropoutDescriptor( + desc, handle, drop_prob, status, states_size_in_bytes, seed)); + } + void restore( + cudnnHandle_t handle, float drop_prob, void* status, + size_t states_size_in_bytes, uint64_t seed) { +#if CUDNN_VERSION >= 7000 + cudnn_check(cudnnRestoreDropoutDescriptor( + desc, handle, drop_prob, status, states_size_in_bytes, 0)); +#else + // cudnnDropoutRestore is not support when cudnn version < 7000 + // so we set the dropoutDesc rather than restore + cudnn_check(cudnnSetDropoutDescriptor( + desc, handle, drop_prob, status, states_size_in_bytes, seed)); +#endif + } + cudnnDropoutDescriptor_t desc; +}; + +class DropoutStatus { + void* status; + uint64_t status_size; + uint64_t seed; + float drop_prob; + DropoutDesc desc; + +public: + DropoutStatus() { + status = nullptr; + status_size = 0; + } + ~DropoutStatus() { + if (status != nullptr) + cuda_check(cudaFree(status)); + } + void set(cudnnHandle_t handle, uint64_t seed, float drop_prob) { + this->seed = seed; + this->drop_prob = drop_prob; + cudnn_check(cudnnDropoutGetStatesSize(handle, &status_size)); + cuda_check(cudaMalloc(&status, status_size)); + desc.set(handle, status, status_size, seed, drop_prob); + } + void restore_desc(cudnnHandle_t handle) { + desc.restore(handle, drop_prob, status, status_size, seed); + } + bool initialized() { return status != nullptr; } + friend class DropoutForwardImpl; + friend class DropoutBackwardImpl; +}; + +// similar to RNG operator, dropout operator also have status +class DropoutForwardImpl final : public DropoutForward { + DropoutStatus dropout_status; + +public: + using DropoutForward::DropoutForward; + void exec( + _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask, + _megdnn_workspace workspace) override; + size_t get_mask_size_in_bytes(const TensorLayout& inp) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +class DropoutBackwardImpl final : public DropoutBackward { +#if CUDNN_VERSION >= 7000 + DropoutDesc op_desc; +#else + // cudnnDropoutRestore is not support when cudnn version < 7000 + // so we need save the dropout status and set the dropoutDesc + // rather than restore + DropoutStatus dropout_status; +#endif + +public: + using DropoutBackward::DropoutBackward; + void exec( + _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 6738740a..df3c5ccb 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -34,6 +34,7 @@ #include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" #include "src/cuda/dot/opr_impl.h" +#include "src/cuda/dropout/opr_impl.h" #include "src/cuda/elemwise/opr_impl.h" #include "src/cuda/elemwise_multi_type/opr_impl.h" #include "src/cuda/eye/opr_impl.h" diff --git a/dnn/src/naive/dropout/opr_impl.cpp b/dnn/src/naive/dropout/opr_impl.cpp new file mode 100644 index 00000000..64b359d4 --- /dev/null +++ b/dnn/src/naive/dropout/opr_impl.cpp @@ -0,0 +1,110 @@ +/** + * \file dnn/src/naive/dropout/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/naive/dropout/opr_impl.h" +#include +#include +#include +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; +using namespace std; +namespace { + +using Param = megdnn::Dropout::Param; + +dt_float32 get_random_number(uint64_t x) { + union { + uint32_t i; + dt_float32 f; + } u; + u.i = (0x7F << 23) | (x >> 41); + return 2 - u.f; +} + +template +void forward( + T* inp, T* oup, void* raw_reserved, size_t len, Xoroshiro128plus& rng, + float drop_prob) { + uint8_t* reserved = reinterpret_cast(raw_reserved); + float scale = 1.0f / (1.0f - drop_prob); + for (size_t i = 0; i < len; ++i) { + float rn = get_random_number(rng()); + reserved[i] = rn < drop_prob ? 0 : 1; + oup[i] = static_cast(reserved[i] ? static_cast(inp[i]) * scale : 0.f); + } +} + +template +void backward(T* doup, T* dinp, void* raw_reserved, size_t len, float drop_prob) { + uint8_t* reserved = reinterpret_cast(raw_reserved); + float scale = 1.0f / (1.0f - drop_prob); + for (size_t i = 0; i < len; ++i) { + dinp[i] = + static_cast(reserved[i] ? static_cast(doup[i]) * scale : 0.f); + } +} + +} // namespace + +namespace megdnn { +namespace naive { + +size_t DropoutForwardImpl::get_mask_size_in_bytes(const TensorLayout& inp) { + return inp.total_nr_elems(); +} + +void DropoutForwardImpl::exec( + _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask, + _megdnn_workspace workspace) { + check_exec(inp.layout, oup.layout, mask.layout, workspace.size); + size_t length = inp.layout.total_nr_elems(); + uint64_t seed = param().seed; + + m_rng.ensure_seed(seed); + +#define cb(DType) \ + if (inp.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(forward( \ + inp.ptr(), oup.ptr(), mask.raw_ptr(), length, m_rng, \ + param().drop_prob)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +void DropoutBackwardImpl::exec( + _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp, + _megdnn_workspace workspace) { + check_exec(doup.layout, mask.layout, dinp.layout, workspace.size); + size_t length = doup.layout.total_nr_elems(); + +#define cb(DType) \ + if (doup.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(backward( \ + doup.ptr(), dinp.ptr(), mask.raw_ptr(), length, \ + param().drop_prob)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/dropout/opr_impl.h b/dnn/src/naive/dropout/opr_impl.h new file mode 100644 index 00000000..f40de4e2 --- /dev/null +++ b/dnn/src/naive/dropout/opr_impl.h @@ -0,0 +1,49 @@ +/** + * \file dnn/src/naive/dropout/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/naive/rng/opr_impl.h" + +namespace megdnn { +namespace naive { + +class DropoutForwardImpl final : public DropoutForward { + Xoroshiro128plus m_rng; + +public: + using DropoutForward::DropoutForward; + void exec( + _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask, + _megdnn_workspace workspace) override; + size_t get_mask_size_in_bytes(const TensorLayout& inp) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +class DropoutBackwardImpl final : public DropoutBackward { +public: + using DropoutBackward::DropoutBackward; + void exec( + _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 1ff6675f..2a705335 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -36,6 +36,7 @@ #include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" #include "src/naive/dot/opr_impl.h" +#include "src/naive/dropout/opr_impl.h" #include "src/naive/elemwise/opr_impl.h" #include "src/naive/elemwise_multi_type/opr_impl.h" #include "src/naive/eye/opr_impl.h" diff --git a/dnn/test/cuda/rng.cpp b/dnn/test/cuda/rng.cpp index 0a5a549d..3b0b705d 100644 --- a/dnn/test/cuda/rng.cpp +++ b/dnn/test/cuda/rng.cpp @@ -193,6 +193,70 @@ void run_shuffle(Handle* handle, bool bwd_flag) { run({6, 3}); } +template +void run_dropout(Handle* handle) { + using ctype = typename DTypeTrait::ctype; + auto run = [&](TensorShape shape, float drop_prob) { + auto fwd = handle->create_operator(); + auto bwd = handle->create_operator(); + fwd->param().drop_prob = drop_prob; + bwd->param().drop_prob = drop_prob; + double scale = 1.0 / (1.0 - drop_prob); + + TensorLayout inp_lay{shape, T()}; + TensorLayout oup_lay{shape, T()}; + TensorLayout mask_lay{{fwd->get_mask_size_in_bytes(inp_lay)}, dtype::Byte()}; + TensorLayout doup_lay{shape, T()}; + TensorLayout dinp_lay{shape, T()}; + TensorLayout fwd_ws_lay{ + {fwd->get_workspace_in_bytes(inp_lay, oup_lay, mask_lay)}, + dtype::Byte()}; + TensorLayout bwd_ws_lay{ + {bwd->get_workspace_in_bytes(doup_lay, mask_lay, dinp_lay)}, + dtype::Byte()}; + + SyncedTensor inp(handle, inp_lay); + SyncedTensor oup(handle, oup_lay); + SyncedTensor::ctype> mask(handle, mask_lay); + SyncedTensor doup(handle, doup_lay); + SyncedTensor dinp(handle, dinp_lay); + SyncedTensor::ctype> fwd_ws(handle, fwd_ws_lay); + SyncedTensor::ctype> bwd_ws(handle, bwd_ws_lay); + + for (size_t i = 0; i < inp.layout().total_nr_elems(); ++i) { + inp.ptr_mutable_host()[i] = 1; + doup.ptr_mutable_host()[i] = 1; + } + + fwd->exec( + inp.tensornd_dev(), oup.tensornd_dev(), mask.tensornd_dev(), + {fwd_ws.ptr_mutable_dev(), fwd_ws.layout().total_nr_elems()}); + size_t droped_cnt = 0; + for (size_t i = 0; i < inp.layout().total_nr_elems(); ++i) { + ASSERT_TRUE( + oup.ptr_host()[i] == 0 || + oup.ptr_host()[i] == static_cast(scale)); + if (oup.ptr_host()[i] == 0) { + droped_cnt++; + } + } + float real_drop = droped_cnt * 1.0 / inp.layout().total_nr_elems(); + ASSERT_LT(abs(drop_prob - real_drop), 1e-2); + +#if CUDNN_VERSION >= 7000 + bwd->exec( + doup.tensornd_dev(), mask.tensornd_dev(), dinp.tensornd_dev(), + {bwd_ws.ptr_mutable_dev(), bwd_ws.layout().total_nr_elems()}); + for (size_t i = 0; i < inp.layout().total_nr_elems(); ++i) { + ASSERT_TRUE(oup.ptr_host()[i] == dinp.ptr_host()[i]); + } +#endif + }; + + run({32, 32, 32, 32}, 0.2); + run({100000}, 0.3); +} + } // anonymous namespace TEST_F(CUDA, UNIFORM_RNG_F32) { @@ -290,6 +354,14 @@ TEST_F(CUDA, SHUFFLE_RNG_BWD_F16) { run_shuffle(handle_cuda(), true); } +TEST_F(CUDA, DROPOUT_F32) { + run_dropout(handle_cuda()); +} + +TEST_F(CUDA, DROPOUT_F16) { + run_dropout(handle_cuda()); +} + } // namespace test } // namespace megdnn diff --git a/dnn/test/naive/rng.cpp b/dnn/test/naive/rng.cpp index b5a827ad..40f1c520 100644 --- a/dnn/test/naive/rng.cpp +++ b/dnn/test/naive/rng.cpp @@ -231,6 +231,67 @@ void run_shuffle(Handle* handle, bool bwd_flag) { run({10}); run({6, 3}); } + +template +void run_dropout(Handle* handle) { + using ctype = typename DTypeTrait::ctype; + auto run = [&](TensorShape shape, float drop_prob) { + auto fwd = handle->create_operator(); + auto bwd = handle->create_operator(); + fwd->param().drop_prob = drop_prob; + bwd->param().drop_prob = drop_prob; + double scale = 1.0 / (1.0 - drop_prob); + + TensorLayout inp_lay{shape, T()}; + TensorLayout oup_lay{shape, T()}; + TensorLayout mask_lay{{fwd->get_mask_size_in_bytes(inp_lay)}, dtype::Byte()}; + TensorLayout doup_lay{shape, T()}; + TensorLayout dinp_lay{shape, T()}; + TensorLayout fwd_ws_lay{ + {fwd->get_workspace_in_bytes(inp_lay, oup_lay, mask_lay)}, + dtype::Byte()}; + TensorLayout bwd_ws_lay{ + {bwd->get_workspace_in_bytes(doup_lay, mask_lay, dinp_lay)}, + dtype::Byte()}; + + Tensor inp(handle, inp_lay); + Tensor oup(handle, oup_lay); + Tensor::ctype> mask(handle, mask_lay); + Tensor doup(handle, doup_lay); + Tensor dinp(handle, dinp_lay); + Tensor::ctype> fwd_ws(handle, fwd_ws_lay); + Tensor::ctype> bwd_ws(handle, bwd_ws_lay); + + for (size_t i = 0; i < inp.layout().total_nr_elems(); ++i) { + inp.ptr()[i] = 1; + doup.ptr()[i] = 1; + } + + fwd->exec( + inp.tensornd(), oup.tensornd(), mask.tensornd(), + {fwd_ws.ptr(), fwd_ws.layout().total_nr_elems()}); + size_t droped_cnt = 0; + for (size_t i = 0; i < inp.layout().total_nr_elems(); ++i) { + ASSERT_TRUE(oup.ptr()[i] == 0 || oup.ptr()[i] == static_cast(scale)); + if (oup.ptr()[i] == 0) { + droped_cnt++; + } + } + float real_drop = droped_cnt * 1.0 / inp.layout().total_nr_elems(); + ASSERT_LT(abs(drop_prob - real_drop), 1e-2); + + bwd->exec( + doup.tensornd(), mask.tensornd(), dinp.tensornd(), + {bwd_ws.ptr(), bwd_ws.layout().total_nr_elems()}); + for (size_t i = 0; i < inp.layout().total_nr_elems(); ++i) { + ASSERT_TRUE(oup.ptr()[i] == dinp.ptr()[i]); + } + }; + + run({32, 32, 32, 32}, 0.2); + run({100000}, 0.3); +} + } // namespace TEST_F(NAIVE, UNIFORM_RNG_F32) { @@ -309,6 +370,14 @@ TEST_F(NAIVE, SHUFFLE_RNG_BWD_F16) { run_shuffle(handle(), true); } +TEST_F(NAIVE, DROPOUT_F32) { + run_dropout(handle()); +} + +TEST_F(NAIVE, DROPOUT_F16) { + run_dropout(handle()); +} + } // namespace test } // namespace megdnn diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 9dfd4298..b0e4f52e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -13,10 +13,12 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union from ..core import _config from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder +from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core.ops import builtin from ..core.ops.builtin import ( BatchNorm, Dimshuffle, + Dropout, Elemwise, GetVarShape, Identity, @@ -39,7 +41,6 @@ from ..core.tensor.utils import ( from ..device import get_default_device from ..distributed import WORLD, is_distributed from ..jit import exclude_from_trace -from ..random import uniform from ..tensor import Tensor from ..utils.deprecation import deprecated_func from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero @@ -1503,12 +1504,9 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: return inp # model in training mode, e.g. model.train() - rv = uniform(size=inp.shape) - mask = rv > drop_prob - ret = inp * mask.astype(inp.dtype) - ret *= 1 / (1 - drop_prob) - - return ret + op = Dropout(drop_prob=drop_prob, seed=_get_global_rng_seed(), handle=0) + outputs = apply(op, inp) + return outputs[0] def one_hot(inp: Tensor, num_classes: int) -> Tensor: diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 515f8ead..be184cff 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -567,7 +567,15 @@ void init_ops(py::module m) { rng::delete_handle(handle); }, py::call_guard()); - m.def("set_global_rng_seed", &rng::set_global_rng_seed); + m.def("set_global_rng_seed", [](uint64_t seed) -> void { + mgb_assert( + python::interpreter_for_py->check_available(), + "set global random seed failed since imperative interpreter has been " + "destroyed"); + python::interpreter_for_py->sync(); + mgb::CompNode::sync_all(); + rng::set_global_rng_seed(seed); + }); m.def("get_global_rng_seed", &rng::get_global_rng_seed); m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 39b3eeba..3176c057 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -59,14 +59,47 @@ def test_where(): def test_dropout(): - # test training mode - data = tensor(np.ones(10000000, dtype=np.float32)) - out = F.nn.dropout(data, 1.0 / 3.0, training=True) - assert not out.numpy().all() - - # test eval mode - out = F.nn.dropout(data, 1.0 / 3.0, training=False) - assert out.numpy().all() + from megengine.autodiff import GradManager + from megengine.core._imperative_rt.ops import set_global_rng_seed + + def test_dropout_with_shape(shape, rate): + data = tensor(np.ones(shape, dtype=np.float32)) + gm = GradManager().attach([data]) + with gm: + out = F.nn.dropout(data, rate, training=True) + gm.backward(out, tensor(np.ones(shape, dtype=np.float32))) + assert not out.numpy().all() + np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7) + + def test_multiple_dropout(shape, rate): + data = tensor(np.ones(shape, dtype=np.float32)) + gm = GradManager().attach([data]) + with gm: + out1 = F.nn.dropout(data, rate, training=True) + out2 = F.nn.dropout(out1, rate, training=True) + out3 = F.nn.dropout(out2, rate, training=True) + gm.backward(out3, tensor(np.ones(shape, dtype=np.float32))) + np.testing.assert_allclose(out3.numpy(), data.grad.numpy(), 1e-7, 1e-7) + + def test_dropout_seed(shape, rate): + data = tensor(np.random.randn(*shape), dtype="float32") + set_global_rng_seed(111) + out1 = F.nn.dropout(data, rate, training=True) + out2 = F.nn.dropout(data, rate, training=True) + assert not (out1.numpy() == out2.numpy()).all() + + set_global_rng_seed(111) + out3 = F.nn.dropout(data, rate, training=True) + assert (out1.numpy() == out3.numpy()).all() + + set_global_rng_seed(222) + out4 = F.nn.dropout(data, rate, training=True) + assert not (out1.numpy() == out4.numpy()).all() + + test_dropout_with_shape([13, 17, 63, 21], 0.4) + test_dropout_with_shape([16, 32, 64], 0.3) + test_multiple_dropout([1024], 0.2) + test_dropout_seed([16, 32], 0.2) def test_matinv(): diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 232629f8..311a780a 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -282,6 +282,21 @@ struct OpMeth { } }; +template <> +struct OpMeth { + using DnnOp = megdnn::Dropout; + using Param = DnnOp::Param; + using OpNode = mgb::opr::Dropout; + static Param make_param(const Dropout& opdef) { + auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle); + mgb_assert( + handle_seed == opdef.seed, + "inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed, + opdef.seed); + return {opdef.drop_prob, handle_seed}; + } +}; + template struct _InferLayout; @@ -482,6 +497,26 @@ SmallVector infer_output_attrs( return dests; } +template <> +SmallVector infer_output_attrs( + const OpDef& op, const SmallVector& inputs) { + SmallVector dests(2); + auto&& cn = inputs[0]->comp_node(); + + dests[0].comp_node = cn; + dests[0].layout = TensorLayout(inputs[0]->layout()); + dests[0].layout.dtype = inputs[0]->layout().dtype; + + auto get_mask_size = [&]() -> size_t { + auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); + return dnn_handle->create_operator()->get_mask_size_in_bytes( + inputs[0]->layout()); + }; + dests[1].comp_node = cn; + dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); + return dests; +} + template std::tuple, SmallVector> infer_output_mem_desc( const OpDef& def, const SmallVector& inputs_tensors, @@ -559,6 +594,25 @@ std::tuple, bool> infer_output_attrs_fallible< return {dests, true}; } +template <> +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& op, const SmallVector& inputs) { + SmallVector dests(2); + auto cn = inputs[0].comp_node; + dests[0].comp_node = cn; + dests[0].layout = TensorLayout(inputs[0].layout); + dests[0].layout.dtype = inputs[0].layout.dtype; + + auto get_mask_size = [&]() -> size_t { + auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); + return dnn_handle->create_operator()->get_mask_size_in_bytes( + inputs[0].layout); + }; + dests[1].comp_node = cn; + dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); + return {dests, true}; +} + } // anonymous namespace Handle new_handle(CompNode comp_node, uint64_t seed) { @@ -599,6 +653,7 @@ REG_RNG_OP(PermutationRNG, SymbolVar) REG_RNG_OP(PoissonRNG, SymbolVar) REG_RNG_OP(BetaRNG, SymbolVar) REG_RNG_OP(ShuffleRNG, SymbolVarArray) +REG_RNG_OP(Dropout, SymbolVarArray) #undef REG_RNG_OP } // namespace mgb::imperative::rng diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index a300ccf8..30c795ad 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -433,4 +433,19 @@ def LRN: MgbHashableOp<"LRN", [LRNParam]>; def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; +def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash_pair_combine( + mgb::hash($_self.drop_prob), + mgb::hash($_self.handle)) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}]; + +} #endif // MGB_OPS diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 6be879f1..043a15e5 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -201,6 +201,8 @@ template class RNGOprBase<::megdnn::BetaRNG>; template class RNGOprBase<::megdnn::PoissonRNG>; template class RNGOprBase<::megdnn::ShuffleRNGForward>; template class RNGOprBase<::megdnn::ShuffleRNGBackward>; +template class RNGOprBase<::megdnn::DropoutForward>; +template class RNGOprBase<::megdnn::DropoutBackward>; #if MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); @@ -300,4 +302,134 @@ MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward); MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true) +/* ================= DropoutForward ================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(DropoutForward); + +DropoutForward::DropoutForward( + VarNode* inp, const Param& param, const OperatorNodeConfig& config) + : Super({inp->owner_graph(), config, "dropout", {inp}}, param) { + add_input({inp}); + add_output(None)->dtype(inp->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + cg::add_workspace_output(this); + add_equivalence_component>(this); +} + +SymbolVarArray DropoutForward::make( + SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { + auto node = inp.node()->owner_graph()->insert_opr( + std::make_unique(inp.node(), param, config)); + mgb_assert(node->output().size() == 3); + return {node->output(0), node->output(1)}; +} + +void DropoutForward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + + auto infer_mask = [this](TensorShape& dest, const InpVal& iv) { + ensure_megdnn_opr(); + dest.ndim = 1; + dest.shape[0] = m_dnn_opr->get_mask_size_in_bytes( + {iv.val[0].shape(), input(0)->dtype()}); + return true; + }; + mgr.register_shape_infer( + output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask}); + + auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { + ensure_megdnn_opr(); + dest.ndim = 1; + dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( + {inp.val[0].shape(), input(0)->dtype()}, + {output(0)->shape(), output(0)->dtype()}, + {output(1)->shape(), output(1)->dtype()}); + return true; + }; + mgr.register_shape_infer( + output(2), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk}); +} + +void DropoutForward::add_input_layout_constraint() { + input(0)->add_layout_constraint_contiguous(); +}; + +void DropoutForward::scn_do_execute() { + auto&& ret = output(0); + if (ret->layout().is_empty()) { + mgb_assert(ret->dev_tensor().empty()); + return; + } + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(2))); +} + +cg::OperatorNodeBase::NodeProp* DropoutForward::do_make_node_prop() const { + auto prop = Super::do_make_node_prop(); + prop->add_flag(NodeProp::Flag::IMPURE_FUNC); + for (auto i : input()) { + prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); + } + return prop; +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(DropoutForward) { + SymbolVar grad = DropoutBackward::make(out_grad[0], opr.output(1), opr.param()); + VarNodeArray ret; + ret.push_back(grad.node()); + return ret; +} +#endif + +/* ==================== LayerNormBackward ==================== */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(DropoutBackward); + +DropoutBackward::DropoutBackward( + VarNode* doup, VarNode* mask, const Param& param, + const OperatorNodeConfig& config) + : Super({doup->owner_graph(), config, "dropout_backward", {doup, mask}}, 0, + true) { + init_megdnn_opr(*this, param); + add_input({doup, mask}); +} + +SymbolVar DropoutBackward::make( + SymbolVar doup, SymbolVar mask, const Param& param, + const OperatorNodeConfig& config) { + return doup.insert_single_output_opr( + doup.node(), mask.node(), param, config); +} + +void DropoutBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + this->init_output_static_infer_desc_workspace(false); +} + +void DropoutBackward::init_output_dtype() { + output(0)->dtype(input(0)->dtype()); +} + +size_t DropoutBackward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return megdnn_opr()->get_workspace_in_bytes( + {input_shapes[0], input(0)->dtype(), input(0)->format()}, + {input_shapes[1], input(1)->dtype(), input(1)->format()}, + {output_shapes[0], output(0)->dtype(), output(0)->format()}); +} + +void DropoutBackward::scn_do_execute() { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), {}); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index fe3bd8b1..869fb72c 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -29,6 +29,19 @@ struct OprMaker { return out[0].node()->owner_opr(); } }; + +// OprMaker in MGB_SEREG_OPR only support unique output opr +template <> +struct OprMaker { + using Param = opr::DropoutForward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::DropoutForward::make(i[0], param, config)[0].node()->owner_opr(); + } +}; + } // namespace serialization namespace opr { @@ -43,6 +56,8 @@ MGB_SEREG_OPR(PermutationRNG, 1); MGB_SEREG_OPR(BetaRNG, 2); MGB_SEREG_OPR(ShuffleRNG, 1); MGB_SEREG_OPR(ShuffleRNGBackward, 3); +MGB_SEREG_OPR(Dropout, 1); +MGB_SEREG_OPR(DropoutBackward, 2); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h index 257d11d6..e7199ccf 100644 --- a/src/opr/include/megbrain/opr/rand.h +++ b/src/opr/include/megbrain/opr/rand.h @@ -87,6 +87,7 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) #undef _OUTPUTS #define _OUTPUTS SymbolVarArray _DEFINE_RNG_OPR_WITH_INPUT_CLASS(ShuffleRNGForward) +_DEFINE_RNG_OPR_WITH_INPUT_CLASS(DropoutForward) #undef _OUTPUTS #undef _INPUTS @@ -108,6 +109,8 @@ using PermutationRNG = intl::PermutationRNG; using PoissonRNG = intl::PoissonRNG; using BetaRNG = intl::BetaRNG; using ShuffleRNG = intl::ShuffleRNGForward; +using Dropout = intl::DropoutForward; +using DropoutForward = intl::DropoutForward; MGB_DEFINE_OPR_CLASS_WITH_EXPORT( ShuffleRNGBackward, intl::MegDNNOprWrapperBwd) // { @@ -121,6 +124,26 @@ public: const Param& param = {}, const OperatorNodeConfig& config = {}); }; +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + DropoutBackward, intl::MegDNNOprWrapperBwd) // { +public: + MGE_WIN_DECLSPEC_FUC DropoutBackward( + VarNode* doup, VarNode* mask, const Param& param, + const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar doup, SymbolVar mask, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + } // namespace opr } // namespace mgb diff --git a/src/opr/test/rand.cpp b/src/opr/test/rand.cpp index b584629a..a7171a19 100644 --- a/src/opr/test/rand.cpp +++ b/src/opr/test/rand.cpp @@ -446,4 +446,42 @@ TEST(TestOprRand, PermutationReprod) { }); } +TEST(TestOprRand, Dropout) { + auto run = [&](TensorShape shape, uint64_t seed, float drop_prob) { + using Param = megdnn::DropoutBase::Param; + Param param(drop_prob, seed); + float scale = 1.0 / (1.0 - drop_prob); + + std::shared_ptr inp_host( + new HostTensorND{CompNode::load("xpux"), shape, dtype::Float32()}); + for (size_t i = 0; i < shape.total_nr_elems(); ++i) { + inp_host->ptr()[i] = 1.0f; + } + + auto graph = ComputingGraph::make(); + auto inp_sym = opr::Host2DeviceCopy::make(*graph, inp_host); + auto outs = opr::DropoutForward::make(inp_sym, param); + + HostTensorND oup_host, mask_host, ws_host; + auto func = graph->compile( + {make_callback_copy(outs[0], oup_host), + make_callback_copy(outs[1], mask_host)}); + func->execute(); + + size_t droped_cnt = 0; + for (size_t i = 0; i < shape.total_nr_elems(); ++i) { + ASSERT_TRUE( + oup_host.ptr()[i] == 0 || + oup_host.ptr()[i] == scale); + if (oup_host.ptr()[i] == 0) { + droped_cnt++; + } + } + float real_drop = droped_cnt * 1.0 / shape.total_nr_elems(); + ASSERT_LT(abs(drop_prob - real_drop), 1e-2); + }; + run({100000}, 0, 0.2); + run({64, 32, 16, 16}, 1, 0.4); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 86894891..7b3aa4ed 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -117,6 +117,7 @@ union OperatorParam { param.ShuffleRNG = 83, param.CheckNonFinite = 84, param.LayerNorm = 85, + param.Dropout = 86, } table Operator { -- GitLab