提交 3a5347ed 编写于 作者: M Megvii Engine Team

perf(imperative): speed up pooling

GitOrigin-RevId: 9f60b45eebf81fbb7f483328815d3744dc4d5811
上级 c0b267ff
......@@ -16,50 +16,55 @@
namespace megdnn {
void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) {
auto errmsg =
megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
"pad_h=" + std::to_string(param().pad_h) + ", " +
"pad_w=" + std::to_string(param().pad_w) + ", " +
"stride_h=" + std::to_string(param().stride_h) + ", " +
"stride_w=" + std::to_string(param().stride_w) + ", " +
"window_h=" + std::to_string(param().window_h) + ", " +
"window_w=" + std::to_string(param().window_w) + ", " +
"is_max=" + std::to_string(param().mode == Mode::MAX) + ", " +
"is_nhwc=" + std::to_string(param().format == Param::Format::NHWC) + ", " +
"is_nhwcd4=" + std::to_string(param().format == Param::Format::NHWCD4);
auto errmsg_c = errmsg.c_str();
MEGDNN_MARK_USED_VAR(errmsg_c);
auto& p = param();
auto pformat = p.format;
// the overhead of generating error message is about 18x of the other part of this
// function so we use a function to wrap the error message and get it only when need.
auto get_errmsg = [&](void) -> std::string {
std::string errmsg =
megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
"pad_h=" + std::to_string(param().pad_h) + ", " +
"pad_w=" + std::to_string(param().pad_w) + ", " +
"stride_h=" + std::to_string(param().stride_h) + ", " +
"stride_w=" + std::to_string(param().stride_w) + ", " +
"window_h=" + std::to_string(param().window_h) + ", " +
"window_w=" + std::to_string(param().window_w) + ", " +
"is_max=" + std::to_string(param().mode == Mode::MAX) + ", " +
"is_nhwc=" + std::to_string(pformat == Param::Format::NHWC) + ", " +
"is_nhwcd4=" + std::to_string(pformat == Param::Format::NHWCD4);
return errmsg;
};
MEGDNN_MARK_USED_VAR(get_errmsg);
megdnn_assert_contiguous(src);
size_t spatial_pos, c_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW) {
megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
if (pformat == Param::Format::NCHW) {
megdnn_assert(src.ndim == 4_z, "%s", get_errmsg().c_str());
spatial_pos = 2;
c_pos = 1;
} else if (param().format == Param::Format::NHWC) {
megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
} else if (pformat == Param::Format::NHWC) {
megdnn_assert(src.ndim == 4_z, "%s", get_errmsg().c_str());
spatial_pos = 1;
c_pos = 3;
} else if (
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW64) {
megdnn_assert(src.ndim == 5_z, "%s", errmsg_c);
pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44 ||
pformat == Param::Format::NCHW88 || pformat == Param::Format::NCHW32 ||
pformat == Param::Format::NCHW64) {
megdnn_assert(src.ndim == 5_z, "%s", get_errmsg().c_str());
spatial_pos = 2;
c_pos = 1;
} else if (param().format == Param::Format::CHWN4) {
} else if (pformat == Param::Format::CHWN4) {
spatial_pos = 1;
c_pos = 0;
batch_pos = 3;
} else {
megdnn_assert(
param().format == Param::Format::NHWCD4 && src.ndim == 5_z, "%s",
errmsg_c);
pformat == Param::Format::NHWCD4 && src.ndim == 5_z, "%s",
get_errmsg().c_str());
spatial_pos = 1;
c_pos = 2;
}
......@@ -67,31 +72,34 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
size_t c = src[c_pos];
size_t ih = src[spatial_pos];
size_t iw = src[spatial_pos + 1];
if (param().format == Param::Format::NHWCD4) {
if (pformat == Param::Format::NHWCD4) {
c *= 4;
iw = src[spatial_pos + 2];
}
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) {
if (pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44 ||
pformat == Param::Format::CHWN4) {
c *= 4;
}
if (param().format == Param::Format::NCHW88) {
if (pformat == Param::Format::NCHW88) {
c *= 8;
}
if (param().format == Param::Format::NCHW32) {
if (pformat == Param::Format::NCHW32) {
c *= 32;
}
if (param().format == Param::Format::NCHW64) {
if (pformat == Param::Format::NCHW64) {
c *= 64;
}
size_t oh, ow;
size_t fh = this->param().window_h;
size_t fw = this->param().window_w;
size_t sh = this->param().stride_h;
size_t sw = this->param().stride_w;
size_t ph = this->param().pad_h;
size_t pw = this->param().pad_w;
size_t fh = p.window_h;
size_t fw = p.window_w;
size_t sh = p.stride_h;
size_t sw = p.stride_w;
size_t ph = p.pad_h;
size_t pw = p.pad_w;
// moving some python assert to here
// megdnn_assert()
if (ph >= fh || pw >= fw) {
megdnn_log_warn(
"pooling padding size (%zu %zu) should not be bigger than "
......@@ -99,26 +107,23 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
pw, ph, fw, fh);
}
infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow);
if (param().format == Param::Format::NCHW) {
if (pformat == Param::Format::NCHW) {
dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype);
} else if (param().format == Param::Format::NHWC) {
megdnn_assert(param().format == Param::Format::NHWC, "invalid pooling format");
} else if (pformat == Param::Format::NHWC) {
megdnn_assert(pformat == Param::Format::NHWC, "invalid pooling format");
dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format);
} else if (
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44) {
} else if (pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44) {
dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format};
} else if (param().format == Param::Format::NCHW88) {
} else if (pformat == Param::Format::NCHW88) {
dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format};
} else if (param().format == Param::Format::NCHW32) {
} else if (pformat == Param::Format::NCHW32) {
dst = TensorLayout{{n, c / 32, oh, ow, 32}, src.dtype, src.format};
} else if (param().format == Param::Format::NCHW64) {
} else if (pformat == Param::Format::NCHW64) {
dst = TensorLayout{{n, c / 64, oh, ow, 64}, src.dtype, src.format};
} else if (param().format == Param::Format::CHWN4) {
} else if (pformat == Param::Format::CHWN4) {
dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format};
} else {
megdnn_assert(
param().format == Param::Format::NHWCD4, "invalid pooling format");
megdnn_assert(pformat == Param::Format::NHWCD4, "invalid pooling format");
dst = TensorLayout{{n, oh, c / 4, ow, 4}, src.dtype, src.format};
}
}
......
/**
* \file imperative/src/impl/ops/pooling.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace {
namespace pooling {
// using OprHandle = opr::intl::UniqPtrWithCN<megdnn::Pooling>;
// static ThinHashMap<CompNode, OprHandle> dnn_oprs;
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
OperatorNodeConfig config{pool.make_name()};
return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(
inputs.size() == 1, "num of inputs of pooling should be 1 but you give %zu",
inputs.size());
auto&& op_def = def.cast_final_safe<Pooling>();
auto&& inp = inputs[0];
auto& inp_cn = inp.comp_node;
if (inp.layout.ndim == 0) {
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false};
}
DnnOprCaller<megdnn::Pooling> caller(inp_cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = op_def.param();
TensorLayout oup_layout;
dnn_opr->deduce_layout(inp.layout, oup_layout);
return {{{oup_layout, inp_cn, {}}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
mgb_assert(
inputs.size() == 1, "num of inputs of pooling should be 1 but you give %zu",
inputs.size());
auto&& op_def = def.cast_final_safe<Pooling>();
auto cn = inputs[0]->comp_node();
megdnn::TensorND inp_tensornd = inputs[0]->dnn_tensor();
DnnOprCaller<megdnn::Pooling> caller(cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = op_def.param();
TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) {
dnn_opr->deduce_layout(inp_tensornd.layout, oup_layout);
}
DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
size_t wk_size = setup_algo<megdnn::Pooling>(
{inp_tensornd.layout, oup_layout}, dnn_opr.get(), 0, false, false, cn,
op_def.policy(), false);
megdnn::Workspace dnn_wk;
if (wk_size != 0) {
auto wk = Blob::make(cn, wk_size);
dnn_wk.raw_ptr = wk->storage().get();
dnn_wk.size = wk_size;
}
dnn_opr->exec(inp_tensornd, out_devtensor.as_megdnn(), {});
return {Tensor::make(out_devtensor)};
}
OP_TRAIT_REG(Pooling, Pooling)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace pooling
} // namespace
} // namespace mgb::imperative
......@@ -333,17 +333,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias)
} // namespace batch_conv_bias
} // namespace
namespace {
namespace pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
OperatorNodeConfig config{pool.make_name()};
return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
}
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
} // namespace pooling
} // namespace
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册