From edb32495c63046652a2b62634eeb3587224d9a6e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 18 Sep 2020 16:01:49 +0800 Subject: [PATCH] feat(dnn/opr): add megdnn adaptive pooling opr GitOrigin-RevId: 563ce65479c90cb5686761889d9b8459c9afcefa --- dnn/include/megdnn/oprs/nn.h | 47 ++++++++++ dnn/src/common/adaptive_pooling.cpp | 37 ++++++++ dnn/src/common/handle_impl.h | 2 + dnn/src/cuda/adaptive_pooling/opr_impl.cpp | 53 +++++++++++ dnn/src/cuda/adaptive_pooling/opr_impl.h | 44 ++++++++++ dnn/src/cuda/handle_create.cpp | 1 + dnn/src/naive/adaptive_pooling/opr_impl.cpp | 52 +++++++++++ dnn/src/naive/adaptive_pooling/opr_impl.h | 43 +++++++++ dnn/src/naive/handle.cpp | 1 + dnn/test/common/adaptive_pooling.h | 55 ++++++++++++ dnn/test/common/opr_trait.h | 2 + dnn/test/cuda/adaptive_pooling.cpp | 97 +++++++++++++++++++++ 12 files changed, 434 insertions(+) create mode 100644 dnn/src/common/adaptive_pooling.cpp create mode 100644 dnn/src/cuda/adaptive_pooling/opr_impl.cpp create mode 100644 dnn/src/cuda/adaptive_pooling/opr_impl.h create mode 100644 dnn/src/naive/adaptive_pooling/opr_impl.cpp create mode 100644 dnn/src/naive/adaptive_pooling/opr_impl.h create mode 100644 dnn/test/common/adaptive_pooling.h create mode 100644 dnn/test/cuda/adaptive_pooling.cpp diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 411429484..3b1a5caaa 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -682,6 +682,53 @@ protected: size_t workspace_in_bytes); }; +/** + * \brief base class for AdaptivePooling + */ +class AdaptivePoolingBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase); + DEF_OPR_PARAM(AdaptivePooling); + +protected: + param::Pooling deduce_pooling_param(const TensorLayout& src, + const TensorLayout& dst); +}; + +class AdaptivePoolingForward : public AdaptivePoolingBase { + DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1); + +public: + /** + * \param[in] src input tensor + * \param[out] dst output tensor + */ + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) = 0; +}; + +using AdaptivePooling = AdaptivePoolingForward; + +class AdaptivePoolingBackward : public AdaptivePoolingBase { + DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1); + +public: + /** + * \param[in] src the `src' parameter in AdaptivePoolingForward::exec + * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec + * \param[in] diff the backpropagated gradient wrt. dst + * \param[out] grad the backpropagated gradient wrt. src + */ + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) = 0; +}; + /** * \brief base class for Local */ diff --git a/dnn/src/common/adaptive_pooling.cpp b/dnn/src/common/adaptive_pooling.cpp new file mode 100644 index 000000000..56bcb3a1b --- /dev/null +++ b/dnn/src/common/adaptive_pooling.cpp @@ -0,0 +1,37 @@ +/** + * \file dnn/src/common/adaptive_pooling.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/opr_param_defs.h" +#include "megdnn/oprs.h" + +#include "src/common/utils.h" +namespace megdnn { + +param::Pooling AdaptivePoolingBase::deduce_pooling_param( + const TensorLayout& src, const TensorLayout& dst) { + megdnn_assert(param().format == param::AdaptivePooling::Format::NCHW); + size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], + OW = dst.shape[3]; + + param::Pooling ret; + ret.mode = param().mode; + ret.format = param().format; + ret.pad_h = ret.pad_w = 0; + ret.stride_h = floor(IH / OH); + ret.stride_w = floor(IW / OW); + ret.window_h = IH - (OH - 1) * ret.stride_h; + ret.window_w = IW - (OW - 1) * ret.stride_w; + + return ret; +} +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 354bbd0ac..844699728 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -199,6 +199,8 @@ private: cb(Remap) \ cb(RemapBackwardData) \ cb(RemapBackwardMat) \ + cb(AdaptivePoolingForward) \ + cb(AdaptivePoolingBackward) \ /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/cuda/adaptive_pooling/opr_impl.cpp b/dnn/src/cuda/adaptive_pooling/opr_impl.cpp new file mode 100644 index 000000000..c9ece2b4b --- /dev/null +++ b/dnn/src/cuda/adaptive_pooling/opr_impl.cpp @@ -0,0 +1,53 @@ +/** + * \file dnn/src/cuda/adaptive_pooling/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/cuda/adaptive_pooling/opr_impl.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void AdaptivePoolingForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + auto opr = handle()->create_operator(); + opr->param() = deduce_pooling_param(src.layout, dst.layout); + opr->exec(src, dst, workspace); +} + +size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { + auto opr = handle()->create_operator(); + opr->param() = deduce_pooling_param(src, dst); + return opr->get_workspace_in_bytes(src, dst); +} + +void AdaptivePoolingBackwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in dst, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + auto opr = handle()->create_operator(); + opr->param() = deduce_pooling_param(src.layout, dst.layout); + opr->exec(src, dst, diff, grad, workspace); +} + +size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) { + auto opr = handle()->create_operator(); + opr->param() = deduce_pooling_param(src, dst); + return opr->get_workspace_in_bytes(src, dst, diff, grad); +} +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/adaptive_pooling/opr_impl.h b/dnn/src/cuda/adaptive_pooling/opr_impl.h new file mode 100644 index 000000000..5df0538bb --- /dev/null +++ b/dnn/src/cuda/adaptive_pooling/opr_impl.h @@ -0,0 +1,44 @@ +/** + * \file dnn/src/cuda/adaptive_pooling/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +class AdaptivePoolingForwardImpl final : public AdaptivePoolingForward { +public: + using AdaptivePoolingForward::AdaptivePoolingForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) override; +}; + +class AdaptivePoolingBackwardImpl final : public AdaptivePoolingBackward { +public: + using AdaptivePoolingBackward::AdaptivePoolingBackward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) override; +}; +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index e060a43d9..59262b5d1 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -11,6 +11,7 @@ #include "src/common/handle_impl.h" +#include "src/cuda/adaptive_pooling/opr_impl.h" #include "src/cuda/add_update/opr_impl.h" #include "src/cuda/argmxx/opr_impl.h" #include "src/cuda/argsort/opr_impl.h" diff --git a/dnn/src/naive/adaptive_pooling/opr_impl.cpp b/dnn/src/naive/adaptive_pooling/opr_impl.cpp new file mode 100644 index 000000000..0d6f53e67 --- /dev/null +++ b/dnn/src/naive/adaptive_pooling/opr_impl.cpp @@ -0,0 +1,52 @@ +/** + * \file dnn/src/naive/adaptive_pooling/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/naive/adaptive_pooling/opr_impl.h" + +#include "src/common/opr_delegate.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace megdnn { +namespace naive { + +void AdaptivePoolingForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), { + auto opr = inplace_cpu_handle()->create_operator(); + opr->param() = deduce_pooling_param(src.layout, dst.layout); + opr->exec(src, dst, workspace); + }); +} + +void AdaptivePoolingBackwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in dst, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), { + auto opr = inplace_cpu_handle()->create_operator(); + opr->param() = deduce_pooling_param(src.layout, dst.layout); + opr->exec(src, dst, diff, grad, workspace); + }); +} + +size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) { + auto opr = inplace_cpu_handle()->create_operator(); + opr->param() = deduce_pooling_param(src, dst); + return opr->get_workspace_in_bytes(src, dst, diff, grad); +} +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/adaptive_pooling/opr_impl.h b/dnn/src/naive/adaptive_pooling/opr_impl.h new file mode 100644 index 000000000..cb3bec172 --- /dev/null +++ b/dnn/src/naive/adaptive_pooling/opr_impl.h @@ -0,0 +1,43 @@ +/** + * \file dnn/src/naive/adaptive_pooling/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace naive { + +class AdaptivePoolingForwardImpl : public AdaptivePoolingForward { +public: + using AdaptivePoolingForward::AdaptivePoolingForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class AdaptivePoolingBackwardImpl : public AdaptivePoolingBackward { +public: + using AdaptivePoolingBackward::AdaptivePoolingBackward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) override; +}; +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 6a1129d95..fd4cb7803 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -13,6 +13,7 @@ #include "src/common/handle_impl.h" +#include "src/naive/adaptive_pooling/opr_impl.h" #include "src/naive/add_update/opr_impl.h" #include "src/naive/argmxx/opr_impl.h" #include "src/naive/argsort/opr_impl.h" diff --git a/dnn/test/common/adaptive_pooling.h b/dnn/test/common/adaptive_pooling.h new file mode 100644 index 000000000..7e4495133 --- /dev/null +++ b/dnn/test/common/adaptive_pooling.h @@ -0,0 +1,55 @@ +/** + * \file dnn/test/common/adaptive_pooling.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace adaptive_pooling { + +struct TestArg { + param::AdaptivePooling param; + TensorShape ishape; + TensorShape oshape; + TestArg(param::AdaptivePooling param, TensorShape ishape, + TensorShape oshape) + : param(param), ishape(ishape), oshape(oshape) {} +}; + +inline std::vector get_args() { + std::vector args; + using Param = param::AdaptivePooling; + using Mode = param::AdaptivePooling::Mode; + + for (size_t i = 36; i < 40; ++i) { + args.emplace_back(Param{Mode::AVERAGE}, TensorShape{2, 3, i, i + 1}, + TensorShape{2, 3, i - 4, i - 2}); + args.emplace_back(Param{Mode::MAX}, TensorShape{2, 3, i, i + 1}, + TensorShape{2, 3, i - 4, i - 2}); + } + + for (size_t i = 5; i < 10; ++i) { + args.emplace_back(Param{Mode::AVERAGE}, TensorShape{2, 3, i, i + 1}, + TensorShape{2, 3, i - 3, i - 2}); + args.emplace_back(Param{Mode::MAX}, TensorShape{2, 3, i, i + 1}, + TensorShape{2, 3, i - 3, i - 2}); + } + return args; +} + +} // namespace adaptive_pooling +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/opr_trait.h b/dnn/test/common/opr_trait.h index 22f0f349b..66af7db3c 100644 --- a/dnn/test/common/opr_trait.h +++ b/dnn/test/common/opr_trait.h @@ -41,6 +41,8 @@ DEF(Images2NeibsForward, 2, true, true); DEF(Images2NeibsBackward, 2, true, false); DEF(PoolingForward, 2, true, true); DEF(PoolingBackward, 4, true, false); +DEF(AdaptivePoolingForward, 2, true, false); +DEF(AdaptivePoolingBackward, 4, true, false); DEF(LocalForward, 3, true, true); DEF(LocalBackwardData, 3, true, false); DEF(LocalBackwardFilter, 3, true, false); diff --git a/dnn/test/cuda/adaptive_pooling.cpp b/dnn/test/cuda/adaptive_pooling.cpp new file mode 100644 index 000000000..14d444c58 --- /dev/null +++ b/dnn/test/cuda/adaptive_pooling.cpp @@ -0,0 +1,97 @@ +/** + * \file dnn/test/cuda/adaptive_pooling.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/cuda/fixture.h" + +#include "megdnn/tensor_iter.h" +#include "test/common/adaptive_pooling.h" +#include "test/common/checker.h" + +#include "src/common/utils.h" +#include "test/cuda/utils.h" + +#include +#include "test/cuda/benchmark.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, ADAPTIVE_POOLING_FORWARD) { + auto args = adaptive_pooling::get_args(); + using Format = param::AdaptivePooling::Format; + DType dtype = dtype::Float32(); + for (auto&& arg : args) { + auto param = arg.param; + auto src = arg.ishape; + auto dst = arg.oshape; + param.format = Format::NCHW; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-2); + checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( + TensorShapeArray{src, dst, {}}); + } +} + +TEST_F(CUDA, ADAPTIVE_POOLING_BACKWARD) { + auto args = adaptive_pooling::get_args(); + for (auto&& arg : args) { + Checker checker(handle_cuda()); + TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); + TensorLayout olayout = TensorLayout(arg.oshape, dtype::Float32()); + + auto constraint = [this, + arg](CheckerHelper::TensorValueArray& tensors_orig) { + megdnn_assert(tensors_orig.size() == 4); + auto opr = handle_cuda()->create_operator(); + opr->param() = arg.param; + + auto tensors_cuda_storage = CheckerHelper::alloc_tensors( + handle_cuda(), + {tensors_orig[0].layout, tensors_orig[1].layout}, 0); + auto&& tensors_cuda = *tensors_cuda_storage; + + auto span = tensors_cuda[0].layout.span(); + auto dst = static_cast(tensors_cuda[0].raw_ptr) + + span.low_byte; + auto src = static_cast(tensors_orig[0].raw_ptr) + + span.low_byte; + megdnn_memcpy_H2D(handle_cuda(), dst, src, span.dist_byte()); + + auto workspace_size = opr->get_workspace_in_bytes( + tensors_cuda[0].layout, tensors_cuda[1].layout); + auto workspace_cuda = megdnn_malloc(handle_cuda(), workspace_size); + Workspace workspace{static_cast(workspace_cuda), + workspace_size}; + opr->exec(tensors_cuda[0], tensors_cuda[1], workspace); + megdnn_free(handle_cuda(), workspace_cuda); + + span = tensors_cuda[1].layout.span(); + dst = static_cast(tensors_orig[1].raw_ptr) + + span.low_byte; + src = static_cast(tensors_cuda[1].raw_ptr) + + span.low_byte; + megdnn_memcpy_D2H(handle_cuda(), dst, src, span.dist_byte()); + }; + + DType dtype = dtype::Float32(); + checker.set_tensors_constraint(constraint) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .set_param(arg.param) + .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout}); + } +} +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen -- GitLab