提交 f8978a2f 编写于 作者: Z zh-hike 提交者: Walter

修改cifar100参数配置

上级 d8f049ae
......@@ -97,7 +97,6 @@ DataLoader:
- RandCropImageV2:
size: [32, 32]
- NormalizeImage:
- Normalize:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
......@@ -152,7 +151,7 @@ DataLoader:
order: hwc
transform_ops_strong2:
- RandCropImageV2:
- RandomResizedCrop:
size: [32, 32]
- RandFlipImage:
flip_code: 1
......@@ -168,8 +167,8 @@ DataLoader:
p: 0.2
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
mean: [0., 0., 0.]
std: [1., 1., 1.]
order: hwc
......
......@@ -60,7 +60,7 @@ Loss:
UnLabelLoss:
Train:
- CCSSLCeLoss:
- CCSSLCELoss:
weight: 1.
- SoftSupConLoss:
weight: 1.0
......
......@@ -89,12 +89,14 @@ class Cifar100(Cifar100_paddle):
expand_labels=1,
transform_ops=None,
transform_ops_weak=None,
transform_ops_strong=None):
transform_ops_strong=None,
transform_ops_strong2=None):
super().__init__(data_file, mode, None, download, backend)
assert isinstance(expand_labels, int)
self._transform_ops = create_operators(transform_ops)
self._transform_ops_weak = create_operators(transform_ops_weak)
self._transform_ops_strong = create_operators(transform_ops_strong)
self._transform_ops_strong2 = create_operators(transform_ops_strong2)
self.class_num = 100
labels = []
......@@ -117,6 +119,14 @@ class Cifar100(Cifar100_paddle):
image1 = transform(image, self._transform_ops)
image1 = image1.transpose((2, 0, 1))
return (image1, np.int64(label))
elif self._transform_ops_weak and self._transform_ops_strong and self._transform_ops_strong2:
image2 = transform(image, self._transform_ops_weak)
image2 = image2.transpose((2, 0, 1))
image3 = transform(image, self._transform_ops_strong)
image3 = image3.transpose((2, 0, 1))
image4 = transform(image, self._transform_ops_strong2)
image4 = image4.transpose((2, 0, 1))
return (image2, image3, image4, np.int64(label))
elif self._transform_ops_weak and self._transform_ops_strong:
image2 = transform(image, self._transform_ops_weak)
image2 = image2.transpose((2, 0, 1))
......
......@@ -41,7 +41,7 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
engine.unlabel_train_dataloader_iter = iter(engine.unlabel_train_dataloader)
unlabel_data_batch = engine.unlabel_train_dataloader_iter.next()
assert len(unlabel_data_batch) == 4
assert len(unlabel_data_batch) in [3, 4]
assert unlabel_data_batch[0].shape == unlabel_data_batch[1].shape == unlabel_data_batch[2].shape
engine.time_info['reader_cost'].update(time.time() - tic)
......
......@@ -39,4 +39,4 @@ class ExponentialMovingAverage():
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
self._update(model, update_fn=lambda e, m: m)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册