starganv2_celeba_hq.yaml 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
epochs: 200
output_dir: output_dir

model:
  name: StarGANv2Model
  latent_dim: &LATENT_DIM 16
  lambda_sty: 1
  lambda_ds: 1
  lambda_cyc: 1
  generator:
    name: StarGANv2Generator
    img_size: &IMAGE_SIZE 256
    w_hpf: 1
    style_dim: &STYLE_DIM 64
  style:
    name: StarGANv2Style
    img_size: *IMAGE_SIZE
    style_dim: *STYLE_DIM
    num_domains: &NUM_DOMAINS 2
  mapping:
    name: StarGANv2Mapping
    latent_dim: *LATENT_DIM
    style_dim: *STYLE_DIM
    num_domains: *NUM_DOMAINS
  fan:
    name: FAN
L
lzzyzlbb 已提交
27
    fname_pretrained: None
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
  discriminator:
    name: StarGANv2Discriminator
    img_size: *IMAGE_SIZE
    num_domains: *NUM_DOMAINS

dataset:
  train:
    name: StarGANv2Dataset
    dataroot: data/stargan-v2/celeba_hq/train/
    is_train: True
    num_workers: 8
    batch_size: 4
    preprocess:
      - name: LoadImageFromFile
        key: src
      - name: LoadImageFromFile
        key: ref
      - name: LoadImageFromFile
        key: ref2
      - name: Transforms
        input_keys: [src, ref, ref2]
        pipeline:
          - name: RandomResizedCropProb
            prob: 0.9
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            scale: [0.8, 1.0]
            ratio: [0.9, 1.1]
            interpolation: 'bilinear' 
            keys: [image, image, image]
          - name: Resize
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            interpolation: 'bilinear' 
            keys: [image, image, image]
          - name: RandomHorizontalFlip
            prob: 0.5
            keys: [image, image, image]
          - name: Transpose
            keys: [image, image, image]
          - name: Normalize
            mean: [127.5, 127.5, 127.5]
            std: [127.5, 127.5, 127.5]
            keys: [image, image, image]

  test:
    name: StarGANv2Dataset
    dataroot: data/stargan-v2/celeba_hq/val/
    is_train: False
    num_workers: 8
    batch_size: 16
    test_count: 16
    preprocess:
      - name: LoadImageFromFile
        key: src
      - name: LoadImageFromFile
        key: ref
      - name: Transforms
        input_keys: [src, ref]
        pipeline:
          - name: Resize
            size: [*IMAGE_SIZE, *IMAGE_SIZE]
            interpolation: 'bicubic' #cv2.INTER_CUBIC
            keys: [image, image]
          - name: Transpose
            keys: [image, image]
          - name: Normalize
            mean: [127.5, 127.5, 127.5]
            std: [127.5, 127.5, 127.5]
            keys: [image, image]

lr_scheduler:
  name: LinearDecay
  learning_rate: 0.0001
  start_epoch: 100
  decay_epochs: 100
  # will get from real dataset
  iters_per_epoch: 365

optimizer:
  generator:
    name: Adam
    net_names:
      - generator
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  style_encoder:
    name: Adam
    net_names:
      - style_encoder
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  mapping_network:
    name: Adam
    net_names:
      - mapping_network
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001
  discriminator:
    name: Adam
    net_names:
      - discriminator
    beta1: 0.0
    beta2: 0.99
    weight_decay: 0.0001

validate:
  interval: 5000
  save_img: false

log_config:
  interval: 5
  visiual_interval: 100

snapshot_config:
  interval: 5