未验证 提交 2f531842 编写于 作者: L lvmengsi 提交者: GitHub

fix gan in python3 (#2494)

上级 fcc4581e
......@@ -294,7 +294,6 @@ class celeba_reader_creator(reader_creator):
img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB')
label = np.array(label).astype("float32")
label = (label + 1) // 2
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
......
......@@ -291,10 +291,12 @@ class AttGAN(object):
label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg)
label_org_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_org)
label_trg_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg)
label_org_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_org))
label_trg_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg))
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
......
......@@ -296,10 +296,12 @@ class STGAN(object):
label_trg = copy.deepcopy(label_org)
np.random.shuffle(label_trg)
label_org_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_org)
label_trg_ = map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg)
label_org_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_org))
label_trg_ = list(
map(lambda x: (x * 2.0 - 1.0) * self.cfg.thres_int,
label_trg))
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册