diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 84bd6c4997f68686099854dce94179ae70a5bbe7..e02135adb2bece4d33e27698e38be5e49bb55ae4 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -25,8 +25,8 @@ namespace mgb::imperative::python { struct TransformationManager { enum Segment { ModuleTrace, - Grad, DTypePromote, + Grad, Scalar, Trace, Eval, diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 5036c7276339575173001194b28b2a1da13a089b..4cdcb0b152015114701974481acfd53bca82097d 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -218,17 +218,15 @@ ValueRefList convolution_backward_rule(const OpDef& op, Span inputs) { ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) { mgb_assert(inputs.size() > 0); - + SmallVector dtypes = get_value_dtypes(inputs); ValueRefList converted(inputs.size()); - converted[0] = imperative::apply( - ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; - for (size_t i = 1; i < inputs.size(); ++i) { - DType idtype = *(inputs[i].dtype()); - if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { + for (size_t i = 0; i < inputs.size(); ++i) { + mgb::DType target_dtype = i == 0 ? DTypePromoteCfg::amp_low_prec_dtype + : DTypePromoteCfg::amp_high_prec_dtype; + if (dtypes[i] != target_dtype) { converted[i] = imperative::apply( - ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), - inputs[i])[0]; + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; } else { converted[i] = inputs[i]; }