提交 e109ae91 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/gopt): fix float32 to float16 opt pass

GitOrigin-RevId: d828512e444ea17f66be20fe47b5c0755501cfe4
上级 1255c9f1
......@@ -26,6 +26,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#include "megdnn/tensor_format.h"
......@@ -741,6 +742,19 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
return opr;
};
auto replace_lsp_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->same_type<opr::Linspace>());
mgb_assert(opr->input().size() == new_inp.size());
auto& lsp_opr = opr->cast_final_safe<opr::Linspace>();
if (lsp_opr.output(0)->dtype() != dtype::Float16()) {
auto cvt_var =
opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {});
return cvt_var.node()->owner_opr();
}
return opr;
};
auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -778,6 +792,29 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
return new_matmul_opr.node()->owner_opr();
};
auto replace_batched_matmul_opr = [use_f32_comp](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>();
auto new_param = matmul_opr.param();
if (use_f32_comp) {
new_param.compute_mode =
megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
mgb_assert(new_inp[0]->dtype() == dtype::Float16(),
"inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(),
new_inp[0]->name().c_str(),
new_inp[0]->owner_opr()->name().c_str());
mgb_assert(new_inp[1]->dtype() == dtype::Float16(),
"inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(),
new_inp[1]->name().c_str(),
new_inp[1]->owner_opr()->name().c_str());
auto new_matmul_opr = opr::BatchedMatrixMul::make(
new_inp[0], new_inp[1], new_param, matmul_opr.config());
return new_matmul_opr.node()->owner_opr();
};
auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
......@@ -871,6 +908,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
VarReplaceCheckFlag::CHECK_DTYPE);
auto&& replace_func = ret->m_opr_replace_func;
replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr;
replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
......@@ -880,6 +918,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
replace_func[opr::BatchedMatrixMul::typeinfo()] =
replace_batched_matmul_opr;
return ret;
#endif
}
......
......@@ -27,6 +27,8 @@
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/blas.h"
#include "megbrain/comp_node_env.h"
#include "./helper.h"
......@@ -892,6 +894,67 @@ TEST(TestGoptInference, Float32TOFloat16EndpointElemwise) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, Float32TOFloat16Linspace) {
CompNode cn = CompNode::load("cpu0");
HostTensorGenerator<> gen(0, 1, 0);
auto host_x = gen({3, 1}, cn);
auto graph = ComputingGraph::make();
auto make_f32_to_f16_graph = [&]() {
graph->options().graph_opt_level = 0;
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto lin = opr::Linspace::make(cv(0), sub(0) - 1, sub(0), {}, {});
auto shp = opr::Concat::make({sub(1), sub(0)}, 0);
auto y = opr::Reshape::make(lin, shp);
auto mm = opr::MatrixMul::make(x, y);
SymbolVar mm_opt;
unpack_vector(gopt::optimize_for_inference(
{mm}, gopt::OptimizeForInferenceOptions{}
.enable_f16_io_comp()),
mm_opt);
return mm_opt;
};
auto make_f16_graph = [&]() {
auto x = opr::TypeCvt::make(opr::Host2DeviceCopy::make(*graph, host_x),
dtype::Float16());
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto lin = opr::Linspace::make(cv(0), sub(0) - 1, sub(0), {}, {});
lin = opr::TypeCvt::make(lin, dtype::Float16());
auto shp = opr::Concat::make({sub(1), sub(0)}, 0);
auto y = opr::Reshape::make(lin, shp);
auto mm = opr::MatrixMul::make(x, y);
mm = opr::TypeCvt::make(mm, dtype::Float32{});
return mm;
};
auto y_opt = make_f32_to_f16_graph();
auto y = make_f16_graph();
ASSERT_EQ(y_opt.dtype(), dtype::Float32{});
ASSERT_EQ(y.dtype(), dtype::Float32{});
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册