提交 4fdcda7c 编写于 作者: H HydrogenSulfate

fix bug in randaug, train_progressive and efficientnet_v2

上级 7e12c73e
......@@ -268,7 +268,8 @@ v2_xl_block = [ # only for 21k pretraining.
]
efficientnetv2_params = {
# params: (block, width, depth, dropout)
"efficientnetv2-s": (v2_s_block, 1.0, 1.0, np.linspace(0.1, 0.3, 4)),
"efficientnetv2-s":
(v2_s_block, 1.0, 1.0, np.linspace(0.1, 0.3, 4).tolist()),
"efficientnetv2-m": (v2_m_block, 1.0, 1.0, 0.3),
"efficientnetv2-l": (v2_l_block, 1.0, 1.0, 0.4),
"efficientnetv2-xl": (v2_xl_block, 1.0, 1.0, 0.4),
......
......@@ -109,6 +109,18 @@ class RandAugmentV2(RawRandAugmentV2):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
class TimmAutoAugment(RawTimmAutoAugment):
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
......
......@@ -203,6 +203,8 @@ class RandAugmentV2(RandAugment):
"cutout": int(40 * abso_level)
}
# from https://stackoverflow.com/questions/5252170/
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
......
......@@ -48,11 +48,12 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step):
cur_image_size = engine.config["DataLoader"]["Train"]["dataset"][
"transform_ops"][1]["RandCropImage"]["progress_size"][stage_id]
cur_magnitude = engine.config["DataLoader"]["Train"]["dataset"][
"transform_ops"][3]["RandAugment"]["progress_magnitude"][stage_id]
"transform_ops"][3]["RandAugmentV2"]["progress_magnitude"][
stage_id]
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][1][
"RandCropImage"]["size"] = cur_image_size
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][3][
"RandAugment"]["magnitude"] = cur_magnitude
"RandAugmentV2"]["magnitude"] = cur_magnitude
engine.train_dataloader = build_dataloader(
engine.config["DataLoader"],
"Train",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册