未验证 提交 45df9be8 编写于 作者: R Ruibiao Chen 提交者: GitHub

Print IPS in auto parallel Engine (#46554)

上级 3cbf0e93
...@@ -23,7 +23,7 @@ from collections import defaultdict ...@@ -23,7 +23,7 @@ from collections import defaultdict
import paddle import paddle
import paddle.utils as utils import paddle.utils as utils
from paddle import fluid, static from paddle import fluid, profiler, static
from paddle.jit import to_static from paddle.jit import to_static
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -570,7 +570,8 @@ class Engine: ...@@ -570,7 +570,8 @@ class Engine:
step=None, step=None,
lr=None, lr=None,
fetch_new_names=None, fetch_new_names=None,
fetch_sections=None): fetch_sections=None,
profiler_log=""):
prefix = "[{}] ".format(mode) prefix = "[{}] ".format(mode)
logs = {} logs = {}
if epoch is not None: if epoch is not None:
...@@ -596,7 +597,7 @@ class Engine: ...@@ -596,7 +597,7 @@ class Engine:
else: else:
for i in range(section_start, section_end): for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i] logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys())) string = prefix + ''.join(list(logs.keys())) + profiler_log
self._logger.info(string.format(*list(logs.values()))) self._logger.info(string.format(*list(logs.values())))
def fit(self, def fit(self,
...@@ -695,29 +696,34 @@ class Engine: ...@@ -695,29 +696,34 @@ class Engine:
mode=self.mode) mode=self.mode)
lr_scheduler = self._get_lr_scheduler(self.main_program) lr_scheduler = self._get_lr_scheduler(self.main_program)
for epoch in range(epochs): with profiler.Profiler(timer_only=True) as prof:
for step, _ in enumerate(train_dataloader): for epoch in range(epochs):
try: for step, _ in enumerate(train_dataloader):
outs = self._executor.run( try:
self.main_program, outs = self._executor.run(
fetch_list=fetch_list, self.main_program,
use_program_cache=self._strategy.use_cache, fetch_list=fetch_list,
return_numpy=self._strategy.return_numpy) use_program_cache=self._strategy.use_cache,
except core.EOFException: return_numpy=self._strategy.return_numpy)
break except core.EOFException:
if lr_scheduler and step % self._k_steps == 0: break
lr_scheduler.step() if lr_scheduler and step % self._k_steps == 0:
lr = self._get_lr(self._lr_optimizer) lr_scheduler.step()
self._print_log(outs, self.mode, epoch, step, lr, lr = self._get_lr(self._lr_optimizer)
fetch_new_names, fetch_sections)
prof.step()
if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size, self._print_log(outs, self.mode, epoch, step, lr,
valid_steps, collate_fn, callbacks) fetch_new_names, fetch_sections,
self._switch_mode("train") prof.step_info())
else:
self._reset_metrics() if valid_data and epoch % valid_freq == 0:
return outs self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return outs
def evaluate(self, def evaluate(self,
valid_data, valid_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册