diff --git a/dnn/src/rocm/miopen_wrapper.cpp b/dnn/src/rocm/miopen_wrapper.cpp index 2d4249f0dda0f505e378555b6c65406bfd403eb8..7b0a2f6b21344adf7015ad61011d1cf4febd333e 100644 --- a/dnn/src/rocm/miopen_wrapper.cpp +++ b/dnn/src/rocm/miopen_wrapper.cpp @@ -96,34 +96,6 @@ void ConvDesc::set(const param::Convolution& param, const size_t nr_group, //! not supported } -PoolingDesc::PoolingDesc() { - miopen_check(miopenCreatePoolingDescriptor(&desc)); -} - -PoolingDesc::~PoolingDesc() { - miopen_check(miopenDestroyPoolingDescriptor(desc)); -} - -void PoolingDesc::set(const param::Pooling& param) { - miopenPoolingMode_t mode; - switch (param.mode) { - case param::Pooling::Mode::MAX: - mode = miopenPoolingMax; - break; - case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: - mode = miopenPoolingAverage; - break; - case param::Pooling::Mode::AVERAGE: - mode = miopenPoolingAverageInclusive; - break; - default: - megdnn_throw("Unsupported pooling mode for miopen"); - } - miopen_check(miopenSet2dPoolingDescriptor( - desc, mode, param.window_h, param.window_w, param.pad_h, - param.pad_w, param.stride_h, param.stride_w)); -} - LRNDesc::LRNDesc() { miopen_check(miopenCreateLRNDescriptor(&desc)); } diff --git a/dnn/src/rocm/miopen_wrapper.h b/dnn/src/rocm/miopen_wrapper.h index 7c1cf1bc0da2babd159733bbb1effafc6b5ad2e0..f51f789da3b6e740df0e3044a20246f367f05c32 100644 --- a/dnn/src/rocm/miopen_wrapper.h +++ b/dnn/src/rocm/miopen_wrapper.h @@ -38,14 +38,6 @@ public: miopenConvolutionDescriptor_t desc; }; -class PoolingDesc { -public: - PoolingDesc(); - void set(const param::Pooling& param); - ~PoolingDesc(); - miopenPoolingDescriptor_t desc; -}; - class LRNDesc { public: LRNDesc(); diff --git a/dnn/src/rocm/pooling/algo.cpp b/dnn/src/rocm/pooling/algo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..471bd3108decc2184aedb93323fcecaca40aa53e --- /dev/null +++ b/dnn/src/rocm/pooling/algo.cpp @@ -0,0 +1,209 @@ +/** + * \file dnn/src/rocm/pooling/algos.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./algo.h" +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +PoolingForwardImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&algo_miopen); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } +} + +PoolingForwardImpl::AlgoPack PoolingForwardImpl::sm_algo_pack; +MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingForwardImpl) + +PoolingForwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingForwardImpl* o, + const TensorLayout& src, + const TensorLayout& dst) + : handle{concrete_handle(o->handle())}, + opr{o}, + layout_src{&src}, + layout_dst{&dst} {} + +PoolingForwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingForwardImpl* opr, + _megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, dst.layout), + src_tensor{&src}, + dst_tensor{&dst}, + workspace{workspace} {} + +std::string PoolingForwardImpl::AlgoBase::SizeArgs::to_string() const { + return ssprintf("src=%s, dst=%s", layout_src->to_string().c_str(), + layout_dst->to_string().c_str()); +} + +bool PoolingForwardImpl::AlgoMIOpen::is_available(const SizeArgs& args) const { + return true; +} + +void PoolingForwardImpl::AlgoMIOpen::init_mode( + const ExecArgs& args, miopenPoolingMode_t& mode) const { + switch (args.opr->param().mode) { + case param::Pooling::Mode::MAX: + mode = miopenPoolingMax; + break; + case param::Pooling::Mode::AVERAGE: + mode = miopenPoolingAverage; + break; + case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + mode = miopenPoolingAverageInclusive; + break; + default: + megdnn_throw(ssprintf("Unspport pooling mode : {%d}", + static_cast(args.opr->param().mode))); + } +} + +size_t PoolingForwardImpl::AlgoMIOpen::get_workspace_in_bytes( + const SizeArgs& args) const { + return 0; +} + +void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { + auto handle = miopen_handle(args.handle); + TensorDesc src_desc, dst_desc; + args.init_desc(src_desc, dst_desc); + miopenPoolingMode_t mode; + init_mode(args, mode); + + miopenPoolingDescriptor_t miopen_desc; + miopen_check(miopenCreatePoolingDescriptor(&miopen_desc)); + miopen_check(miopenSet2dPoolingDescriptor( + miopen_desc, mode, args.opr->param().window_h, + args.opr->param().window_w, args.opr->param().pad_h, + args.opr->param().pad_w, args.opr->param().stride_h, + args.opr->param().stride_w)); + + dt_float32 alpha = 1.0f, beta = 0.0f; + miopen_check(miopenPoolingForward( + handle, miopen_desc, &alpha, src_desc.desc, + args.src_tensor->raw_ptr, &beta, dst_desc.desc, + args.src_tensor->raw_ptr, false, nullptr, 0_z)); + miopen_check(miopenDestroyPoolingDescriptor(miopen_desc)); +} + +PoolingBackwardImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&algo_miopen); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } +} + +PoolingBackwardImpl::AlgoPack PoolingBackwardImpl::sm_algo_pack; +MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingBackwardImpl) + +PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingBackwardImpl* o, + const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) + : handle{concrete_handle(o->handle())}, + opr{o}, + layout_src{&src}, + layout_dst{&dst}, + layout_diff{&diff}, + layout_grad{&grad} {} + +PoolingBackwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingBackwardImpl* opr, + _megdnn_tensor_in src, + _megdnn_tensor_in dst, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, dst.layout, diff.layout, grad.layout), + src_tensor{&src}, + dst_tensor{&dst}, + diff_tensor{&diff}, + grad_tensor{&grad}, + workspace{workspace} {} + +std::string PoolingBackwardImpl::AlgoBase::SizeArgs::to_string() const { + return ssprintf( + "src=%s, dst=%s, diff=%s, grad=%s", layout_src->to_string().c_str(), + layout_dst->to_string().c_str(), layout_diff->to_string().c_str(), + layout_grad->to_string().c_str()); +} + +bool PoolingBackwardImpl::AlgoMIOpen::is_available(const SizeArgs&) const { + return true; +} + +size_t PoolingBackwardImpl::AlgoMIOpen::get_workspace_in_bytes( + const SizeArgs& args) const { + TensorDesc dst_desc; + dst_desc.set(*args.layout_dst); + + size_t ws_size = 0_z; + miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); + return ws_size; +} + +void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args, + miopenPoolingMode_t& mode) const { + switch (args.opr->param().mode) { + case param::Pooling::Mode::MAX: + mode = miopenPoolingMax; + break; + case param::Pooling::Mode::AVERAGE: + mode = miopenPoolingAverage; + break; + case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + mode = miopenPoolingAverageInclusive; + break; + default: + megdnn_throw(ssprintf("Unspport pooling mode : {%d}", + static_cast(args.opr->param().mode))); + } +} + +void PoolingBackwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { + auto handle = miopen_handle(args.handle); + TensorDesc src_desc, dst_desc, diff_desc, grad_desc; + args.init_desc(src_desc, dst_desc, diff_desc, grad_desc); + miopenPoolingMode_t mode; + init_mode(args, mode); + + miopenPoolingDescriptor_t miopen_desc; + miopen_check(miopenCreatePoolingDescriptor(&miopen_desc)); + miopen_check(miopenSet2dPoolingDescriptor( + miopen_desc, mode, args.opr->param().window_h, + args.opr->param().window_w, args.opr->param().pad_h, + args.opr->param().pad_w, args.opr->param().stride_h, + args.opr->param().stride_w)); + + float alpha = 1.0f, beta = 0.0f; + if (args.opr->param().mode == param::Pooling::Mode::MAX) { + //! FIXME: when using max pooling opr, the backward opr need the indices + //! of the forward opr which stored in workspace. We have to recompute + //! the indices by calling miopenPoolingForward again. + miopen_check(miopenPoolingForward( + handle, miopen_desc, &alpha, src_desc.desc, + args.src_tensor->raw_ptr, &beta, dst_desc.desc, + args.dst_tensor->raw_ptr, true, args.workspace.raw_ptr, + args.workspace.size)); + } + miopen_check(miopenPoolingBackward( + handle, miopen_desc, &alpha, dst_desc.desc, + args.dst_tensor->raw_ptr, diff_desc.desc, args.diff_tensor->raw_ptr, + src_desc.desc, args.src_tensor->raw_ptr, &beta, grad_desc.desc, + args.grad_tensor->raw_ptr, args.workspace.raw_ptr)); +} \ No newline at end of file diff --git a/dnn/src/rocm/pooling/algo.h b/dnn/src/rocm/pooling/algo.h new file mode 100644 index 0000000000000000000000000000000000000000..692bb6c9855ab92d07a4bc6fb9ca4fdde7e9d404 --- /dev/null +++ b/dnn/src/rocm/pooling/algo.h @@ -0,0 +1,195 @@ +/** + * \file dnn/src/rocm/pooling/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/rocm/miopen_wrapper.h" +#include "src/rocm/pooling/opr_impl.h" +#include "src/rocm/handle.h" + +namespace megdnn { +namespace rocm { + +class PoolingForwardImpl::AlgoBase : public Algorithm { +public: + enum class AlgoType : uint32_t { ROCM_MIOPEN }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } + struct SizeArgs { + HandleImpl* handle; + PoolingForwardImpl* opr; + const TensorLayout *layout_src, *layout_dst; + + std::string to_string() const; + void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc) const { + src_desc.set(*layout_src, opr->param().format); + dst_desc.set(*layout_dst, opr->param().format); + } + + SizeArgs(PoolingForwardImpl* opr, const TensorLayout& src, + const TensorLayout& dst); + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *dst_tensor; + Workspace workspace; + + ExecArgs(PoolingForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_out dst, _megdnn_workspace workspace); + }; + + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_attribute( + const SizeArgs& args, + const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, + const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { + return contain_attribute_all(positive_attr) && + !contain_attribute_any(negative_attr) && is_available(args); + } + +protected: + ~AlgoBase() = default; +}; + +class PoolingForwardImpl::AlgoMIOpen final : public AlgoBase { + std::string m_algo_name; + AlgoAttribute m_algo_attribute; + +public: + AlgoMIOpen(AlgoAttribute attr) + : m_algo_name("MIOpenPoolingForward"), m_algo_attribute(attr) {} + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return m_algo_name.c_str(); } + AlgoAttribute attribute() const override { return m_algo_attribute; } + + MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algo_attribute, ret); + return ret; + } +}; + +class PoolingForwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; + +public: + AlgoPack(); + AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE}; + + std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } +}; + +class PoolingBackwardImpl::AlgoBase : public Algorithm { +public: + enum class AlgoType : uint32_t { ROCM_MIOPEN }; + using Mapper = std::unordered_map; + + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } + struct SizeArgs { + HandleImpl* handle; + PoolingBackwardImpl* opr; + const TensorLayout *layout_src, *layout_dst, *layout_diff, *layout_grad; + + std::string to_string() const; + void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc, + TensorDesc& diff_desc, TensorDesc& grad_desc) const { + src_desc.set(*layout_src); + dst_desc.set(*layout_dst); + diff_desc.set(*layout_diff); + grad_desc.set(*layout_grad); + } + SizeArgs(PoolingBackwardImpl* opr, const TensorLayout& src, + const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad); + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *dst_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(PoolingBackwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace); + }; + + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_attribute( + const SizeArgs& args, + const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, + const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { + return contain_attribute_all(positive_attr) && + !contain_attribute_any(negative_attr) && is_available(args); + } + +protected: + ~AlgoBase() = default; +}; + +class PoolingBackwardImpl::AlgoMIOpen final : public AlgoBase { + std::string m_algo_name; + AlgoAttribute m_algo_attribute; + +public: + AlgoMIOpen(AlgoAttribute attr) + : m_algo_name("MIOpenPoolingBackward"), m_algo_attribute(attr) {} + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return m_algo_name.c_str(); } + AlgoAttribute attribute() const override { + return m_algo_attribute; + } + + MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algo_attribute, ret); + return ret; + } +}; + +class PoolingBackwardImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; + +public: + AlgoPack(); + AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE}; + std::vector all_algos; + + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } +}; + +} // namespace rocm +} // namespace megdnn diff --git a/dnn/src/rocm/pooling/opr_impl.cpp b/dnn/src/rocm/pooling/opr_impl.cpp index d7525e11662834fddb8fcd586c8f3f21a6f03528..ae49eba9af8487a150fe59ddea3ae5edef77c1cc 100644 --- a/dnn/src/rocm/pooling/opr_impl.cpp +++ b/dnn/src/rocm/pooling/opr_impl.cpp @@ -10,18 +10,47 @@ */ #include "hcc_detail/hcc_defs_prologue.h" #include "src/rocm/pooling/opr_impl.h" - #include "src/rocm/utils.h" +#include "./algo.h" +#include "src/common/algo_chooser.h" namespace megdnn { namespace rocm { -void PoolingForwardImpl::setup_descs(const TensorLayout &src, - const TensorLayout &dst) -{ - src_desc.set(src, param().format); - dst_desc.set(dst, param().format); - pooling_desc.set(this->param()); +size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) { + AlgoBase::SizeArgs args(this, src, dst); + return get_algorithm(this, src, dst)->get_workspace_in_bytes(args); +} + +const char* PoolingForwardImpl::get_algorithm_set_name() const { + return "ROCM_POOLING_FORWARD"; +} + +std::vector +PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, + const TensorLayout& dst) { + return megdnn::get_all_algorithms({this, src, dst}); +} + +PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); + + AlgoBase::SizeArgs args(this, src, dst); + for (auto&& iter : sm_algo_pack.all_algos) { + if (iter->is_available_attribute(args, positive_attr, negative_attr)) { + return iter; + } + } + megdnn_throw( + ssprintf("require algorithm with attribute(%s) and without " + "attribute(%s), but can't get suitable algo.\n", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str())); + return nullptr; } void PoolingForwardImpl::exec(_megdnn_tensor_in src, @@ -29,24 +58,52 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - auto handle = miopen_handle(this->handle()); - setup_descs(src.layout, dst.layout); - dt_float32 alpha = 1.0f, beta = 0.0f; - miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, - src_desc.desc, src.raw_ptr, &beta, - dst_desc.desc, dst.raw_ptr, false, - nullptr, 0_z)); + { + AlgoBase::ExecArgs args(this, src, dst, workspace); + auto algo = get_algorithm(this, src.layout, dst.layout); + algo->exec(args); + } } -void PoolingBackwardImpl::setup_descs(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) { - src_desc.set(src); - dst_desc.set(dst); - diff_desc.set(diff); - grad_desc.set(grad); - pooling_desc.set(this->param()); +size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) { + AlgoBase::SizeArgs args(this, src, dst, diff, grad); + return get_algorithm(this, src, dst, diff, grad) + ->get_workspace_in_bytes(args); +}; + +const char* PoolingBackwardImpl::get_algorithm_set_name() const { + return "ROCM_POOLING_BACKWARD"; +} + +std::vector PoolingBackwardImpl::get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) { + return megdnn::get_all_algorithms( + {this, src, dst, diff, grad}); +} + +Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); + + AlgoBase::SizeArgs args(this, src, dst, diff, grad); + for (auto iter : sm_algo_pack.all_algos) { + if (iter->is_available_attribute(args, positive_attr, negative_attr)) { + return iter; + } + } + megdnn_throw( + ssprintf("require algorithm with attribute(%s) and without " + "attribute(%s), but can't get suitable algo.\n", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str())); + return nullptr; } void PoolingBackwardImpl::exec(_megdnn_tensor_in src, @@ -55,35 +112,16 @@ void PoolingBackwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out grad, _megdnn_workspace workspace) { - check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size); - auto handle = miopen_handle(this->handle()); - setup_descs(src.layout, dst.layout, diff.layout, grad.layout); - float alpha = 1.0f, beta = 0.0f; - if (param().mode == param::Pooling::Mode::MAX) { - //! FIXME: when using max pooling opr, the backward opr need the indices - //! of the forward opr which stored in workspace. We have to recompute - //! the indices by calling miopenPoolingForward again. - miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, - src_desc.desc, src.raw_ptr, &beta, - dst_desc.desc, dst.raw_ptr, true, - workspace.raw_ptr, workspace.size)); + check_exec(src.layout, dst.layout, diff.layout, grad.layout, + workspace.size); + { + AlgoBase::ExecArgs args(this, src, dst, diff, grad, workspace); + auto algo = get_algorithm(this, src.layout, dst.layout, diff.layout, + grad.layout); + algo->exec(args); } - miopen_check(miopenPoolingBackward( - handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr, - diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta, - grad_desc.desc, grad.raw_ptr, workspace.raw_ptr)); } -size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) { - setup_descs(src, dst, diff, grad); - size_t ws_size = 0_z; - miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); - return ws_size; -}; - } // namespace rocm } // namespace megdnn diff --git a/dnn/src/rocm/pooling/opr_impl.h b/dnn/src/rocm/pooling/opr_impl.h index feb755ebabbcaf81743860b3f619a661b6590f7a..57be4c201a33977d89e9e286365b7b98c40011b1 100644 --- a/dnn/src/rocm/pooling/opr_impl.h +++ b/dnn/src/rocm/pooling/opr_impl.h @@ -22,13 +22,37 @@ class PoolingForwardImpl final: public PoolingForward { void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override { - return 0; + const TensorLayout &) override; + + const char* get_algorithm_set_name() const override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; + + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, + positive_attr, negative_attr) + ->info(); } + + class AlgoBase; + class AlgoMIOpen; + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) override; + private: - TensorDesc src_desc, dst_desc; - PoolingDesc pooling_desc; - void setup_descs(const TensorLayout &src, const TensorLayout &dst); + static AlgoPack sm_algo_pack; }; class PoolingBackwardImpl final: public PoolingBackward { @@ -43,14 +67,41 @@ class PoolingBackwardImpl final: public PoolingBackward { const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad) override; - private: - TensorDesc src_desc, dst_desc, diff_desc, grad_desc; - PoolingDesc pooling_desc; - void setup_descs(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &diff, - const TensorLayout &grad); + const char* get_algorithm_set_name() const override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; + + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + return get_algorithm_heuristic(src, dst, diff, grad, + workspace_limit_in_bytes, + positive_attr, negative_attr) + ->info(); + } + + class AlgoBase; + class AlgoMIOpen; + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + protected: + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) override; + + private: + static AlgoPack sm_algo_pack; }; } // namespace rocm