提交 c8545d66 编写于 作者: H hysunflower 提交者: Jinhua Liang

Add profiler for paddle gan (#3996)

* add_for_stgan_models

* add_for_pix2pix_models

* add_for_stargan_models
上级 0f9ad827
...@@ -70,7 +70,7 @@ if __name__ == "__main__": ...@@ -70,7 +70,7 @@ if __name__ == "__main__":
if cfg.profile: if cfg.profile:
if cfg.use_gpu: if cfg.use_gpu:
with fluid.profiler.profiler('All', 'total', with fluid.profiler.profiler('All', 'total',
'/tmp/profile') as prof: cfg.profiler_path) as prof:
train(cfg) train(cfg)
else: else:
with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof: with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof:
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from network.Pix2pix_network import Pix2pix_model from network.Pix2pix_network import Pix2pix_model
from util import utility from util import utility
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import sys import sys
import time import time
...@@ -255,9 +256,13 @@ class Pix2pix(object): ...@@ -255,9 +256,13 @@ class Pix2pix(object):
t_time = 0 t_time = 0
total_train_batch = 0 # used for benchmark
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for tensor in loader(): for tensor in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time() s_time = time.time()
tensor_A, tensor_B = tensor[0]['input_A'], tensor[0]['input_B'] tensor_A, tensor_B = tensor[0]['input_A'], tensor[0]['input_B']
...@@ -294,6 +299,12 @@ class Pix2pix(object): ...@@ -294,6 +299,12 @@ class Pix2pix(object):
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.data( image_name = fluid.data(
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
from network.STGAN_network import STGAN_model from network.STGAN_network import STGAN_model
from util import utility from util import utility
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import sys import sys
import time import time
import copy import copy
...@@ -341,9 +342,13 @@ class STGAN(object): ...@@ -341,9 +342,13 @@ class STGAN(object):
t_time = 0 t_time = 0
total_train_batch = 0 # used for benchmark
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for data in py_reader(): for data in py_reader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time() s_time = time.time()
# optimize the discriminator network # optimize the discriminator network
fetches = [ fetches = [
...@@ -377,6 +382,12 @@ class STGAN(object): ...@@ -377,6 +382,12 @@ class STGAN(object):
d_loss_gp[0], batch_time)) d_loss_gp[0], batch_time))
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.data( image_name = fluid.data(
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
from network.StarGAN_network import StarGAN_model from network.StarGAN_network import StarGAN_model
from util import utility from util import utility
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import sys import sys
import time import time
import copy import copy
...@@ -305,10 +306,12 @@ class StarGAN(object): ...@@ -305,10 +306,12 @@ class StarGAN(object):
build_strategy=build_strategy) build_strategy=build_strategy)
t_time = 0 t_time = 0
total_train_batch = 0 # used for benchmark
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for data in py_reader(): for data in py_reader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time() s_time = time.time()
d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run( d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run(
dis_trainer_program, dis_trainer_program,
...@@ -344,6 +347,12 @@ class StarGAN(object): ...@@ -344,6 +347,12 @@ class StarGAN(object):
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
if self.cfg.run_test: if self.cfg.run_test:
image_name = fluid.data( image_name = fluid.data(
......
...@@ -85,6 +85,11 @@ def base_parse_args(parser): ...@@ -85,6 +85,11 @@ def base_parse_args(parser):
add_arg('run_test', bool, True, "Whether to run test.") add_arg('run_test', bool, True, "Whether to run test.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.") add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('profile', bool, False, "Whether to profile.") add_arg('profile', bool, False, "Whether to profile.")
# NOTE: add args for profiler, used for benchmark
add_arg('profiler_path', str, '/tmp/profile', "the profiler output files. (used for benchmark)")
add_arg('max_iter', int, 0, "the max iter to train. (used for benchmark)")
add_arg('dropout', bool, False, "Whether to use drouput.") add_arg('dropout', bool, False, "Whether to use drouput.")
add_arg('drop_last', bool, False, add_arg('drop_last', bool, False,
"Whether to drop the last images that cannot form a batch") "Whether to drop the last images that cannot form a batch")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册