提交 8b0315b3 编写于 作者: M Megvii Engine Team

fix(mgb): fix nhwcd4 optpass

GitOrigin-RevId: 9295abec77af2763d301ba372116e4b2281f442a
上级 b588d93e
......@@ -94,6 +94,11 @@ DefaultConvolution3DBackwardFilterAlgorithm
HandleImpl::m_default_conv3d_bwd_filter_algo;
DefaultBatchConvBiasForwardAlgorithm
HandleImpl::m_default_batch_conv_bias_fwd_algo;
DefaultLocalShareForwardAlgorithm HandleImpl::m_default_local_share_fwd_algo;
DefaultLocalShareBackwardDataAlgorithm
HandleImpl::m_default_local_share_bwd_data_algo;
DefaultLocalShareBackwardFilterAlgorithm
HandleImpl::m_default_local_share_bwd_filter_algo;
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type)
......
......@@ -13,6 +13,7 @@
#include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/local_share/algorithms.h"
#include "src/naive/convolution3d/algorithms.h"
#include <functional>
......@@ -39,6 +40,11 @@ class HandleImpl : public HandleImplHelper {
m_default_conv3d_bwd_filter_algo;
static DefaultBatchConvBiasForwardAlgorithm
m_default_batch_conv_bias_fwd_algo;
static DefaultLocalShareForwardAlgorithm m_default_local_share_fwd_algo;
static DefaultLocalShareBackwardDataAlgorithm
m_default_local_share_bwd_data_algo;
static DefaultLocalShareBackwardFilterAlgorithm
m_default_local_share_bwd_filter_algo;
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template <typename T>
......@@ -91,6 +97,18 @@ public:
return &m_default_batch_conv_bias_fwd_algo;
}
LocalShareForward::Algorithm* default_local_share_fwd_algo() {
return &m_default_local_share_fwd_algo;
}
LocalShareBackwardData::Algorithm* default_local_share_bwd_data_algo() {
return &m_default_local_share_bwd_data_algo;
}
LocalShareBackwardFilter::Algorithm* default_local_share_bwd_filter_algo() {
return &m_default_local_share_bwd_filter_algo;
}
Relayout* relayout_opr() override {
return get_helper_opr<Relayout, 2>(this);
}
......
/**
* \file dnn/src/naive/local_share/algorithms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class DefaultLocalShareForwardAlgorithm final:
public megdnn::LocalShareForward::Algorithm {
bool is_reproducible() const override
{ return true; }
const char* name() const override
{ return "DEFAULT"; }
};
class DefaultLocalShareBackwardDataAlgorithm final:
public megdnn::LocalShareBackwardData::Algorithm {
bool is_reproducible() const override
{ return true; }
const char* name() const override
{ return "DEFAULT"; }
};
class DefaultLocalShareBackwardFilterAlgorithm final:
public megdnn::LocalShareBackwardFilter::Algorithm {
bool is_reproducible() const override
{ return true; }
const char* name() const override
{ return "DEFAULT"; }
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -152,4 +152,77 @@ void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src,
StrategyBwdFlt>(src, grad, diff, param())););
}
std::vector<LocalShareForward::Algorithm*>
LocalShareForwardImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()};
}
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) {
auto algo =
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
algo->name());
}
return algo;
}
std::vector<LocalShareBackwardData::Algorithm*>
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo()};
}
LocalShareBackwardData::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) {
auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
algo->name());
}
return algo;
}
std::vector<LocalShareBackwardFilter::Algorithm*>
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo()};
}
LocalShareBackwardFilter::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
bool reproducible) {
auto algo = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo();
if (reproducible) {
megdnn_assert(algo->is_reproducible(),
"require reproducible algorithm, but heuristic "
"algorithm(%s) is not "
"reproducible",
algo->name());
}
return algo;
}
// vim: syntax=cpp.doxygen
......@@ -27,17 +27,13 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*src*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/) override {
return {};
}
const TensorLayout& /*dst*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/,
const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/,
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override {
return nullptr;
}
bool /*reproducible*/) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
};
......@@ -55,17 +51,13 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override {
return {};
}
const TensorLayout& /*grad*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/,
const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/,
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override {
return nullptr;
}
bool /*reproducible*/) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
};
......@@ -83,17 +75,13 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*src*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override {
return {};
}
const TensorLayout& /*grad*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/,
const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/,
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override {
return nullptr;
}
bool /*reproducible*/) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
};
......
......@@ -14,6 +14,7 @@
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h"
......@@ -1358,23 +1359,28 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_pooling_opr.node()->owner_opr();
};
auto relayout_inp_to_chw = [](OperatorNodeBase* opr,
auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
if (!inp->shape().eq_shape(new_inp->shape())) {
mgb_assert(inp->shape().ndim == 4 &&
inp->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4);
mgb_assert(new_inp->shape().ndim == 5 &&
new_inp->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
auto param = megdnn::param::RelayoutFormat();
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
auto rf = opr::RelayoutFormat::make(new_inp, param);
return rf.node();
}
return new_inp;
};
auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray t_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
mgb_assert(opr->input(i)->shape().ndim == 4 &&
opr->input(i)->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4);
mgb_assert(new_inp[i]->shape().ndim == 5 &&
new_inp[i]->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
auto param = megdnn::param::RelayoutFormat();
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
auto rf = opr::RelayoutFormat::make(new_inp[i], param);
t_inp[i] = rf.node();
}
t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
}
auto new_opr =
serialization::copy_opr_shallow(*opr, t_inp, opr->config());
......@@ -1415,6 +1421,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
}
};
/* This helper function converts the first input to the NCHW format to
* handle operations that do not support NHWCD4 format
*/
auto relayout_first_inp_to_chw =
[var_to_chw](OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> OperatorNodeBase* {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray t_inp = new_inp;
t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
};
auto ret = std::make_unique<ConvertFormatPass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
auto&& replace_func = ret->m_opr_replace_func;
......@@ -1436,6 +1454,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr;
replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
replace_func[opr::GroupLocalForward::typeinfo()] =
relayout_first_inp_to_chw;
return ret;
}
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
#include "megbrain/gopt/inference.h"
......@@ -919,6 +920,69 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x = gen({2, 8, 8, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::Convolution::Param param;
param.pad_h = param.pad_w = 1;
auto w1 = mkcvar("w1", {4, 8, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param);
auto w2 = mkcvar("w2", {8, 16, 4, 3, 3, 4}),
local = opr::Local::make(conv1, w2, param);
auto w3 = mkcvar("w3", {4, 4, 3, 3}),
conv2 = opr::Convolution::make(local, w3, param);
opr::GroupLocal::Param param_group_local;
param_group_local.pad_h = param_group_local.pad_w = 1;
auto w4 = mkcvar("w4", {2, 8, 16, 2, 3, 3, 2}),
group_local = opr::GroupLocal::make(conv2, w4, param_group_local);
auto w5 = mkcvar("w5", {4, 4, 3, 3}),
y = opr::Convolution::make(group_local, w5, param);
SymbolVar y_opt;
unpack_vector(
gopt::optimize_for_inference(
{y},
gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()),
y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(opr::Local::Param::Format::NCHW,
find_opr<opr::Local>(y_opt).param().format);
ASSERT_EQ(opr::GroupLocal::Param::Format::NCHW,
find_opr<opr::GroupLocal>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNHWCD4LOCAL.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册