From f621144caf2598ad9afcbeda5bcb720909c2a2ac Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Thu, 22 Aug 2019 15:34:02 +0800 Subject: [PATCH] fix fetch_list type of gan (#795) --- 02.recognize_digits/README.cn.md | 2 +- 02.recognize_digits/index.cn.html | 2 +- 09.gan/README.cn.md | 10 +++++----- 09.gan/index.cn.html | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/02.recognize_digits/README.cn.md b/02.recognize_digits/README.cn.md index fd8160a..edb4ea2 100644 --- a/02.recognize_digits/README.cn.md +++ b/02.recognize_digits/README.cn.md @@ -399,7 +399,7 @@ prediction, [avg_loss, acc] = train_program() # 输入的原始图像数据,名称为img,大小为28*28*1 # 标签层,名称为label,对应输入图片的类别标签 # 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值 -feeder = fluid.DataFeeder(feed_list=[‘img’, ‘label’], place=place) +feeder = fluid.DataFeeder(feed_list=['img', 'label'], place=place) # 选择Adam优化器 optimizer = optimizer_program() diff --git a/02.recognize_digits/index.cn.html b/02.recognize_digits/index.cn.html index e40718b..816f56a 100644 --- a/02.recognize_digits/index.cn.html +++ b/02.recognize_digits/index.cn.html @@ -441,7 +441,7 @@ prediction, [avg_loss, acc] = train_program() # 输入的原始图像数据,名称为img,大小为28*28*1 # 标签层,名称为label,对应输入图片的类别标签 # 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值 -feeder = fluid.DataFeeder(feed_list=[‘img’, ‘label’], place=place) +feeder = fluid.DataFeeder(feed_list=['img', 'label'], place=place) # 选择Adam优化器 optimizer = optimizer_program() diff --git a/09.gan/README.cn.md b/09.gan/README.cn.md index 5934efa..10266d5 100644 --- a/09.gan/README.cn.md +++ b/09.gan/README.cn.md @@ -368,7 +368,7 @@ for pass_id in range(epoch): # 虚假图片 generated_image = exe.run(g_program, feed={'noise': noise_data}, - fetch_list={g_img})[0] + fetch_list=[g_img])[0] total_images = np.concatenate([real_image, generated_image]) @@ -378,7 +378,7 @@ for pass_id in range(epoch): 'img': generated_image, 'label': fake_labels, }, - fetch_list={d_loss})[0][0] + fetch_list=[d_loss])[0][0] # D 判断真实图片为真的loss d_loss_2 = exe.run(d_program, @@ -386,7 +386,7 @@ for pass_id in range(epoch): 'img': real_image, 'label': real_labels, }, - fetch_list={d_loss})[0][0] + fetch_list=[d_loss])[0][0] d_loss_n = d_loss_1 + d_loss_2 losses[0].append(d_loss_n) @@ -398,7 +398,7 @@ for pass_id in range(epoch): size=[batch_size, NOISE_SIZE]).astype('float32') dg_loss_n = exe.run(dg_program, feed={'noise': noise_data}, - fetch_list={dg_loss})[0][0] + fetch_list=[dg_loss])[0][0] losses[1].append(dg_loss_n) t_time += (time.time() - s_time) if batch_id % 10 == 0 : @@ -407,7 +407,7 @@ for pass_id in range(epoch): # 每轮的生成结果 generated_images = exe.run(g_program_test, feed={'noise': const_n}, - fetch_list={g_img})[0] + fetch_list=[g_img])[0] # 将真实图片和生成图片连接 total_images = np.concatenate([real_image, generated_images]) fig = plot(total_images) diff --git a/09.gan/index.cn.html b/09.gan/index.cn.html index a80ffc3..ca7ee66 100644 --- a/09.gan/index.cn.html +++ b/09.gan/index.cn.html @@ -410,7 +410,7 @@ for pass_id in range(epoch): # 虚假图片 generated_image = exe.run(g_program, feed={'noise': noise_data}, - fetch_list={g_img})[0] + fetch_list=[g_img])[0] total_images = np.concatenate([real_image, generated_image]) @@ -420,7 +420,7 @@ for pass_id in range(epoch): 'img': generated_image, 'label': fake_labels, }, - fetch_list={d_loss})[0][0] + fetch_list=[d_loss])[0][0] # D 判断真实图片为真的loss d_loss_2 = exe.run(d_program, @@ -428,7 +428,7 @@ for pass_id in range(epoch): 'img': real_image, 'label': real_labels, }, - fetch_list={d_loss})[0][0] + fetch_list=[d_loss])[0][0] d_loss_n = d_loss_1 + d_loss_2 losses[0].append(d_loss_n) @@ -440,7 +440,7 @@ for pass_id in range(epoch): size=[batch_size, NOISE_SIZE]).astype('float32') dg_loss_n = exe.run(dg_program, feed={'noise': noise_data}, - fetch_list={dg_loss})[0][0] + fetch_list=[dg_loss])[0][0] losses[1].append(dg_loss_n) t_time += (time.time() - s_time) if batch_id % 10 == 0 : @@ -449,7 +449,7 @@ for pass_id in range(epoch): # 每轮的生成结果 generated_images = exe.run(g_program_test, feed={'noise': const_n}, - fetch_list={g_img})[0] + fetch_list=[g_img])[0] # 将真实图片和生成图片连接 total_images = np.concatenate([real_image, generated_images]) fig = plot(total_images) -- GitLab