diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 94c6d0a4d569a1ce458ed3590385de446d0ee150..6e8bfbb4a7761009f031cd4da74310cbb6294114 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -273,8 +273,9 @@ static inline std::shared_ptr CastToBF16( template static inline framework::proto::VarType::Type GetPromoteType( - const std::string& op_type, const NameVarMap& ins) { - auto dst_type = framework::proto::VarType::FP16; + const std::string& op_type, const NameVarMap& ins, + const framework::proto::VarType::Type amp_dtype) { + auto dst_type = amp_dtype; for (const auto& pair : ins) { for (const auto& var : pair.second) { if (GetDataType(var) == framework::proto::VarType::FP32) { @@ -337,7 +338,8 @@ NameVarMap AutoCastInputs(const std::string& op_type, } return new_ins; } else { - auto dst_type = GetPromoteType(op_type, ins); + auto dst_type = + GetPromoteType(op_type, ins, framework::proto::VarType::FP16); // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32. if (dst_type == framework::proto::VarType::FP16 && @@ -435,7 +437,7 @@ NameVarMap AutoCastBF16Inputs(const std::string& op_type, } } return new_ins; - } else { + } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { for (auto& pair : new_ins) { VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float"; @@ -444,6 +446,26 @@ NameVarMap AutoCastBF16Inputs(const std::string& op_type, } } return new_ins; + } else { + auto dst_type = + GetPromoteType(op_type, ins, framework::proto::VarType::BF16); + // NOTE(zhangbo): if the op has op fp16 kernel, fall back to fp32. + if (dst_type == framework::proto::VarType::BF16 && + AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count( + op_type)) { + dst_type = framework::proto::VarType::FP32; + } + for (auto& pair : new_ins) { + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " + << GetDtypeStr(*pair.second.cbegin()) << " to " + << framework::DataTypeToString(dst_type); + for (auto& var : pair.second) { + var = (dst_type == framework::proto::VarType::FP32 + ? CastToFP32(var) + : CastToBF16(var)); + } + } + return new_ins; } return new_ins; } diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 03811ac778779c24beb765de118f2d7d00af515b..c832787d9890621f131d3934e0190935aa9d9d24 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -205,17 +205,19 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap& ins, NameVarMap new_ins = ins; if (amp_level_ == AmpLevel::O1) { - VLOG(5) << "Auto mixed precision run operator: " << type; if (amp_dtype_ == phi::DataType::FLOAT16) { + VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type; new_ins = AutoCastInputs(type, ins); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { + VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type; new_ins = AutoCastBF16Inputs(type, ins); } } else if (amp_level_ == AmpLevel::O2) { - VLOG(5) << "Pure fp16 run operator: " << type; if (amp_dtype_ == phi::DataType::FLOAT16) { + VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type; new_ins = CastPureFp16Inputs(type, ins); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { + VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type; new_ins = CastPureBf16Inputs(type, ins); } } diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 8230e4bbd777473dda2d4a654ffb93da3c14f91c..f43a51063b00ac0439aacfbf46ff593e7b1b4f43 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -75,7 +75,7 @@ PURE_FP16_BLACK_LIST = { 'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' } -BF16_WHITE_LIST = {'conv2d'} +BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} BF16_BLACK_LIST = {' '} _g_amp_state_ = None diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 67c4bb3b2c7464b48218e1093fa87d9ec337385b..5cb72512f99af7b4948e9fe4c01e9b993c1e247e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -1131,20 +1131,29 @@ class TestBf16(unittest.TestCase): test amp for BF16 ''' - def train(self, enable_amp=True): + def train(self, enable_amp=True, amp_level='O1'): paddle.seed(100) input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) conv = paddle.nn.Conv2D(4, 6, (3, 3)) with paddle.amp.auto_cast( - enable=enable_amp, level='O2', dtype='bfloat16'): + enable=enable_amp, level=amp_level, dtype='bfloat16'): output = conv(input) output = output.cast('float32') return output.numpy() def test_bf16(self): - out_fp32 = self.train(enable_amp=False) - out_bf16 = self.train(enable_amp=True) - self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1)) + if fluid.core.is_compiled_with_cuda(): + cudnn_version = paddle.device.get_cudnn_version() + if cudnn_version is not None and cudnn_version >= 8100: + out_fp32 = self.train(enable_amp=False) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) class TestPyLayerWithAmp(unittest.TestCase):