未验证 提交 23ab01e3 编写于 作者: R Roc 提交者: GitHub

Dynamic amp support sync_batch_norm op (#32770)

上级 beab9563
......@@ -160,7 +160,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") &&
pair.first != "X") {
continue;
}
......@@ -191,7 +192,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
}
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") &&
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册