未验证 提交 2f531842 编写于 作者: L lvmengsi 提交者: GitHub

fix gan in python3 (#2494)

上级 fcc4581e
...@@ -294,7 +294,6 @@ class celeba_reader_creator(reader_creator): ...@@ -294,7 +294,6 @@ class celeba_reader_creator(reader_creator):
img = Image.open(os.path.join(self.image_dir, img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB') file)).convert('RGB')
label = np.array(label).astype("float32") label = np.array(label).astype("float32")
label = (label + 1) // 2
img = CentorCrop(img, args.crop_size, args.crop_size) img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size), img = img.resize((args.image_size, args.image_size),
Image.BILINEAR) Image.BILINEAR)
......
...@@ -291,10 +291,12 @@ class AttGAN(object): ...@@ -291,10 +291,12 @@ class AttGAN(object):
label_trg = copy.deepcopy(label_org) label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg) np.random.shuffle(label_trg)
label_org_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int, label_org_ = list(
label_org) map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int, label_org))
label_trg) label_trg_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg))
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor() tensor_label_org = fluid.LoDTensor()
......
...@@ -296,10 +296,12 @@ class STGAN(object): ...@@ -296,10 +296,12 @@ class STGAN(object):
label_trg = copy.deepcopy(label_org) label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg) np.random.shuffle(label_trg)
label_org_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int, label_org_ = list(
label_org) map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int, label_org))
label_trg) label_trg_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg))
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor() tensor_label_org = fluid.LoDTensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册