提交 c8545d66 编写于 作者: H hysunflower 提交者: Jinhua Liang

Add profiler for paddle gan (#3996)

* add_for_stgan_models

* add_for_pix2pix_models

* add_for_stargan_models
上级 0f9ad827
......@@ -70,7 +70,7 @@ if __name__ == "__main__":
if cfg.profile:
if cfg.use_gpu:
with fluid.profiler.profiler('All', 'total',
'/tmp/profile') as prof:
cfg.profiler_path) as prof:
train(cfg)
else:
with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof:
......
......@@ -18,6 +18,7 @@ from __future__ import print_function
from network.Pix2pix_network import Pix2pix_model
from util import utility
import paddle.fluid as fluid
from paddle.fluid import profiler
import sys
import time
......@@ -255,9 +256,13 @@ class Pix2pix(object):
t_time = 0
total_train_batch = 0 # used for benchmark
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for tensor in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
tensor_A, tensor_B = tensor[0]['input_A'], tensor[0]['input_B']
......@@ -294,6 +299,12 @@ class Pix2pix(object):
sys.stdout.flush()
batch_id += 1
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
if self.cfg.run_test:
image_name = fluid.data(
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
from network.STGAN_network import STGAN_model
from util import utility
import paddle.fluid as fluid
from paddle.fluid import profiler
import sys
import time
import copy
......@@ -341,9 +342,13 @@ class STGAN(object):
t_time = 0
total_train_batch = 0 # used for benchmark
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for data in py_reader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
# optimize the discriminator network
fetches = [
......@@ -377,6 +382,12 @@ class STGAN(object):
d_loss_gp[0], batch_time))
sys.stdout.flush()
batch_id += 1
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
if self.cfg.run_test:
image_name = fluid.data(
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
from network.StarGAN_network import StarGAN_model
from util import utility
import paddle.fluid as fluid
from paddle.fluid import profiler
import sys
import time
import copy
......@@ -305,10 +306,12 @@ class StarGAN(object):
build_strategy=build_strategy)
t_time = 0
total_train_batch = 0 # used for benchmark
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for data in py_reader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run(
dis_trainer_program,
......@@ -344,6 +347,12 @@ class StarGAN(object):
sys.stdout.flush()
batch_id += 1
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
if self.cfg.run_test:
image_name = fluid.data(
......
......@@ -85,6 +85,11 @@ def base_parse_args(parser):
add_arg('run_test', bool, True, "Whether to run test.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('profile', bool, False, "Whether to profile.")
# NOTE: add args for profiler, used for benchmark
add_arg('profiler_path', str, '/tmp/profile', "the profiler output files. (used for benchmark)")
add_arg('max_iter', int, 0, "the max iter to train. (used for benchmark)")
add_arg('dropout', bool, False, "Whether to use drouput.")
add_arg('drop_last', bool, False,
"Whether to drop the last images that cannot form a batch")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册