未验证 提交 6e12f8a1 编写于 作者: L lvmengsi 提交者: GitHub

change set (#3042)

change set to list in c_gan
上级 d594e88b
......@@ -31,7 +31,6 @@ matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
NOISE_SIZE = 100
LEARNING_RATE = 2e-4
......@@ -98,8 +97,7 @@ def train(args):
exe.run(fluid.default_startup_program())
if args.run_ce:
train_reader = paddle.batch(
paddle.dataset.mnist.train(),
batch_size=args.batch_size)
paddle.dataset.mnist.train(), batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.reader.shuffle(
......@@ -111,7 +109,7 @@ def train(args):
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
t_time = 0
losses = [[],[]]
losses = [[], []]
for pass_id in range(args.epoch):
for batch_id, data in enumerate(train_reader()):
if len(data) != args.batch_size:
......@@ -133,7 +131,7 @@ def train(args):
g_program,
feed={'noise': noise_data,
'conditions': conditions_data},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image])
......@@ -143,7 +141,7 @@ def train(args):
'label': fake_labels,
'conditions': conditions_data
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_2 = exe.run(d_program,
feed={
......@@ -151,7 +149,7 @@ def train(args):
'label': real_labels,
'conditions': conditions_data
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n)
......@@ -163,13 +161,11 @@ def train(args):
dg_program,
feed={'noise': noise_data,
'conditions': conditions_data},
fetch_list={dg_loss})[0][0]
fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n)
batch_time = time.time() - s_time
t_time += batch_time
if batch_id % 10 == 0 and not args.run_ce:
if not os.path.exists(args.output):
os.makedirs(args.output)
......@@ -178,7 +174,7 @@ def train(args):
g_program_test,
feed={'noise': const_n,
'conditions': conditions_data},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}\n " \
......@@ -196,7 +192,7 @@ def train(args):
print("kpis,cgan_d_train_cost,{}".format(np.mean(losses[0])))
print("kpis,cgan_g_train_cost,{}".format(np.mean(losses[1])))
print("kpis,cgan_duration,{}".format(t_time / args.epoch))
if __name__ == "__main__":
args = parser.parse_args()
......
......@@ -93,8 +93,7 @@ def train(args):
if args.run_ce:
train_reader = paddle.batch(
paddle.dataset.mnist.train(),
batch_size=args.batch_size)
paddle.dataset.mnist.train(), batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.reader.shuffle(
......@@ -125,7 +124,7 @@ def train(args):
s_time = time.time()
generated_image = exe.run(g_program,
feed={'noise': noise_data},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image])
......@@ -134,14 +133,14 @@ def train(args):
'img': generated_image,
'label': fake_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_2 = exe.run(d_program,
feed={
'img': real_image,
'label': real_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n)
......@@ -150,8 +149,8 @@ def train(args):
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
dg_loss_n = exe.run(dg_program,
feed={'noise': noise_data},
fetch_list={dg_loss})[0][0]
feed={'noise': noise_data},
fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not args.run_ce:
......@@ -160,12 +159,12 @@ def train(args):
# generate image each batch
generated_images = exe.run(g_program_test,
feed={'noise': const_n},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
msg = "Epoch ID={0} Batch ID={1} D-Loss={2} DG-Loss={3}\n gen={4}".format(
pass_id, batch_id,
d_loss_n, dg_loss_n, check(generated_images))
pass_id, batch_id, d_loss_n, dg_loss_n,
check(generated_images))
print(msg)
plt.title(msg)
plt.savefig(
......@@ -177,7 +176,7 @@ def train(args):
print("kpis,dcgan_d_train_cost,{}".format(np.mean(losses[0])))
print("kpis,dcgan_g_train_cost,{}".format(np.mean(losses[1])))
print("kpis,dcgan_duration,{}".format(t_time / args.epoch))
if __name__ == "__main__":
args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册