diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9e3f99aedaf54ab024cc59cb1e106c7139aadc14..d4e7345ec81d270d41353c409e5a5cef4e1800ad 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2595,7 +2595,17 @@ All parameter, weight, gradient are variables in Paddle. [] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); m.def("get_low_precision_op_list", [] { - return phi::KernelFactory::Instance().GetLowPrecisionKernelList(); + py::dict op_list; + auto list_op = phi::KernelFactory::Instance().GetLowPrecisionKernelList(); + for (auto iter = list_op.begin(); iter != list_op.end(); iter++) { + auto op_name = (iter->first).c_str(); + auto counts = iter->second; + op_list[op_name] = std::to_string(counts.fp16_called_) + "," + + std::to_string(counts.bf16_called_) + "," + + std::to_string(counts.fp32_called_) + "," + + std::to_string(counts.other_called_); + } + return op_list; }); m.def("autotune_status", [] { diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 0809cfab3f702c86d84a1b9e2aec6425b49219e4..7c15d60414673cb2c3fb9e9f866dc5d64057fe52 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -115,18 +115,25 @@ void KernelFactory::AddToLowPrecisionKernelList( if (op_name.find("_grad") != std::string::npos) { return; // only record forward api } - bool is_low_precision = - (kernel_key_type == paddle::experimental::DataType::FLOAT16 || - kernel_key_type == paddle::experimental::DataType::BFLOAT16); - bool need_record = - FLAGS_low_precision_op_list == 1 ? is_low_precision : true; - if (need_record) { - low_precision_kernels_[op_name] += 1; + + if (low_precision_kernels_.find(op_name) == low_precision_kernels_.end()) { + auto count = OpCount(); + low_precision_kernels_[op_name] = count; + } + if (kernel_key_type == paddle::experimental::DataType::FLOAT16) { + low_precision_kernels_[op_name].fp16_called_ += 1; + } else if (kernel_key_type == paddle::experimental::DataType::BFLOAT16) { + low_precision_kernels_[op_name].bf16_called_ += 1; + } else if (kernel_key_type == paddle::experimental::DataType::FLOAT32) { + low_precision_kernels_[op_name].fp32_called_ += 1; + } else { + low_precision_kernels_[op_name].other_called_ += 1; } } } -std::map KernelFactory::GetLowPrecisionKernelList() { +std::map +KernelFactory::GetLowPrecisionKernelList() { return low_precision_kernels_; } diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index a106ac727c5d0dde0fed1a339057edf146b19d42..8b8eb8fd0d958d0c686ed2ef4ccbfea124c52415 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -34,6 +34,19 @@ namespace phi { using DataType = paddle::experimental::DataType; +struct OpCount { + OpCount() { + fp16_called_ = 0; + bf16_called_ = 0; + fp32_called_ = 0; + other_called_ = 0; + } + int fp16_called_; + int bf16_called_; + int fp32_called_; + int other_called_; +}; + /** * [ Naming considerations ] * @@ -309,7 +322,7 @@ class KernelFactory { const std::string& name, const paddle::experimental::DataType& kernel_key_type); - std::map GetLowPrecisionKernelList(); + std::map GetLowPrecisionKernelList(); private: KernelFactory() = default; @@ -317,7 +330,7 @@ class KernelFactory { KernelNameMap kernels_; // Get the low precision kernel list of current module. - std::map low_precision_kernels_; + std::map low_precision_kernels_; }; inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 6c8ddbd579359c2f84e07771fa9c05a3aa03ab91..bcba245178d429460509d7b073f28c003a313178 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -97,21 +97,29 @@ _g_amp_state_ = None def low_precision_op_list(): if os.getenv("FLAGS_low_precision_op_list") is not None: level = int(os.getenv("FLAGS_low_precision_op_list")) - if level == 0: - return - if level == 1: - print('<{:-^60}>'.format(" low precision op list ")) - else: - print('<{:-^60}>'.format(" op list ")) + print('<{:-^120}>'.format(" op list ")) op_list = paddle.fluid.core.get_low_precision_op_list() op_count = 0 print( - '<{:-^40}'.format(" op_name "), '|', '{:-^17}>'.format(" op count ") + '<{:-^40}'.format(" Op Name "), + '|', + '{:-^17}'.format("FP16 Calls"), + '|', + '{:-^17}'.format("BF16 Calls"), + '|', + '{:-^17}'.format('FP32 Calls'), + '|', + '{:-^17}>'.format('Other Calls'), ) for x in op_list: - print(' %-40s| %-15d' % (x, op_list[x])) + # fp16, bf16, fp32, other + called = op_list[x].split(",") + print( + ' %-40s| %-17s| %-17s| %-17s| %-17s' + % (x, called[0], called[1], called[2], called[3]) + ) op_count += 1 - print('<{:-^60}>'.format(" op count: " + str(op_count) + " ")) + print('<{:-^120}>'.format(" op count: " + str(op_count) + " ")) def amp_state(): diff --git a/python/paddle/fluid/tests/unittests/test_low_precision_list.py b/python/paddle/fluid/tests/unittests/test_low_precision_list.py index 0641a21be6354ae09007ff051ccef57edb01e72c..c8dea508999ea1935c62c5c0a163a1b9077565d8 100644 --- a/python/paddle/fluid/tests/unittests/test_low_precision_list.py +++ b/python/paddle/fluid/tests/unittests/test_low_precision_list.py @@ -30,12 +30,24 @@ class TestAMPList(unittest.TestCase): c = a + b paddle.amp.low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list() - if conv.dtype == paddle.float16: - self.assertTrue('elementwise_add' in op_list) - self.assertTrue('conv2d' in op_list) - self.assertTrue(2 == len(op_list)) - else: - self.assertTrue(0 == len(op_list)) + + self.assertTrue('elementwise_add' in op_list) + self.assertTrue('conv2d' in op_list) + + conv2d_called = op_list['conv2d'].split(',') + add_called = op_list['elementwise_add'].split(',') + add_num = 0 + conv_num = 0 + for i in range(4): + add_num += int(add_called[i]) + conv_num += int(add_called[i]) + + self.assertTrue(conv_num == 1) + self.assertTrue(add_num == 1) + + if conv.dtype == "float16": + self.assertTrue(int(conv2d_called[0]) == 1) + self.assertTrue(int(add_called[0]) == 1) if __name__ == "__main__":