提交 452f5321 编写于 作者: D dongshuilong

add CompCars train and fix bugs

上级 bba0cf8f
...@@ -50,8 +50,7 @@ class RecModel(nn.Layer): ...@@ -50,8 +50,7 @@ class RecModel(nn.Layer):
self.backbone.stop_after(stop_layer_config["name"]) self.backbone.stop_after(stop_layer_config["name"])
if stop_layer_config.get("embedding_size", 0) > 0: if stop_layer_config.get("embedding_size", 0) > 0:
# self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"]) self.neck = nn.Linear(stop_layer_config["output_dim"],
self.neck = nn.Conv2D(stop_layer_config["output_dim"],
stop_layer_config["embedding_size"]) stop_layer_config["embedding_size"])
embedding_size = stop_layer_config["embedding_size"] embedding_size = stop_layer_config["embedding_size"]
else: else:
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
import sys
import copy import copy
import sys
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -69,6 +69,9 @@ class Topk(nn.Layer): ...@@ -69,6 +69,9 @@ class Topk(nn.Layer):
self.topk = topk self.topk = topk
def forward(self, x, label): def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
metric_dict = dict() metric_dict = dict()
for k in self.topk: for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy( metric_dict["top{}".format(k)] = paddle.metric.accuracy(
......
...@@ -16,12 +16,15 @@ Global: ...@@ -16,12 +16,15 @@ Global:
save_inference_dir: "./inference" save_inference_dir: "./inference"
# model architecture # model architecture
RecModel: Arch:
Backbone: "ResNet50" name: "RecModel"
Stoplayer: "adaptive_avg_pool2d_0" Backbone:
name: "ResNet50"
Stoplayer:
name: "flatten_0"
output_dim: 2048
embedding_size: 512 embedding_size: 512
Head:
Head:
name: "ArcMargin" name: "ArcMargin"
embedding_size: 512 embedding_size: 512
class_num: 431 class_num: 431
...@@ -43,7 +46,7 @@ Optimizer: ...@@ -43,7 +46,7 @@ Optimizer:
lr: lr:
name: MultiStepDecay name: MultiStepDecay
learning_rate: 0.01 learning_rate: 0.01
decay_epochs: [30, 60, 70, 80, 90, 100, 120, 140] milestones: [30, 60, 70, 80, 90, 100, 120, 140]
gamma: 0.5 gamma: 0.5
verbose: False verbose: False
last_epoch: -1 last_epoch: -1
...@@ -82,7 +85,7 @@ DataLoader: ...@@ -82,7 +85,7 @@ DataLoader:
sampler: sampler:
name: DistributedRandomIdentitySampler name: DistributedRandomIdentitySampler
batch_size: 128 batch_size: 64
num_instances: 2 num_instances: 2
drop_last: False drop_last: False
shuffle: True shuffle: True
......
...@@ -55,6 +55,14 @@ class Trainer(object): ...@@ -55,6 +55,14 @@ class Trainer(object):
"distributed"] = paddle.distributed.get_world_size() != 1 "distributed"] = paddle.distributed.get_world_size() != 1
if self.config["Global"]["distributed"]: if self.config["Global"]["distributed"]:
dist.init_parallel_env() dist.init_parallel_env()
if "Head" in self.config["Arch"]:
self.config["Arch"]["Head"]["class_num"] = self.config["Global"][
"class_num"]
self.is_rec = True
else:
self.is_rec = False
self.model = build_model(self.config["Arch"]) self.model = build_model(self.config["Arch"])
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
...@@ -143,7 +151,10 @@ class Trainer(object): ...@@ -143,7 +151,10 @@ class Trainer(object):
.reshape([-1, 1])) .reshape([-1, 1]))
global_step += 1 global_step += 1
# image input # image input
if not self.is_rec:
out = self.model(batch[0]) out = self.model(batch[0])
else:
out = self.model(batch[0], batch[1])
# calc loss # calc loss
loss_dict = loss_func(out, batch[-1]) loss_dict = loss_func(out, batch[-1])
for key in loss_dict: for key in loss_dict:
......
import copy import copy
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from ppcls.utils import logger
from .celoss import CELoss from .celoss import CELoss
from .centerloss import CenterLoss
from .triplet import TripletLoss, TripletLossV2
from .msmloss import MSMLoss
from .emlloss import EmlLoss from .emlloss import EmlLoss
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss from .trihardloss import TriHardLoss
from .centerloss import CenterLoss from .triplet import TripletLoss, TripletLossV2
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
...@@ -39,6 +41,7 @@ class CombinedLoss(nn.Layer): ...@@ -39,6 +41,7 @@ class CombinedLoss(nn.Layer):
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict return loss_dict
def build_loss(config): def build_loss(config):
module_class = CombinedLoss(config) module_class = CombinedLoss(config)
logger.info("build loss {} success.".format(module_class)) logger.info("build loss {} success.".format(module_class))
......
...@@ -31,7 +31,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -31,7 +31,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch}) lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
if 'name' in lr_config: if 'name' in lr_config:
lr_name = lr_config.pop('name') lr_name = lr_config.pop('name')
lr = getattr(learning_rate, lr_name)(**lr_config)() lr = getattr(learning_rate, lr_name)(**lr_config)
if isinstance(lr, paddle.optimizer.lr.LRScheduler):
return lr
else:
return lr()
else: else:
lr = lr_config['learning_rate'] lr = lr_config['learning_rate']
return lr return lr
......
...@@ -11,11 +11,11 @@ ...@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import (absolute_import, division, print_function,
from __future__ import division unicode_literals)
from __future__ import print_function
from __future__ import unicode_literals
from paddle.optimizer import lr from paddle.optimizer import lr
from paddle.optimizer.lr import LRScheduler
class Linear(object): class Linear(object):
...@@ -181,3 +181,104 @@ class Piecewise(object): ...@@ -181,3 +181,104 @@ class Piecewise(object):
end_lr=self.values[0], end_lr=self.values[0],
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class MultiStepDecay(LRScheduler):
"""
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
milestones = [30, 50]
gamma = 0.1
if epoch < 30:
learning_rate = 0.5
elif epoch < 50:
learning_rate = 0.05
else:
learning_rate = 0.005
Args:
learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiStepDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(20):
for batch_id in range(5):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # If you update learning rate each step
# scheduler.step() # If you update learning rate each epoch
# train on static graph mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[None, 4, 5])
y = paddle.static.data(name='y', shape=[None, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(5):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=loss.name)
scheduler.step() # If you update learning rate each step
# scheduler.step() # If you update learning rate each epoch
"""
def __init__(self,
learning_rate,
milestones,
epochs,
step_each_epoch,
gamma=0.1,
last_epoch=-1,
verbose=False):
if not isinstance(milestones, (tuple, list)):
raise TypeError(
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
% type(milestones))
if not all([
milestones[i] < milestones[i + 1]
for i in range(len(milestones) - 1)
]):
raise ValueError('The elements of milestones must be incremented')
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
self.milestones = [x * step_each_epoch for x in milestones]
self.gamma = gamma
super(MultiStepDecay, self).__init__(learning_rate, last_epoch,
verbose)
def get_lr(self):
for i in range(len(self.milestones)):
if self.last_epoch < self.milestones[i]:
return self.base_lr * (self.gamma**i)
return self.base_lr * (self.gamma**len(self.milestones))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册