提交 7cd71c31 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix cd4 elewise transform

GitOrigin-RevId: 027d5e53e43088fb3bfb9932369a4a0429f736bc
上级 cae8c8a4
......@@ -1561,16 +1561,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_opr;
};
auto replace_elemwise_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto replace_elemwise_opr = [&relayout_inp_to_chw](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
bool can_exec_cd4 = true;
for (size_t i = 0; i < opr->input().size(); i++) {
if (!new_inp[i]->format().is_default()) {
has_inp_changed = true;
break;
} else if (new_inp[i]->shape().ndim == 4) {
if (new_inp[i]->shape()[1] % 4 != 0) {
can_exec_cd4 = false;
}
//! cd4 elemwise with scaler is supported
} else if (!new_inp[i]->shape().is_scalar()) {
can_exec_cd4 = false;
}
}
if (!can_exec_cd4) {
return relayout_inp_to_chw(opr, new_inp);
}
if (has_inp_changed) {
// assumption: all inputs are changed from nchw to nhwcd4
auto t_inp = new_inp;
......
......@@ -2232,6 +2232,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
if (new_inp[i]->shape()[1] % pack_c_size != 0) {
can_exec_ncwxx = false;
}
} else if (!new_inp[i]->shape().is_scalar()) {
can_exec_ncwxx = false;
}
}
if (has_inp_changed) {
......
......@@ -1197,6 +1197,67 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x = gen({8, 8, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::Convolution::Param param;
param.pad_h = param.pad_w = 0;
auto w1 = mkcvar("w1", {8, 8, 3, 3}),
conv = opr::Convolution::make(x, w1, param);
auto b = mkvar("b", {1, 1, 1, 1}),
elem = opr::Elemwise::make({conv + b},
opr::Elemwise::Param::Mode::RELU);
param.pad_h = param.pad_w = 1;
auto w2 = mkcvar("w2", {8, 8, 3, 3}),
conv2 = opr::Convolution::make(elem, w2, param);
auto b_scaler = mkvar("b", {1}), elem2 = conv2 + b_scaler;
param.pad_h = param.pad_w = 1;
auto w3 = mkcvar("w2", {8, 8, 3, 3}),
y = opr::Convolution::make(elem2, w3, param);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNHWCD4Elemwise.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
*host_x = *gen({8, 8, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
......@@ -3228,7 +3289,15 @@ TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) {
conv1 = opr::Convolution::make(x, w1, param_conv);
auto b = mkvar("b", {1, 1, 16, 16}),
y = opr::Elemwise::make({conv1 + b}, opr::Elemwise::Param::Mode::RELU);
elem0 = opr::Elemwise::make({conv1 + b + b},
opr::Elemwise::Param::Mode::RELU);
auto w2 = mkcvar("w2", {8, 8, 3, 3}),
conv2 = opr::Convolution::make(elem0, w2, param_conv);
auto b1 = mkvar("b1", {1}),
y = opr::Elemwise::make({conv2 + b1 + b},
opr::Elemwise::Param::Mode::RELU);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册