未验证 提交 0739cc75 编写于 作者: W wanghuancoder 提交者: GitHub

use pre-commit formate code (#4870)

* fix ptb_dy time print for benchmark, test=develop

* use pre-commit formate code
上级 38ada7ff
...@@ -49,20 +49,25 @@ import pickle ...@@ -49,20 +49,25 @@ import pickle
SEED = 123 SEED = 123
class TimeCostAverage(object): class TimeCostAverage(object):
def __init__(self): def __init__(self):
self.reset() self.reset()
def reset(self): def reset(self):
self.cnt = 0 self.cnt = 0
self.total_time = 0 self.total_time = 0
def record(self, usetime): def record(self, usetime):
self.cnt += 1 self.cnt += 1
self.total_time += usetime self.total_time += usetime
def get_average(self): def get_average(self):
if self.cnt == 0: if self.cnt == 0:
return 0 return 0
return self.total_time / self.cnt return self.total_time / self.cnt
@contextlib.contextmanager @contextlib.contextmanager
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'): def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
if profile: if profile:
...@@ -339,7 +344,8 @@ def main(): ...@@ -339,7 +344,8 @@ def main():
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print( print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0])) % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0],
lr[0]))
batch_cost_avg.reset() batch_cost_avg.reset()
# profiler tools for benchmark # profiler tools for benchmark
...@@ -402,7 +408,8 @@ def main(): ...@@ -402,7 +408,8 @@ def main():
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print( print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0])) % (epoch_id, batch_id, batch_cost_avg.get_average(),
ppl[0], lr[0]))
batch_cost_avg.reset() batch_cost_avg.reset()
batch_id += 1 batch_id += 1
...@@ -507,4 +514,3 @@ def main(): ...@@ -507,4 +514,3 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册