提交 e64536a3 编写于 作者: M Megvii Engine Team

fix(imperative): fix the dtype promote problem when amp

GitOrigin-RevId: 43e1035fc86bc5d9212f5de00e4cce347940a7dd
上级 2b80806f
...@@ -25,8 +25,8 @@ namespace mgb::imperative::python { ...@@ -25,8 +25,8 @@ namespace mgb::imperative::python {
struct TransformationManager { struct TransformationManager {
enum Segment { enum Segment {
ModuleTrace, ModuleTrace,
Grad,
DTypePromote, DTypePromote,
Grad,
Scalar, Scalar,
Trace, Trace,
Eval, Eval,
......
...@@ -218,17 +218,15 @@ ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -218,17 +218,15 @@ ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) {
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
mgb_assert(inputs.size() > 0); mgb_assert(inputs.size() > 0);
SmallVector<DType> dtypes = get_value_dtypes(inputs);
ValueRefList converted(inputs.size()); ValueRefList converted(inputs.size());
converted[0] = imperative::apply(
ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0];
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
DType idtype = *(inputs[i].dtype()); mgb::DType target_dtype = i == 0 ? DTypePromoteCfg::amp_low_prec_dtype
if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { : DTypePromoteCfg::amp_high_prec_dtype;
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply( converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
inputs[i])[0];
} else { } else {
converted[i] = inputs[i]; converted[i] = inputs[i];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册