未验证 提交 395520f1 编写于 作者: N niuliling123 提交者: GitHub

Update the style of print for low precision op list (#49648)

上级 18a7e13f
......@@ -2595,7 +2595,17 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("get_low_precision_op_list", [] {
return phi::KernelFactory::Instance().GetLowPrecisionKernelList();
py::dict op_list;
auto list_op = phi::KernelFactory::Instance().GetLowPrecisionKernelList();
for (auto iter = list_op.begin(); iter != list_op.end(); iter++) {
auto op_name = (iter->first).c_str();
auto counts = iter->second;
op_list[op_name] = std::to_string(counts.fp16_called_) + "," +
std::to_string(counts.bf16_called_) + "," +
std::to_string(counts.fp32_called_) + "," +
std::to_string(counts.other_called_);
}
return op_list;
});
m.def("autotune_status", [] {
......
......@@ -115,18 +115,25 @@ void KernelFactory::AddToLowPrecisionKernelList(
if (op_name.find("_grad") != std::string::npos) {
return; // only record forward api
}
bool is_low_precision =
(kernel_key_type == paddle::experimental::DataType::FLOAT16 ||
kernel_key_type == paddle::experimental::DataType::BFLOAT16);
bool need_record =
FLAGS_low_precision_op_list == 1 ? is_low_precision : true;
if (need_record) {
low_precision_kernels_[op_name] += 1;
if (low_precision_kernels_.find(op_name) == low_precision_kernels_.end()) {
auto count = OpCount();
low_precision_kernels_[op_name] = count;
}
if (kernel_key_type == paddle::experimental::DataType::FLOAT16) {
low_precision_kernels_[op_name].fp16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::BFLOAT16) {
low_precision_kernels_[op_name].bf16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::FLOAT32) {
low_precision_kernels_[op_name].fp32_called_ += 1;
} else {
low_precision_kernels_[op_name].other_called_ += 1;
}
}
}
std::map<const std::string, int> KernelFactory::GetLowPrecisionKernelList() {
std::map<const std::string, OpCount>
KernelFactory::GetLowPrecisionKernelList() {
return low_precision_kernels_;
}
......
......@@ -34,6 +34,19 @@ namespace phi {
using DataType = paddle::experimental::DataType;
struct OpCount {
OpCount() {
fp16_called_ = 0;
bf16_called_ = 0;
fp32_called_ = 0;
other_called_ = 0;
}
int fp16_called_;
int bf16_called_;
int fp32_called_;
int other_called_;
};
/**
* [ Naming considerations ]
*
......@@ -309,7 +322,7 @@ class KernelFactory {
const std::string& name,
const paddle::experimental::DataType& kernel_key_type);
std::map<const std::string, int> GetLowPrecisionKernelList();
std::map<const std::string, OpCount> GetLowPrecisionKernelList();
private:
KernelFactory() = default;
......@@ -317,7 +330,7 @@ class KernelFactory {
KernelNameMap kernels_;
// Get the low precision kernel list of current module.
std::map<const std::string, int> low_precision_kernels_;
std::map<const std::string, OpCount> low_precision_kernels_;
};
inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
......
......@@ -97,21 +97,29 @@ _g_amp_state_ = None
def low_precision_op_list():
if os.getenv("FLAGS_low_precision_op_list") is not None:
level = int(os.getenv("FLAGS_low_precision_op_list"))
if level == 0:
return
if level == 1:
print('<{:-^60}>'.format(" low precision op list "))
else:
print('<{:-^60}>'.format(" op list "))
print('<{:-^120}>'.format(" op list "))
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print(
'<{:-^40}'.format(" op_name "), '|', '{:-^17}>'.format(" op count ")
'<{:-^40}'.format(" Op Name "),
'|',
'{:-^17}'.format("FP16 Calls"),
'|',
'{:-^17}'.format("BF16 Calls"),
'|',
'{:-^17}'.format('FP32 Calls'),
'|',
'{:-^17}>'.format('Other Calls'),
)
for x in op_list:
print(' %-40s| %-15d' % (x, op_list[x]))
# fp16, bf16, fp32, other
called = op_list[x].split(",")
print(
' %-40s| %-17s| %-17s| %-17s| %-17s'
% (x, called[0], called[1], called[2], called[3])
)
op_count += 1
print('<{:-^60}>'.format(" op count: " + str(op_count) + " "))
print('<{:-^120}>'.format(" op count: " + str(op_count) + " "))
def amp_state():
......
......@@ -30,12 +30,24 @@ class TestAMPList(unittest.TestCase):
c = a + b
paddle.amp.low_precision_op_list()
op_list = paddle.fluid.core.get_low_precision_op_list()
if conv.dtype == paddle.float16:
self.assertTrue('elementwise_add' in op_list)
self.assertTrue('conv2d' in op_list)
self.assertTrue(2 == len(op_list))
else:
self.assertTrue(0 == len(op_list))
self.assertTrue('elementwise_add' in op_list)
self.assertTrue('conv2d' in op_list)
conv2d_called = op_list['conv2d'].split(',')
add_called = op_list['elementwise_add'].split(',')
add_num = 0
conv_num = 0
for i in range(4):
add_num += int(add_called[i])
conv_num += int(add_called[i])
self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)
if conv.dtype == "float16":
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册