未验证 提交 18ee051e 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] Refine BF16 amp-o1 logic (#39815)

* refine bf16 amp-o1 logic

* refine amp GLOG

* refine unittest

* refine unittest
上级 d1595c26
......@@ -273,8 +273,9 @@ static inline std::shared_ptr<VarType> CastToBF16(
template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarMap<VarType>& ins) {
auto dst_type = framework::proto::VarType::FP16;
const std::string& op_type, const NameVarMap<VarType>& ins,
const framework::proto::VarType::Type amp_dtype) {
auto dst_type = amp_dtype;
for (const auto& pair : ins) {
for (const auto& var : pair.second) {
if (GetDataType<VarType>(var) == framework::proto::VarType::FP32) {
......@@ -337,7 +338,8 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
}
return new_ins;
} else {
auto dst_type = GetPromoteType<VarType>(op_type, ins);
auto dst_type =
GetPromoteType<VarType>(op_type, ins, framework::proto::VarType::FP16);
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::FP16 &&
......@@ -435,7 +437,7 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
}
}
return new_ins;
} else {
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
......@@ -444,6 +446,26 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
}
}
return new_ins;
} else {
auto dst_type =
GetPromoteType<VarType>(op_type, ins, framework::proto::VarType::BF16);
// NOTE(zhangbo): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::BF16 &&
AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(
op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32
? CastToFP32<VarType>(var)
: CastToBF16<VarType>(var));
}
}
return new_ins;
}
return new_ins;
}
......
......@@ -205,17 +205,19 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
if (amp_dtype_ == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
if (amp_dtype_ == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins);
}
}
......
......@@ -75,7 +75,7 @@ PURE_FP16_BLACK_LIST = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad'
}
BF16_WHITE_LIST = {'conv2d'}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = {' '}
_g_amp_state_ = None
......
......@@ -1131,20 +1131,29 @@ class TestBf16(unittest.TestCase):
test amp for BF16
'''
def train(self, enable_amp=True):
def train(self, enable_amp=True, amp_level='O1'):
paddle.seed(100)
input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(4, 6, (3, 3))
with paddle.amp.auto_cast(
enable=enable_amp, level='O2', dtype='bfloat16'):
enable=enable_amp, level=amp_level, dtype='bfloat16'):
output = conv(input)
output = output.cast('float32')
return output.numpy()
def test_bf16(self):
if fluid.core.is_compiled_with_cuda():
cudnn_version = paddle.device.get_cudnn_version()
if cudnn_version is not None and cudnn_version >= 8100:
out_fp32 = self.train(enable_amp=False)
out_bf16 = self.train(enable_amp=True)
self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1))
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
class TestPyLayerWithAmp(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册