提交 26ea33c6 编写于 作者: M Megvii Engine Team

perf(imperative): improve convbwd performance

GitOrigin-RevId: cfc8623d7a818c84ee8d6ed56e0dea6251ef36d6
上级 3949d425
...@@ -310,7 +310,7 @@ public: ...@@ -310,7 +310,7 @@ public:
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) = 0; const TensorLayout& grad) = 0;
void deduce_dtype(DType filter, DType diff, DType& grad); MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType filter, DType diff, DType& grad);
void deduce_layout( void deduce_layout(
const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad); const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
......
...@@ -250,8 +250,165 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -250,8 +250,165 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
} }
} }
TensorLayout convbwd_do_shape_infer(
const OpDef& def, size_t diff_ndim, TensorLayout filter, TensorLayout diff,
CompNode cn) {
auto&& bwd_conv = static_cast<const ConvolutionBackwardData&>(def);
DnnOprCaller<megdnn::ConvolutionBackwardData> caller(cn);
auto&& dnn_opr = caller.op;
using Param = ::megdnn::param::Convolution;
// using Param1 = ::megdnn::param::ConvolutionBackwardData;
auto img_ndim = diff_ndim - 2;
mgb_assert(
img_ndim == 2,
"only 2D convolution is supported, and input should be 4-dim; "
"got input dim = %zu",
diff_ndim);
size_t group = 1;
size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
if (bwd_conv.sparse == Param::Sparse::DENSE) {
mgb_assert(
filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
"bad filter ndim for dense convolution: "
"spatial_ndim=%zu filter_ndim=%zu",
img_ndim, filter.ndim);
group = 1;
flt_start = 0;
} else { // Param::Sparse::GROUP
mgb_assert(
filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
"bad filter ndim for group convolution: "
"spatial_ndim=%zu filter_ndim=%zu",
img_ndim, filter.ndim);
// grp, oc, ic, dims[]
group = filter[0];
flt_start = 1;
}
uint32_t ic_block_size = 1, oc_block_size = 1;
size_t src_or_dst_c_pos = 0;
size_t src_or_dst_spatial_start = 0;
if (bwd_conv.format == Param::Format::NCHW) {
// filter should be (oc, ic, fh, fw)
flt_spatial_start = 2;
ocpg_pos = 0;
icpg_pos = 1;
src_or_dst_c_pos = 1;
src_or_dst_spatial_start = 2;
} else { // Param::Format::NHWC
// filter should be (oc, fh, fw, ic)
flt_spatial_start = 1;
ocpg_pos = 0;
icpg_pos = 3;
src_or_dst_c_pos = 3;
src_or_dst_spatial_start = 1;
}
size_t ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
size_t icpg = filter[flt_start + icpg_pos] * ic_block_size;
uint32_t dilation[2], dilated_spatial[2], stride[2], padding[2];
dilation[0] = bwd_conv.dilate_h;
dilation[1] = bwd_conv.dilate_w;
stride[0] = bwd_conv.stride_h;
stride[1] = bwd_conv.stride_w;
padding[0] = bwd_conv.pad_h;
padding[1] = bwd_conv.pad_w;
for (size_t i = 0; i < img_ndim; ++i) {
mgb_assert(
dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
dilation[i]);
dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
}
mgb_assert(ocpg * group == diff[src_or_dst_c_pos], "group conv invalid");
auto deduce = [](size_t out, size_t filter, size_t stride, size_t pad) {
auto i = (out - 1) * stride + filter;
mgb_assert(i > pad * 2);
return i - pad * 2;
};
DType dst_dtype = bwd_conv.dtype;
dnn_opr->deduce_dtype(filter.dtype, diff.dtype, dst_dtype);
TensorLayout dst{dst_dtype};
dst.ndim = diff_ndim;
dst[0] = diff[0];
dst[src_or_dst_c_pos] = icpg * group;
for (size_t i = 0; i < img_ndim; ++i) {
dst[i + src_or_dst_spatial_start] =
deduce(diff[i + src_or_dst_spatial_start], dilated_spatial[i],
stride[i], padding[i]);
}
dst.init_contiguous_stride();
return dst;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
SmallVector<LogicalTensorDesc> dests(1);
auto&& desc = dests[0];
desc.comp_node = inputs[0].comp_node;
TensorLayout filter = inputs[0].layout;
TensorLayout diff = inputs[1].layout;
size_t filter_ndim = filter.ndim;
size_t diff_ndim = diff.ndim;
if (filter_ndim == 0) {
desc.layout = filter;
return {dests, false};
}
desc.layout =
convbwd_do_shape_infer(def, diff_ndim, filter, diff, inputs[0].comp_node);
return {dests, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
// create megdnn opr
auto&& convbwd = static_cast<const ConvolutionBackwardData&>(def);
CompNode cn = inputs[0]->comp_node();
TensorLayout out_layout = output_descs[0].layout;
if (!validated)
out_layout = convbwd_do_shape_infer(
def, inputs[1]->layout().ndim, inputs[0]->layout(), inputs[1]->layout(),
cn);
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout);
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds(inputs.size());
TensorLayoutArray inp_shapes(inputs.size()), oup_shapes(output_descs.size());
for (unsigned i = 0; i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dnn_tensor();
inp_shapes[i] = inputs[i]->layout();
}
oup_shapes[0] = out_layout;
DnnOprCaller<megdnn::ConvolutionBackwardData> dnn_opr(cn);
dnn_opr.op->param() = convbwd.param();
size_t sz = setup_algo<megdnn::ConvolutionBackwardData>(
{inp_shapes[0], inp_shapes[1], oup_shapes[0]}, dnn_opr.op.get(), 0, false,
false, cn, convbwd.policy(), false);
auto wk = Blob::make(cn, sz);
auto ptr = wk->storage().get();
megdnn::Workspace dnn_wk(ptr, sz);
// exeucte
dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk);
return {Tensor::make(out)};
}
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback(); .fallback();
} // namespace convolution_backward_data } // namespace convolution_backward_data
} // namespace } // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册