diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 4d641942e915dcc509c88a37bd5b4f982aa17970..0a05aae15dce9c0e793d02e910bfcc3f32fa723a 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -625,8 +625,14 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { auto rewriter = state.graph().make_rewriter(); VarNodeArray new_inp_cache; - auto on_opr = [this, &rewriter, &new_inp_cache, - &state](OperatorNodeBase* opr) { + // record original output dtype + const SymbolVarArray& vars = state.graph().endpoint_vars(); + std::vector dtypes; + for (size_t i = 0; i < vars.size(); i++) { + dtypes.push_back(vars[i].node()->dtype()); + } + + auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) { auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); if (it != m_opr_replace_func.end()) { auto&& new_inp = new_inp_cache; @@ -642,40 +648,32 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(), new_opr->dyn_typeinfo()->name); - //! change the output type if it's the endpoint for (size_t i = 0; i < origin_out.size(); i++) { - if (state.graph().endpoint_contain(origin_out[i]) && - origin_out[i]->dtype().enumv() != - cur_out[i]->dtype().enumv()) { - rewriter.replace_var( - origin_out[i], - opr::TypeCvt::make(cur_out[i], - origin_out[i]->dtype()) - .node(), - nullptr); - } else { - rewriter.replace_var(origin_out[i], cur_out[i], nullptr); - } + rewriter.replace_var(origin_out[i], cur_out[i], nullptr); } } else { - auto new_opr = rewriter.auto_replace_outputs(opr); - auto&& out = opr->output(); - auto&& new_out = new_opr->output(); - for (size_t i = 0; i < out.size(); i++) { - if (state.graph().endpoint_contain(out[i]) && - new_out[i]->dtype().enumv() != out[i]->dtype().enumv()) { - rewriter.replace_var( - new_out[i], - opr::TypeCvt::make(new_out[i], - out[i]->dtype()) - .node(), - nullptr); - } - } + rewriter.auto_replace_outputs(opr); } }; state.graph().iter(on_opr); rewriter.apply_inplace(); + + // recover output dtype + rewriter = state.graph().make_rewriter(); + const SymbolVarArray& endpoints = state.graph().endpoint_vars(); + auto replace_output = [&]() { + for (size_t i = 0; i < endpoints.size(); i++) { + VarNode* var = endpoints[i].node(); + if (var->dtype().enumv() != dtypes[i].enumv()) { + auto new_var = opr::TypeCvt::make(var, dtypes[i]).node(); + rewriter.replace_var(var, new_var, nullptr); + } + } + }; + mgb_assert(endpoints.size() > 0); + auto opr = endpoints[0].node()->owner_opr(); + state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE); + rewriter.apply_inplace(); } std::unique_ptr ConvertF32ToF16Pass::make( diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 91adc3ea557da78e7fe27052465ba80a3c78a8f8..f276a188ade444158f7879468b1164f7f7e16844 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -994,6 +994,38 @@ TEST(TestGoptInference, Float32TOFloat16Linspace) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, Float32TOFloat16Endpoints) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); + }; + + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp)) + .rename(name); + }; + + graph->options().graph_opt_level = 0; + opr::Convolution::Param param; + param.pad_h = param.pad_w = 0; + + auto x = mkvar("x", {8, 8, 8, 8}), + y = mkvar("y", {8, 8, 8, 8}), + w = mkcvar("w", {4, 8, 3, 3}), + z = opr::Convolution::make(x + y, w, param); + + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_f16_io_f32_comp(); + SymbolVarArray out = gopt::optimize_for_inference({x + y, z}, options); + + ASSERT_EQ(out[0].dtype(), dtype::Float32()); + ASSERT_EQ(out[1].dtype(), dtype::Float32()); + ASSERT_EQ(out[0].node()->owner_opr()->input(0)->dtype(), dtype::Float16()); + ASSERT_EQ(out[1].node()->owner_opr()->input(0)->dtype(), dtype::Float16()); +} + TEST(TestGoptInference, ConvertFormatNHWCD4) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle;