未验证 提交 17bbda7f 编写于 作者: Q qingqing01 提交者: GitHub

Refine and clean code (#1902)

* Refine and clean code
* Fix conflicts
Co-authored-by: NKaipeng Deng <dengkaipeng@baidu.com>
上级 531242be
...@@ -24,6 +24,11 @@ import yaml ...@@ -24,6 +24,11 @@ import yaml
import copy import copy
import collections import collections
try:
collectionsAbc = collections.abc
except AttributeError:
collectionsAbc = collections
from .config.schema import SchemaDict, SharedConfig, extract_schema from .config.schema import SchemaDict, SharedConfig, extract_schema
from .config.yaml_helpers import serializable from .config.yaml_helpers import serializable
...@@ -143,7 +148,7 @@ def dict_merge(dct, merge_dct): ...@@ -143,7 +148,7 @@ def dict_merge(dct, merge_dct):
""" """
for k, v in merge_dct.items(): for k, v in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collections.Mapping)): isinstance(merge_dct[k], collectionsAbc.Mapping)):
dict_merge(dct[k], merge_dct[k]) dict_merge(dct[k], merge_dct[k])
else: else:
dct[k] = merge_dct[k] dct[k] = merge_dct[k]
......
...@@ -14,5 +14,4 @@ ...@@ -14,5 +14,4 @@
from .source import * from .source import *
from .transform import * from .transform import *
from .sampler import *
from .reader import * from .reader import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy import copy
import traceback import traceback
import logging
import threading
import six import six
import sys import sys
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
...@@ -9,13 +21,16 @@ if sys.version_info >= (3, 0): ...@@ -9,13 +21,16 @@ if sys.version_info >= (3, 0):
else: else:
import Queue import Queue
import numpy as np import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from ppdet.core.workspace import register, serializable, create from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
from . import transform from . import transform
from .transform import operator, batch_operator from .transform import operator, batch_operator
logger = logging.getLogger(__name__) from ppdet.utils.logger import setup_logger
logger = setup_logger('reader')
class Compose(object): class Compose(object):
...@@ -118,7 +133,6 @@ class BaseDataLoader(object): ...@@ -118,7 +133,6 @@ class BaseDataLoader(object):
def __call__(self, def __call__(self,
dataset, dataset,
worker_num, worker_num,
device=None,
batch_sampler=None, batch_sampler=None,
return_list=False, return_list=False,
use_prefetch=True): use_prefetch=True):
...@@ -144,7 +158,6 @@ class BaseDataLoader(object): ...@@ -144,7 +158,6 @@ class BaseDataLoader(object):
batch_sampler=self._batch_sampler, batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms, collate_fn=self._batch_transforms,
num_workers=worker_num, num_workers=worker_num,
places=device,
return_list=return_list, return_list=return_list,
use_buffer_reader=use_prefetch, use_buffer_reader=use_prefetch,
use_shared_memory=False) use_shared_memory=False)
......
import os
import sys
import six
import time
import math
import socket
import contextlib
import numpy as np
from paddle.io import BatchSampler
from paddle.distributed import ParallelEnv
_parallel_context_initialized = False
class DistributedBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
self.drop_last = drop_last
self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def set_epoch(self, epoch):
self.epoch = epoch
...@@ -60,7 +60,7 @@ class COCODataSet(DetDataset): ...@@ -60,7 +60,7 @@ class COCODataSet(DetDataset):
if 'annotations' not in coco.dataset: if 'annotations' not in coco.dataset:
self.load_image_only = True self.load_image_only = True
logger.warn('Annotation file: {} does not contains ground truth ' logger.warning('Annotation file: {} does not contains ground truth '
'and load image information only.'.format(anno_path)) 'and load image information only.'.format(anno_path))
for img_id in img_ids: for img_id in img_ids:
...@@ -72,14 +72,14 @@ class COCODataSet(DetDataset): ...@@ -72,14 +72,14 @@ class COCODataSet(DetDataset):
im_path = os.path.join(image_dir, im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname im_fname) if image_dir else im_fname
if not os.path.exists(im_path): if not os.path.exists(im_path):
logger.warn('Illegal image file: {}, and it will be ' logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path)) 'ignored'.format(im_path))
continue continue
if im_w < 0 or im_h < 0: if im_w < 0 or im_h < 0:
logger.warn('Illegal width: {} or height: {} in annotation, ' logger.warning('Illegal width: {} or height: {} in annotation, '
'and im_id: {} will be ignored'.format(im_w, im_h, 'and im_id: {} will be ignored'.format(
img_id)) im_w, im_h, img_id))
continue continue
coco_rec = { coco_rec = {
...@@ -110,7 +110,7 @@ class COCODataSet(DetDataset): ...@@ -110,7 +110,7 @@ class COCODataSet(DetDataset):
inst['clean_bbox'] = [x1, y1, x2, y2] inst['clean_bbox'] = [x1, y1, x2, y2]
bboxes.append(inst) bboxes.append(inst)
else: else:
logger.warn( logger.warning(
'Found an invalid bbox in annotations: im_id: {}, ' 'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format( 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2)) img_id, float(inst['area']), x1, y1, x2, y2))
......
...@@ -40,7 +40,7 @@ class DetDataset(Dataset): ...@@ -40,7 +40,7 @@ class DetDataset(Dataset):
self.image_dir = image_dir if image_dir is not None else '' self.image_dir = image_dir if image_dir is not None else ''
self.sample_num = sample_num self.sample_num = sample_num
self.use_default_label = use_default_label self.use_default_label = use_default_label
self.epoch = 0 self._epoch = 0
def __len__(self, ): def __len__(self, ):
return len(self.roidbs) return len(self.roidbs)
...@@ -48,15 +48,15 @@ class DetDataset(Dataset): ...@@ -48,15 +48,15 @@ class DetDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
# data batch # data batch
roidb = copy.deepcopy(self.roidbs[idx]) roidb = copy.deepcopy(self.roidbs[idx])
if self.mixup_epoch == 0 or self.epoch < self.mixup_epoch: if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
n = len(self.roidbs) n = len(self.roidbs)
idx = np.random.randint(n) idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])] roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.cutmix_epoch == 0 or self.epoch < self.cutmix_epoch: elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
n = len(self.roidbs) n = len(self.roidbs)
idx = np.random.randint(n) idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])] roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.mosaic_epoch == 0 or self.epoch < self.mosaic_epoch: elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
n = len(self.roidbs) n = len(self.roidbs)
roidb = [roidb, ] + [ roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)]) copy.deepcopy(self.roidbs[np.random.randint(n)])
...@@ -76,6 +76,9 @@ class DetDataset(Dataset): ...@@ -76,6 +76,9 @@ class DetDataset(Dataset):
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1) self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1) self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def set_out(self, sample_transform, fields): def set_out(self, sample_transform, fields):
self.transform = sample_transform self.transform = sample_transform
self.fields = fields self.fields = fields
......
...@@ -212,8 +212,11 @@ class BBoxHead(nn.Layer): ...@@ -212,8 +212,11 @@ class BBoxHead(nn.Layer):
def get_loss(self, bbox_head_out, targets): def get_loss(self, bbox_head_out, targets):
loss_bbox = {} loss_bbox = {}
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
for lvl, (bboxhead, target) in enumerate(zip(bbox_head_out, targets)): for lvl, (bboxhead, target) in enumerate(zip(bbox_head_out, targets)):
score, delta = bboxhead score, delta = bboxhead
if len(targets) > 1:
cls_name = 'loss_bbox_cls_{}'.format(lvl) cls_name = 'loss_bbox_cls_{}'.format(lvl)
reg_name = 'loss_bbox_reg_{}'.format(lvl) reg_name = 'loss_bbox_reg_{}'.format(lvl)
loss_bbox_cls, loss_bbox_reg = self._get_head_loss(score, delta, loss_bbox_cls, loss_bbox_reg = self._get_head_loss(score, delta,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os
import sys
from paddle.distributed import ParallelEnv
__all__ = ['setup_logger']
logger_initialized = []
def setup_logger(name="ppdet", output=None):
"""
Initialize logger and set its verbosity level to INFO.
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
logger.setLevel(logging.INFO)
logger.propagate = False
formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%m/%d %H:%M:%S")
# stdout logging: master only
local_rank = ParallelEnv().local_rank
if local_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if local_rank > 0:
filename = filename + ".rank{}".format(local_rank)
os.makedirs(os.path.dirname(filename))
fh = logging.FileHandler(filename, mode='a')
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
logger_initialized.append(name)
return logger
...@@ -16,7 +16,7 @@ import collections ...@@ -16,7 +16,7 @@ import collections
import numpy as np import numpy as np
import datetime import datetime
__all__ = ['TrainingStats', 'Time'] __all__ = ['SmoothedValue', 'TrainingStats']
class SmoothedValue(object): class SmoothedValue(object):
...@@ -24,42 +24,72 @@ class SmoothedValue(object): ...@@ -24,42 +24,72 @@ class SmoothedValue(object):
window or the global series average. window or the global series average.
""" """
def __init__(self, window_size): def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({avg:.4f})"
self.deque = collections.deque(maxlen=window_size) self.deque = collections.deque(maxlen=window_size)
self.fmt = fmt
self.total = 0.
self.count = 0
def add_value(self, value): def update(self, value, n=1):
self.deque.append(value) self.deque.append(value)
self.count += n
self.total += value * n
def get_median_value(self): @property
def median(self):
return np.median(self.deque) return np.median(self.deque)
@property
def avg(self):
return np.mean(self.deque)
def Time(): @property
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') def max(self):
return np.max(self.deque)
@property
def value(self):
return self.deque[-1]
@property
def global_avg(self):
return self.total / self.count
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, max=self.max, value=self.value)
class TrainingStats(object): class TrainingStats(object):
def __init__(self, window_size, stats_keys): def __init__(self, window_size, delimiter=' '):
self.smoothed_losses_and_metrics = { self.meters = None
key: SmoothedValue(window_size) self.window_size = window_size
for key in stats_keys self.delimiter = delimiter
}
def update(self, stats): def update(self, stats):
for k, v in self.smoothed_losses_and_metrics.items(): if self.meters is None:
v.add_value(stats[k].numpy()) self.meters = {
k: SmoothedValue(self.window_size)
for k in stats.keys()
}
for k, v in self.meters.items():
v.update(stats[k].numpy())
def get(self, extras=None): def get(self, extras=None):
stats = collections.OrderedDict() stats = collections.OrderedDict()
if extras: if extras:
for k, v in extras.items(): for k, v in extras.items():
stats[k] = v stats[k] = v
for k, v in self.smoothed_losses_and_metrics.items(): for k, v in self.meters.items():
stats[k] = format(v.get_median_value(), '.6f') stats[k] = format(v.median, '.4f')
return stats return stats
def log(self, extras=None): def log(self, extras=None):
d = self.get(extras) d = self.get(extras)
strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items()) strs = []
return strs for k, v in d.items():
strs.append("{}: {}".format(k, str(v)))
return self.delimiter.join(strs)
...@@ -65,9 +65,9 @@ def _walk_voc_dir(devkit_dir, year, output_dir): ...@@ -65,9 +65,9 @@ def _walk_voc_dir(devkit_dir, year, output_dir):
for _, _, files in os.walk(filelist_dir): for _, _, files in os.walk(filelist_dir):
for fname in files: for fname in files:
img_ann_list = [] img_ann_list = []
if re.match('[a-z]+_trainval\.txt', fname): if re.match(r'[a-z]+_trainval\.txt', fname):
img_ann_list = trainval_list img_ann_list = trainval_list
elif re.match('[a-z]+_test\.txt', fname): elif re.match(r'[a-z]+_test\.txt', fname):
img_ann_list = test_list img_ann_list = test_list
else: else:
continue continue
......
...@@ -21,23 +21,23 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) ...@@ -21,23 +21,23 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path: if parent_path not in sys.path:
sys.path.append(parent_path) sys.path.append(parent_path)
import time
# ignore numba warning # ignore numba warning
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import random import random
import numpy as np import numpy as np
import paddle import paddle
import time
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import get_infer_results, eval_results from ppdet.utils.eval_utils import get_infer_results, eval_results
from ppdet.utils.checkpoint import load_weight from ppdet.utils.checkpoint import load_weight
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' from ppdet.utils.logger import setup_logger
logging.basicConfig(level=logging.INFO, format=FORMAT) logger = setup_logger('eval')
logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
...@@ -69,16 +69,15 @@ def run(FLAGS, cfg, place): ...@@ -69,16 +69,15 @@ def run(FLAGS, cfg, place):
# Data Reader # Data Reader
dataset = cfg.EvalDataset dataset = cfg.EvalDataset
eval_loader = create('EvalReader')(dataset, cfg['worker_num'], place) eval_loader = create('EvalReader')(dataset, cfg['worker_num'])
extra_key = ['im_shape', 'scale_factor', 'im_id'] extra_key = ['im_shape', 'scale_factor', 'im_id']
if cfg.metric == 'VOC': if cfg.metric == 'VOC':
extra_key += ['gt_bbox', 'gt_class', 'difficult'] extra_key += ['gt_bbox', 'gt_class', 'difficult']
# Run Eval # Run Eval
outs_res = [] outs_res = []
start_time = time.time()
sample_num = 0 sample_num = 0
start_time = time.time()
for iter_id, data in enumerate(eval_loader): for iter_id, data in enumerate(eval_loader):
# forward # forward
model.eval() model.eval()
......
...@@ -27,19 +27,20 @@ warnings.filterwarnings('ignore') ...@@ -27,19 +27,20 @@ warnings.filterwarnings('ignore')
import glob import glob
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import paddle import paddle
import paddle.nn as nn
from paddle.static import InputSpec
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_weight from ppdet.utils.checkpoint import load_weight
from export_utils import dump_infer_config from export_utils import dump_infer_config
from paddle.jit import to_static from paddle.jit import to_static
import paddle.nn as nn
from paddle.static import InputSpec from ppdet.utils.logger import setup_logger
import logging logger = setup_logger('eval')
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
......
...@@ -21,10 +21,10 @@ import yaml ...@@ -21,10 +21,10 @@ import yaml
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
import logging from ppdet.utils.logger import setup_logger
logger = logging.getLogger(__name__) logger = setup_logger('export_utils')
__all__ = ['dump_infer_config', 'save_infer_model'] __all__ = ['dump_infer_config']
# Global dictionary # Global dictionary
TRT_MIN_SUBGRAPH = { TRT_MIN_SUBGRAPH = {
......
...@@ -27,6 +27,7 @@ warnings.filterwarnings('ignore') ...@@ -27,6 +27,7 @@ warnings.filterwarnings('ignore')
import glob import glob
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
...@@ -35,10 +36,9 @@ from ppdet.utils.visualizer import visualize_results ...@@ -35,10 +36,9 @@ from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_weight from ppdet.utils.checkpoint import load_weight
from ppdet.utils.eval_utils import get_infer_results from ppdet.utils.eval_utils import get_infer_results
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' from ppdet.utils.logger import setup_logger
logging.basicConfig(level=logging.INFO, format=FORMAT) logger = setup_logger('train')
logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
...@@ -129,8 +129,7 @@ def run(FLAGS, cfg, place): ...@@ -129,8 +129,7 @@ def run(FLAGS, cfg, place):
dataset = cfg.TestDataset dataset = cfg.TestDataset
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
dataset.set_images(test_images) dataset.set_images(test_images)
test_loader = create('TestReader')(dataset, cfg['worker_num'], place) test_loader = create('TestReader')(dataset, cfg['worker_num'])
extra_key = ['im_shape', 'scale_factor', 'im_id'] extra_key = ['im_shape', 'scale_factor', 'im_id']
if cfg.metric == 'VOC': if cfg.metric == 'VOC':
extra_key += ['gt_bbox', 'gt_class', 'difficult'] extra_key += ['gt_bbox', 'gt_class', 'difficult']
......
...@@ -15,38 +15,36 @@ ...@@ -15,38 +15,36 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os, sys import os, sys
# add python path of PadleDetection to sys.path # add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path: if parent_path not in sys.path:
sys.path.append(parent_path) sys.path.append(parent_path)
import time
# ignore numba warning # ignore numba warning
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import random import random
import datetime import datetime
import time
import numpy as np import numpy as np
from collections import deque
import paddle import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
from export_model import dygraph_to_static
from paddle.distributed import ParallelEnv
import logging import ppdet.utils.cli as cli
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' import ppdet.utils.check as check
logging.basicConfig(level=logging.INFO, format=FORMAT) import ppdet.utils.stats as stats
logger = logging.getLogger(__name__) from ppdet.utils.logger import setup_logger
logger = setup_logger('train')
def parse_args(): def parse_args():
parser = ArgsParser() parser = cli.ArgsParser()
parser.add_argument( parser.add_argument(
"--weight_type", "--weight_type",
default='pretrain', default='pretrain',
...@@ -102,15 +100,15 @@ def run(FLAGS, cfg, place): ...@@ -102,15 +100,15 @@ def run(FLAGS, cfg, place):
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
# Data # Data
dataset = cfg.TrainDataset datasets = cfg.TrainDataset
train_loader = create('TrainReader')(dataset, cfg['worker_num'], place) train_loader = create('TrainReader')(datasets, cfg['worker_num'])
step_per_epoch = len(train_loader) steps = len(train_loader)
# Model # Model
model = create(cfg.architecture) model = create(cfg.architecture)
# Optimizer # Optimizer
lr = create('LearningRate')(step_per_epoch) lr = create('LearningRate')(steps)
optimizer = create('OptimizerBuilder')(lr, model.parameters()) optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer # Init Model & Optimzer
...@@ -136,56 +134,70 @@ def run(FLAGS, cfg, place): ...@@ -136,56 +134,70 @@ def run(FLAGS, cfg, place):
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name) save_dir = os.path.join(cfg.save_dir, cfg_name)
# Run Train # Run Train
time_stat = deque(maxlen=cfg.log_iter) end_epoch = int(cfg.epoch)
start_time = time.time() batch_size = int(cfg['TrainReader']['batch_size'])
total_steps = (end_epoch - start_epoch) * steps
step_id = 0
train_stats = stats.TrainingStats(cfg.log_iter)
batch_time = stats.SmoothedValue(fmt='{avg:.4f}')
data_time = stats.SmoothedValue(fmt='{avg:.4f}')
end_time = time.time() end_time = time.time()
space_fmt = ':' + str(len(str(steps))) + 'd'
# Run Train # Run Train
for cur_eid in range(start_epoch, int(cfg.epoch)): for cur_eid in range(start_epoch, end_epoch):
train_loader.dataset.epoch = cur_eid datasets.set_epoch(cur_eid)
for iter_id, data in enumerate(train_loader): for iter_id, data in enumerate(train_loader):
start_time = end_time data_time.update(time.time() - end_time)
end_time = time.time()
time_stat.append(end_time - start_time)
time_cost = np.mean(time_stat)
eta_sec = (
(cfg.epoch - cur_eid) * step_per_epoch - iter_id) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec)))
# Model Forward # Model Forward
model.train() model.train()
outputs = model(data, mode='train') outputs = model(data, mode='train')
# Model Backward
loss = outputs['loss'] loss = outputs['loss']
# Model Backward
loss.backward() loss.backward()
optimizer.step() optimizer.step()
curr_lr = optimizer.get_lr() curr_lr = optimizer.get_lr()
lr.step() lr.step()
optimizer.clear_grad() optimizer.clear_grad()
batch_time.update(time.time() - end_time)
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state
if cur_eid == start_epoch and iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs) train_stats.update(outputs)
logs = train_stats.log() logs = train_stats.log()
if iter_id % cfg.log_iter == 0: if iter_id % cfg.log_iter == 0:
ips = float(cfg['TrainReader']['batch_size']) / time_cost eta_sec = (total_steps - step_id) * batch_time.global_avg
strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'.format( eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
cur_eid, iter_id, curr_lr, logs, eta, time_cost, ips) ips = float(batch_size) / batch_time.avg
logger.info(strs) fmt = ' '.join([
'Epoch: [{}]',
'[{' + space_fmt + '}/{}]',
'{meters}',
'eta: {eta}',
'batch_cost: {btime}',
'data_cost: {dtime}',
'ips: {ips:.4f} images/s',
])
fmt = fmt.format(
cur_eid,
iter_id,
steps,
meters=logs,
eta=eta_str,
btime=str(batch_time),
dtime=str(data_time),
ips=ips)
logger.info(fmt)
step_id += 1
end_time = time.time() # after copy outputs to CPU.
# Save Stage # Save Stage
if ParallelEnv().local_rank == 0 and ( if (ParallelEnv().local_rank == 0 and \
cur_eid % cfg.snapshot_epoch == 0 or (cur_eid % cfg.snapshot_epoch) == 0) or (cur_eid + 1) == end_epoch:
(cur_eid + 1) == int(cfg.epoch)): save_name = str(
save_name = str(cur_eid) if cur_eid + 1 != int( cur_eid) if cur_eid + 1 != end_epoch else "model_final"
cfg.epoch) else "model_final"
save_model(model, optimizer, save_dir, save_name, cur_eid + 1) save_model(model, optimizer, save_dir, save_name, cur_eid + 1)
# TODO(guanghua): dygraph model to static model
# if ParallelEnv().local_rank == 0 and (cur_eid + 1) == int(cfg.epoch)):
# dygraph_to_static(model, os.path.join(save_dir, 'static_model_final'), cfg)
def main(): def main():
...@@ -193,9 +205,9 @@ def main(): ...@@ -193,9 +205,9 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
check_config(cfg) check.check_config(cfg)
check_gpu(cfg.use_gpu) check.check_gpu(cfg.use_gpu)
check_version() check.check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place) place = paddle.set_device(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册