From 24f72797d13cf912678e96c5932449713437a151 Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Mon, 24 Jun 2019 11:57:02 +0800 Subject: [PATCH] Fix py3 in cyclegan and save image (#2499) * fix py3 --- PaddleCV/gan/trainer/CycleGAN.py | 12 ++++++------ PaddleCV/gan/util/utility.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/PaddleCV/gan/trainer/CycleGAN.py b/PaddleCV/gan/trainer/CycleGAN.py index a8b9561e..751695f9 100644 --- a/PaddleCV/gan/trainer/CycleGAN.py +++ b/PaddleCV/gan/trainer/CycleGAN.py @@ -96,11 +96,11 @@ class GTrainer(): learning_rate=fluid.layers.piecewise_decay( boundaries=[99 * step_per_epoch] + [ x * step_per_epoch - for x in xrange(100, cfg.epoch - 1) + for x in range(100, cfg.epoch - 1) ], values=[lr] + [ lr * (1.0 - (x - 99.0) / 101.0) - for x in xrange(100, cfg.epoch) + for x in range(100, cfg.epoch) ]), beta1=0.5, beta2=0.999, @@ -136,11 +136,11 @@ class DATrainer(): learning_rate=fluid.layers.piecewise_decay( boundaries=[99 * step_per_epoch] + [ x * step_per_epoch - for x in xrange(100, cfg.epoch - 1) + for x in range(100, cfg.epoch - 1) ], values=[lr] + [ lr * (1.0 - (x - 99.0) / 101.0) - for x in xrange(100, cfg.epoch) + for x in range(100, cfg.epoch) ]), beta1=0.5, beta2=0.999, @@ -175,11 +175,11 @@ class DBTrainer(): learning_rate=fluid.layers.piecewise_decay( boundaries=[99 * step_per_epoch] + [ x * step_per_epoch - for x in xrange(100, cfg.epoch - 1) + for x in range(100, cfg.epoch - 1) ], values=[lr] + [ lr * (1.0 - (x - 99.0) / 101.0) - for x in xrange(100, cfg.epoch) + for x in range(100, cfg.epoch) ]), beta1=0.5, beta2=0.999, diff --git a/PaddleCV/gan/util/utility.py b/PaddleCV/gan/util/utility.py index d14878a8..aacfb1e6 100644 --- a/PaddleCV/gan/util/utility.py +++ b/PaddleCV/gan/util/utility.py @@ -149,7 +149,8 @@ def save_test_image(epoch, for j in range(len(label_org)): label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] - label_trg_ = map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp) + label_trg_ = list( + map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) for j in range(len(label_org)): label_trg_[j][i] = label_trg_[j][i] * 2.0 -- GitLab