未验证 提交 0f01ed13 编写于 作者: C ceci3 提交者: GitHub

Fix multi cpu (#4131)

* fix elementwise

* fix_pix2pix

* fix reader
上级 4ffcb3b1
...@@ -630,6 +630,7 @@ class data_reader(object): ...@@ -630,6 +630,7 @@ class data_reader(object):
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
mode="TRAIN") mode="TRAIN")
reader_test = None reader_test = None
id2name = None
if self.cfg.run_test: if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt") test_list = os.path.join(dataset_dir, "test.txt")
if self.cfg.test_list is not None: if self.cfg.test_list is not None:
......
...@@ -21,6 +21,7 @@ import paddle.fluid as fluid ...@@ -21,6 +21,7 @@ import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
import sys import sys
import time import time
import numpy as np
class GTrainer(): class GTrainer():
...@@ -271,7 +272,6 @@ class Pix2pix(object): ...@@ -271,7 +272,6 @@ class Pix2pix(object):
return return
s_time = time.time() s_time = time.time()
tensor_A, tensor_B = tensor[0]['input_A'], tensor[0]['input_B']
# optimize the generator network # optimize the generator network
g_loss_gan, g_loss_l1, fake_B_tmp = exe.run( g_loss_gan, g_loss_l1, fake_B_tmp = exe.run(
gen_trainer_program, gen_trainer_program,
...@@ -281,17 +281,18 @@ class Pix2pix(object): ...@@ -281,17 +281,18 @@ class Pix2pix(object):
], ],
feed=tensor) feed=tensor)
devices_num = utility.get_device_num(self.cfg)
fake_per_device = int(len(fake_B_tmp) / devices_num)
for dev in range(devices_num):
tensor[dev]['input_fake'] = fake_B_tmp[dev * fake_per_device : (dev+1) * fake_per_device]
# optimize the discriminator network # optimize the discriminator network
d_loss_real, d_loss_fake = exe.run(dis_trainer_program, d_loss_real, d_loss_fake = exe.run(dis_trainer_program,
fetch_list=[ fetch_list=[
dis_trainer.d_loss_real, dis_trainer.d_loss_real,
dis_trainer.d_loss_fake dis_trainer.d_loss_fake
], ],
feed={ feed=tensor)
"input_A": tensor_A,
"input_B": tensor_B,
"input_fake": fake_B_tmp
})
batch_time = time.time() - s_time batch_time = time.time() - s_time
t_time += batch_time t_time += batch_time
......
...@@ -425,3 +425,12 @@ def check_version(): ...@@ -425,3 +425,12 @@ def check_version():
except Exception as e: except Exception as e:
print(err) print(err)
sys.exit(1) sys.exit(1)
def get_device_num(args):
if args.use_gpu:
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
gpu_num = len(gpus.split(','))
return gpu_num
else:
cpu_num = os.environ.get("CPU_NUM", 1)
return int(cpu_num)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册