未验证 提交 6e12f8a1 编写于 作者: L lvmengsi 提交者: GitHub

change set (#3042)

change set to list in c_gan
上级 d594e88b
...@@ -31,7 +31,6 @@ matplotlib.use('agg') ...@@ -31,7 +31,6 @@ matplotlib.use('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
NOISE_SIZE = 100 NOISE_SIZE = 100
LEARNING_RATE = 2e-4 LEARNING_RATE = 2e-4
...@@ -98,8 +97,7 @@ def train(args): ...@@ -98,8 +97,7 @@ def train(args):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if args.run_ce: if args.run_ce:
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), paddle.dataset.mnist.train(), batch_size=args.batch_size)
batch_size=args.batch_size)
else: else:
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -111,7 +109,7 @@ def train(args): ...@@ -111,7 +109,7 @@ def train(args):
low=-1.0, high=1.0, low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32') size=[args.batch_size, NOISE_SIZE]).astype('float32')
t_time = 0 t_time = 0
losses = [[],[]] losses = [[], []]
for pass_id in range(args.epoch): for pass_id in range(args.epoch):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if len(data) != args.batch_size: if len(data) != args.batch_size:
...@@ -133,7 +131,7 @@ def train(args): ...@@ -133,7 +131,7 @@ def train(args):
g_program, g_program,
feed={'noise': noise_data, feed={'noise': noise_data,
'conditions': conditions_data}, 'conditions': conditions_data},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image]) total_images = np.concatenate([real_image, generated_image])
...@@ -143,7 +141,7 @@ def train(args): ...@@ -143,7 +141,7 @@ def train(args):
'label': fake_labels, 'label': fake_labels,
'conditions': conditions_data 'conditions': conditions_data
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_2 = exe.run(d_program, d_loss_2 = exe.run(d_program,
feed={ feed={
...@@ -151,7 +149,7 @@ def train(args): ...@@ -151,7 +149,7 @@ def train(args):
'label': real_labels, 'label': real_labels,
'conditions': conditions_data 'conditions': conditions_data
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2 d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n) losses[0].append(d_loss_n)
...@@ -163,13 +161,11 @@ def train(args): ...@@ -163,13 +161,11 @@ def train(args):
dg_program, dg_program,
feed={'noise': noise_data, feed={'noise': noise_data,
'conditions': conditions_data}, 'conditions': conditions_data},
fetch_list={dg_loss})[0][0] fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
batch_time = time.time() - s_time batch_time = time.time() - s_time
t_time += batch_time t_time += batch_time
if batch_id % 10 == 0 and not args.run_ce: if batch_id % 10 == 0 and not args.run_ce:
if not os.path.exists(args.output): if not os.path.exists(args.output):
os.makedirs(args.output) os.makedirs(args.output)
...@@ -178,7 +174,7 @@ def train(args): ...@@ -178,7 +174,7 @@ def train(args):
g_program_test, g_program_test,
feed={'noise': const_n, feed={'noise': const_n,
'conditions': conditions_data}, 'conditions': conditions_data},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_images]) total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images) fig = plot(total_images)
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}\n " \ msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}\n " \
...@@ -196,7 +192,7 @@ def train(args): ...@@ -196,7 +192,7 @@ def train(args):
print("kpis,cgan_d_train_cost,{}".format(np.mean(losses[0]))) print("kpis,cgan_d_train_cost,{}".format(np.mean(losses[0])))
print("kpis,cgan_g_train_cost,{}".format(np.mean(losses[1]))) print("kpis,cgan_g_train_cost,{}".format(np.mean(losses[1])))
print("kpis,cgan_duration,{}".format(t_time / args.epoch)) print("kpis,cgan_duration,{}".format(t_time / args.epoch))
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
......
...@@ -93,8 +93,7 @@ def train(args): ...@@ -93,8 +93,7 @@ def train(args):
if args.run_ce: if args.run_ce:
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), paddle.dataset.mnist.train(), batch_size=args.batch_size)
batch_size=args.batch_size)
else: else:
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -125,7 +124,7 @@ def train(args): ...@@ -125,7 +124,7 @@ def train(args):
s_time = time.time() s_time = time.time()
generated_image = exe.run(g_program, generated_image = exe.run(g_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image]) total_images = np.concatenate([real_image, generated_image])
...@@ -134,14 +133,14 @@ def train(args): ...@@ -134,14 +133,14 @@ def train(args):
'img': generated_image, 'img': generated_image,
'label': fake_labels, 'label': fake_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_2 = exe.run(d_program, d_loss_2 = exe.run(d_program,
feed={ feed={
'img': real_image, 'img': real_image,
'label': real_labels, 'label': real_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2 d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n) losses[0].append(d_loss_n)
...@@ -150,8 +149,8 @@ def train(args): ...@@ -150,8 +149,8 @@ def train(args):
low=-1.0, high=1.0, low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32') size=[args.batch_size, NOISE_SIZE]).astype('float32')
dg_loss_n = exe.run(dg_program, dg_loss_n = exe.run(dg_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={dg_loss})[0][0] fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not args.run_ce: if batch_id % 10 == 0 and not args.run_ce:
...@@ -160,12 +159,12 @@ def train(args): ...@@ -160,12 +159,12 @@ def train(args):
# generate image each batch # generate image each batch
generated_images = exe.run(g_program_test, generated_images = exe.run(g_program_test,
feed={'noise': const_n}, feed={'noise': const_n},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_images]) total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images) fig = plot(total_images)
msg = "Epoch ID={0} Batch ID={1} D-Loss={2} DG-Loss={3}\n gen={4}".format( msg = "Epoch ID={0} Batch ID={1} D-Loss={2} DG-Loss={3}\n gen={4}".format(
pass_id, batch_id, pass_id, batch_id, d_loss_n, dg_loss_n,
d_loss_n, dg_loss_n, check(generated_images)) check(generated_images))
print(msg) print(msg)
plt.title(msg) plt.title(msg)
plt.savefig( plt.savefig(
...@@ -177,7 +176,7 @@ def train(args): ...@@ -177,7 +176,7 @@ def train(args):
print("kpis,dcgan_d_train_cost,{}".format(np.mean(losses[0]))) print("kpis,dcgan_d_train_cost,{}".format(np.mean(losses[0])))
print("kpis,dcgan_g_train_cost,{}".format(np.mean(losses[1]))) print("kpis,dcgan_g_train_cost,{}".format(np.mean(losses[1])))
print("kpis,dcgan_duration,{}".format(t_time / args.epoch)) print("kpis,dcgan_duration,{}".format(t_time / args.epoch))
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册