diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 68902d2f5448b154ada9e8380fd3cabeef6db662..87626537eb10d935235111bb52a797e97edae801 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -26,6 +26,7 @@ #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" +#include "megbrain/opr/tensor_gen.h" #include "megdnn/tensor_format.h" @@ -741,6 +742,19 @@ std::unique_ptr ConvertF32ToF16Pass::make( return opr; }; + auto replace_lsp_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->same_type()); + mgb_assert(opr->input().size() == new_inp.size()); + auto& lsp_opr = opr->cast_final_safe(); + if (lsp_opr.output(0)->dtype() != dtype::Float16()) { + auto cvt_var = + opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {}); + return cvt_var.node()->owner_opr(); + } + return opr; + }; + auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -778,6 +792,29 @@ std::unique_ptr ConvertF32ToF16Pass::make( return new_matmul_opr.node()->owner_opr(); }; + auto replace_batched_matmul_opr = [use_f32_comp]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& matmul_opr = opr->cast_final_safe(); + auto new_param = matmul_opr.param(); + if (use_f32_comp) { + new_param.compute_mode = + megdnn::param::MatrixMul::ComputeMode::FLOAT32; + } + mgb_assert(new_inp[0]->dtype() == dtype::Float16(), + "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(), + new_inp[0]->name().c_str(), + new_inp[0]->owner_opr()->name().c_str()); + mgb_assert(new_inp[1]->dtype() == dtype::Float16(), + "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(), + new_inp[1]->name().c_str(), + new_inp[1]->owner_opr()->name().c_str()); + auto new_matmul_opr = opr::BatchedMatrixMul::make( + new_inp[0], new_inp[1], new_param, matmul_opr.config()); + return new_matmul_opr.node()->owner_opr(); + }; + auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr, const VarNodeArray& new_inp) { auto& reduce_opr = opr->cast_final_safe(); @@ -871,6 +908,7 @@ std::unique_ptr ConvertF32ToF16Pass::make( ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ VarReplaceCheckFlag::CHECK_DTYPE); auto&& replace_func = ret->m_opr_replace_func; + replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr; replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr; replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr; replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; @@ -880,6 +918,8 @@ std::unique_ptr ConvertF32ToF16Pass::make( replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr; replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; replace_func[opr::Remap::typeinfo()] = replace_remap_opr; + replace_func[opr::BatchedMatrixMul::typeinfo()] = + replace_batched_matmul_opr; return ret; #endif } diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 457941a0382f704c86611394da01e0e2468a89cf..7d6456ace952767f0183e576fb49d33695959f7a 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -27,6 +27,8 @@ #include "megbrain/opr/nn_int.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/tensor_gen.h" +#include "megbrain/opr/blas.h" #include "megbrain/comp_node_env.h" #include "./helper.h" @@ -892,6 +894,67 @@ TEST(TestGoptInference, Float32TOFloat16EndpointElemwise) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, Float32TOFloat16Linspace) { + CompNode cn = CompNode::load("cpu0"); + HostTensorGenerator<> gen(0, 1, 0); + auto host_x = gen({3, 1}, cn); + auto graph = ComputingGraph::make(); + + auto make_f32_to_f16_graph = [&]() { + graph->options().graph_opt_level = 0; + + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto lin = opr::Linspace::make(cv(0), sub(0) - 1, sub(0), {}, {}); + auto shp = opr::Concat::make({sub(1), sub(0)}, 0); + auto y = opr::Reshape::make(lin, shp); + auto mm = opr::MatrixMul::make(x, y); + + SymbolVar mm_opt; + unpack_vector(gopt::optimize_for_inference( + {mm}, gopt::OptimizeForInferenceOptions{} + .enable_f16_io_comp()), + mm_opt); + return mm_opt; + }; + + auto make_f16_graph = [&]() { + auto x = opr::TypeCvt::make(opr::Host2DeviceCopy::make(*graph, host_x), + dtype::Float16()); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto lin = opr::Linspace::make(cv(0), sub(0) - 1, sub(0), {}, {}); + lin = opr::TypeCvt::make(lin, dtype::Float16()); + auto shp = opr::Concat::make({sub(1), sub(0)}, 0); + auto y = opr::Reshape::make(lin, shp); + auto mm = opr::MatrixMul::make(x, y); + + mm = opr::TypeCvt::make(mm, dtype::Float32{}); + + return mm; + }; + + auto y_opt = make_f32_to_f16_graph(); + auto y = make_f16_graph(); + ASSERT_EQ(y_opt.dtype(), dtype::Float32{}); + ASSERT_EQ(y.dtype(), dtype::Float32{}); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + TEST(TestGoptInference, ConvertFormatNHWCD4) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle;