From 395520f1ecc1bb0081c6bec8f474cac2d50a0c50 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Wed, 11 Jan 2023 14:56:09 +0800 Subject: [PATCH] Update the style of print for low precision op list (#49648) --- paddle/fluid/pybind/pybind.cc | 12 ++++++++- paddle/phi/core/kernel_factory.cc | 23 ++++++++++------ paddle/phi/core/kernel_factory.h | 17 ++++++++++-- python/paddle/amp/auto_cast.py | 26 ++++++++++++------- .../unittests/test_low_precision_list.py | 24 ++++++++++++----- 5 files changed, 76 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9e3f99aeda..d4e7345ec8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2595,7 +2595,17 @@ All parameter, weight, gradient are variables in Paddle. [] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); m.def("get_low_precision_op_list", [] { - return phi::KernelFactory::Instance().GetLowPrecisionKernelList(); + py::dict op_list; + auto list_op = phi::KernelFactory::Instance().GetLowPrecisionKernelList(); + for (auto iter = list_op.begin(); iter != list_op.end(); iter++) { + auto op_name = (iter->first).c_str(); + auto counts = iter->second; + op_list[op_name] = std::to_string(counts.fp16_called_) + "," + + std::to_string(counts.bf16_called_) + "," + + std::to_string(counts.fp32_called_) + "," + + std::to_string(counts.other_called_); + } + return op_list; }); m.def("autotune_status", [] { diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 0809cfab3f..7c15d60414 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -115,18 +115,25 @@ void KernelFactory::AddToLowPrecisionKernelList( if (op_name.find("_grad") != std::string::npos) { return; // only record forward api } - bool is_low_precision = - (kernel_key_type == paddle::experimental::DataType::FLOAT16 || - kernel_key_type == paddle::experimental::DataType::BFLOAT16); - bool need_record = - FLAGS_low_precision_op_list == 1 ? is_low_precision : true; - if (need_record) { - low_precision_kernels_[op_name] += 1; + + if (low_precision_kernels_.find(op_name) == low_precision_kernels_.end()) { + auto count = OpCount(); + low_precision_kernels_[op_name] = count; + } + if (kernel_key_type == paddle::experimental::DataType::FLOAT16) { + low_precision_kernels_[op_name].fp16_called_ += 1; + } else if (kernel_key_type == paddle::experimental::DataType::BFLOAT16) { + low_precision_kernels_[op_name].bf16_called_ += 1; + } else if (kernel_key_type == paddle::experimental::DataType::FLOAT32) { + low_precision_kernels_[op_name].fp32_called_ += 1; + } else { + low_precision_kernels_[op_name].other_called_ += 1; } } } -std::map KernelFactory::GetLowPrecisionKernelList() { +std::map +KernelFactory::GetLowPrecisionKernelList() { return low_precision_kernels_; } diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index a106ac727c..8b8eb8fd0d 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -34,6 +34,19 @@ namespace phi { using DataType = paddle::experimental::DataType; +struct OpCount { + OpCount() { + fp16_called_ = 0; + bf16_called_ = 0; + fp32_called_ = 0; + other_called_ = 0; + } + int fp16_called_; + int bf16_called_; + int fp32_called_; + int other_called_; +}; + /** * [ Naming considerations ] * @@ -309,7 +322,7 @@ class KernelFactory { const std::string& name, const paddle::experimental::DataType& kernel_key_type); - std::map GetLowPrecisionKernelList(); + std::map GetLowPrecisionKernelList(); private: KernelFactory() = default; @@ -317,7 +330,7 @@ class KernelFactory { KernelNameMap kernels_; // Get the low precision kernel list of current module. - std::map low_precision_kernels_; + std::map low_precision_kernels_; }; inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 6c8ddbd579..bcba245178 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -97,21 +97,29 @@ _g_amp_state_ = None def low_precision_op_list(): if os.getenv("FLAGS_low_precision_op_list") is not None: level = int(os.getenv("FLAGS_low_precision_op_list")) - if level == 0: - return - if level == 1: - print('<{:-^60}>'.format(" low precision op list ")) - else: - print('<{:-^60}>'.format(" op list ")) + print('<{:-^120}>'.format(" op list ")) op_list = paddle.fluid.core.get_low_precision_op_list() op_count = 0 print( - '<{:-^40}'.format(" op_name "), '|', '{:-^17}>'.format(" op count ") + '<{:-^40}'.format(" Op Name "), + '|', + '{:-^17}'.format("FP16 Calls"), + '|', + '{:-^17}'.format("BF16 Calls"), + '|', + '{:-^17}'.format('FP32 Calls'), + '|', + '{:-^17}>'.format('Other Calls'), ) for x in op_list: - print(' %-40s| %-15d' % (x, op_list[x])) + # fp16, bf16, fp32, other + called = op_list[x].split(",") + print( + ' %-40s| %-17s| %-17s| %-17s| %-17s' + % (x, called[0], called[1], called[2], called[3]) + ) op_count += 1 - print('<{:-^60}>'.format(" op count: " + str(op_count) + " ")) + print('<{:-^120}>'.format(" op count: " + str(op_count) + " ")) def amp_state(): diff --git a/python/paddle/fluid/tests/unittests/test_low_precision_list.py b/python/paddle/fluid/tests/unittests/test_low_precision_list.py index 0641a21be6..c8dea50899 100644 --- a/python/paddle/fluid/tests/unittests/test_low_precision_list.py +++ b/python/paddle/fluid/tests/unittests/test_low_precision_list.py @@ -30,12 +30,24 @@ class TestAMPList(unittest.TestCase): c = a + b paddle.amp.low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list() - if conv.dtype == paddle.float16: - self.assertTrue('elementwise_add' in op_list) - self.assertTrue('conv2d' in op_list) - self.assertTrue(2 == len(op_list)) - else: - self.assertTrue(0 == len(op_list)) + + self.assertTrue('elementwise_add' in op_list) + self.assertTrue('conv2d' in op_list) + + conv2d_called = op_list['conv2d'].split(',') + add_called = op_list['elementwise_add'].split(',') + add_num = 0 + conv_num = 0 + for i in range(4): + add_num += int(add_called[i]) + conv_num += int(add_called[i]) + + self.assertTrue(conv_num == 1) + self.assertTrue(add_num == 1) + + if conv.dtype == "float16": + self.assertTrue(int(conv2d_called[0]) == 1) + self.assertTrue(int(add_called[0]) == 1) if __name__ == "__main__": -- GitLab