From e64536a31edba544bb3fbdd19457600aa9322f1f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 11 Mar 2022 18:30:32 +0800 Subject: [PATCH] fix(imperative): fix the dtype promote problem when amp GitOrigin-RevId: 43e1035fc86bc5d9212f5de00e4cce347940a7dd --- imperative/python/src/transformation.h | 2 +- .../src/impl/transformations/dtype_promote.cpp | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 84bd6c499..e02135adb 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -25,8 +25,8 @@ namespace mgb::imperative::python { struct TransformationManager { enum Segment { ModuleTrace, - Grad, DTypePromote, + Grad, Scalar, Trace, Eval, diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 5036c7276..4cdcb0b15 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -218,17 +218,15 @@ ValueRefList convolution_backward_rule(const OpDef& op, Span inputs) { ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) { mgb_assert(inputs.size() > 0); - + SmallVector dtypes = get_value_dtypes(inputs); ValueRefList converted(inputs.size()); - converted[0] = imperative::apply( - ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; - for (size_t i = 1; i < inputs.size(); ++i) { - DType idtype = *(inputs[i].dtype()); - if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { + for (size_t i = 0; i < inputs.size(); ++i) { + mgb::DType target_dtype = i == 0 ? DTypePromoteCfg::amp_low_prec_dtype + : DTypePromoteCfg::amp_high_prec_dtype; + if (dtypes[i] != target_dtype) { converted[i] = imperative::apply( - ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), - inputs[i])[0]; + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; } else { converted[i] = inputs[i]; } -- GitLab