提交 d7cc4628 编写于 作者: M Megvii Engine Team

perf(gopt): opt concat for OpenCL

GitOrigin-RevId: 9bb226d4b122bacaa9d7c1d69130bbc20eaed95e
上级 3f0bb47a
...@@ -1589,6 +1589,67 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1589,6 +1589,67 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_opr; return new_opr;
}; };
auto replace_concat_opr = [&relayout_inp_to_chw](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
//! map nchw axis to CD4 axis(n h c/4 w 4)
auto axis_nchw_to_cd4_map = [=](int32_t org_axis) -> int32_t {
mgb_assert(org_axis >= 0 && org_axis <= 3);
int32_t ret = 0;
if (0 == org_axis) {
ret = 0;
} else if (1 == org_axis) {
ret = 2;
} else if (2 == org_axis) {
ret = 1;
} else if (3 == org_axis) {
mgb_throw(InternalError,
"Do not support axis=3 for concat bypass for CD4!");
} else {
mgb_throw(InternalError,
"Do not support axis for concat pass, may input is "
"not NCHW format!");
}
return ret;
};
mgb_assert(opr->input().size() == new_inp.size());
auto nchw_axis = opr->cast_final_safe<opr::Concat>().param().axis;
if (nchw_axis < 0 || nchw_axis > 3) {
mgb_log_warn("concat pass fallback to relayout chw\n");
return relayout_inp_to_chw(opr, new_inp);
}
bool can_exec_cd4 = true;
//! only consider OpenCL CD4, if other backend has relayout performance
//! issue, may add other bypass format
for (size_t i = 0; i < opr->input().size(); i++) {
if (opr->input(i)->format().type() != TensorFormat::Type::DEFAULT ||
opr->input(i)->shape()[1] % 4 != 0 ||
new_inp[i]->shape().ndim != 5 ||
new_inp[i]->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4 ||
nchw_axis == 3) {
can_exec_cd4 = false;
break;
}
}
if (!can_exec_cd4) {
mgb_log_warn("concat pass fallback to relayout chw");
return relayout_inp_to_chw(opr, new_inp);
}
megdnn::param::Axis param;
//! now only support nchw bypass to CD4
mgb_log_warn("concat pass bypass to CD4");
param.axis = axis_nchw_to_cd4_map(nchw_axis);
return opr::Concat::make(VarNodeArrayView(new_inp), param,
opr->config())
.node()
->owner_opr();
};
auto replace_elemwise_opr = [&relayout_inp_to_chw]( auto replace_elemwise_opr = [&relayout_inp_to_chw](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
...@@ -1654,7 +1715,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1654,7 +1715,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Concat::typeinfo()] = replace_concat_opr;
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw; replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
......
...@@ -1591,6 +1591,77 @@ TEST(TestGoptInference, ConvertFormatPadIC) { ...@@ -1591,6 +1591,77 @@ TEST(TestGoptInference, ConvertFormatPadIC) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
} }
TEST(TestGoptInference, concatbypass) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_inp1 = gen({1, 6, 16, 16}, cn),
host_inp2 = gen({1, 6, 32, 32}, cn);
auto inp1 = opr::Host2DeviceCopy::make(*graph, host_inp1),
inp2 = opr::Host2DeviceCopy::make(*graph, host_inp2);
auto shape_tmp = mkcvar("tmp", {32, 32});
auto shape_of = opr::GetVarShape::make(shape_tmp);
opr::Resize::Param param_resize;
param_resize.format = opr::Resize::Param::Format::NCHW;
auto resize = opr::ResizeForward::make(inp1, shape_of, param_resize);
//! this concat should forward to chw
auto concat = opr::Concat::make({inp2, resize}, 1);
opr::Convolution::Param param;
param.pad_h = param.pad_w = 1;
param.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {12, 12, 3, 3});
auto w2 = mkcvar("w1", {12, 24, 3, 3});
auto y = opr::Convolution::make(concat, w1, param);
//! this concat should bypass CD4
y = opr::Concat::make({y, y}, 0);
y = opr::Convolution::make(y, w1, param);
//! this concat should bypass CD4
y = opr::Concat::make({y, y}, 1);
y = opr::Convolution::make(y, w2, param);
//! this concat should bypass CD4
y = opr::Concat::make({y, y}, 2);
y = opr::Convolution::make(y, w1, param);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
size_t relayout_format_nr = 0;
auto cb = [&](cg::OperatorNodeBase* opr) {
if (opr->try_cast_final<opr::Convolution>()) {
auto conv_inputs = opr->input();
for (auto& input : conv_inputs) {
if (std::string::npos !=
std::string(input->cname()).find("relayout_format")) {
relayout_format_nr++;
}
}
}
return true;
};
func->iter_opr_seq(cb);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(1, relayout_format_nr);
}
TEST(TestGoptInference, ConvertBatchNormPass) { TEST(TestGoptInference, ConvertBatchNormPass) {
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册