diff --git a/tools/train.py b/tools/train.py index 54af54ab809348ed2c7b91c1f622a5bd24a34cc7..83198fcea3b057f12b21468bf882fa8f54d0a922 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,6 +22,7 @@ import numpy as np import random import datetime from collections import deque +from paddle.fluid import profiler def set_paddle_flags(**kwargs): @@ -256,6 +257,13 @@ def main(): it, np.mean(outs[-1]), logs, time_cost, eta) logger.info(strs) + # NOTE : profiler tools, used for benchmark + if FLAGS.is_profiler and it == 5: + profiler.start_profiler("All") + elif FLAGS.is_profiler and it == 10: + profiler.stop_profiler("total", FLAGS.profiler_path) + return + if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ and (not FLAGS.dist or trainer_id == 0): @@ -340,5 +348,17 @@ if __name__ == '__main__': default=False, help="If set True, enable continuous evaluation job." "This flag is only used for internal test.") + + #NOTE:args for profiler tools, used for benchmark + parser.add_argument( + '--is_profiler', + type=int, + default=0, + help='The switch of profiler tools. (used for benchmark)') + parser.add_argument( + '--profiler_path', + type=str, + default="./detection.profiler", + help='The profiler output file path. (used for benchmark)') FLAGS = parser.parse_args() main()