提交 c1fb3c8c 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix ConvertF32ToF16Pass endpoints

GitOrigin-RevId: 850eaa090681f947a165a6479ea5fb5d7df21f66
上级 380cb6e4
......@@ -625,8 +625,14 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
auto rewriter = state.graph().make_rewriter();
VarNodeArray new_inp_cache;
auto on_opr = [this, &rewriter, &new_inp_cache,
&state](OperatorNodeBase* opr) {
// record original output dtype
const SymbolVarArray& vars = state.graph().endpoint_vars();
std::vector<DType> dtypes;
for (size_t i = 0; i < vars.size(); i++) {
dtypes.push_back(vars[i].node()->dtype());
}
auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) {
auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
if (it != m_opr_replace_func.end()) {
auto&& new_inp = new_inp_cache;
......@@ -642,40 +648,32 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
"bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
opr->dyn_typeinfo()->name, new_opr->cname(),
new_opr->dyn_typeinfo()->name);
//! change the output type if it's the endpoint
for (size_t i = 0; i < origin_out.size(); i++) {
if (state.graph().endpoint_contain(origin_out[i]) &&
origin_out[i]->dtype().enumv() !=
cur_out[i]->dtype().enumv()) {
rewriter.replace_var(
origin_out[i],
opr::TypeCvt::make(cur_out[i],
origin_out[i]->dtype())
.node(),
nullptr);
} else {
rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
}
rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
}
} else {
auto new_opr = rewriter.auto_replace_outputs(opr);
auto&& out = opr->output();
auto&& new_out = new_opr->output();
for (size_t i = 0; i < out.size(); i++) {
if (state.graph().endpoint_contain(out[i]) &&
new_out[i]->dtype().enumv() != out[i]->dtype().enumv()) {
rewriter.replace_var(
new_out[i],
opr::TypeCvt::make(new_out[i],
out[i]->dtype())
.node(),
nullptr);
}
}
rewriter.auto_replace_outputs(opr);
}
};
state.graph().iter(on_opr);
rewriter.apply_inplace();
// recover output dtype
rewriter = state.graph().make_rewriter();
const SymbolVarArray& endpoints = state.graph().endpoint_vars();
auto replace_output = [&]() {
for (size_t i = 0; i < endpoints.size(); i++) {
VarNode* var = endpoints[i].node();
if (var->dtype().enumv() != dtypes[i].enumv()) {
auto new_var = opr::TypeCvt::make(var, dtypes[i]).node();
rewriter.replace_var(var, new_var, nullptr);
}
}
};
mgb_assert(endpoints.size() > 0);
auto opr = endpoints[0].node()->owner_opr();
state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE);
rewriter.apply_inplace();
}
std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
......
......@@ -994,6 +994,38 @@ TEST(TestGoptInference, Float32TOFloat16Linspace) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, Float32TOFloat16Endpoints) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp))
.rename(name);
};
graph->options().graph_opt_level = 0;
opr::Convolution::Param param;
param.pad_h = param.pad_w = 0;
auto x = mkvar("x", {8, 8, 8, 8}),
y = mkvar("y", {8, 8, 8, 8}),
w = mkcvar("w", {4, 8, 3, 3}),
z = opr::Convolution::make(x + y, w, param);
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_f16_io_f32_comp();
SymbolVarArray out = gopt::optimize_for_inference({x + y, z}, options);
ASSERT_EQ(out[0].dtype(), dtype::Float32());
ASSERT_EQ(out[1].dtype(), dtype::Float32());
ASSERT_EQ(out[0].node()->owner_opr()->input(0)->dtype(), dtype::Float16());
ASSERT_EQ(out[1].node()->owner_opr()->input(0)->dtype(), dtype::Float16());
}
TEST(TestGoptInference, ConvertFormatNHWCD4) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册