未验证 提交 45df9be8 编写于 作者: R Ruibiao Chen 提交者: GitHub

Print IPS in auto parallel Engine (#46554)

上级 3cbf0e93
...@@ -23,7 +23,7 @@ from collections import defaultdict ...@@ -23,7 +23,7 @@ from collections import defaultdict
import paddle import paddle
import paddle.utils as utils import paddle.utils as utils
from paddle import fluid, static from paddle import fluid, profiler, static
from paddle.jit import to_static from paddle.jit import to_static
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -570,7 +570,8 @@ class Engine: ...@@ -570,7 +570,8 @@ class Engine:
step=None, step=None,
lr=None, lr=None,
fetch_new_names=None, fetch_new_names=None,
fetch_sections=None): fetch_sections=None,
profiler_log=""):
prefix = "[{}] ".format(mode) prefix = "[{}] ".format(mode)
logs = {} logs = {}
if epoch is not None: if epoch is not None:
...@@ -596,7 +597,7 @@ class Engine: ...@@ -596,7 +597,7 @@ class Engine:
else: else:
for i in range(section_start, section_end): for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i] logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys())) string = prefix + ''.join(list(logs.keys())) + profiler_log
self._logger.info(string.format(*list(logs.values()))) self._logger.info(string.format(*list(logs.values())))
def fit(self, def fit(self,
...@@ -695,6 +696,7 @@ class Engine: ...@@ -695,6 +696,7 @@ class Engine:
mode=self.mode) mode=self.mode)
lr_scheduler = self._get_lr_scheduler(self.main_program) lr_scheduler = self._get_lr_scheduler(self.main_program)
with profiler.Profiler(timer_only=True) as prof:
for epoch in range(epochs): for epoch in range(epochs):
for step, _ in enumerate(train_dataloader): for step, _ in enumerate(train_dataloader):
try: try:
...@@ -708,8 +710,12 @@ class Engine: ...@@ -708,8 +710,12 @@ class Engine:
if lr_scheduler and step % self._k_steps == 0: if lr_scheduler and step % self._k_steps == 0:
lr_scheduler.step() lr_scheduler.step()
lr = self._get_lr(self._lr_optimizer) lr = self._get_lr(self._lr_optimizer)
prof.step()
self._print_log(outs, self.mode, epoch, step, lr, self._print_log(outs, self.mode, epoch, step, lr,
fetch_new_names, fetch_sections) fetch_new_names, fetch_sections,
prof.step_info())
if valid_data and epoch % valid_freq == 0: if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size, self.evaluate(valid_data, valid_sample_split, batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册