提交 25932352 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(mgb/dnn): rocm pooling rebase algochooser

GitOrigin-RevId: 95be9298415636583c9737362568b59e2d38580d
上级 1cfdbc56
......@@ -96,34 +96,6 @@ void ConvDesc::set(const param::Convolution& param, const size_t nr_group,
//! not supported
}
PoolingDesc::PoolingDesc() {
miopen_check(miopenCreatePoolingDescriptor(&desc));
}
PoolingDesc::~PoolingDesc() {
miopen_check(miopenDestroyPoolingDescriptor(desc));
}
void PoolingDesc::set(const param::Pooling& param) {
miopenPoolingMode_t mode;
switch (param.mode) {
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverageInclusive;
break;
default:
megdnn_throw("Unsupported pooling mode for miopen");
}
miopen_check(miopenSet2dPoolingDescriptor(
desc, mode, param.window_h, param.window_w, param.pad_h,
param.pad_w, param.stride_h, param.stride_w));
}
LRNDesc::LRNDesc() {
miopen_check(miopenCreateLRNDescriptor(&desc));
}
......
......@@ -38,14 +38,6 @@ public:
miopenConvolutionDescriptor_t desc;
};
class PoolingDesc {
public:
PoolingDesc();
void set(const param::Pooling& param);
~PoolingDesc();
miopenPoolingDescriptor_t desc;
};
class LRNDesc {
public:
LRNDesc();
......
/**
* \file dnn/src/rocm/pooling/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./algo.h"
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
PoolingForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_miopen);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
PoolingForwardImpl::AlgoPack PoolingForwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingForwardImpl)
PoolingForwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingForwardImpl* o,
const TensorLayout& src,
const TensorLayout& dst)
: handle{concrete_handle(o->handle())},
opr{o},
layout_src{&src},
layout_dst{&dst} {}
PoolingForwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingForwardImpl* opr,
_megdnn_tensor_in src,
_megdnn_tensor_out dst,
_megdnn_workspace workspace)
: SizeArgs(opr, src.layout, dst.layout),
src_tensor{&src},
dst_tensor{&dst},
workspace{workspace} {}
std::string PoolingForwardImpl::AlgoBase::SizeArgs::to_string() const {
return ssprintf("src=%s, dst=%s", layout_src->to_string().c_str(),
layout_dst->to_string().c_str());
}
bool PoolingForwardImpl::AlgoMIOpen::is_available(const SizeArgs& args) const {
return true;
}
void PoolingForwardImpl::AlgoMIOpen::init_mode(
const ExecArgs& args, miopenPoolingMode_t& mode) const {
switch (args.opr->param().mode) {
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverageInclusive;
break;
default:
megdnn_throw(ssprintf("Unspport pooling mode : {%d}",
static_cast<int>(args.opr->param().mode)));
}
}
size_t PoolingForwardImpl::AlgoMIOpen::get_workspace_in_bytes(
const SizeArgs& args) const {
return 0;
}
void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const {
auto handle = miopen_handle(args.handle);
TensorDesc src_desc, dst_desc;
args.init_desc(src_desc, dst_desc);
miopenPoolingMode_t mode;
init_mode(args, mode);
miopenPoolingDescriptor_t miopen_desc;
miopen_check(miopenCreatePoolingDescriptor(&miopen_desc));
miopen_check(miopenSet2dPoolingDescriptor(
miopen_desc, mode, args.opr->param().window_h,
args.opr->param().window_w, args.opr->param().pad_h,
args.opr->param().pad_w, args.opr->param().stride_h,
args.opr->param().stride_w));
dt_float32 alpha = 1.0f, beta = 0.0f;
miopen_check(miopenPoolingForward(
handle, miopen_desc, &alpha, src_desc.desc,
args.src_tensor->raw_ptr, &beta, dst_desc.desc,
args.src_tensor->raw_ptr, false, nullptr, 0_z));
miopen_check(miopenDestroyPoolingDescriptor(miopen_desc));
}
PoolingBackwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_miopen);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
PoolingBackwardImpl::AlgoPack PoolingBackwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingBackwardImpl)
PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingBackwardImpl* o,
const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad)
: handle{concrete_handle(o->handle())},
opr{o},
layout_src{&src},
layout_dst{&dst},
layout_diff{&diff},
layout_grad{&grad} {}
PoolingBackwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingBackwardImpl* opr,
_megdnn_tensor_in src,
_megdnn_tensor_in dst,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
: SizeArgs(opr, src.layout, dst.layout, diff.layout, grad.layout),
src_tensor{&src},
dst_tensor{&dst},
diff_tensor{&diff},
grad_tensor{&grad},
workspace{workspace} {}
std::string PoolingBackwardImpl::AlgoBase::SizeArgs::to_string() const {
return ssprintf(
"src=%s, dst=%s, diff=%s, grad=%s", layout_src->to_string().c_str(),
layout_dst->to_string().c_str(), layout_diff->to_string().c_str(),
layout_grad->to_string().c_str());
}
bool PoolingBackwardImpl::AlgoMIOpen::is_available(const SizeArgs&) const {
return true;
}
size_t PoolingBackwardImpl::AlgoMIOpen::get_workspace_in_bytes(
const SizeArgs& args) const {
TensorDesc dst_desc;
dst_desc.set(*args.layout_dst);
size_t ws_size = 0_z;
miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size);
return ws_size;
}
void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args,
miopenPoolingMode_t& mode) const {
switch (args.opr->param().mode) {
case param::Pooling::Mode::MAX:
mode = miopenPoolingMax;
break;
case param::Pooling::Mode::AVERAGE:
mode = miopenPoolingAverage;
break;
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
mode = miopenPoolingAverageInclusive;
break;
default:
megdnn_throw(ssprintf("Unspport pooling mode : {%d}",
static_cast<int>(args.opr->param().mode)));
}
}
void PoolingBackwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const {
auto handle = miopen_handle(args.handle);
TensorDesc src_desc, dst_desc, diff_desc, grad_desc;
args.init_desc(src_desc, dst_desc, diff_desc, grad_desc);
miopenPoolingMode_t mode;
init_mode(args, mode);
miopenPoolingDescriptor_t miopen_desc;
miopen_check(miopenCreatePoolingDescriptor(&miopen_desc));
miopen_check(miopenSet2dPoolingDescriptor(
miopen_desc, mode, args.opr->param().window_h,
args.opr->param().window_w, args.opr->param().pad_h,
args.opr->param().pad_w, args.opr->param().stride_h,
args.opr->param().stride_w));
float alpha = 1.0f, beta = 0.0f;
if (args.opr->param().mode == param::Pooling::Mode::MAX) {
//! FIXME: when using max pooling opr, the backward opr need the indices
//! of the forward opr which stored in workspace. We have to recompute
//! the indices by calling miopenPoolingForward again.
miopen_check(miopenPoolingForward(
handle, miopen_desc, &alpha, src_desc.desc,
args.src_tensor->raw_ptr, &beta, dst_desc.desc,
args.dst_tensor->raw_ptr, true, args.workspace.raw_ptr,
args.workspace.size));
}
miopen_check(miopenPoolingBackward(
handle, miopen_desc, &alpha, dst_desc.desc,
args.dst_tensor->raw_ptr, diff_desc.desc, args.diff_tensor->raw_ptr,
src_desc.desc, args.src_tensor->raw_ptr, &beta, grad_desc.desc,
args.grad_tensor->raw_ptr, args.workspace.raw_ptr));
}
\ No newline at end of file
/**
* \file dnn/src/rocm/pooling/algo.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <unordered_map>
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/rocm/miopen_wrapper.h"
#include "src/rocm/pooling/opr_impl.h"
#include "src/rocm/handle.h"
namespace megdnn {
namespace rocm {
class PoolingForwardImpl::AlgoBase : public Algorithm {
public:
enum class AlgoType : uint32_t { ROCM_MIOPEN };
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs {
HandleImpl* handle;
PoolingForwardImpl* opr;
const TensorLayout *layout_src, *layout_dst;
std::string to_string() const;
void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc) const {
src_desc.set(*layout_src, opr->param().format);
dst_desc.set(*layout_dst, opr->param().format);
}
SizeArgs(PoolingForwardImpl* opr, const TensorLayout& src,
const TensorLayout& dst);
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *dst_tensor;
Workspace workspace;
ExecArgs(PoolingForwardImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_out dst, _megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && is_available(args);
}
protected:
~AlgoBase() = default;
};
class PoolingForwardImpl::AlgoMIOpen final : public AlgoBase {
std::string m_algo_name;
AlgoAttribute m_algo_attribute;
public:
AlgoMIOpen(AlgoAttribute attr)
: m_algo_name("MIOpenPoolingForward"), m_algo_attribute(attr) {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_algo_name.c_str(); }
AlgoAttribute attribute() const override { return m_algo_attribute; }
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_attribute, ret);
return ret;
}
};
class PoolingForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE};
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
class PoolingBackwardImpl::AlgoBase : public Algorithm {
public:
enum class AlgoType : uint32_t { ROCM_MIOPEN };
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs {
HandleImpl* handle;
PoolingBackwardImpl* opr;
const TensorLayout *layout_src, *layout_dst, *layout_diff, *layout_grad;
std::string to_string() const;
void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc,
TensorDesc& diff_desc, TensorDesc& grad_desc) const {
src_desc.set(*layout_src);
dst_desc.set(*layout_dst);
diff_desc.set(*layout_diff);
grad_desc.set(*layout_grad);
}
SizeArgs(PoolingBackwardImpl* opr, const TensorLayout& src,
const TensorLayout& dst, const TensorLayout& diff,
const TensorLayout& grad);
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *dst_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;
ExecArgs(PoolingBackwardImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_in dst, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && is_available(args);
}
protected:
~AlgoBase() = default;
};
class PoolingBackwardImpl::AlgoMIOpen final : public AlgoBase {
std::string m_algo_name;
AlgoAttribute m_algo_attribute;
public:
AlgoMIOpen(AlgoAttribute attr)
: m_algo_name("MIOpenPoolingBackward"), m_algo_attribute(attr) {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_algo_name.c_str(); }
AlgoAttribute attribute() const override {
return m_algo_attribute;
}
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_attribute, ret);
return ret;
}
};
class PoolingBackwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack();
AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE};
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} // namespace rocm
} // namespace megdnn
......@@ -10,18 +10,47 @@
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/pooling/opr_impl.h"
#include "src/rocm/utils.h"
#include "./algo.h"
#include "src/common/algo_chooser.h"
namespace megdnn {
namespace rocm {
void PoolingForwardImpl::setup_descs(const TensorLayout &src,
const TensorLayout &dst)
{
src_desc.set(src, param().format);
dst_desc.set(dst, param().format);
pooling_desc.set(this->param());
size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
AlgoBase::SizeArgs args(this, src, dst);
return get_algorithm(this, src, dst)->get_workspace_in_bytes(args);
}
const char* PoolingForwardImpl::get_algorithm_set_name() const {
return "ROCM_POOLING_FORWARD";
}
std::vector<PoolingForwardImpl::Algorithm*>
PoolingForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& dst) {
return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst});
}
PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
AlgoBase::SizeArgs args(this, src, dst);
for (auto&& iter : sm_algo_pack.all_algos) {
if (iter->is_available_attribute(args, positive_attr, negative_attr)) {
return iter;
}
}
megdnn_throw(
ssprintf("require algorithm with attribute(%s) and without "
"attribute(%s), but can't get suitable algo.\n",
Algorithm::attribute_str(positive_attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str()));
return nullptr;
}
void PoolingForwardImpl::exec(_megdnn_tensor_in src,
......@@ -29,24 +58,52 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_workspace workspace)
{
check_exec(src.layout, dst.layout, workspace.size);
auto handle = miopen_handle(this->handle());
setup_descs(src.layout, dst.layout);
dt_float32 alpha = 1.0f, beta = 0.0f;
miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha,
src_desc.desc, src.raw_ptr, &beta,
dst_desc.desc, dst.raw_ptr, false,
nullptr, 0_z));
{
AlgoBase::ExecArgs args(this, src, dst, workspace);
auto algo = get_algorithm(this, src.layout, dst.layout);
algo->exec(args);
}
}
void PoolingBackwardImpl::setup_descs(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) {
src_desc.set(src);
dst_desc.set(dst);
diff_desc.set(diff);
grad_desc.set(grad);
pooling_desc.set(this->param());
size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, src, dst, diff, grad);
return get_algorithm(this, src, dst, diff, grad)
->get_workspace_in_bytes(args);
};
const char* PoolingBackwardImpl::get_algorithm_set_name() const {
return "ROCM_POOLING_BACKWARD";
}
std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) {
return megdnn::get_all_algorithms<PoolingBackwardImpl>(
{this, src, dst, diff, grad});
}
Algorithm* PoolingBackwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
AlgoBase::SizeArgs args(this, src, dst, diff, grad);
for (auto iter : sm_algo_pack.all_algos) {
if (iter->is_available_attribute(args, positive_attr, negative_attr)) {
return iter;
}
}
megdnn_throw(
ssprintf("require algorithm with attribute(%s) and without "
"attribute(%s), but can't get suitable algo.\n",
Algorithm::attribute_str(positive_attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str()));
return nullptr;
}
void PoolingBackwardImpl::exec(_megdnn_tensor_in src,
......@@ -55,35 +112,16 @@ void PoolingBackwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
{
check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size);
auto handle = miopen_handle(this->handle());
setup_descs(src.layout, dst.layout, diff.layout, grad.layout);
float alpha = 1.0f, beta = 0.0f;
if (param().mode == param::Pooling::Mode::MAX) {
//! FIXME: when using max pooling opr, the backward opr need the indices
//! of the forward opr which stored in workspace. We have to recompute
//! the indices by calling miopenPoolingForward again.
miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha,
src_desc.desc, src.raw_ptr, &beta,
dst_desc.desc, dst.raw_ptr, true,
workspace.raw_ptr, workspace.size));
check_exec(src.layout, dst.layout, diff.layout, grad.layout,
workspace.size);
{
AlgoBase::ExecArgs args(this, src, dst, diff, grad, workspace);
auto algo = get_algorithm(this, src.layout, dst.layout, diff.layout,
grad.layout);
algo->exec(args);
}
miopen_check(miopenPoolingBackward(
handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr,
diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta,
grad_desc.desc, grad.raw_ptr, workspace.raw_ptr));
}
size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) {
setup_descs(src, dst, diff, grad);
size_t ws_size = 0_z;
miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size);
return ws_size;
};
} // namespace rocm
} // namespace megdnn
......
......@@ -22,13 +22,37 @@ class PoolingForwardImpl final: public PoolingForward {
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &,
const TensorLayout &) override {
return 0;
const TensorLayout &) override;
const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes,
positive_attr, negative_attr)
->info();
}
class AlgoBase;
class AlgoMIOpen;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
TensorDesc src_desc, dst_desc;
PoolingDesc pooling_desc;
void setup_descs(const TensorLayout &src, const TensorLayout &dst);
static AlgoPack sm_algo_pack;
};
class PoolingBackwardImpl final: public PoolingBackward {
......@@ -43,14 +67,41 @@ class PoolingBackwardImpl final: public PoolingBackward {
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) override;
private:
TensorDesc src_desc, dst_desc, diff_desc, grad_desc;
PoolingDesc pooling_desc;
void setup_descs(const TensorLayout &src,
const TensorLayout &dst,
const TensorLayout &diff,
const TensorLayout &grad);
const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, dst, diff, grad,
workspace_limit_in_bytes,
positive_attr, negative_attr)
->info();
}
class AlgoBase;
class AlgoMIOpen;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
private:
static AlgoPack sm_algo_pack;
};
} // namespace rocm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册