perf(imperative): speed up conv_transpose3d

GitOrigin-RevId: e741305446e926086c36affcb54d77f739133bbe
上级 3a5347ed
......@@ -784,6 +784,10 @@ public:
protected:
void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
public:
MGE_WIN_DECLSPEC_FUC static void deduce_layout_impl(
const TensorLayout& src, const Param& param, TensorLayout& dst);
};
class PoolingForward : public PoolingBase,
......@@ -1241,6 +1245,8 @@ protected:
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) const;
static CanonizedFilterMeta make_canonized_filter_meta_impl(
size_t src_ndim, const TensorLayout& filter, const Param& param);
CanonizedFilterMeta make_canonized_filter_meta(
size_t src_ndim, const TensorLayout& filter) const;
};
......@@ -1286,6 +1292,10 @@ public:
* \param[in] diff (n, oc, od, oh, ow)
* \param[out] grad (n, ic, id, ih, iw)
*/
MGE_WIN_DECLSPEC_FUC static void deduce_layout_impl(
const TensorLayout& filter, const TensorLayout& diff, const Param& param,
TensorLayout& grad);
virtual void exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
......
......@@ -38,17 +38,18 @@ std::string get_errmsg(
}
} // namespace
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_meta(
size_t src_ndim, const TensorLayout& filter) const {
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::
make_canonized_filter_meta_impl(
size_t src_ndim, const TensorLayout& filter, const Param& param) {
megdnn_assert_contiguous(filter);
auto img_ndim = src_ndim - 2;
CanonizedFilterMeta ret;
ret.dtype_enum = filter.dtype.enumv();
ret.format = param().format;
if (param().mode == Mode::CONVOLUTION) {
ret.format = param.format;
if (param.mode == Mode::CONVOLUTION) {
ret.should_flip = true;
} else {
megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
megdnn_assert(param.mode == Mode::CROSS_CORRELATION, "invalid conv mode");
ret.should_flip = false;
}
size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
......@@ -56,7 +57,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
MEGDNN_MARK_USED_VAR(ocpg_pos);
MEGDNN_MARK_USED_VAR(icpg_pos);
if (param().sparse == Param::Sparse::DENSE) {
if (param.sparse == Param::Sparse::DENSE) {
megdnn_assert(
filter.ndim == img_ndim + 2,
"bad filter ndim for dense convolution: "
......@@ -66,7 +67,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
flt_start = 0;
} else {
megdnn_assert(
param().sparse == Param::Sparse::GROUP,
param.sparse == Param::Sparse::GROUP,
"invalid convolution sparse type");
megdnn_assert(
filter.ndim == img_ndim + 3,
......@@ -77,14 +78,14 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
flt_start = 1;
}
if (param().format == Param::Format::NCDHW) {
if (param.format == Param::Format::NCDHW) {
// filter should be (oc, ic, fd, fh, fw)
flt_spatial_start = 2;
ocpg_pos = 0;
icpg_pos = 1;
} else {
megdnn_assert(
param().format == Param::Format::NDHWC, "invalid conv tensor format");
param.format == Param::Format::NDHWC, "invalid conv tensor format");
// filter should be (oc, fd, fh, fw, ic)
flt_spatial_start = 1;
ocpg_pos = 0;
......@@ -96,15 +97,15 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
"only 3D convolution is supported, and input should be 5-dim; "
"got input dim = %zu",
src_ndim);
ret.stride[0] = this->param().stride_d;
ret.stride[1] = this->param().stride_h;
ret.stride[2] = this->param().stride_w;
ret.padding[0] = this->param().pad_d;
ret.padding[1] = this->param().pad_h;
ret.padding[2] = this->param().pad_w;
ret.dilation[0] = param().dilate_d;
ret.dilation[1] = param().dilate_h;
ret.dilation[2] = param().dilate_w;
ret.stride[0] = param.stride_d;
ret.stride[1] = param.stride_h;
ret.stride[2] = param.stride_w;
ret.padding[0] = param.pad_d;
ret.padding[1] = param.pad_h;
ret.padding[2] = param.pad_w;
ret.dilation[0] = param.dilate_d;
ret.dilation[1] = param.dilate_h;
ret.dilation[2] = param.dilate_w;
ret.ocpg = filter[flt_start + ocpg_pos];
ret.icpg = filter[flt_start + icpg_pos];
for (size_t i = 0; i < ret.spatial_ndim; ++i) {
......@@ -117,6 +118,11 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
return ret;
}
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_meta(
size_t src_ndim, const TensorLayout& filter) const {
return make_canonized_filter_meta_impl(src_ndim, filter, param());
}
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) const {
auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
......@@ -213,12 +219,13 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec(
return ret;
}
void Convolution3DBackwardData::deduce_layout(
const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
void Convolution3DBackwardData::deduce_layout_impl(
const TensorLayout& filter, const TensorLayout& diff, const Param& param,
TensorLayout& grad) {
megdnn_assert(
param().data_type == Param::DataType::FLOAT,
param.data_type == Param::DataType::FLOAT,
"only float type is supported for conv backward");
auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param); };
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert_contiguous(filter);
megdnn_assert_contiguous(diff);
......@@ -226,7 +233,7 @@ void Convolution3DBackwardData::deduce_layout(
megdnn_assert(diff.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(filter.dtype == diff.dtype, "%s", errmsg().c_str());
auto cflt = make_canonized_filter_meta(diff.ndim, filter);
auto cflt = make_canonized_filter_meta_impl(diff.ndim, filter, param);
megdnn_assert(cflt.ocpg * cflt.group == diff[1], "%s", errmsg().c_str());
auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
......@@ -247,6 +254,11 @@ void Convolution3DBackwardData::deduce_layout(
grad.init_contiguous_stride();
}
void Convolution3DBackwardData::deduce_layout(
const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
deduce_layout_impl(filter, diff, param(), grad);
}
Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardFilter::check_exec(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes) {
......
......@@ -15,22 +15,22 @@
namespace megdnn {
void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) {
auto& p = param();
auto pformat = p.format;
void PoolingBase::deduce_layout_impl(
const TensorLayout& src, const Param& param, TensorLayout& dst) {
auto pformat = param.format;
// the overhead of generating error message is about 18x of the other part of this
// function so we use a function to wrap the error message and get it only when need.
auto get_errmsg = [&](void) -> std::string {
std::string errmsg =
megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
"pad_h=" + std::to_string(param().pad_h) + ", " +
"pad_w=" + std::to_string(param().pad_w) + ", " +
"stride_h=" + std::to_string(param().stride_h) + ", " +
"stride_w=" + std::to_string(param().stride_w) + ", " +
"window_h=" + std::to_string(param().window_h) + ", " +
"window_w=" + std::to_string(param().window_w) + ", " +
"is_max=" + std::to_string(param().mode == Mode::MAX) + ", " +
"pad_h=" + std::to_string(param.pad_h) + ", " +
"pad_w=" + std::to_string(param.pad_w) + ", " +
"stride_h=" + std::to_string(param.stride_h) + ", " +
"stride_w=" + std::to_string(param.stride_w) + ", " +
"window_h=" + std::to_string(param.window_h) + ", " +
"window_w=" + std::to_string(param.window_w) + ", " +
"is_max=" + std::to_string(param.mode == Mode::MAX) + ", " +
"is_nhwc=" + std::to_string(pformat == Param::Format::NHWC) + ", " +
"is_nhwcd4=" + std::to_string(pformat == Param::Format::NHWCD4);
return errmsg;
......@@ -90,12 +90,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
c *= 64;
}
size_t oh, ow;
size_t fh = p.window_h;
size_t fw = p.window_w;
size_t sh = p.stride_h;
size_t sw = p.stride_w;
size_t ph = p.pad_h;
size_t pw = p.pad_w;
size_t fh = param.window_h;
size_t fw = param.window_w;
size_t sh = param.stride_h;
size_t sw = param.stride_w;
size_t ph = param.pad_h;
size_t pw = param.pad_w;
// moving some python assert to here
// megdnn_assert()
......@@ -128,12 +128,15 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
}
}
void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) {
deduce_layout_impl(src, param(), dst);
}
void PoolingBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) {
TensorLayout dst_expected;
megdnn_assert_eq_dtype(src, dst);
deduce_layout_fwd(src, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert(src.dtype == dst.dtype);
megdnn_assert(
src.dtype.category() == DTypeCategory::FLOAT ||
src.dtype == dtype::Int8() ||
......
......@@ -93,12 +93,17 @@ __all__ = [
def expand_hw(x):
# judge int is 5 times faster than judge Sequence
if isinstance(x, int):
return x, x
if isinstance(x, Sequence):
return int(x[0]), int(x[1])
return int(x), int(x)
def expand_dhw(x):
if isinstance(x, int):
return x, x, x
if isinstance(x, Sequence):
return int(x[0]), int(x[1]), int(x[2])
return int(x), int(x), int(x)
......
#pragma once
#include "megbrain/rdnn/algo_chooser.h"
#include "megdnn/heuristic_cache.h"
......
......@@ -8,6 +8,7 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h"
......
......@@ -579,6 +579,63 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
namespace {
namespace convolution3d_backward_data {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(
inputs.size() == 2,
"inputs num of conv_transpose3d should be 2 but you give %zu",
inputs.size());
auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>();
auto&& weight = inputs[0];
auto&& diff = inputs[1];
auto& cn = weight.comp_node;
if (weight.layout.ndim == 0) {
return {{{TensorLayout{weight.layout.dtype}, cn, {}}}, false};
}
TensorLayout oup_layout;
megdnn::Convolution3DBackwardData::deduce_layout_impl(
weight.layout, diff.layout, op_def.param(), oup_layout);
return {{{oup_layout, cn, {}}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>();
auto cn = inputs[0]->comp_node();
megdnn::TensorND weight = inputs[0]->dnn_tensor();
megdnn::TensorND diff = inputs[1]->dnn_tensor();
DnnOprCaller<megdnn::Convolution3DBackwardData> caller(cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = op_def.param();
TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) {
megdnn::Convolution3DBackwardData::deduce_layout_impl(
weight.layout, diff.layout, op_def.param(), oup_layout);
}
DeviceTensorND oup =
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
size_t wk_size = setup_algo<megdnn::Convolution3DBackwardData>(
{weight.layout, diff.layout, oup_layout}, dnn_opr.get(), 0, false, false,
cn, op_def.policy(), false);
megdnn::Workspace dnn_wk;
if (wk_size != 0) {
auto wk = Blob::make(cn, wk_size);
dnn_wk.raw_ptr = wk->storage().get();
dnn_wk.size = wk_size;
}
dnn_opr->exec(weight, diff, oup.as_megdnn(), dnn_wk);
return {Tensor::make(oup)};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()};
......@@ -589,6 +646,8 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace convolution3d_backward_data
} // namespace
......
......@@ -11,6 +11,7 @@
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
......@@ -25,9 +26,6 @@ namespace mgb::imperative {
namespace {
namespace pooling {
// using OprHandle = opr::intl::UniqPtrWithCN<megdnn::Pooling>;
// static ThinHashMap<CompNode, OprHandle> dnn_oprs;
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
OperatorNodeConfig config{pool.make_name()};
......@@ -48,11 +46,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false};
}
DnnOprCaller<megdnn::Pooling> caller(inp_cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = op_def.param();
TensorLayout oup_layout;
dnn_opr->deduce_layout(inp.layout, oup_layout);
megdnn::Pooling::deduce_layout_impl(inp.layout, op_def.param(), oup_layout);
return {{{oup_layout, inp_cn, {}}}, true};
}
......@@ -73,7 +69,8 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) {
dnn_opr->deduce_layout(inp_tensornd.layout, oup_layout);
megdnn::Pooling::deduce_layout_impl(
inp_tensornd.layout, op_def.param(), oup_layout);
}
DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部