提交 72835980 编写于 作者: D dongshuilong

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleClas into adaface

...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1908.07919
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1512.00567v3
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import math import math
import paddle import paddle
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1704.04861
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from paddle import ParamAttr from paddle import ParamAttr
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1905.02244
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import paddle import paddle
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/pdf/1512.03385
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import numpy as np import numpy as np
...@@ -276,6 +278,7 @@ class ResNet(TheseusLayer): ...@@ -276,6 +278,7 @@ class ResNet(TheseusLayer):
config, config,
stages_pattern, stages_pattern,
version="vb", version="vb",
stem_act="relu",
class_num=1000, class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW", data_format="NCHW",
...@@ -309,13 +312,13 @@ class ResNet(TheseusLayer): ...@@ -309,13 +312,13 @@ class ResNet(TheseusLayer):
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]] [[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
} }
self.stem = nn.Sequential(* [ self.stem = nn.Sequential(*[
ConvBNLayer( ConvBNLayer(
num_channels=in_c, num_channels=in_c,
num_filters=out_c, num_filters=out_c,
filter_size=k, filter_size=k,
stride=s, stride=s,
act="relu", act=stem_act,
lr_mult=self.lr_mult_list[0], lr_mult=self.lr_mult_list[0],
data_format=data_format) data_format=data_format)
for in_c, out_c, k, s in self.stem_cfg[version] for in_c, out_c, k, s in self.stem_cfg[version]
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1409.1556
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was heavily based on https://github.com/rwightman/pytorch-image-models # Code was heavily based on https://github.com/rwightman/pytorch-image-models
# reference: https://arxiv.org/abs/1911.11929
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/BR-IDL/PaddleViT/blob/develop/image_classification/CSwin/cswin.py # Code was based on https://github.com/BR-IDL/PaddleViT/blob/develop/image_classification/CSwin/cswin.py
# reference: https://arxiv.org/abs/2107.00652
import copy import copy
import numpy as np import numpy as np
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1804.02767
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1608.06993
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was heavily based on https://github.com/facebookresearch/deit # Code was heavily based on https://github.com/facebookresearch/deit
# reference: https://arxiv.org/abs/2012.12877
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/ucbdrive/dla # Code was based on https://github.com/ucbdrive/dla
# reference: https://arxiv.org/abs/1707.06484
import math import math
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1707.01629
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/lukemelas/EfficientNet-PyTorch # Code was based on https://github.com/lukemelas/EfficientNet-PyTorch
# reference: https://arxiv.org/abs/1905.11946
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/ghostnet_pytorch # Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/ghostnet_pytorch
# reference: https://arxiv.org/abs/1911.11907
import math import math
import paddle import paddle
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://arxiv.org/abs/1409.4842
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/Meituan-AutoML/Twins # Code was based on https://github.com/Meituan-AutoML/Twins
# reference: https://arxiv.org/abs/2104.13840
from functools import partial from functools import partial
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/PingoLH/Pytorch-HarDNet # Code was based on https://github.com/PingoLH/Pytorch-HarDNet
# reference: https://arxiv.org/abs/1909.00948
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1602.07261
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/facebookresearch/LeViT # Code was based on https://github.com/facebookresearch/LeViT
# reference: https://openaccess.thecvf.com/content/ICCV2021/html/Graham_LeViT_A_Vision_Transformer_in_ConvNets_Clothing_for_Faster_Inference_ICCV_2021_paper.html
import itertools import itertools
import math import math
......
...@@ -11,11 +11,8 @@ ...@@ -11,11 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
MixNet for ImageNet-1K, implemented in Paddle. # reference: https://arxiv.org/abs/1907.09595
Original paper: 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
"""
import os import os
from inspect import isfunction from inspect import isfunction
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1801.04381
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# Code was based on https://github.com/BR-IDL/PaddleViT/blob/develop/image_classification/MobileViT/mobilevit.py # Code was based on https://github.com/BR-IDL/PaddleViT/blob/develop/image_classification/MobileViT/mobilevit.py
# and https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py # and https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
# reference: https://arxiv.org/abs/2110.02178
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was heavily based on https://github.com/whai362/PVT # Code was heavily based on https://github.com/whai362/PVT
# reference: https://arxiv.org/abs/2106.13797
from functools import partial from functools import partial
import math import math
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/d-li14/involution # Code was based on https://github.com/d-li14/involution
# reference: https://arxiv.org/abs/2103.06255
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/facebookresearch/pycls # Code was based on https://github.com/facebookresearch/pycls
# reference: https://arxiv.org/abs/1905.13214
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/DingXiaoH/RepVGG # Code was based on https://github.com/DingXiaoH/RepVGG
# reference: https://arxiv.org/abs/2101.03697
import paddle.nn as nn import paddle.nn as nn
import paddle import paddle
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1904.01169
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1904.01169 & https://arxiv.org/abs/1812.01187
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/zhanghang1989/ResNeSt # Code was based on https://github.com/zhanghang1989/ResNeSt
# reference: https://arxiv.org/abs/2004.08955
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1812.01187
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1611.05431
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://arxiv.org/abs/1805.00932
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1611.05431 & https://arxiv.org/abs/1812.01187
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/2007.00992
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1812.01187 & https://arxiv.org/abs/1709.01507
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1611.05431 & https://arxiv.org/abs/1709.01507
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1611.05431 & https://arxiv.org/abs/1812.01187 & https://arxiv.org/abs/1709.01507
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1807.11164
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1709.01507
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/microsoft/Swin-Transformer # Code was based on https://github.com/microsoft/Swin-Transformer
# reference: https://arxiv.org/abs/2103.14030
import numpy as np import numpy as np
import paddle import paddle
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch # Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch
# reference: https://arxiv.org/abs/2103.00112
import math import math
import numpy as np import numpy as np
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was heavily based on https://github.com/Visual-Attention-Network/VAN-Classification # Code was heavily based on https://github.com/Visual-Attention-Network/VAN-Classification
# reference: https://arxiv.org/abs/2202.09741
from functools import partial from functools import partial
import math import math
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# reference: https://arxiv.org/abs/2010.11929
from collections.abc import Callable from collections.abc import Callable
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://arxiv.org/abs/1610.02357
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1706.05587
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1801.07698
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import math import math
......
...@@ -17,21 +17,32 @@ from __future__ import absolute_import, division, print_function ...@@ -17,21 +17,32 @@ from __future__ import absolute_import, division, print_function
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from ppcls.arch.utils import get_param_attr_dict
class BNNeck(nn.Layer): class BNNeck(nn.Layer):
def __init__(self, num_features): def __init__(self, num_features, **kwargs):
super().__init__() super().__init__()
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0)) initializer=paddle.nn.initializer.Constant(value=1.0))
bias_attr = paddle.ParamAttr( bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0), initializer=paddle.nn.initializer.Constant(value=0.0),
trainable=False) trainable=False)
if 'weight_attr' in kwargs:
weight_attr = get_param_attr_dict(kwargs['weight_attr'])
bias_attr = None
if 'bias_attr' in kwargs:
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
self.feat_bn = nn.BatchNorm1D( self.feat_bn = nn.BatchNorm1D(
num_features, num_features,
momentum=0.9, momentum=0.9,
epsilon=1e-05, epsilon=1e-05,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
def forward(self, x): def forward(self, x):
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/2002.10857
import math import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# reference: https://arxiv.org/abs/1801.09414
import paddle import paddle
import math import math
import paddle.nn as nn import paddle.nn as nn
......
...@@ -19,16 +19,29 @@ from __future__ import print_function ...@@ -19,16 +19,29 @@ from __future__ import print_function
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from ppcls.arch.utils import get_param_attr_dict
class FC(nn.Layer): class FC(nn.Layer):
def __init__(self, embedding_size, class_num): def __init__(self, embedding_size, class_num, **kwargs):
super(FC, self).__init__() super(FC, self).__init__()
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.class_num = class_num self.class_num = class_num
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal()) initializer=paddle.nn.initializer.XavierNormal())
self.fc = paddle.nn.Linear( if 'weight_attr' in kwargs:
self.embedding_size, self.class_num, weight_attr=weight_attr) weight_attr = get_param_attr_dict(kwargs['weight_attr'])
bias_attr = None
if 'bias_attr' in kwargs:
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
self.fc = nn.Linear(
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=bias_attr)
def forward(self, input, label=None): def forward(self, input, label=None):
out = self.fc(input) out = self.fc(input)
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import six import six
import types import types
import paddle
from difflib import SequenceMatcher from difflib import SequenceMatcher
from . import backbone from . import backbone
from typing import Any, Dict, Union
def get_architectures(): def get_architectures():
...@@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10): ...@@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10):
scores.sort(key=lambda x: x[1], reverse=True) scores.sort(key=lambda x: x[1], reverse=True)
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]] similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
return similar_names return similar_names
def get_param_attr_dict(ParamAttr_config: Union[None, bool, Dict[str, Dict]]
) -> Union[None, bool, paddle.ParamAttr]:
"""parse ParamAttr from an dict
Args:
ParamAttr_config (Union[None, bool, Dict[str, Dict]]): ParamAttr configure
Returns:
Union[None, bool, paddle.ParamAttr]: Generated ParamAttr
"""
if ParamAttr_config is None:
return None
if isinstance(ParamAttr_config, bool):
return ParamAttr_config
ParamAttr_dict = {}
if 'initializer' in ParamAttr_config:
initializer_cfg = ParamAttr_config.get('initializer')
if 'name' in initializer_cfg:
initializer_name = initializer_cfg.pop('name')
ParamAttr_dict['initializer'] = getattr(
paddle.nn.initializer, initializer_name)(**initializer_cfg)
else:
raise ValueError(f"'name' must specified in initializer_cfg")
if 'learning_rate' in ParamAttr_config:
# NOTE: only support an single value now
learning_rate_value = ParamAttr_config.get('learning_rate')
if isinstance(learning_rate_value, (int, float)):
ParamAttr_dict['learning_rate'] = learning_rate_value
else:
raise ValueError(
f"learning_rate_value must be float or int, but got {type(learning_rate_value)}"
)
if 'regularizer' in ParamAttr_config:
regularizer_cfg = ParamAttr_config.get('regularizer')
if 'name' in regularizer_cfg:
# L1Decay or L2Decay
regularizer_name = regularizer_cfg.pop('name')
ParamAttr_dict['regularizer'] = getattr(
paddle.regularizer, regularizer_name)(**regularizer_cfg)
else:
raise ValueError(f"'name' must specified in regularizer_cfg")
return paddle.ParamAttr(**ParamAttr_dict)
...@@ -42,11 +42,12 @@ Optimizer: ...@@ -42,11 +42,12 @@ Optimizer:
no_weight_decay_name: pos_embed cls_token .bias norm no_weight_decay_name: pos_embed cls_token .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1.25e-4 learning_rate: 2.5e-4
eta_min: 1.25e-6 eta_min: 2.5e-6
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1.25e-7 warmup_start_lr: 2.5e-7
# data loader for train and eval # data loader for train and eval
......
...@@ -42,11 +42,12 @@ Optimizer: ...@@ -42,11 +42,12 @@ Optimizer:
no_weight_decay_name: pos_embed cls_token .bias norm no_weight_decay_name: pos_embed cls_token .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 6.25e-5 learning_rate: 1.25e-4
eta_min: 6.25e-7 eta_min: 1.25e-6
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 6.25e-8 warmup_start_lr: 1.25e-7
# data loader for train and eval # data loader for train and eval
......
...@@ -42,11 +42,12 @@ Optimizer: ...@@ -42,11 +42,12 @@ Optimizer:
no_weight_decay_name: pos_embed cls_token .bias norm no_weight_decay_name: pos_embed cls_token .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1.25e-4 learning_rate: 2.5e-4
eta_min: 1.25e-6 eta_min: 2.5e-6
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1.25e-7 warmup_start_lr: 2.5e-7
# data loader for train and eval # data loader for train and eval
......
...@@ -42,11 +42,12 @@ Optimizer: ...@@ -42,11 +42,12 @@ Optimizer:
no_weight_decay_name: pos_embed cls_token .bias norm no_weight_decay_name: pos_embed cls_token .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 3.125e-5 learning_rate: 6.25e-5
eta_min: 3.125e-7 eta_min: 6.25e-7
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 3.125e-8 warmup_start_lr: 6.25e-8
# data loader for train and eval # data loader for train and eval
......
...@@ -42,11 +42,12 @@ Optimizer: ...@@ -42,11 +42,12 @@ Optimizer:
no_weight_decay_name: pos_embed cls_token .bias norm no_weight_decay_name: pos_embed cls_token .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 2.5e-4 learning_rate: 5e-4
eta_min: 2.5e-6 eta_min: 5e-6
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 2.5e-7 warmup_start_lr: 5e-7
# data loader for train and eval # data loader for train and eval
......
...@@ -42,11 +42,12 @@ Optimizer: ...@@ -42,11 +42,12 @@ Optimizer:
no_weight_decay_name: pos_embed cls_token .bias norm no_weight_decay_name: pos_embed cls_token .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -41,10 +41,10 @@ Optimizer: ...@@ -41,10 +41,10 @@ Optimizer:
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -40,11 +40,12 @@ Optimizer: ...@@ -40,11 +40,12 @@ Optimizer:
no_weight_decay_name: norm cls_token pos_embed dist_token no_weight_decay_name: norm cls_token pos_embed dist_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 1e-3 learning_rate: 2e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
DataLoader: DataLoader:
......
...@@ -49,9 +49,8 @@ Loss: ...@@ -49,9 +49,8 @@ Loss:
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
Eval: Eval:
- DistillationGTCELoss: - CELoss:
weight: 1.0 weight: 1.0
model_names: ["Student"]
Optimizer: Optimizer:
......
...@@ -88,10 +88,8 @@ Loss: ...@@ -88,10 +88,8 @@ Loss:
s_shapes: *s_shapes s_shapes: *s_shapes
t_shapes: *t_shapes t_shapes: *t_shapes
Eval: Eval:
- DistillationGTCELoss: - CELoss:
weight: 1.0 weight: 1.0
model_names: ["Student"]
Optimizer: Optimizer:
name: Momentum name: Momentum
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -44,11 +44,12 @@ Optimizer: ...@@ -44,11 +44,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -44,11 +44,12 @@ Optimizer: ...@@ -44,11 +44,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -44,11 +44,12 @@ Optimizer: ...@@ -44,11 +44,12 @@ Optimizer:
no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token no_weight_decay_name: pos_embed1 pos_embed2 pos_embed3 pos_embed4 cls_token
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 5e-6 eta_min: 1e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 5e-7 warmup_start_lr: 1e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -41,11 +41,12 @@ Optimizer: ...@@ -41,11 +41,12 @@ Optimizer:
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -41,11 +41,12 @@ Optimizer: ...@@ -41,11 +41,12 @@ Optimizer:
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -41,11 +41,12 @@ Optimizer: ...@@ -41,11 +41,12 @@ Optimizer:
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -41,11 +41,12 @@ Optimizer: ...@@ -41,11 +41,12 @@ Optimizer:
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -41,11 +41,12 @@ Optimizer: ...@@ -41,11 +41,12 @@ Optimizer:
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -41,11 +41,12 @@ Optimizer: ...@@ -41,11 +41,12 @@ Optimizer:
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
...@@ -43,11 +43,12 @@ Optimizer: ...@@ -43,11 +43,12 @@ Optimizer:
no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block no_weight_decay_name: norm cls_token proj.0.weight proj.1.weight proj.2.weight proj.3.weight pos_block
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
# for 8 cards
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 1e-3
eta_min: 1e-5 eta_min: 2e-5
warmup_epoch: 5 warmup_epoch: 5
warmup_start_lr: 1e-6 warmup_start_lr: 2e-6
# data loader for train and eval # data loader for train and eval
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 40
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
eval_mode: "retrieval"
retrieval_feature_from: "backbone" # 'backbone' or 'neck'
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Head:
name: "FC"
embedding_size: 2048
class_num: 751
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: False
feature_from: "backbone"
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Adam
lr:
name: Piecewise
decay_epochs: [40, 70]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_train"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImageV2:
size: [128, 256]
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "query"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_test"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 40
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
eval_mode: "retrieval"
retrieval_feature_from: "features" # 'backbone' or 'features'
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Neck:
name: BNNeck
num_features: &feat_dim 2048
weight_attr:
initializer:
name: Constant
value: 1.0
bias_attr:
initializer:
name: Constant
value: 0.0
learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero
Head:
name: "FC"
embedding_size: *feat_dim
class_num: 751
weight_attr:
initializer:
name: Normal
std: 0.001
bias_attr: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: False
feature_from: "backbone"
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Adam
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_train"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImageV2:
size: [128, 256]
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0.485, 0.456, 0.406]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "query"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_test"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 40
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
eval_mode: "retrieval"
retrieval_feature_from: "features" # 'backbone' or 'features'
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Neck:
name: BNNeck
num_features: &feat_dim 2048
weight_attr:
initializer:
name: Constant
value: 1.0
bias_attr:
initializer:
name: Constant
value: 0.0
learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero
Head:
name: "FC"
embedding_size: *feat_dim
class_num: &class_num 751
weight_attr:
initializer:
name: Normal
std: 0.001
bias_attr: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: False
feature_from: "backbone"
- CenterLoss:
weight: 0.0005
num_classes: *class_num
feat_dim: *feat_dim
feature_from: "backbone"
Eval:
- CELoss:
weight: 1.0
Optimizer:
- Adam:
scope: RecModel
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
coeff: 0.0005
- SGD:
scope: CenterLoss
lr:
name: Constant
learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents.
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_train"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImageV2:
size: [128, 256]
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0.485, 0.456, 0.406]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "query"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_test"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
...@@ -43,7 +43,11 @@ class Market1501(Dataset): ...@@ -43,7 +43,11 @@ class Market1501(Dataset):
""" """
_dataset_dir = 'market1501/Market-1501-v15.09.15' _dataset_dir = 'market1501/Market-1501-v15.09.15'
def __init__(self, image_root, cls_label_path, transform_ops=None): def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2"):
self._img_root = image_root self._img_root = image_root
self._cls_path = cls_label_path # the sub folder in the dataset self._cls_path = cls_label_path # the sub folder in the dataset
self._dataset_dir = osp.join(image_root, self._dataset_dir, self._dataset_dir = osp.join(image_root, self._dataset_dir,
...@@ -51,6 +55,7 @@ class Market1501(Dataset): ...@@ -51,6 +55,7 @@ class Market1501(Dataset):
self._check_before_run() self._check_before_run()
if transform_ops: if transform_ops:
self._transform_ops = create_operators(transform_ops) self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype() self._dtype = paddle.get_default_dtype()
self._load_anno(relabel=True if 'train' in self._cls_path else False) self._load_anno(relabel=True if 'train' in self._cls_path else False)
...@@ -92,9 +97,11 @@ class Market1501(Dataset): ...@@ -92,9 +97,11 @@ class Market1501(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
try: try:
img = Image.open(self.images[idx]).convert('RGB') img = Image.open(self.images[idx]).convert('RGB')
if self.backend == "cv2":
img = np.array(img, dtype="float32").astype(np.uint8) img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops: if self._transform_ops:
img = transform(img, self._transform_ops) img = transform(img, self._transform_ops)
if self.backend == "cv2":
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx]) return (img, self.labels[idx], self.cameras[idx])
except Exception as ex: except Exception as ex:
......
...@@ -25,10 +25,14 @@ from ppcls.data.preprocess.ops.operators import DecodeImage ...@@ -25,10 +25,14 @@ from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.data.preprocess.ops.operators import ResizeImage from ppcls.data.preprocess.ops.operators import ResizeImage
from ppcls.data.preprocess.ops.operators import CropImage from ppcls.data.preprocess.ops.operators import CropImage
from ppcls.data.preprocess.ops.operators import RandCropImage from ppcls.data.preprocess.ops.operators import RandCropImage
from ppcls.data.preprocess.ops.operators import RandCropImageV2
from ppcls.data.preprocess.ops.operators import RandFlipImage from ppcls.data.preprocess.ops.operators import RandFlipImage
from ppcls.data.preprocess.ops.operators import NormalizeImage from ppcls.data.preprocess.ops.operators import NormalizeImage
from ppcls.data.preprocess.ops.operators import ToCHWImage from ppcls.data.preprocess.ops.operators import ToCHWImage
from ppcls.data.preprocess.ops.operators import AugMix from ppcls.data.preprocess.ops.operators import AugMix
from ppcls.data.preprocess.ops.operators import Pad
from ppcls.data.preprocess.ops.operators import ToTensor
from ppcls.data.preprocess.ops.operators import Normalize
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py # This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
# reference: https://arxiv.org/abs/1805.09501
from PIL import Image, ImageEnhance, ImageOps from PIL import Image, ImageEnhance, ImageOps
import numpy as np import numpy as np
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# This code is based on https://github.com/uoguelph-mlrg/Cutout # This code is based on https://github.com/uoguelph-mlrg/Cutout
# reference: https://arxiv.org/abs/1708.04552
import numpy as np import numpy as np
import random import random
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# This code is based on https://github.com/ecs-vlc/FMix
# reference: https://arxiv.org/abs/2002.12047
import math import math
import random import random
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# This code is based on https://github.com/akuxcw/GridMask # This code is based on https://github.com/akuxcw/GridMask
# reference: https://arxiv.org/abs/2001.04086.
import numpy as np import numpy as np
from PIL import Image from PIL import Image
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# This code is based on https://github.com/kkanshul/Hide-and-Seek # This code is based on https://github.com/kkanshul/Hide-and-Seek
# reference: http://krsingh.cs.ucdavis.edu/krishna_files/papers/hide_and_seek/my_files/iccv2017.pdf
import numpy as np import numpy as np
import random import random
......
...@@ -23,8 +23,9 @@ import math ...@@ -23,8 +23,9 @@ import math
import random import random
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image, ImageOps, __version__ as PILLOW_VERSION
from paddle.vision.transforms import ColorJitter as RawColorJitter from paddle.vision.transforms import ColorJitter as RawColorJitter
from paddle.vision.transforms import ToTensor, Normalize
from .autoaugment import ImageNetPolicy from .autoaugment import ImageNetPolicy
from .functional import augmentations from .functional import augmentations
...@@ -32,7 +33,7 @@ from ppcls.utils import logger ...@@ -32,7 +33,7 @@ from ppcls.utils import logger
class UnifiedResize(object): class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2"): def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = { _cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST, 'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR, 'bilinear': cv2.INTER_LINEAR,
...@@ -56,12 +57,17 @@ class UnifiedResize(object): ...@@ -56,12 +57,17 @@ class UnifiedResize(object):
resample = random.choice(resample) resample = random.choice(resample)
return cv2.resize(src, size, interpolation=resample) return cv2.resize(src, size, interpolation=resample)
def _pil_resize(src, size, resample): def _pil_resize(src, size, resample, return_numpy=True):
if isinstance(resample, tuple): if isinstance(resample, tuple):
resample = random.choice(resample) resample = random.choice(resample)
if isinstance(src, np.ndarray):
pil_img = Image.fromarray(src) pil_img = Image.fromarray(src)
else:
pil_img = src
pil_img = pil_img.resize(size, resample) pil_img = pil_img.resize(size, resample)
if return_numpy:
return np.asarray(pil_img) return np.asarray(pil_img)
return pil_img
if backend.lower() == "cv2": if backend.lower() == "cv2":
if isinstance(interpolation, str): if isinstance(interpolation, str):
...@@ -73,7 +79,8 @@ class UnifiedResize(object): ...@@ -73,7 +79,8 @@ class UnifiedResize(object):
elif backend.lower() == "pil": elif backend.lower() == "pil":
if isinstance(interpolation, str): if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()] interpolation = _pil_interp_from_str[interpolation.lower()]
self.resize_func = partial(_pil_resize, resample=interpolation) self.resize_func = partial(
_pil_resize, resample=interpolation, return_numpy=return_numpy)
else: else:
logger.warning( logger.warning(
f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
...@@ -81,6 +88,8 @@ class UnifiedResize(object): ...@@ -81,6 +88,8 @@ class UnifiedResize(object):
self.resize_func = cv2.resize self.resize_func = cv2.resize
def __call__(self, src, size): def __call__(self, src, size):
if isinstance(size, list):
size = tuple(size)
return self.resize_func(src, size) return self.resize_func(src, size)
...@@ -99,6 +108,7 @@ class DecodeImage(object): ...@@ -99,6 +108,7 @@ class DecodeImage(object):
self.channel_first = channel_first # only enabled when to_np is True self.channel_first = channel_first # only enabled when to_np is True
def __call__(self, img): def __call__(self, img):
if not isinstance(img, np.ndarray):
if six.PY2: if six.PY2:
assert type(img) is str and len( assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage" img) > 0, "invalid input 'img' in DecodeImage"
...@@ -125,7 +135,8 @@ class ResizeImage(object): ...@@ -125,7 +135,8 @@ class ResizeImage(object):
size=None, size=None,
resize_short=None, resize_short=None,
interpolation=None, interpolation=None,
backend="cv2"): backend="cv2",
return_numpy=True):
if resize_short is not None and resize_short > 0: if resize_short is not None and resize_short > 0:
self.resize_short = resize_short self.resize_short = resize_short
self.w = None self.w = None
...@@ -139,10 +150,16 @@ class ResizeImage(object): ...@@ -139,10 +150,16 @@ class ResizeImage(object):
'both 'size' and 'resize_short' are None") 'both 'size' and 'resize_short' are None")
self._resize_func = UnifiedResize( self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend) interpolation=interpolation,
backend=backend,
return_numpy=return_numpy)
def __call__(self, img): def __call__(self, img):
if isinstance(img, np.ndarray):
img_h, img_w = img.shape[:2] img_h, img_w = img.shape[:2]
else:
img_w, img_h = img.size
if self.resize_short is not None: if self.resize_short is not None:
percent = float(self.resize_short) / min(img_w, img_h) percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent)) w = int(round(img_w * percent))
...@@ -222,6 +239,40 @@ class RandCropImage(object): ...@@ -222,6 +239,40 @@ class RandCropImage(object):
return self._resize_func(img, size) return self._resize_func(img, size)
class RandCropImageV2(object):
""" RandCropImageV2 is different from RandCropImage,
it will Select a cutting position randomly in a uniform distribution way,
and cut according to the given size without resize at last."""
def __init__(self, size):
if type(size) is int:
self.size = (size, size) # (h, w)
else:
self.size = size
def __call__(self, img):
if isinstance(img, np.ndarray):
img_h, img_w = img.shap[0], img.shap[1]
else:
img_w, img_h = img.size
tw, th = self.size
if img_h + 1 < th or img_w + 1 < tw:
raise ValueError(
"Required crop size {} is larger then input image size {}".
format((th, tw), (img_h, img_w)))
if img_w == tw and img_h == th:
return img
top = random.randint(0, img_h - th + 1)
left = random.randint(0, img_w - tw + 1)
if isinstance(img, np.ndarray):
return img[top:top + th, left:left + tw, :]
else:
return img.crop((left, top, left + tw, top + th))
class RandFlipImage(object): class RandFlipImage(object):
""" random flip image """ random flip image
flip_code: flip_code:
...@@ -237,7 +288,10 @@ class RandFlipImage(object): ...@@ -237,7 +288,10 @@ class RandFlipImage(object):
def __call__(self, img): def __call__(self, img):
if random.randint(0, 1) == 1: if random.randint(0, 1) == 1:
if isinstance(img, np.ndarray):
return cv2.flip(img, self.flip_code) return cv2.flip(img, self.flip_code)
else:
return img.transpose(Image.FLIP_LEFT_RIGHT)
else: else:
return img return img
...@@ -391,3 +445,58 @@ class ColorJitter(RawColorJitter): ...@@ -391,3 +445,58 @@ class ColorJitter(RawColorJitter):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.asarray(img) img = np.asarray(img)
return img return img
class Pad(object):
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad
"""
def __init__(self, padding: int, fill: int=0,
padding_mode: str="constant"):
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"):
# Process fill color for affine transforms
major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (
int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and
minor_found < minor_required):
if fill is None:
return {}
else:
msg = (
"The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = (
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
def __call__(self, img):
opts = self._parse_fill(self.fill, img, "2.3.0", name="fill")
if img.mode == "P":
palette = img.getpalette()
img = ImageOps.expand(img, border=self.padding, **opts)
img.putpalette(palette)
return img
return ImageOps.expand(img, border=self.padding, **opts)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# This code is based on https://github.com/heartInsert/randaugment # This code is based on https://github.com/heartInsert/randaugment
# reference: https://arxiv.org/abs/1909.13719
from PIL import Image, ImageEnhance, ImageOps from PIL import Image, ImageEnhance, ImageOps
import numpy as np import numpy as np
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册