未验证 提交 8c7878d9 编写于 作者: L LielinJiang 提交者: GitHub

Add reference for some codes (#460)

* add copyright

* update comment
上级 79b2287d
# code was reference to mmcv
import cv2 import cv2
from .builder import PREPROCESS from .builder import PREPROCESS
......
# code was heavily based on https://github.com/wtjiang98/PSGAN # code was heavily based on https://github.com/wtjiang98/PSGAN
# MIT License # MIT License
# Copyright (c) 2020 Wentao Jiang # Copyright (c) 2020 Wentao Jiang
import os import os
......
...@@ -11,5 +11,3 @@ ...@@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .resnet_backbone import resnet18, resnet34, resnet50, resnet101, resnet152
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
__all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
]
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2D(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Layer):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False)
self.bn1 = nn.BatchNorm(planes)
self.conv2 = nn.Conv2D(planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(planes)
self.conv3 = nn.Conv2D(planes,
planes * 4,
kernel_size=1,
bias_attr=False)
self.bn3 = nn.BatchNorm(planes * 4)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Layer):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2D(3,
64,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = nn.BatchNorm(64)
self.relu = nn.ReLU()
self.maxpool = nn.Pool2D(pool_size=3, pool_stride=2, pool_padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.Pool2D(7, pool_stride=1, pool_type='avg')
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = paddle.reshape(x, (x.shape[0], -1))
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
# code was based on https://github.com/znxlwm/UGATIT-pytorch
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# code was based on https://github.com/xinntao/ESRGAN
import paddle.nn as nn import paddle.nn as nn
from .builder import DISCRIMINATORS from .builder import DISCRIMINATORS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # MIT License
# # Copyright (c) 2020 Yong Guo
#Licensed under the Apache License, Version 2.0 (the "License"); # code was based on https://github.com/guoyongcs/DRN
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math import math
import paddle import paddle
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # code was based on https://github.com/fastai/fastai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
......
# code was based on https://github.com/hellloxiaotian/LESRCNN
import math import math
import numpy as np import numpy as np
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # code was based on https://github.com/swz30/MPRNet
# # Users should be careful about adopting these functions in any commercial matters.
# Licensed under the Apache License, Version 2.0 (the "License"); # https://github.com/swz30/MPRNet/blob/main/LICENSE.md
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
...@@ -23,7 +13,6 @@ from ...modules.init import kaiming_normal_, constant_ ...@@ -23,7 +13,6 @@ from ...modules.init import kaiming_normal_, constant_
from .builder import GENERATORS from .builder import GENERATORS
##########################################################################
def conv(in_channels, out_channels, kernel_size, bias_attr=False, stride=1): def conv(in_channels, out_channels, kernel_size, bias_attr=False, stride=1):
return nn.Conv2D(in_channels, return nn.Conv2D(in_channels,
out_channels, out_channels,
...@@ -33,7 +22,6 @@ def conv(in_channels, out_channels, kernel_size, bias_attr=False, stride=1): ...@@ -33,7 +22,6 @@ def conv(in_channels, out_channels, kernel_size, bias_attr=False, stride=1):
stride=stride) stride=stride)
##########################################################################
## Channel Attention Layer ## Channel Attention Layer
class CALayer(nn.Layer): class CALayer(nn.Layer):
def __init__(self, channel, reduction=16, bias_attr=False): def __init__(self, channel, reduction=16, bias_attr=False):
...@@ -59,7 +47,6 @@ class CALayer(nn.Layer): ...@@ -59,7 +47,6 @@ class CALayer(nn.Layer):
return x * y return x * y
##########################################################################
## Channel Attention Block (CAB) ## Channel Attention Block (CAB)
class CAB(nn.Layer): class CAB(nn.Layer):
def __init__(self, n_feat, kernel_size, reduction, bias_attr, act): def __init__(self, n_feat, kernel_size, reduction, bias_attr, act):
...@@ -81,7 +68,6 @@ class CAB(nn.Layer): ...@@ -81,7 +68,6 @@ class CAB(nn.Layer):
return res return res
##########################################################################
##---------- Resizing Modules ---------- ##---------- Resizing Modules ----------
class DownSample(nn.Layer): class DownSample(nn.Layer):
def __init__(self, in_channels, s_factor): def __init__(self, in_channels, s_factor):
...@@ -274,7 +260,6 @@ class Decoder(nn.Layer): ...@@ -274,7 +260,6 @@ class Decoder(nn.Layer):
return [dec1, dec2, dec3] return [dec1, dec2, dec3]
##########################################################################
## Original Resolution Block (ORB) ## Original Resolution Block (ORB)
class ORB(nn.Layer): class ORB(nn.Layer):
def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, num_cab): def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, num_cab):
...@@ -293,7 +278,6 @@ class ORB(nn.Layer): ...@@ -293,7 +278,6 @@ class ORB(nn.Layer):
return res return res
##########################################################################
class ORSNet(nn.Layer): class ORSNet(nn.Layer):
def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act,
bias_attr, scale_unetfeats, num_cab): bias_attr, scale_unetfeats, num_cab):
...@@ -358,8 +342,7 @@ class ORSNet(nn.Layer): ...@@ -358,8 +342,7 @@ class ORSNet(nn.Layer):
return x return x
########################################################################## # Supervised Attention Module
## Supervised Attention Module
class SAM(nn.Layer): class SAM(nn.Layer):
def __init__(self, n_feat, kernel_size, bias_attr): def __init__(self, n_feat, kernel_size, bias_attr):
super(SAM, self).__init__() super(SAM, self).__init__()
......
...@@ -91,16 +91,14 @@ class UpsampleConcat(nn.Layer): ...@@ -91,16 +91,14 @@ class UpsampleConcat(nn.Layer):
class SourceReferenceAttention(nn.Layer): class SourceReferenceAttention(nn.Layer):
""" """
Source-Reference Attention Layer Source-Reference Attention Layer
Args:
in_planes_s (int): Number of input source feature vector channels.
in_planes_r (int): Number of input reference feature vector channels.
""" """
def __init__(self, in_planes_s, in_planes_r): def __init__(self, in_planes_s, in_planes_r):
"""
Parameters
----------
in_planes_s: int
Number of input source feature vector channels.
in_planes_r: int
Number of input reference feature vector channels.
"""
super(SourceReferenceAttention, self).__init__() super(SourceReferenceAttention, self).__init__()
self.query_conv = nn.Conv3D(in_channels=in_planes_s, self.query_conv = nn.Conv3D(in_channels=in_planes_s,
out_channels=in_planes_s // 8, out_channels=in_planes_s // 8,
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# code was based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import functools import functools
...@@ -26,6 +28,16 @@ class ResnetGenerator(nn.Layer): ...@@ -26,6 +28,16 @@ class ResnetGenerator(nn.Layer):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
Args:
input_nc (int): the number of channels in input images
output_nc (int): the number of channels in output images
ngf (int): the number of filters in the last conv layer
norm_type (str): the name of the normalization layer: batch | instance | none
use_dropout (bool): if use dropout layers
n_blocks (int): the number of ResNet blocks
padding_type (str): the name of padding layer in conv layers: reflect | replicate | zero
""" """
def __init__(self, def __init__(self,
input_nc, input_nc,
...@@ -35,17 +47,7 @@ class ResnetGenerator(nn.Layer): ...@@ -35,17 +47,7 @@ class ResnetGenerator(nn.Layer):
use_dropout=False, use_dropout=False,
n_blocks=6, n_blocks=6,
padding_type='reflect'): padding_type='reflect'):
"""Construct a Resnet-based generator
Args:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_type (str) -- the name of the normalization layer: batch | instance | none
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert (n_blocks >= 0) assert (n_blocks >= 0)
super(ResnetGenerator, self).__init__() super(ResnetGenerator, self).__init__()
...@@ -133,12 +135,12 @@ class ResnetBlock(nn.Layer): ...@@ -133,12 +135,12 @@ class ResnetBlock(nn.Layer):
use_bias): use_bias):
"""Construct a convolutional block. """Construct a convolutional block.
Parameters: Args:
dim (int) -- the number of channels in the conv layer. dim (int): the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero padding_type (str): the name of padding layer: reflect | replicate | zero.
norm_layer -- normalization layer norm_layer (paddle.nn.Layer): normalization layer.
use_dropout (bool) -- if use dropout layers. use_dropout (bool): whether to use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not use_bias (bool): whether to use the conv layer bias or not.
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
""" """
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # MIT License
# # Copyright (c) 2019 Hyeonwoo Kang
# Licensed under the Apache License, Version 2.0 (the "License"); # code was based on https://github.com/znxlwm/UGATIT-pytorch
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools import functools
import paddle import paddle
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# code was based on https://github.com/xinntao/ESRGAN
import functools import functools
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# code was based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import functools import functools
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -30,17 +31,19 @@ class UnetGenerator(nn.Layer): ...@@ -30,17 +31,19 @@ class UnetGenerator(nn.Layer):
ngf=64, ngf=64,
norm_type='batch', norm_type='batch',
use_dropout=False): use_dropout=False):
"""Construct a Unet generator """
Construct a Unet generator
the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
Args: Args:
input_nc (int) -- the number of channels in input images input_nc (int): the number of channels in input images.
output_nc (int) -- the number of channels in output images output_nc (int): the number of channels in output images.
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, num_downs (int): the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck image of size 128x128 will become of size 1x1 # at the bottleneck.
ngf (int) -- the number of filters in the last conv layer ngf (int): the number of filters in the last conv layer.
norm_layer -- normalization layer norm_type (str): normalization type, default: 'batch'.
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
""" """
super(UnetGenerator, self).__init__() super(UnetGenerator, self).__init__()
norm_layer = build_norm_layer(norm_type) norm_layer = build_norm_layer(norm_type)
...@@ -105,15 +108,15 @@ class UnetSkipConnectionBlock(nn.Layer): ...@@ -105,15 +108,15 @@ class UnetSkipConnectionBlock(nn.Layer):
use_dropout=False): use_dropout=False):
"""Construct a Unet submodule with skip connections. """Construct a Unet submodule with skip connections.
Parameters: Args:
outer_nc (int) -- the number of filters in the outer conv layer outer_nc (int): the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer inner_nc (int): the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features input_nc (int): the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules submodule (UnetSkipConnectionBlock): previously defined submodules
outermost (bool) -- if this module is the outermost module outermost (bool): if this module is the outermost module
innermost (bool) -- if this module is the innermost module innermost (bool): if this module is the innermost module
norm_layer -- normalization layer norm_layer (paddle.nn.Layer): normalization layer
use_dropout (bool) -- if use dropout layers. use_dropout (bool): whether to use dropout layers.
""" """
super(UnetSkipConnectionBlock, self).__init__() super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost self.outermost = outermost
...@@ -173,5 +176,6 @@ class UnetSkipConnectionBlock(nn.Layer): ...@@ -173,5 +176,6 @@ class UnetSkipConnectionBlock(nn.Layer):
def forward(self, x): def forward(self, x):
if self.outermost: if self.outermost:
return self.model(x) return self.model(x)
else: # add skip connections # add skip connections
else:
return paddle.concat([x, self.model(x)], 1) return paddle.concat([x, self.model(x)], 1)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # code was based on torch init module
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
...@@ -325,6 +313,7 @@ def init_weights(net, ...@@ -325,6 +313,7 @@ def init_weights(net,
logger.debug('initialize network with %s' % init_type) logger.debug('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func> net.apply(init_func) # apply the initialization function <init_func>
def reset_parameters(m): def reset_parameters(m):
kaiming_uniform_(m.weight, a=math.sqrt(5)) kaiming_uniform_(m.weight, a=math.sqrt(5))
if m.bias is not None: if m.bias is not None:
......
...@@ -45,46 +45,26 @@ class LinearDecay(LambdaDecay): ...@@ -45,46 +45,26 @@ class LinearDecay(LambdaDecay):
super().__init__(learning_rate, lambda_rule) super().__init__(learning_rate, lambda_rule)
def get_position_from_periods(iteration, cumulative_period):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_period = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_period (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for i, period in enumerate(cumulative_period):
if iteration <= period:
return i
@LRSCHEDULERS.register() @LRSCHEDULERS.register()
class CosineAnnealingRestartLR(LRScheduler): class CosineAnnealingRestartLR(LRScheduler):
""" Cosine annealing with restarts learning rate scheme. """ Cosine annealing with restarts learning rate scheme.
An example of config: An example config from configs/edvr_l_blur_wo_tsa.yaml:
periods = [10, 10, 10, 10] learning_rate: !!float 4e-4
restart_weights = [1, 0.5, 0.5, 0.5] periods: [150000, 150000, 150000, 150000]
eta_min=1e-7 restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the It has four cycles, each has 150000 iterations. At 150000th, 300000th,
scheduler will restart with the weights in restart_weights. 450000th, the scheduler will restart with the weights in restart_weights.
Args: Args:
learning_rate (float|paddle.nn.optimizer): PaddlePaddle optimizer. learning_rate (float): Base learning rate.
periods (list): Period for each cosine anneling cycle. periods (list): Period for each cosine anneling cycle.
restart_weights (list): Restart weights at each restart iteration. restart_weights (list): Restart weights at each restart iteration.
Default: [1]. Default: [1].
eta_min (float): The mimimum lr. Default: 0. eta_min (float): The mimimum learning rate of the cosine anneling cycle. Default: 0.
last_epoch (int): Used in _LRScheduler. Default: -1. last_epoch (int): Used in paddle.nn._LRScheduler. Default: -1.
""" """
def __init__(self, def __init__(self,
learning_rate, learning_rate,
...@@ -104,10 +84,14 @@ class CosineAnnealingRestartLR(LRScheduler): ...@@ -104,10 +84,14 @@ class CosineAnnealingRestartLR(LRScheduler):
last_epoch) last_epoch)
def get_lr(self): def get_lr(self):
idx = get_position_from_periods(self.last_epoch, self.cumulative_period) for i, period in enumerate(self.cumulative_period):
current_weight = self.restart_weights[idx] if self.last_epoch <= period:
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] index = i
current_period = self.periods[idx] break
current_weight = self.restart_weights[index]
nearest_restart = 0 if index == 0 else self.cumulative_period[index - 1]
current_period = self.periods[index]
lr = self.eta_min + current_weight * 0.5 * ( lr = self.eta_min + current_weight * 0.5 * (
self.base_lr - self.eta_min) * (1 + math.cos(math.pi * ( self.base_lr - self.eta_min) * (1 + math.cos(math.pi * (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册