starganv2_celeba_hq.yaml 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
epochs: 200
output_dir: output_dir

model:
  name: StarGANv2Model
  latent_dim: &LATENT_DIM 16
  lambda_sty: 1
  lambda_ds: 1
  lambda_cyc: 1
  generator:
    name: StarGANv2Generator
    img_size: &IMAGE_SIZE 256
    w_hpf: 1
    style_dim: &STYLE_DIM 64
  style:
    name: StarGANv2Style
    img_size: *IMAGE_SIZE
    style_dim: *STYLE_DIM
    num_domains: &NUM_DOMAINS 2
  mapping:
    name: StarGANv2Mapping
    latent_dim: *LATENT_DIM
    style_dim: *STYLE_DIM
    num_domains: *NUM_DOMAINS
  fan:
    name: FAN
L
lzzyzlbb 已提交
27
    fname_pretrained: None
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  discriminator:
    name: StarGANv2Discriminator
    img_size: *IMAGE_SIZE
    num_domains: *NUM_DOMAINS

dataset:
  train:
    name: StarGANv2Dataset
    dataroot: data/stargan-v2/celeba_hq/train/
    is_train: True
    num_workers: 8
    batch_size: 4
    preprocess:
      - name: LoadImageFromFile
        key: src
      - name: LoadImageFromFile
        key: ref
      - name: LoadImageFromFile
        key: ref2
      - name: Transforms
        input_keys: [src, ref, ref2]
        pipeline:
          - name: RandomResizedCropProb
            prob: 0.9
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            scale: [0.8, 1.0]
            ratio: [0.9, 1.1]
W
wangna11BD 已提交
55
            interpolation: 'bilinear'
56 57 58
            keys: [image, image, image]
          - name: Resize
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
W
wangna11BD 已提交
59
            interpolation: 'bilinear'
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
            keys: [image, image, image]
          - name: RandomHorizontalFlip
            prob: 0.5
            keys: [image, image, image]
          - name: Transpose
            keys: [image, image, image]
          - name: Normalize
            mean: [127.5, 127.5, 127.5]
            std: [127.5, 127.5, 127.5]
            keys: [image, image, image]

  test:
    name: StarGANv2Dataset
    dataroot: data/stargan-v2/celeba_hq/val/
    is_train: False
    num_workers: 8
    batch_size: 16
    test_count: 16
    preprocess:
      - name: LoadImageFromFile
        key: src
      - name: LoadImageFromFile
        key: ref
      - name: Transforms
        input_keys: [src, ref]
        pipeline:
          - name: Resize
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            interpolation: 'bicubic' #cv2.INTER_CUBIC
            keys: [image, image]
          - name: Transpose
            keys: [image, image]
          - name: Normalize
            mean: [127.5, 127.5, 127.5]
            std: [127.5, 127.5, 127.5]
            keys: [image, image]

lr_scheduler:
  name: LinearDecay
  learning_rate: 0.0001
  start_epoch: 100
  decay_epochs: 100
  # will get from real dataset
  iters_per_epoch: 365

optimizer:
  generator:
    name: Adam
    net_names:
      - generator
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  style_encoder:
    name: Adam
    net_names:
      - style_encoder
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  mapping_network:
    name: Adam
    net_names:
      - mapping_network
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  discriminator:
    name: Adam
    net_names:
      - discriminator
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001

validate:
W
wangna11BD 已提交
136
  interval: 3000
137 138 139
  save_img: false

log_config:
W
wangna11BD 已提交
140 141
  interval: 100
  visiual_interval: 3000
142 143 144

snapshot_config:
  interval: 5