starganv2_celeba_hq.yaml 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
epochs: 200
output_dir: output_dir

model:
  name: StarGANv2Model
  latent_dim: &LATENT_DIM 16
  lambda_sty: 1
  lambda_ds: 1
  lambda_cyc: 1
  generator:
    name: StarGANv2Generator
    img_size: &IMAGE_SIZE 256
    w_hpf: 1
    style_dim: &STYLE_DIM 64
  style:
    name: StarGANv2Style
    img_size: *IMAGE_SIZE
    style_dim: *STYLE_DIM
    num_domains: &NUM_DOMAINS 2
  mapping:
    name: StarGANv2Mapping
    latent_dim: *LATENT_DIM
    style_dim: *STYLE_DIM
    num_domains: *NUM_DOMAINS
  fan:
    name: FAN
    fname_pretrained: models/stargan-v2/wing.pdparams
  discriminator:
    name: StarGANv2Discriminator
    img_size: *IMAGE_SIZE
    num_domains: *NUM_DOMAINS

dataset:
  train:
    name: StarGANv2Dataset
    dataroot: data/stargan-v2/celeba_hq/train/
    is_train: True
    num_workers: 8
    batch_size: 4
    preprocess:
      - name: LoadImageFromFile
        key: src
      - name: LoadImageFromFile
        key: ref
      - name: LoadImageFromFile
        key: ref2
      - name: Transforms
        input_keys: [src, ref, ref2]
        pipeline:
          - name: RandomResizedCropProb
            prob: 0.9
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            scale: [0.8, 1.0]
            ratio: [0.9, 1.1]
            interpolation: 'bilinear' 
            keys: [image, image, image]
          - name: Resize
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            interpolation: 'bilinear' 
            keys: [image, image, image]
          - name: RandomHorizontalFlip
            prob: 0.5
            keys: [image, image, image]
          - name: Transpose
            keys: [image, image, image]
          - name: Normalize
            mean: [127.5, 127.5, 127.5]
            std: [127.5, 127.5, 127.5]
            keys: [image, image, image]

  test:
    name: StarGANv2Dataset
    dataroot: data/stargan-v2/celeba_hq/val/
    is_train: False
    num_workers: 8
    batch_size: 16
    test_count: 16
    preprocess:
      - name: LoadImageFromFile
        key: src
      - name: LoadImageFromFile
        key: ref
      - name: Transforms
        input_keys: [src, ref]
        pipeline:
          - name: Resize
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            interpolation: 'bicubic' #cv2.INTER_CUBIC
            keys: [image, image]
          - name: Transpose
            keys: [image, image]
          - name: Normalize
            mean: [127.5, 127.5, 127.5]
            std: [127.5, 127.5, 127.5]
            keys: [image, image]

lr_scheduler:
  name: LinearDecay
  learning_rate: 0.0001
  start_epoch: 100
  decay_epochs: 100
  # will get from real dataset
  iters_per_epoch: 365

optimizer:
  generator:
    name: Adam
    net_names:
      - generator
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  style_encoder:
    name: Adam
    net_names:
      - style_encoder
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  mapping_network:
    name: Adam
    net_names:
      - mapping_network
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  discriminator:
    name: Adam
    net_names:
      - discriminator
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001

validate:
  interval: 5000
  save_img: false

log_config:
  interval: 5
  visiual_interval: 100

snapshot_config:
  interval: 5