未验证 提交 395520f1 编写于 作者: N niuliling123 提交者: GitHub

Update the style of print for low precision op list (#49648)

上级 18a7e13f
...@@ -2595,7 +2595,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2595,7 +2595,17 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); [] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("get_low_precision_op_list", [] { m.def("get_low_precision_op_list", [] {
return phi::KernelFactory::Instance().GetLowPrecisionKernelList(); py::dict op_list;
auto list_op = phi::KernelFactory::Instance().GetLowPrecisionKernelList();
for (auto iter = list_op.begin(); iter != list_op.end(); iter++) {
auto op_name = (iter->first).c_str();
auto counts = iter->second;
op_list[op_name] = std::to_string(counts.fp16_called_) + "," +
std::to_string(counts.bf16_called_) + "," +
std::to_string(counts.fp32_called_) + "," +
std::to_string(counts.other_called_);
}
return op_list;
}); });
m.def("autotune_status", [] { m.def("autotune_status", [] {
......
...@@ -115,18 +115,25 @@ void KernelFactory::AddToLowPrecisionKernelList( ...@@ -115,18 +115,25 @@ void KernelFactory::AddToLowPrecisionKernelList(
if (op_name.find("_grad") != std::string::npos) { if (op_name.find("_grad") != std::string::npos) {
return; // only record forward api return; // only record forward api
} }
bool is_low_precision =
(kernel_key_type == paddle::experimental::DataType::FLOAT16 || if (low_precision_kernels_.find(op_name) == low_precision_kernels_.end()) {
kernel_key_type == paddle::experimental::DataType::BFLOAT16); auto count = OpCount();
bool need_record = low_precision_kernels_[op_name] = count;
FLAGS_low_precision_op_list == 1 ? is_low_precision : true; }
if (need_record) { if (kernel_key_type == paddle::experimental::DataType::FLOAT16) {
low_precision_kernels_[op_name] += 1; low_precision_kernels_[op_name].fp16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::BFLOAT16) {
low_precision_kernels_[op_name].bf16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::FLOAT32) {
low_precision_kernels_[op_name].fp32_called_ += 1;
} else {
low_precision_kernels_[op_name].other_called_ += 1;
} }
} }
} }
std::map<const std::string, int> KernelFactory::GetLowPrecisionKernelList() { std::map<const std::string, OpCount>
KernelFactory::GetLowPrecisionKernelList() {
return low_precision_kernels_; return low_precision_kernels_;
} }
......
...@@ -34,6 +34,19 @@ namespace phi { ...@@ -34,6 +34,19 @@ namespace phi {
using DataType = paddle::experimental::DataType; using DataType = paddle::experimental::DataType;
struct OpCount {
OpCount() {
fp16_called_ = 0;
bf16_called_ = 0;
fp32_called_ = 0;
other_called_ = 0;
}
int fp16_called_;
int bf16_called_;
int fp32_called_;
int other_called_;
};
/** /**
* [ Naming considerations ] * [ Naming considerations ]
* *
...@@ -309,7 +322,7 @@ class KernelFactory { ...@@ -309,7 +322,7 @@ class KernelFactory {
const std::string& name, const std::string& name,
const paddle::experimental::DataType& kernel_key_type); const paddle::experimental::DataType& kernel_key_type);
std::map<const std::string, int> GetLowPrecisionKernelList(); std::map<const std::string, OpCount> GetLowPrecisionKernelList();
private: private:
KernelFactory() = default; KernelFactory() = default;
...@@ -317,7 +330,7 @@ class KernelFactory { ...@@ -317,7 +330,7 @@ class KernelFactory {
KernelNameMap kernels_; KernelNameMap kernels_;
// Get the low precision kernel list of current module. // Get the low precision kernel list of current module.
std::map<const std::string, int> low_precision_kernels_; std::map<const std::string, OpCount> low_precision_kernels_;
}; };
inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
......
...@@ -97,21 +97,29 @@ _g_amp_state_ = None ...@@ -97,21 +97,29 @@ _g_amp_state_ = None
def low_precision_op_list(): def low_precision_op_list():
if os.getenv("FLAGS_low_precision_op_list") is not None: if os.getenv("FLAGS_low_precision_op_list") is not None:
level = int(os.getenv("FLAGS_low_precision_op_list")) level = int(os.getenv("FLAGS_low_precision_op_list"))
if level == 0: print('<{:-^120}>'.format(" op list "))
return
if level == 1:
print('<{:-^60}>'.format(" low precision op list "))
else:
print('<{:-^60}>'.format(" op list "))
op_list = paddle.fluid.core.get_low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0 op_count = 0
print( print(
'<{:-^40}'.format(" op_name "), '|', '{:-^17}>'.format(" op count ") '<{:-^40}'.format(" Op Name "),
'|',
'{:-^17}'.format("FP16 Calls"),
'|',
'{:-^17}'.format("BF16 Calls"),
'|',
'{:-^17}'.format('FP32 Calls'),
'|',
'{:-^17}>'.format('Other Calls'),
) )
for x in op_list: for x in op_list:
print(' %-40s| %-15d' % (x, op_list[x])) # fp16, bf16, fp32, other
called = op_list[x].split(",")
print(
' %-40s| %-17s| %-17s| %-17s| %-17s'
% (x, called[0], called[1], called[2], called[3])
)
op_count += 1 op_count += 1
print('<{:-^60}>'.format(" op count: " + str(op_count) + " ")) print('<{:-^120}>'.format(" op count: " + str(op_count) + " "))
def amp_state(): def amp_state():
......
...@@ -30,12 +30,24 @@ class TestAMPList(unittest.TestCase): ...@@ -30,12 +30,24 @@ class TestAMPList(unittest.TestCase):
c = a + b c = a + b
paddle.amp.low_precision_op_list() paddle.amp.low_precision_op_list()
op_list = paddle.fluid.core.get_low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list()
if conv.dtype == paddle.float16:
self.assertTrue('elementwise_add' in op_list) self.assertTrue('elementwise_add' in op_list)
self.assertTrue('conv2d' in op_list) self.assertTrue('conv2d' in op_list)
self.assertTrue(2 == len(op_list))
else: conv2d_called = op_list['conv2d'].split(',')
self.assertTrue(0 == len(op_list)) add_called = op_list['elementwise_add'].split(',')
add_num = 0
conv_num = 0
for i in range(4):
add_num += int(add_called[i])
conv_num += int(add_called[i])
self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)
if conv.dtype == "float16":
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册