提交 8d24f838 编写于 作者: M Megvii Engine Team

fix(imperative): fix setsubtensor dtype_promotion

GitOrigin-RevId: 025a052b6153db0031f1bb1e280cb77b17d45a5d
上级 45c0e40f
...@@ -347,7 +347,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -347,7 +347,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList norm_rule(const OpDef& op, Span<ValueRef> inputs) {
// avoid the amp_dtype_autocast // avoid the amp_dtype_autocast
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);
...@@ -369,13 +369,12 @@ ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -369,13 +369,12 @@ ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
ValueRefList group_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);
ValueRefList converted(inputs.size()); mgb::DType target_dtype = get_promoted_dtype(dtypes);
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype;
if (dtypes[i] != target_dtype) { if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply( converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
...@@ -385,23 +384,22 @@ ValueRefList group_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -385,23 +384,22 @@ ValueRefList group_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
} }
return imperative::apply(op, converted); return imperative::apply(op, converted);
}
return imperative::apply(op, inputs);
} }
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList setsubtensor_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs); mgb::DType target_dtype = *(inputs[0].dtype());
mgb::DType target_dtype = get_promoted_dtype(dtypes);
ValueRefList converted(inputs.size()); ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) { converted[0] = inputs[0];
if (dtypes[i] != target_dtype) { if (*(inputs[1].dtype()) != target_dtype) {
converted[i] = imperative::apply( converted[1] =
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; imperative::apply(ApplyOp(*TypeCvt::make(target_dtype)), inputs[1])[0];
} else { } else {
converted[i] = inputs[i]; converted[1] = inputs[1];
} }
for (size_t i = 2; i < inputs.size(); i++) {
converted[i] = inputs[i];
} }
return imperative::apply(op, converted); return imperative::apply(op, converted);
...@@ -422,8 +420,10 @@ struct DTypePromoteRuleRegistry { ...@@ -422,8 +420,10 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(naive_promote_rule); register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
register_dtype_promote_rule<LayerNorm>(layer_norm_rule); register_dtype_promote_rule<LayerNorm>(norm_rule);
register_dtype_promote_rule<GroupNorm>(group_norm_rule); register_dtype_promote_rule<GroupNorm>(norm_rule);
register_dtype_promote_rule<SetSubtensor>(setsubtensor_rule);
register_dtype_promote_rule<IndexingSetMultiAxisVec>(setsubtensor_rule);
} }
} register_helper; } register_helper;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册