未验证 提交 f8da1e2e 编写于 作者: L lvmengsi 提交者: GitHub

fix infer (#2819)

上级 cea2bbcc
...@@ -177,8 +177,8 @@ def infer(args): ...@@ -177,8 +177,8 @@ def infer(args):
label_trg_ = list( label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
if args.model_net == 'AttGAN': if args.model_net == 'AttGAN':
for j in range(len(label_org)): for k in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0 label_trg_[k][i] = label_trg_[k][i] * 2.0
tensor_label_org_.set(label_org_, place) tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place) tensor_label_trg_.set(label_trg_, place)
......
...@@ -160,8 +160,8 @@ def save_test_image(epoch, ...@@ -160,8 +160,8 @@ def save_test_image(epoch,
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
if cfg.model_net == 'AttGAN': if cfg.model_net == 'AttGAN':
for j in range(len(label_org)): for k in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0 label_trg_[k][i] = label_trg_[k][i] * 2.0
tensor_label_org_.set(label_org_, place) tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place) tensor_label_trg_.set(label_trg_, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册