未验证 提交 83714347 编写于 作者: K Kaipeng Deng 提交者: GitHub

polish weights (#2376)

* unity weights
上级 abe1ffe4
_BASE_: [ _BASE_: [
'cascade_rcnn_dcn_r50_fpn_1x_coco.yml', 'cascade_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco/model_final weights: output/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_dcn_r50_fpn_1x_coco.yml', 'faster_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/faster_rcnn_dcn_r101_vd_fpn_1x_coco/model_final weights: output/faster_rcnn_dcn_r101_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_dcn_r50_fpn_1x_coco.yml', 'faster_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/faster_rcnn_dcn_r50_vd_fpn_2x_coco/model_final weights: output/faster_rcnn_dcn_r50_vd_fpn_2x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_dcn_r50_fpn_1x_coco.yml', 'faster_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/faster_rcnn_dcn_r50_vd_fpn_2x_coco/model_final weights: output/faster_rcnn_dcn_r50_vd_fpn_2x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_dcn_r50_fpn_1x_coco.yml', 'faster_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/faster_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco/model_final weights: output/faster_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'mask_rcnn_dcn_r50_fpn_1x_coco.yml', 'mask_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/mask_rcnn_dcn_r101_vd_fpn_1x_coco/model_final weights: output/mask_rcnn_dcn_r101_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'mask_rcnn_dcn_r50_fpn_1x_coco.yml', 'mask_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/mask_rcnn_dcn_r50_vd_fpn_2x_coco/model_final weights: output/mask_rcnn_dcn_r50_vd_fpn_2x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'mask_rcnn_dcn_r50_fpn_1x_coco.yml', 'mask_rcnn_dcn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/mask_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco/model_final weights: output/mask_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_1x_coco.yml', 'faster_rcnn_r50_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams
weights: output/faster_rcnn_r101_1x_coco/model_final weights: output/faster_rcnn_r101_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams
weights: output/faster_rcnn_r101_fpn_1x_coco/model_final weights: output/faster_rcnn_r101_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams
weights: output/faster_rcnn_r101_fpn_2x_coco/model_final weights: output/faster_rcnn_r101_fpn_2x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/faster_rcnn_r101_vd_fpn_1x_coco/model_final weights: output/faster_rcnn_r101_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/faster_rcnn_r101_vd_fpn_2x_coco/model_final weights: output/faster_rcnn_r101_vd_fpn_2x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet34_pretrained.pdparams
weights: output/faster_rcnn_r34_fpn_1x_coco/model_final weights: output/faster_rcnn_r34_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet34_vd_pretrained.pdparams
weights: output/faster_rcnn_r34_vd_fpn_1x_coco/model_final weights: output/faster_rcnn_r34_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_r50_1x_coco.yml', 'faster_rcnn_r50_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/faster_rcnn_r50_vd_1x_coco/model_final weights: output/faster_rcnn_r50_vd_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/faster_rcnn_r50_vd_fpn_1x_coco/model_final weights: output/faster_rcnn_r50_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/faster_rcnn_r50_vd_fpn_2x_coco/model_final weights: output/faster_rcnn_r50_vd_fpn_2x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/faster_rcnn_x101_vd_64x4d_fpn_1x_coco/model_final weights: output/faster_rcnn_x101_vd_64x4d_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'faster_rcnn_r50_fpn_1x_coco.yml', 'faster_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/faster_rcnn_x101_vd_64x4d_fpn_2x_coco/model_final weights: output/faster_rcnn_x101_vd_64x4d_fpn_2x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'mask_rcnn_r50_fpn_1x_coco.yml', 'mask_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams
weights: output/mask_rcnn_r101_fpn_1x_coco/model_final weights: output/mask_rcnn_r101_fpn_1x_coco/model_final
ResNet: ResNet:
......
_BASE_: [ _BASE_: [
'mask_rcnn_r50_fpn_1x_coco.yml', 'mask_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/mask_rcnn_r101_vd_fpn_1x_coco/model_final weights: output/mask_rcnn_r101_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'mask_rcnn_r50_fpn_1x_coco.yml', 'mask_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/mask_rcnn_r50_vd_fpn_1x_coco/model_final weights: output/mask_rcnn_r50_vd_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'mask_rcnn_r50_fpn_1x_coco.yml', 'mask_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams
weights: output/mask_rcnn_r50_vd_fpn_2x_coco/model_final weights: output/mask_rcnn_r50_vd_fpn_2x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'mask_rcnn_r50_fpn_1x_coco.yml', 'mask_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/mask_rcnn_x101_vd_64x4d_fpn_1x_coco/model_final weights: output/mask_rcnn_x101_vd_64x4d_fpn_1x_coco/model_final
ResNet: ResNet:
......
...@@ -2,7 +2,7 @@ _BASE_: [ ...@@ -2,7 +2,7 @@ _BASE_: [
'mask_rcnn_r50_fpn_1x_coco.yml', 'mask_rcnn_r50_fpn_1x_coco.yml',
] ]
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt101_vd_64x4d_pretrained.pdparams
weights: output/mask_rcnn_x101_vd_64x4d_fpn_2x_coco/model_final weights: output/mask_rcnn_x101_vd_64x4d_fpn_2x_coco/model_final
ResNet: ResNet:
......
# Weights of yolov3_mobilenet_v1_voc # Weights of yolov3_mobilenet_v1_voc
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_voc.pdparams
weight_type: resume weight_type: resume
slim: Pruner slim: Pruner
......
# Weights of yolov3_mobilenet_v1_voc # Weights of yolov3_mobilenet_v1_voc
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_voc.pdparams
weight_type: resume weight_type: resume
slim: Pruner slim: Pruner
......
# Weights of yolov3_mobilenet_v1_coco # Weights of yolov3_mobilenet_v1_coco
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams
weight_type: resume weight_type: resume
slim: QAT slim: QAT
......
# Weights of yolov3_mobilenet_v3_coco # Weights of yolov3_mobilenet_v3_coco
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v3_large_270e_coco.pdparams
weight_type: resume weight_type: resume
slim: QAT slim: QAT
......
architecture: SSD architecture: SSD
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/VGG16_caffe_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/VGG16_caffe_pretrained.pdparams
# Model Achitecture # Model Achitecture
SSD: SSD:
......
...@@ -4,7 +4,7 @@ _BASE_: [ ...@@ -4,7 +4,7 @@ _BASE_: [
'_base_/yolov3_mobilenet_v1.yml', '_base_/yolov3_mobilenet_v1.yml',
'_base_/yolov3_reader.yml', '_base_/yolov3_reader.yml',
] ]
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams
norm_type: sync_bn norm_type: sync_bn
weights: output/yolov3_mobilenet_v1_roadsign/model_final weights: output/yolov3_mobilenet_v1_roadsign/model_final
metric: VOC metric: VOC
......
...@@ -58,7 +58,7 @@ class Trainer(object): ...@@ -58,7 +58,7 @@ class Trainer(object):
# model slim build # model slim build
if 'slim' in cfg and cfg.slim: if 'slim' in cfg and cfg.slim:
if self.mode == 'train': if self.mode == 'train':
self.load_weights(cfg.pretrain_weights, cfg.weight_type) self.load_weights(cfg.pretrain_weights)
self.slim = create(cfg.slim) self.slim = create(cfg.slim)
self.slim(self.model) self.slim(self.model)
...@@ -174,17 +174,14 @@ class Trainer(object): ...@@ -174,17 +174,14 @@ class Trainer(object):
"metrics shoule be instances of subclass of Metric" "metrics shoule be instances of subclass of Metric"
self._metrics.extend(metrics) self._metrics.extend(metrics)
def load_weights(self, weights, weight_type='pretrain'): def load_weights(self, weights):
assert weight_type in ['pretrain', 'resume', 'finetune'], \ self.start_epoch = 0
"weight_type can only be 'pretrain', 'resume', 'finetune'" load_pretrain_weight(self.model, weights)
if weight_type == 'resume': logger.debug("Load weights {} to start training".format(weights))
def resume_weights(self, weights):
self.start_epoch = load_weight(self.model, weights, self.optimizer) self.start_epoch = load_weight(self.model, weights, self.optimizer)
logger.debug("Resume weights of epoch {}".format(self.start_epoch)) logger.debug("Resume weights of epoch {}".format(self.start_epoch))
else:
self.start_epoch = 0
load_pretrain_weight(self.model, weights, weight_type)
logger.debug("Load {} weights {} to start training".format(
weight_type, weights))
def train(self, validate=False): def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
......
...@@ -123,8 +123,7 @@ def load_weight(model, weight, optimizer=None): ...@@ -123,8 +123,7 @@ def load_weight(model, weight, optimizer=None):
assert incorrect_keys == 0, "Load weight {} incorrectly, \ assert incorrect_keys == 0, "Load weight {} incorrectly, \
{} keys unmatched, please check again.".format(weight, {} keys unmatched, please check again.".format(weight,
incorrect_keys) incorrect_keys)
logger.info('Finish loading model weight parameter: {}'.format( logger.info('Finish resuming model weights: {}'.format(pdparam_path))
pdparam_path))
model.set_dict(model_weight) model.set_dict(model_weight)
...@@ -142,9 +141,7 @@ def load_weight(model, weight, optimizer=None): ...@@ -142,9 +141,7 @@ def load_weight(model, weight, optimizer=None):
return last_epoch return last_epoch
def load_pretrain_weight(model, pretrain_weight, weight_type='pretrain'): def load_pretrain_weight(model, pretrain_weight):
assert weight_type in ['pretrain', 'finetune']
if is_url(pretrain_weight): if is_url(pretrain_weight):
pretrain_weight = get_weights_path_dist(pretrain_weight) pretrain_weight = get_weights_path_dist(pretrain_weight)
...@@ -158,24 +155,27 @@ def load_pretrain_weight(model, pretrain_weight, weight_type='pretrain'): ...@@ -158,24 +155,27 @@ def load_pretrain_weight(model, pretrain_weight, weight_type='pretrain'):
model_dict = model.state_dict() model_dict = model.state_dict()
param_state_dict = paddle.load(path + '.pdparams') weights_path = path + '.pdparams'
if weight_type == 'pretrain': param_state_dict = paddle.load(weights_path)
model.backbone.set_dict(param_state_dict)
else:
ignore_set = set() ignore_set = set()
lack_modules = set()
for name, weight in model_dict.items(): for name, weight in model_dict.items():
if name in param_state_dict.keys(): if name in param_state_dict.keys():
if weight.shape != list(param_state_dict[name].shape): if weight.shape != list(param_state_dict[name].shape):
logger.info( logger.info(
'{} not used, shape {} unmatched with {} in model.'. '{} not used, shape {} unmatched with {} in model.'.format(
format(name, name, list(param_state_dict[name].shape), weight.shape))
list(param_state_dict[name].shape),
weight.shape))
param_state_dict.pop(name, None) param_state_dict.pop(name, None)
else: else:
logger.info('Lack weight: {}'.format(name)) lack_modules.add(name.split('.')[0])
logger.debug('Lack weights: {}'.format(name))
if len(lack_modules) > 0:
logger.info('Lack weights of modules: {}'.format(', '.join(
list(lack_modules))))
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
return logger.info('Finish loading model weights: {}'.format(weights_path))
def save_model(model, optimizer, save_dir, save_name, last_epoch): def save_model(model, optimizer, save_dir, save_name, last_epoch):
......
...@@ -92,7 +92,7 @@ DATASETS = { ...@@ -92,7 +92,7 @@ DATASETS = {
DOWNLOAD_RETRY_LIMIT = 3 DOWNLOAD_RETRY_LIMIT = 3
PPDET_WEIGHTS_DOWNLOAD_URL_PREFIX = 'https://paddlemodels.bj.bcebos.com/object_detection/' PPDET_WEIGHTS_DOWNLOAD_URL_PREFIX = 'https://paddledet.bj.bcebos.com/'
def parse_url(url): def parse_url(url):
......
...@@ -93,7 +93,7 @@ def run(FLAGS, cfg): ...@@ -93,7 +93,7 @@ def run(FLAGS, cfg):
trainer = Trainer(cfg, mode='eval') trainer = Trainer(cfg, mode='eval')
# load weights # load weights
trainer.load_weights(cfg.weights, 'resume') trainer.load_weights(cfg.weights)
# training # training
trainer.evaluate() trainer.evaluate()
......
...@@ -62,7 +62,7 @@ def run(FLAGS, cfg): ...@@ -62,7 +62,7 @@ def run(FLAGS, cfg):
trainer = Trainer(cfg, mode='test') trainer = Trainer(cfg, mode='test')
# load weights # load weights
trainer.load_weights(cfg.weights, 'resume') trainer.load_weights(cfg.weights)
# export model # export model
trainer.export(FLAGS.output_dir) trainer.export(FLAGS.output_dir)
......
...@@ -114,7 +114,7 @@ def run(FLAGS, cfg): ...@@ -114,7 +114,7 @@ def run(FLAGS, cfg):
trainer = Trainer(cfg, mode='test') trainer = Trainer(cfg, mode='test')
# load weights # load weights
trainer.load_weights(cfg.weights, 'resume') trainer.load_weights(cfg.weights)
# get inference images # get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
......
...@@ -43,17 +43,13 @@ logger = setup_logger('train') ...@@ -43,17 +43,13 @@ logger = setup_logger('train')
def parse_args(): def parse_args():
parser = cli.ArgsParser() parser = cli.ArgsParser()
parser.add_argument(
"--weight_type",
default='pretrain',
type=str,
help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'."
)
parser.add_argument( parser.add_argument(
"--eval", "--eval",
action='store_true', action='store_true',
default=False, default=False,
help="Whether to perform evaluation in train") help="Whether to perform evaluation in train")
parser.add_argument(
"-r", "--resume", default=None, help="weights path for resume")
parser.add_argument( parser.add_argument(
"--slim_config", "--slim_config",
default=None, default=None,
...@@ -101,8 +97,10 @@ def run(FLAGS, cfg): ...@@ -101,8 +97,10 @@ def run(FLAGS, cfg):
trainer = Trainer(cfg, mode='train') trainer = Trainer(cfg, mode='train')
# load weights # load weights
if not FLAGS.slim_config and 'pretrain_weights' in cfg and cfg.pretrain_weights: if FLAGS.resume is not None:
trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type) trainer.resume_weights(FLAGS.resume)
elif not FLAGS.slim_config and 'pretrain_weights' in cfg and cfg.pretrain_weights:
trainer.load_weights(cfg.pretrain_weights)
# training # training
trainer.train(FLAGS.eval) trainer.train(FLAGS.eval)
...@@ -120,8 +118,6 @@ def main(): ...@@ -120,8 +118,6 @@ def main():
if FLAGS.slim_config: if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config) slim_cfg = load_config(FLAGS.slim_config)
merge_config(slim_cfg) merge_config(slim_cfg)
if 'weight_type' not in cfg:
cfg.weight_type = FLAGS.weight_type
check.check_config(cfg) check.check_config(cfg)
check.check_gpu(cfg.use_gpu) check.check_gpu(cfg.use_gpu)
check.check_version() check.check_version()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册