提交 a9c3515c 编写于 作者: M Megvii Engine Team

fix(opr/naive): add DeformableConv algorithms interface

GitOrigin-RevId: adccb05f1a85552f7ab74aba5b2556675d5e2685
上级 d4bb54d4
......@@ -63,6 +63,32 @@ class DefaultPoolingBackwardAlgorithm final
const char* name() const override { return "DEFAULT"; }
};
class DeformableConvForwardAlgorithm final
: public megdnn::DeformableConvForward::Algorithm {
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DeformableConvBackwardFilterAlgorithm final
: public megdnn::DeformableConvBackwardFilter::Algorithm {
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
class DeformableConvBackwardDataAlgorithm final
: public megdnn::DeformableConvBackwardData::Algorithm {
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
uint32_t type() const override { return 0; }
const char* name() const override { return "DEFAULT"; }
};
} // namespace naive
} // namespace megdnn
......
#include "src/naive/deformable_conv/opr_impl.h"
#include <vector>
#include "src/common/utils.h"
#include "src/naive/convolution/helper.h"
#include "src/naive/handle.h"
......@@ -123,6 +124,38 @@ void Fwd::exec(
return;
}
std::vector<DeformableConvForward::Algorithm*> Fwd::get_all_algorithms(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) {
return {static_cast<HandleImpl*>(handle())->default_deformable_conv_fwd_algo()};
}
std::vector<DeformableConvForward::Algorithm*> Fwd::get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) {
return {static_cast<HandleImpl*>(handle())->default_deformable_conv_fwd_algo()};
}
DeformableConvForward::Algorithm* Fwd::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())->default_deformable_conv_fwd_algo();
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
DeformableConvForward::Algorithm* Fwd::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_deformable_conv_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
/* ============== Bwd Implementation ============== */
static void deformable_conv_backward_weight(
......@@ -388,6 +421,41 @@ void BwdFlt::exec(
out_grad.ptr<float>(), filter_grad.ptr<float>(), OC, IC, N, FH, FW, IH, IW,
PH, PW, DH, DW, SH, SW, OH, OW, group, deformable_group));
}
std::vector<BwdFlt::Algorithm*> BwdFlt::get_all_algorithms(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) {
return {static_cast<HandleImpl*>(handle())
->default_deformable_conv_bwd_filter_algo()};
}
std::vector<BwdFlt::Algorithm*> BwdFlt::get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) {
return {static_cast<HandleImpl*>(handle())
->default_deformable_conv_bwd_filter_algo()};
}
BwdFlt::Algorithm* BwdFlt::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
auto algo = static_cast<HandleImpl*>(handle())
->default_deformable_conv_bwd_filter_algo();
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
BwdFlt::Algorithm* BwdFlt::get_algorithm_from_desc(const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_deformable_conv_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
size_t BwdData::get_workspace_in_bytes(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
......@@ -417,4 +485,42 @@ void BwdData::exec(
PH, PW, SH, SW, DH, DW, OH, OW, group, deformable_group));
}
std::vector<BwdData::Algorithm*> BwdData::get_all_algorithms(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */, const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */, const TensorLayout& /* mask_grad */) {
return {static_cast<HandleImpl*>(handle())
->default_deformable_conv_bwd_data_algo()};
}
std::vector<BwdData::Algorithm*> BwdData::get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */, const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */, const TensorLayout& /* mask_grad */) {
return {static_cast<HandleImpl*>(handle())
->default_deformable_conv_bwd_data_algo()};
}
BwdData::Algorithm* BwdData::get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */, const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */, const TensorLayout& /* mask_grad */,
size_t /* workspace_limit_in_bytes */, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
auto algo =
static_cast<HandleImpl*>(handle())->default_deformable_conv_bwd_data_algo();
algo->check_attribute(positive_attr, negative_attr);
return algo;
}
BwdData::Algorithm* BwdData::get_algorithm_from_desc(const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_deformable_conv_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
// vim: syntax=cpp.doxygen
......@@ -12,24 +12,18 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) override {
return std::vector<Algorithm*>();
};
const TensorLayout& /* dst */) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) override {
return std::vector<Algorithm*>();
};
const TensorLayout& /* dst */) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override {
return nullptr;
};
const AlgoAttribute& /*negative_attr*/) override;
size_t get_workspace_in_bytes(
const TensorLayout& /* src */, const TensorLayout& /* filter */,
......@@ -42,7 +36,7 @@ public:
return "DEFORMABLE_CONV2_NAIVE";
};
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { return {}; }
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
......@@ -57,16 +51,12 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /* im */, const TensorLayout& /* offset */,
const TensorLayout& /* mask */, const TensorLayout& /* out_grad */,
const TensorLayout& /* filter_grad */) override {
return std::vector<Algorithm*>();
};
const TensorLayout& /* filter_grad */) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* offset */,
const TensorLayout& /* mask */, const TensorLayout& /* out_grad */,
const TensorLayout& /* filter_grad */) override {
return std::vector<Algorithm*>();
};
const TensorLayout& /* filter_grad */) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* offset */,
......@@ -74,9 +64,7 @@ public:
const TensorLayout& /* filter_grad */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override {
return nullptr;
};
const AlgoAttribute& /*negative_attr*/) override;
size_t get_workspace_in_bytes(
const TensorLayout& im, const TensorLayout& offset,
......@@ -87,7 +75,7 @@ public:
return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE";
};
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { return {}; }
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
void exec(
_megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
......@@ -104,18 +92,14 @@ public:
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */, const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */,
const TensorLayout& /* mask_grad */) override {
return std::vector<Algorithm*>();
};
const TensorLayout& /* mask_grad */) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */, const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */,
const TensorLayout& /* mask_grad */) override {
return std::vector<Algorithm*>();
};
const TensorLayout& /* mask_grad */) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
......@@ -124,9 +108,7 @@ public:
const TensorLayout& /* offset_grad */, const TensorLayout& /* mask_grad */,
size_t /* workspace_limit_in_bytes */,
const AlgoAttribute& /*positive_attr*/,
const AlgoAttribute& /*negative_attr*/) override {
return nullptr;
};
const AlgoAttribute& /*negative_attr*/) override;
size_t get_workspace_in_bytes(
const TensorLayout& im, const TensorLayout& filter,
......@@ -138,7 +120,7 @@ public:
return "DEFORMABLE_CONV2_BWD_DATA_NAIVE";
};
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { return {}; }
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
void exec(
_megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
......
......@@ -115,6 +115,10 @@ DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo;
DefaultPoolingForwardAlgorithm HandleImpl::m_default_pooling_fwd_algo;
DefaultPoolingBackwardAlgorithm HandleImpl::m_default_pooling_bwd_algo;
DeformableConvForwardAlgorithm HandleImpl::m_default_deformable_conv_fwd_algo;
DeformableConvBackwardDataAlgorithm HandleImpl::m_default_deformable_conv_bwd_data_algo;
DeformableConvBackwardFilterAlgorithm
HandleImpl::m_default_deformable_conv_bwd_filter_algo;
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, HandleType type)
: HandleImplHelper(computing_handle, type),
......
......@@ -38,6 +38,10 @@ class HandleImpl : public HandleImplHelper {
static DefaultPoolingForwardAlgorithm m_default_pooling_fwd_algo;
static DefaultPoolingBackwardAlgorithm m_default_pooling_bwd_algo;
static DeformableConvForwardAlgorithm m_default_deformable_conv_fwd_algo;
static DeformableConvBackwardDataAlgorithm m_default_deformable_conv_bwd_data_algo;
static DeformableConvBackwardFilterAlgorithm
m_default_deformable_conv_bwd_filter_algo;
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template <typename T>
......@@ -119,6 +123,18 @@ public:
return &m_default_pooling_bwd_algo;
}
DeformableConvForward::Algorithm* default_deformable_conv_fwd_algo() {
return &m_default_deformable_conv_fwd_algo;
}
DeformableConvBackwardData::Algorithm* default_deformable_conv_bwd_data_algo() {
return &m_default_deformable_conv_bwd_data_algo;
}
DeformableConvBackwardFilter::Algorithm* default_deformable_conv_bwd_filter_algo() {
return &m_default_deformable_conv_bwd_filter_algo;
}
Relayout* relayout_opr() override { return get_helper_opr<Relayout, 2>(this); }
/*!
* \brief pass a kernel to the dispatcher associated with the megcore
......
#include "megdnn/dtype.h"
#include "test/naive/fixture.h"
#include "megdnn/oprs/nn.h"
......@@ -52,6 +53,15 @@ TEST_F(NAIVE, DEFORMABLE_CONV_FWD) {
{1, 2 * 2 * 3 * 3, 5, 5},
{1, 2 * 3 * 3, 5, 5},
{}});
//! check algo interface
auto opr = handle()->create_operator<DeformableConv>();
auto i0 = megdnn::TensorLayout({1, 2, 5, 5}, megdnn::dtype::Float32());
auto i1 = megdnn::TensorLayout({2, 1, 1, 3, 3}, megdnn::dtype::Float32());
auto i2 = megdnn::TensorLayout({1, 2 * 2 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto i3 = megdnn::TensorLayout({1, 2 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto o = opr->get_algorithm_info_heuristic(i0, i1, i2, i3, {});
auto kk = o.desc.name;
printf("%s\n", kk.c_str());
}
TEST_F(NAIVE, DEFORMABLE_CONV_BWD_FILTER) {
......@@ -82,6 +92,18 @@ TEST_F(NAIVE, DEFORMABLE_CONV_BWD_FILTER) {
{1, 2 * 3 * 3, 5, 5},
{1, 2, 5, 5},
{2, 1, 1, 3, 3}});
//! check algo interface
auto opr = handle()->create_operator<DeformableConvBackwardFilter>();
auto i0 = megdnn::TensorLayout({1, 2, 5, 5}, megdnn::dtype::Float32());
auto i1 = megdnn::TensorLayout({1, 2 * 2 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto i2 = megdnn::TensorLayout({1, 2 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto i3 = megdnn::TensorLayout({1, 2, 5, 5}, megdnn::dtype::Float32());
auto i4 = megdnn::TensorLayout({2, 1, 1, 3, 3}, megdnn::dtype::Float32());
auto o = opr->get_algorithm_info_heuristic(i0, i1, i2, i3, i4);
auto kk = o.desc.name;
printf("%s\n", kk.c_str());
}
TEST_F(NAIVE, DEFORMABLE_CONV_BWD_DATA) {
......@@ -118,5 +140,18 @@ TEST_F(NAIVE, DEFORMABLE_CONV_BWD_DATA) {
{1, 2, 5, 5},
{1, 1 * 2 * 3 * 3, 5, 5},
{1, 1 * 3 * 3, 5, 5}});
//! check algo interface
auto opr = handle()->create_operator<DeformableConvBackwardData>();
auto i0 = megdnn::TensorLayout({1, 2, 5, 5}, megdnn::dtype::Float32());
auto i1 = megdnn::TensorLayout({2, 1, 1, 3, 3}, megdnn::dtype::Float32());
auto i2 = megdnn::TensorLayout({1, 1 * 2 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto i3 = megdnn::TensorLayout({1, 1 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto i4 = megdnn::TensorLayout({1, 2, 5, 5}, megdnn::dtype::Float32());
auto i5 = megdnn::TensorLayout({1, 2, 5, 5}, megdnn::dtype::Float32());
auto i6 = megdnn::TensorLayout({1, 1 * 2 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto i7 = megdnn::TensorLayout({1, 1 * 3 * 3, 5, 5}, megdnn::dtype::Float32());
auto o = opr->get_algorithm_info_heuristic(i0, i1, i2, i3, i4, i5, i6, i7);
auto kk = o.desc.name;
printf("%s\n", kk.c_str());
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册