提交 ee984e86 编写于 作者: M Megvii Engine Team

fix(imperative/amp): fix distributed backward callback for nhwc amp

GitOrigin-RevId: 4d725b0ea438d078f1a57c1f58dc707e96f314c5
上级 15c6da62
......@@ -379,8 +379,11 @@ ValueRefList concat_rule(
ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast(t.value_type());
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
if (auto& src = inputs[0].as_ref(t.value_type())) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src->format());
} else {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
}
ValueRefList batchnorm_rule(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册