提交 26d1e4f7 编写于 作者: M Megvii Engine Team

feat(gopt): optimize cd4 pass rule for elemwise and typecvt to let cd4 start as soon as possible

GitOrigin-RevId: 6580dedca7d9854142fd3cf824604f52e6826093
上级 ac26bdce
...@@ -1588,45 +1588,61 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1588,45 +1588,61 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
auto replace_elemwise_opr = [&relayout_inp_to_chw]( auto replace_elemwise_opr = [&relayout_inp_to_chw](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
enum class TENSORPROPERTY {
SCALA = 0,
DEFAULT_MOD4 = 1,
NHWCD4 = 2,
UNKNOW = 3,
};
auto get_property = [](VarNode* node) -> TENSORPROPERTY {
auto&& shape = node->shape();
auto&& format = node->format();
if (shape.ndim == 4 && format.is_default() && shape[1] % 4 == 0) {
return TENSORPROPERTY::DEFAULT_MOD4;
}
if (shape.is_scalar()) {
return TENSORPROPERTY::SCALA;
}
if (shape.ndim == 5 && format.type() == TensorFormat::Type::IMAGE2D_PACK4) {
return TENSORPROPERTY::NHWCD4;
}
return TENSORPROPERTY::UNKNOW;
};
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
bool can_exec_cd4 = true; bool can_exec_cd4 = true;
for (size_t i = 0; i < opr->input().size(); i++) { for (size_t i = 0; i < opr->input().size(); i++) {
if (!new_inp[i]->format().is_default()) { auto property = get_property(new_inp[i]);
has_inp_changed = true; if (property == TENSORPROPERTY::UNKNOW) {
} else if (new_inp[i]->shape().ndim == 4) {
if (new_inp[i]->shape()[1] % 4 != 0) {
can_exec_cd4 = false;
}
//! cd4 elemwise with scaler is unsupported
} else if (!new_inp[i]->shape().is_scalar()) {
can_exec_cd4 = false; can_exec_cd4 = false;
break;
} }
} }
if (!can_exec_cd4) { if (!can_exec_cd4) {
return relayout_inp_to_chw(opr, new_inp); return relayout_inp_to_chw(opr, new_inp);
} }
if (has_inp_changed) {
// assumption: all inputs are changed from nchw to nhwcd4 //! check and change all inputs to cd4
auto t_inp = new_inp; auto t_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) { for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 4) { auto property = get_property(new_inp[i]);
auto param = megdnn::param::RelayoutFormat(); if (property == TENSORPROPERTY::DEFAULT_MOD4) {
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; auto param = megdnn::param::RelayoutFormat();
auto rf = opr::RelayoutFormat::make(new_inp[i], param); param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
t_inp[i] = rf.node(); auto rf = opr::RelayoutFormat::make(new_inp[i], param);
} else { t_inp[i] = rf.node();
mgb_assert( } else {
(new_inp[i]->shape().ndim == 5 && mgb_assert(
new_inp[i]->format().type() == property == TENSORPROPERTY::SCALA ||
TensorFormat::Type::IMAGE2D_PACK4) || property == TENSORPROPERTY::NHWCD4,
new_inp[i]->shape().is_scalar()); "This node should be scala ir CD4 format, but got shape = %s, "
} "format = %s",
new_inp[i]->shape().to_string().c_str(),
new_inp[i]->format().to_string().c_str());
} }
return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
} else {
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
} }
return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
}; };
/* This helper function converts the first input to the NCHW format to /* This helper function converts the first input to the NCHW format to
...@@ -1661,7 +1677,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1661,7 +1677,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr; replace_warp_perspective_opr;
......
...@@ -1270,6 +1270,54 @@ TEST(TestGoptInference, ConvertFormatNHWCD4OpenCL) { ...@@ -1270,6 +1270,54 @@ TEST(TestGoptInference, ConvertFormatNHWCD4OpenCL) {
#undef REQUIRE_OPENCL #undef REQUIRE_OPENCL
#endif #endif
//! this is to test elemwise to cd4 only
TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise0) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto host_x = gen({8, 8, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto a = mkvar("a", {1});
auto b = mkvar("b", {1});
auto y = x * a + b;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(
opr::Elemwise::Mode::FUSE_MUL_ADD3,
find_opr<opr::Elemwise>(y_opt).param().mode);
ASSERT_EQ(
TensorFormat::Type::IMAGE2D_PACK4,
find_opr<opr::Elemwise>(y_opt).input(1)->format().type());
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNHWCD4Elemwise0.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
*host_x = *gen({8, 8, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) {
// hwcd4 is only supported in naive handle // hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle; NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册