未验证 提交 48e21f3c 编写于 作者: G Guanghua Yu 提交者: GitHub

update reader in dygraph (#1687)

* update reader in dygraph

* update dataloader some detail

* fix num_classes with background
上级 fab94925
max_iters: 180000
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [120000, 160000]
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
......
max_iters: 500000
epoch: 270
LearningRate:
base_lr: 0.001
......@@ -6,8 +6,8 @@ LearningRate:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 400000
- 450000
- 216
- 243
- !LinearWarmup
start_factor: 0.
steps: 4000
......
worker_num: 2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_mask']
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_poly']
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
is_mask_flip: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
- DecodeImage: {to_rgb: true}
- RandomFlipImage: {prob: 0.5, is_mask_flip: true}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
- Permute: {to_bgr: false, channel_first: true}
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
pad_gt: True
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true}
batch_size: 1
shuffle: true
worker_num: 2
drop_last: false
use_process: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
fields: ['image', 'im_info', 'im_id']
sample_transforms:
- !DecodeImage
to_rgb: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
- DecodeImage: {to_rgb: true}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true}
- Permute: {channel_first: true, to_bgr: false}
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
pad_gt: True
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: false}
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
fields: ['image', 'im_info', 'im_id']
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
- DecodeImage: {to_rgb: true, with_mixup: false}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true}
- Permute: {channel_first: true, to_bgr: false}
batch_size: 1
shuffle: false
drop_last: false
worker_num: 2
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
- DecodeImage: {to_rgb: True, with_mixup: True}
- MixupImage: {alpha: 1.5, beta: 1.5}
- ColorDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlipImage: {is_normalized: false}
- NormalizeBox: {}
- PadBox: {num_max_boxes: 50}
- BboxXYXY2XYWH: {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
- RandomShape: {sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_inter: True}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True, is_channel_first: false}
- Permute: {to_bgr: false, channel_first: True}
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
- Gt2YoloTarget: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], downsample_ratios: [32, 16, 8]}
batch_size: 8
shuffle: true
mixup_epoch: 250
drop_last: true
worker_num: 4
bufsize: 4
use_process: true
EvalReader:
......@@ -54,42 +31,21 @@ EvalReader:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
- DecodeImage: {to_rgb: True}
- ResizeImage: {target_size: 608, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True, is_channel_first: false}
- PadBox: {num_max_boxes: 50}
- Permute: {to_bgr: false, channel_first: True}
batch_size: 1
drop_empty: false
worker_num: 8
bufsize: 16
TestReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id']
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
- DecodeImage: {to_rgb: True}
- ResizeImage: {target_size: 608, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True, is_channel_first: false}
- Permute: {to_bgr: false, channel_first: True}
batch_size: 1
use_gpu: true
log_iter: 50
save_dir: output
snapshot_iter: 10000
snapshot_epoch: 2
......@@ -104,6 +104,7 @@ def _parse_with_background():
global_config['TrainReader']['with_background'] = with_background
global_config['EvalReader']['with_background'] = with_background
global_config['TestReader']['with_background'] = with_background
global_config['num_classes'] += with_background
def load_config(file_path):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .reader import *
from .source import *
from .transform import *
from .sampler import *
from .reader import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# transform samples in 'source' using 'worker'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import six
if six.PY3:
from queue import Empty
else:
from Queue import Empty
import uuid
import logging
import signal
import threading
import traceback
logger = logging.getLogger(__name__)
main_pid = os.getpid()
worker_set = set()
class EndSignal(object):
""" signal used to notify worker to exit
"""
def __init__(self, id, errno=0, errmsg=''):
self.id = id
self.errno = errno
self.errmsg = errmsg
class ParallelMap(object):
"""
Transform samples to mapped samples which is similar to
'basic.MappedDataset', but multiple workers (threads or processes)
will be used
Notes:
this class is not thread-safe
"""
def __init__(self,
source,
worker,
worker_num,
bufsize=100,
use_process=False,
memsize='3G'):
self._worker_num = worker_num
self._bufsize = bufsize
self._use_process = use_process
if self._use_process and sys.platform == "win32":
logger.debug("Use multi-thread reader instead of "
"multi-process reader on Windows.")
self._use_process = False
if self._use_process and type(memsize) is str:
assert memsize[-1].lower() in ['g', 'm'], \
"invalid param for memsize[%s], should be " \
"ended with 'G' or 'g' or 'M' or 'm'" % (memsize)
power = 3 if memsize[-1].lower() == 'g' else 2
self._memsize = int(memsize[:-1]) * (1024**power)
self._started = False
self._source = source
self._worker = worker
self._exit = False
self._setup()
self._souce_drained = False
def __iter__(self):
return self
def __next__(self):
return self.next()
def _setup(self):
"""setup input/output queues and workers """
use_process = self._use_process
bufsize = self._bufsize
if use_process:
from .shared_queue import SharedQueue as Queue
from multiprocessing import Process as Worker
from multiprocessing import Event
memsize = self._memsize
self._inq = Queue(bufsize, memsize=memsize)
self._outq = Queue(bufsize, memsize=memsize)
else:
if six.PY3:
from queue import Queue
else:
from Queue import Queue
from threading import Thread as Worker
from threading import Event
self._inq = Queue(bufsize)
self._outq = Queue(bufsize)
consumer_num = self._worker_num
id = str(uuid.uuid4())[-3:]
self._producer = threading.Thread(
target=self._produce,
args=('producer-' + id, self._source, self._inq))
self._producer.daemon = True
self._consumers = []
self._consumer_endsig = {}
global worker_set
for i in range(consumer_num):
consumer_id = 'consumer-' + id + '-' + str(i)
p = Worker(
target=self._consume,
args=(consumer_id, self._inq, self._outq, self._worker))
self._consumers.append(p)
p.daemon = True
setattr(p, 'id', consumer_id)
if use_process:
worker_set.add(p)
self._epoch = -1
self._feeding_ev = Event()
self._produced = 0 # produced sample in self._produce
self._consumed = 0 # consumed sample in self.next
def _produce(self, id, source, inq):
"""Fetch data from source and feed it to 'inq' queue"""
endsig = EndSignal(id)
while True:
self._feeding_ev.wait()
if self._exit:
break
try:
s = source.next()
inq.put(s)
self._produced += 1
except StopIteration:
self._souce_drained = True
self._feeding_ev.clear()
self._feeding_ev.wait()
except Exception as e:
endsig.errno = -1
endsig.errmsg = "producer[{}] failed with error: {}" \
.format(id, str(e))
inq.put(endsig)
break
def _consume(self, id, inq, outq, worker):
"""Fetch data from 'inq', process it and put result to 'outq'"""
if self._use_process:
# handle SIGTERM signal to exit to prevent print stack frame
signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
endsig = EndSignal(id)
while True:
sample = inq.get()
if isinstance(sample, EndSignal):
endsig.errno = sample.errno
endsig.errmsg = "consumer[{}] exits for reason[{}]" \
.format(id, sample.errmsg)
outq.put(endsig)
break
try:
result = worker(sample)
outq.put(result)
except Exception as e:
endsig.errno = -2
endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \
.format(id, str(e))
outq.put(endsig)
break
def drained(self):
assert self._epoch >= 0, "first epoch has not started yet"
return self._source.drained() and self._produced == self._consumed
def stop(self):
""" notify to exit
"""
self._exit = True
self._feeding_ev.set()
for _ in range(len(self._consumers)):
self._inq.put(EndSignal(0, "notify consumers to exit"))
def _consumer_healthy(self):
abnormal_num = 0
for w in self._consumers:
if not w.is_alive() and w.id not in self._consumer_endsig:
abnormal_num += 1
if self._use_process:
errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \
.format(w.pid, w.exitcode)
else:
errmsg = "consumer[{}] exit abnormally".format(w.ident)
logger.warn(errmsg)
if abnormal_num > 0:
logger.warn("{} consumers have exited abnormally!!!" \
.format(abnormal_num))
return abnormal_num == 0
def next(self):
""" get next transformed sample
"""
if self._epoch < 0:
self.reset()
if self.drained():
raise StopIteration()
while not self._exit:
try:
sample = self._outq.get(timeout=3)
except Empty as e:
if not self._consumer_healthy():
raise StopIteration()
else:
continue
if isinstance(sample, EndSignal):
self._consumer_endsig[sample.id] = sample
logger.warn("recv endsignal from outq with errmsg[{}]" \
.format(sample.errmsg))
if len(self._consumer_endsig.keys()) < len(self._consumers):
self._inq.put(sample)
else:
self._exit = True
raise StopIteration("all consumers exited, no more samples")
else:
self._consumed += 1
return sample
raise StopIteration()
def reset(self):
""" reset for a new epoch of samples
"""
assert not self._exit, "cannot reset for already stopped dataset"
if self._epoch < 0:
self._epoch = 0
for w in self._consumers:
w.start()
self._producer.start()
else:
assert self._consumer_healthy(), "cannot start another pass of data" \
" for some consumers exited abnormally before!!!"
if not self.drained():
logger.warn("reset before epoch[{}] finishes".format(
self._epoch))
self._produced = self._produced - self._consumed
else:
self._produced = 0
self._epoch += 1
assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
+ " cannot start another epoch"
self._source.reset()
self._souce_drained = False
self._consumed = 0
self._feeding_ev.set()
# FIXME: fix me if you have better impliment
# handle terminate reader process, do not print stack frame
signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
# FIXME(dkp): KeyboardInterrupt should be handled inside ParallelMap
# and do such as: 1. exit workers 2. close queues 3. release shared
# memory, HACK KeyboardInterrupt with global signal.SIGINT handler
# here, should be refined later
def _term_workers(sig_num, frame):
global worker_set, main_pid
# only do subporcess killing in main process
if os.getpid() != main_pid:
return
logger.info("KeyboardInterrupt: main proc {} exit, kill subprocess {}" \
.format(os.getpid(), [w.pid for w in worker_set]))
for w in worker_set:
if w.pid is not None:
os.kill(w.pid, signal.SIGINT)
sys.exit()
signal.signal(signal.SIGINT, _term_workers)
此差异已折叠。
import os
import sys
import six
import time
import math
import socket
import contextlib
import numpy as np
from paddle import fluid
from paddle.io import BatchSampler
from paddle.fluid.layers import collective
from paddle.distributed import ParallelEnv
from paddle.fluid.dygraph.parallel import ParallelStrategy
_parallel_context_initialized = False
class DistributedBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
self.drop_last = drop_last
self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def set_epoch(self, epoch):
self.epoch = epoch
def wait_server_ready(endpoints):
assert not isinstance(endpoints, six.string_types)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with contextlib.closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
time.sleep(3)
else:
break
def init_communicator(program, rank, nranks, wait_port, current_endpoint,
endpoints):
if nranks < 2:
return
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=fluid.unique_name.generate('nccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
})
def prepare_distributed_context(place=None):
if place is None:
place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \
else fluid.CUDAPlace(0)
strategy = ParallelStrategy()
strategy.nranks = ParallelEnv().nranks
strategy.local_rank = ParallelEnv().local_rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
global _parallel_context_initialized
if not _parallel_context_initialized and isinstance(place, fluid.CUDAPlace):
def _init_context():
communicator_prog = fluid.Program()
init_communicator(communicator_prog, strategy.local_rank,
strategy.nranks, True, strategy.current_endpoint,
strategy.trainer_endpoints)
exe = fluid.Executor(place)
exe.run(communicator_prog)
fluid.disable_dygraph()
_init_context()
fluid.enable_dygraph(place)
else:
assert ("Only support CUDAPlace for now.")
_parallel_context_initialized = True
return strategy
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
__all__ = ['SharedBuffer', 'SharedMemoryMgr', 'SharedQueue']
from .sharedmemory import SharedBuffer
from .sharedmemory import SharedMemoryMgr
from .sharedmemory import SharedMemoryError
from .queue import SharedQueue
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
if six.PY3:
import pickle
from io import BytesIO as StringIO
from queue import Empty
else:
import cPickle as pickle
from cStringIO import StringIO
from Queue import Empty
import logging
import traceback
import multiprocessing as mp
from multiprocessing.queues import Queue
from .sharedmemory import SharedMemoryMgr
logger = logging.getLogger(__name__)
class SharedQueueError(ValueError):
""" SharedQueueError
"""
pass
class SharedQueue(Queue):
""" a Queue based on shared memory to communicate data between Process,
and it's interface is compatible with 'multiprocessing.queues.Queue'
"""
def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None):
""" init
"""
if six.PY3:
super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context())
else:
super(SharedQueue, self).__init__(maxsize)
if mem_mgr is not None:
self._shared_mem = mem_mgr
else:
self._shared_mem = SharedMemoryMgr(
capacity=memsize, pagesize=pagesize)
def put(self, obj, **kwargs):
""" put an object to this queue
"""
obj = pickle.dumps(obj, -1)
buff = None
try:
buff = self._shared_mem.malloc(len(obj))
buff.put(obj)
super(SharedQueue, self).put(buff, **kwargs)
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to put a element to SharedQueue '\
'with stack info[%s]' % (stack_info)
logger.warn(err_msg)
if buff is not None:
buff.free()
raise e
def get(self, **kwargs):
""" get an object from this queue
"""
buff = None
try:
buff = super(SharedQueue, self).get(**kwargs)
data = buff.get()
return pickle.load(StringIO(data))
except Empty as e:
raise e
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to get element from SharedQueue '\
'with stack info[%s]' % (stack_info)
logger.warn(err_msg)
raise e
finally:
if buff is not None:
buff.free()
def release(self):
self._shared_mem.release()
self._shared_mem = None
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import time
import math
import struct
import sys
import six
if six.PY3:
import pickle
else:
import cPickle as pickle
import json
import uuid
import random
import numpy as np
import weakref
import logging
from multiprocessing import Lock
from multiprocessing import RawArray
logger = logging.getLogger(__name__)
class SharedMemoryError(ValueError):
""" SharedMemoryError
"""
pass
class SharedBufferError(SharedMemoryError):
""" SharedBufferError
"""
pass
class MemoryFullError(SharedMemoryError):
""" MemoryFullError
"""
def __init__(self, errmsg=''):
super(MemoryFullError, self).__init__()
self.errmsg = errmsg
def memcopy(dst, src, offset=0, length=None):
""" copy data from 'src' to 'dst' in bytes
"""
length = length if length is not None else len(src)
assert type(dst) == np.ndarray, 'invalid type for "dst" in memcopy'
if type(src) is not np.ndarray:
if type(src) is str and six.PY3:
src = src.encode()
src = np.frombuffer(src, dtype='uint8', count=len(src))
dst[:] = src[offset:offset + length]
class SharedBuffer(object):
""" Buffer allocated from SharedMemoryMgr, and it stores data on shared memory
note that:
every instance of this should be freed explicitely by calling 'self.free'
"""
def __init__(self, owner, capacity, pos, size=0, alloc_status=''):
""" Init
Args:
owner (str): manager to own this buffer
capacity (int): capacity in bytes for this buffer
pos (int): page position in shared memory
size (int): bytes already used
alloc_status (str): debug info about allocator when allocate this
"""
self._owner = owner
self._cap = capacity
self._pos = pos
self._size = size
self._alloc_status = alloc_status
assert self._pos >= 0 and self._cap > 0, \
"invalid params[%d:%d] to construct SharedBuffer" \
% (self._pos, self._cap)
def owner(self):
""" get owner
"""
return SharedMemoryMgr.get_mgr(self._owner)
def put(self, data, override=False):
""" put data to this buffer
Args:
data (str): data to be stored in this buffer
Returns:
None
Raises:
SharedMemoryError when not enough space in this buffer
"""
assert type(data) in [str, bytes], \
'invalid type[%s] for SharedBuffer::put' % (str(type(data)))
if self._size > 0 and not override:
raise SharedBufferError('already has already been setted before')
if self.capacity() < len(data):
raise SharedBufferError('data[%d] is larger than size of buffer[%s]'\
% (len(data), str(self)))
self.owner().put_data(self, data)
self._size = len(data)
def get(self, offset=0, size=None, no_copy=True):
""" get the data stored this buffer
Args:
offset (int): position for the start point to 'get'
size (int): size to get
Returns:
data (np.ndarray('uint8')): user's data in numpy
which is passed in by 'put'
None: if no data stored in
"""
offset = offset if offset >= 0 else self._size + offset
if self._size <= 0:
return None
size = self._size if size is None else size
assert offset + size <= self._cap, 'invalid offset[%d] '\
'or size[%d] for capacity[%d]' % (offset, size, self._cap)
return self.owner().get_data(self, offset, size, no_copy=no_copy)
def size(self):
""" bytes of used memory
"""
return self._size
def resize(self, size):
""" resize the used memory to 'size', should not be greater than capacity
"""
assert size >= 0 and size <= self._cap, \
"invalid size[%d] for resize" % (size)
self._size = size
def capacity(self):
""" size of allocated memory
"""
return self._cap
def __str__(self):
""" human readable format
"""
return "SharedBuffer(owner:%s, pos:%d, size:%d, "\
"capacity:%d, alloc_status:[%s], pid:%d)" \
% (str(self._owner), self._pos, self._size, \
self._cap, self._alloc_status, os.getpid())
def free(self):
""" free this buffer to it's owner
"""
if self._owner is not None:
self.owner().free(self)
self._owner = None
self._cap = 0
self._pos = -1
self._size = 0
return True
else:
return False
class PageAllocator(object):
""" allocator used to malloc and free shared memory which
is split into pages
"""
s_allocator_header = 12
def __init__(self, base, total_pages, page_size):
""" init
"""
self._magic_num = 1234321000 + random.randint(100, 999)
self._base = base
self._total_pages = total_pages
self._page_size = page_size
header_pages = int(
math.ceil((total_pages + self.s_allocator_header) / page_size))
self._header_pages = header_pages
self._free_pages = total_pages - header_pages
self._header_size = self._header_pages * page_size
self._reset()
def _dump_alloc_info(self, fname):
hpages, tpages, pos, used = self.header()
start = self.s_allocator_header
end = start + self._page_size * hpages
alloc_flags = self._base[start:end].tostring()
info = {
'magic_num': self._magic_num,
'header_pages': hpages,
'total_pages': tpages,
'pos': pos,
'used': used
}
info['alloc_flags'] = alloc_flags
fname = fname + '.' + str(uuid.uuid4())[:6]
with open(fname, 'wb') as f:
f.write(pickle.dumps(info, -1))
logger.warn('dump alloc info to file[%s]' % (fname))
def _reset(self):
alloc_page_pos = self._header_pages
used_pages = self._header_pages
header_info = struct.pack(
str('III'), self._magic_num, alloc_page_pos, used_pages)
assert len(header_info) == self.s_allocator_header, \
'invalid size of header_info'
memcopy(self._base[0:self.s_allocator_header], header_info)
self.set_page_status(0, self._header_pages, '1')
self.set_page_status(self._header_pages, self._free_pages, '0')
def header(self):
""" get header info of this allocator
"""
header_str = self._base[0:self.s_allocator_header].tostring()
magic, pos, used = struct.unpack(str('III'), header_str)
assert magic == self._magic_num, \
'invalid header magic[%d] in shared memory' % (magic)
return self._header_pages, self._total_pages, pos, used
def empty(self):
""" are all allocatable pages available
"""
header_pages, pages, pos, used = self.header()
return header_pages == used
def full(self):
""" are all allocatable pages used
"""
header_pages, pages, pos, used = self.header()
return header_pages + used == pages
def __str__(self):
header_pages, pages, pos, used = self.header()
desc = '{page_info[magic:%d,total:%d,used:%d,header:%d,alloc_pos:%d,pagesize:%d]}' \
% (self._magic_num, pages, used, header_pages, pos, self._page_size)
return 'PageAllocator:%s' % (desc)
def set_alloc_info(self, alloc_pos, used_pages):
""" set allocating position to new value
"""
memcopy(self._base[4:12], struct.pack(str('II'), alloc_pos, used_pages))
def set_page_status(self, start, page_num, status):
""" set pages from 'start' to 'end' with new same status 'status'
"""
assert status in ['0', '1'], 'invalid status[%s] for page status '\
'in allocator[%s]' % (status, str(self))
start += self.s_allocator_header
end = start + page_num
assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\
'in allocator[%s]' % (end, str(self))
memcopy(self._base[start:end], str(status * page_num))
def get_page_status(self, start, page_num, ret_flag=False):
start += self.s_allocator_header
end = start + page_num
assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\
'in allocator[%s]' % (end, str(self))
status = self._base[start:end].tostring().decode()
if ret_flag:
return status
zero_num = status.count('0')
if zero_num == 0:
return (page_num, 1)
else:
return (zero_num, 0)
def malloc_page(self, page_num):
header_pages, pages, pos, used = self.header()
end = pos + page_num
if end > pages:
pos = self._header_pages
end = pos + page_num
start_pos = pos
flags = ''
while True:
flags = self.get_page_status(pos, page_num, ret_flag=True)
if flags.count('0') == page_num:
break
# not found enough pages, so shift to next few pages
free_pos = flags.rfind('1') + 1
pos += free_pos
end = pos + page_num
if end > pages:
pos = self._header_pages
end = pos + page_num
flags = ''
# not found available pages after scan all pages
if pos <= start_pos and end >= start_pos:
logger.debug('not found available pages after scan all pages')
break
page_status = (flags.count('0'), 0)
if page_status != (page_num, 0):
free_pages = self._total_pages - used
if free_pages == 0:
err_msg = 'all pages have been used:%s' % (str(self))
else:
err_msg = 'not found enough pages[avail:%d, expect:%d] '\
'with total free pages[%d]' % (page_status[0], page_num, free_pages)
err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] '\
'and allocator status[%s]' % (page_num, pos, err_msg, str(self))
raise MemoryFullError(err_msg)
self.set_page_status(pos, page_num, '1')
used += page_num
self.set_alloc_info(end, used)
return pos
def free_page(self, start, page_num):
""" free 'page_num' pages start from 'start'
"""
page_status = self.get_page_status(start, page_num)
assert page_status == (page_num, 1), \
'invalid status[%s] when free [%d, %d]' \
% (str(page_status), start, page_num)
self.set_page_status(start, page_num, '0')
_, _, pos, used = self.header()
used -= page_num
self.set_alloc_info(pos, used)
DEFAULT_SHARED_MEMORY_SIZE = 1024 * 1024 * 1024
class SharedMemoryMgr(object):
""" manage a continouse block of memory, provide
'malloc' to allocate new buffer, and 'free' to free buffer
"""
s_memory_mgrs = weakref.WeakValueDictionary()
s_mgr_num = 0
s_log_statis = False
@classmethod
def get_mgr(cls, id):
""" get a SharedMemoryMgr with size of 'capacity'
"""
assert id in cls.s_memory_mgrs, 'invalid id[%s] for memory managers' % (
id)
return cls.s_memory_mgrs[id]
def __init__(self, capacity=None, pagesize=None):
""" init
"""
logger.debug('create SharedMemoryMgr')
pagesize = 64 * 1024 if pagesize is None else pagesize
assert type(pagesize) is int, "invalid type of pagesize[%s]" \
% (str(pagesize))
capacity = DEFAULT_SHARED_MEMORY_SIZE if capacity is None else capacity
assert type(capacity) is int, "invalid type of capacity[%s]" \
% (str(capacity))
assert capacity > 0, '"size of shared memory should be greater than 0'
self._released = False
self._cap = capacity
self._page_size = pagesize
assert self._cap % self._page_size == 0, \
"capacity[%d] and pagesize[%d] are not consistent" \
% (self._cap, self._page_size)
self._total_pages = self._cap // self._page_size
self._pid = os.getpid()
SharedMemoryMgr.s_mgr_num += 1
self._id = self._pid * 100 + SharedMemoryMgr.s_mgr_num
SharedMemoryMgr.s_memory_mgrs[self._id] = self
self._locker = Lock()
self._setup()
def _setup(self):
self._shared_mem = RawArray('c', self._cap)
self._base = np.frombuffer(
self._shared_mem, dtype='uint8', count=self._cap)
self._locker.acquire()
try:
self._allocator = PageAllocator(self._base, self._total_pages,
self._page_size)
finally:
self._locker.release()
def malloc(self, size, wait=True):
""" malloc a new SharedBuffer
Args:
size (int): buffer size to be malloc
wait (bool): whether to wait when no enough memory
Returns:
SharedBuffer
Raises:
SharedMemoryError when not found available memory
"""
page_num = int(math.ceil(size / self._page_size))
size = page_num * self._page_size
start = None
ct = 0
errmsg = ''
while True:
self._locker.acquire()
try:
start = self._allocator.malloc_page(page_num)
alloc_status = str(self._allocator)
except MemoryFullError as e:
start = None
errmsg = e.errmsg
if not wait:
raise e
finally:
self._locker.release()
if start is None:
time.sleep(0.1)
if ct % 100 == 0:
logger.warn('not enough space for reason[%s]' % (errmsg))
ct += 1
else:
break
return SharedBuffer(self._id, size, start, alloc_status=alloc_status)
def free(self, shared_buf):
""" free a SharedBuffer
Args:
shared_buf (SharedBuffer): buffer to be freed
Returns:
None
Raises:
SharedMemoryError when failed to release this buffer
"""
assert shared_buf._owner == self._id, "invalid shared_buf[%s] "\
"for it's not allocated from me[%s]" % (str(shared_buf), str(self))
cap = shared_buf.capacity()
start_page = shared_buf._pos
page_num = cap // self._page_size
#maybe we don't need this lock here
self._locker.acquire()
try:
self._allocator.free_page(start_page, page_num)
finally:
self._locker.release()
def put_data(self, shared_buf, data):
""" fill 'data' into 'shared_buf'
"""
assert len(data) <= shared_buf.capacity(), 'too large data[%d] '\
'for this buffer[%s]' % (len(data), str(shared_buf))
start = shared_buf._pos * self._page_size
end = start + len(data)
assert start >= 0 and end <= self._cap, "invalid start "\
"position[%d] when put data to buff:%s" % (start, str(shared_buf))
self._base[start:end] = np.frombuffer(data, 'uint8', len(data))
def get_data(self, shared_buf, offset, size, no_copy=True):
""" extract 'data' from 'shared_buf' in range [offset, offset + size)
"""
start = shared_buf._pos * self._page_size
start += offset
if no_copy:
return self._base[start:start + size]
else:
return self._base[start:start + size].tostring()
def __str__(self):
return 'SharedMemoryMgr:{id:%d, %s}' % (self._id, str(self._allocator))
def __del__(self):
if SharedMemoryMgr.s_log_statis:
logger.info('destroy [%s]' % (self))
if not self._released and not self._allocator.empty():
logger.debug('not empty when delete this SharedMemoryMgr[%s]' %
(self))
else:
self._released = True
if self._id in SharedMemoryMgr.s_memory_mgrs:
del SharedMemoryMgr.s_memory_mgrs[self._id]
SharedMemoryMgr.s_mgr_num -= 1
......@@ -13,9 +13,10 @@
# limitations under the License.
from . import coco
from . import voc
from . import widerface
# TODO add voc and widerface dataset
#from . import voc
#from . import widerface
from .coco import *
from .voc import *
from .widerface import *
#from .voc import *
#from .widerface import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from .dataset import DataSet
import logging
from ppdet.core.workspace import register, serializable
from .dataset import DetDataset
import logging
logger = logging.getLogger(__name__)
@register
@serializable
class COCODataSet(DataSet):
"""
Load COCO records with annotations in json file 'anno_path'
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): json file path.
sample_num (int): number of samples to load, -1 means all.
"""
class COCODataSet(DetDataset):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
dataset_dir=None,
sample_num=-1):
super(COCODataSet, self).__init__(
image_dir=image_dir,
anno_path=anno_path,
dataset_dir=dataset_dir,
sample_num=sample_num)
self.anno_path = anno_path
self.sample_num = sample_num
# `roidbs` is list of dict whose structure is:
# {
# 'im_file': im_fname, # image file name
# 'im_id': img_id, # image id
# 'h': im_h, # height of image
# 'w': im_w, # width
# 'is_crowd': is_crowd,
# 'gt_score': gt_score,
# 'gt_class': gt_class,
# 'gt_bbox': gt_bbox,
# 'gt_poly': gt_poly,
# }
self.roidbs = None
# a dict used to map category name to class id
self.cname2cid = None
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
sample_num)
self.load_image_only = False
self.load_semantic = False
def load_roidb_and_cname2cid(self, with_background=True):
def parse_dataset(self, with_background=True):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
......@@ -99,11 +69,11 @@ class COCODataSet(DataSet):
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
im_fname = os.path.join(image_dir,
im_fname) if image_dir else im_fname
if not os.path.exists(im_fname):
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
if not os.path.exists(im_path):
logger.warn('Illegal image file: {}, and it will be '
'ignored'.format(im_fname))
'ignored'.format(im_path))
continue
if im_w < 0 or im_h < 0:
......@@ -113,7 +83,7 @@ class COCODataSet(DataSet):
continue
coco_rec = {
'im_file': im_fname,
'im_file': im_path,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
......@@ -122,14 +92,20 @@ class COCODataSet(DataSet):
if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
for inst in instances:
# check gt bbox
if 'bbox' not in inst.keys():
continue
else:
if not any(np.array(inst['bbox'])):
continue
x, y, box_w, box_h = inst['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(im_w - 1, x1 + max(0, box_w - 1))
y2 = min(im_h - 1, y1 + max(0, box_h - 1))
if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
inst['clean_bbox'] = [x1, y1, x2, y2]
bboxes.append(inst)
......@@ -138,7 +114,6 @@ class COCODataSet(DataSet):
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
......@@ -153,9 +128,16 @@ class COCODataSet(DataSet):
gt_class[i][0] = catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
# check RLE format
if box['iscrowd'] == 1:
gt_poly[i] = [[0.0, 0.0], ]
continue
if 'segmentation' in box:
gt_poly[i] = box['segmentation']
if not any(gt_poly):
continue
coco_rec.update({
'is_crowd': is_crowd,
'gt_class': gt_class,
......@@ -163,9 +145,14 @@ class COCODataSet(DataSet):
'gt_score': gt_score,
'gt_poly': gt_poly,
})
# TODO: remove load_semantic
if self.load_semantic:
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
'train2017', im_fname[:-3] + 'png')
coco_rec.update({'semantic': seg_path})
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_fname, img_id, im_h, im_w))
im_path, img_id, im_h, im_w))
records.append(coco_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from collections import OrderedDict
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from paddle.io import Dataset
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
@serializable
class DataSet(object):
"""
Dataset, e.g., coco, pascal voc
Args:
annotation (str): annotation file path
image_dir (str): directory where image files are stored
shuffle (bool): shuffle samples
"""
class DetDataset(Dataset):
def __init__(self,
dataset_dir=None,
image_dir=None,
......@@ -42,95 +33,54 @@ class DataSet(object):
sample_num=-1,
use_default_label=None,
**kwargs):
super(DataSet, self).__init__()
super(DetDataset, self).__init__()
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
self.anno_path = anno_path
self.image_dir = image_dir if image_dir is not None else ''
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
self.sample_num = sample_num
self.use_default_label = use_default_label
self.cname2cid = None
self._imid2path = None
def load_roidb_and_cname2cid(self):
"""load dataset"""
raise NotImplementedError('%s.load_roidb_and_cname2cid not available' %
(self.__class__.__name__))
def __len__(self, ):
return len(self.roidbs)
def get_roidb(self, with_background=True):
if not self.roidbs:
data_dir = get_dataset_path(self.dataset_dir, self.anno_path,
self.image_dir)
if data_dir:
self.dataset_dir = data_dir
self.load_roidb_and_cname2cid(with_background)
def __getitem__(self, idx):
# data batch
roidb = self.roidbs[idx]
# data augment
roidb = self.transform(roidb)
# data item
out = OrderedDict()
for k in self.fields:
out[k] = roidb[k]
return out.values()
return self.roidbs
def set_out(self, sample_transform, fields):
self.transform = sample_transform
self.fields = fields
def get_cname2cid(self):
if not self.cname2cid:
self.load_roidb_and_cname2cid()
return self.cname2cid
def parse_dataset(self, with_background=True):
raise NotImplemented(
"Need to implement parse_dataset method of Dataset")
def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path)
def get_imid2path(self):
return self._imid2path
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
return f.lower().endswith(extensions)
def _make_dataset(dir):
dir = os.path.expanduser(dir)
if not os.path.isdir(d):
raise ('{} should be a dir'.format(dir))
images = []
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
images.append(path)
return images
@register
@serializable
class ImageFolder(DataSet):
"""
Args:
dataset_dir (str): root directory for dataset.
image_dir(list|str): list of image folders or list of image files
anno_path (str): annotation file path.
samples (int): number of samples to load, -1 means all
"""
class ImageFolder(DetDataset):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
sample_num=-1,
use_default_label=None,
**kwargs):
super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path,
sample_num, use_default_label)
self.roidbs = None
self._imid2path = {}
def get_roidb(self):
if not self.roidbs:
self.roidbs = self._load_images()
return self.roidbs
sample_num)
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
def _parse(self):
def parse_dataset(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
......@@ -141,20 +91,4 @@ class ImageFolder(DataSet):
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._parse()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
self.roidbs = images
......@@ -26,6 +26,7 @@ import cv2
import numpy as np
from .operator import register_op, BaseOperator
from .op_helper import jaccard_overlap, gaussian2D
from .operators import NormalizeImage, Permute
logger = logging.getLogger(__name__)
......@@ -55,7 +56,7 @@ class PadBatch(BaseOperator):
self.use_padded_im_info = use_padded_im_info
self.pad_gt = pad_gt
def __call__(self, samples, context=None):
def __call__(self, samples):
"""
Args:
samples (list): a batch of sample, each is dict.
......@@ -156,7 +157,7 @@ class RandomShape(BaseOperator):
] if random_inter else []
self.resize_box = resize_box
def __call__(self, samples, context=None):
def __call__(self, samples):
shape = np.random.choice(self.sizes)
method = np.random.choice(self.interps) if self.random_inter \
else cv2.INTER_NEAREST
......@@ -191,7 +192,7 @@ class PadMultiScaleTest(BaseOperator):
super(PadMultiScaleTest, self).__init__()
self.pad_to_stride = pad_to_stride
def __call__(self, samples, context=None):
def __call__(self, samples):
coarsest_stride = self.pad_to_stride
if coarsest_stride == 0:
return samples
......@@ -247,7 +248,7 @@ class Gt2YoloTarget(BaseOperator):
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, samples, context=None):
def __call__(self, samples):
assert len(self.anchor_masks) == len(self.downsample_ratios), \
"anchor_masks', and 'downsample_ratios' should have same length."
......@@ -430,7 +431,7 @@ class Gt2FCOSTarget(BaseOperator):
inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
return inside_gt_box
def __call__(self, samples, context=None):
def __call__(self, samples):
assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
"object_sizes_of_interest', and 'downsample_ratios' should have same length."
......@@ -554,7 +555,7 @@ class Gt2TTFTarget(BaseOperator):
self.num_classes = num_classes
self.alpha = alpha
def __call__(self, samples, context=None):
def __call__(self, samples):
output_size = samples[0]['image'].shape[1]
feat_size = output_size // self.down_ratio
for sample in samples:
......
......@@ -71,7 +71,7 @@ class DecodeImage(BaseOperator):
if not isinstance(self.with_cutmix, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
def __call__(self, sample):
""" load image if 'im_file' field is not empty but 'image' is"""
if 'image' not in sample:
with open(sample['im_file'], 'rb') as f:
......@@ -106,10 +106,10 @@ class DecodeImage(BaseOperator):
[im.shape[0], im.shape[1], 1.], dtype=np.float32)
# decode mixup image
if self.with_mixup and 'mixup' in sample:
self.__call__(sample['mixup'], context)
self.__call__(sample['mixup'])
# decode cutmix image
if self.with_cutmix and 'cutmix' in sample:
self.__call__(sample['cutmix'], context)
self.__call__(sample['cutmix'])
return sample
......@@ -150,7 +150,7 @@ class MultiscaleTestResize(BaseOperator):
and isinstance(self.interp, int)):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
def __call__(self, sample):
""" Resize the image numpy for multi-scale test.
"""
origin_ims = {}
......@@ -368,7 +368,7 @@ class RandomFlipImage(BaseOperator):
gt_keypoint[:, i] = width - old_x - 1
return gt_keypoint
def __call__(self, sample, context=None):
def __call__(self, sample):
"""Filp the image and bounding box.
Operators:
1. Flip the image numpy.
......@@ -441,7 +441,7 @@ class RandomErasingImage(BaseOperator):
self.sh = sh
self.r1 = r1
def __call__(self, sample, context=None):
def __call__(self, sample):
samples = sample
batch_input = True
if not isinstance(samples, Sequence):
......@@ -525,7 +525,7 @@ class GridMaskOp(BaseOperator):
prob=prob,
upper_iter=upper_iter)
def __call__(self, sample, context=None):
def __call__(self, sample):
samples = sample
batch_input = True
if not isinstance(samples, Sequence):
......@@ -553,7 +553,7 @@ class AutoAugmentImage(BaseOperator):
if not isinstance(self.is_normalized, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
def __call__(self, sample):
"""
Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172
"""
......@@ -632,7 +632,7 @@ class NormalizeImage(BaseOperator):
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, sample, context=None):
def __call__(self, sample):
"""Normalize the image.
Operators:
1.(optional) Scale the image to [0,1]
......@@ -748,7 +748,7 @@ class RandomDistort(BaseOperator):
img = Image.fromarray(img, mode='HSV').convert('RGB')
return img
def __call__(self, sample, context):
def __call__(self, sample):
"""random distort the image"""
ops = [
self.random_brightness, self.random_contrast,
......@@ -789,7 +789,7 @@ class ExpandImage(BaseOperator):
self.mean = mean
self.prob = prob
def __call__(self, sample, context):
def __call__(self, sample):
"""
Expand the image and modify bounding box.
Operators:
......@@ -873,7 +873,7 @@ class CropImage(BaseOperator):
self.satisfy_all = satisfy_all
self.avoid_no_bbox = avoid_no_bbox
def __call__(self, sample, context):
def __call__(self, sample):
"""
Crop the image and modify bounding box.
Operators:
......@@ -969,7 +969,7 @@ class CropImageWithDataAchorSampling(BaseOperator):
self.avoid_no_bbox = avoid_no_bbox
self.das_anchor_scales = np.array(das_anchor_scales)
def __call__(self, sample, context):
def __call__(self, sample):
"""
Crop the image and modify bounding box.
Operators:
......@@ -1102,7 +1102,7 @@ class NormalizeBox(BaseOperator):
def __init__(self):
super(NormalizeBox, self).__init__()
def __call__(self, sample, context):
def __call__(self, sample):
gt_bbox = sample['gt_bbox']
width = sample['w']
height = sample['h']
......@@ -1948,22 +1948,21 @@ class PadBox(BaseOperator):
self.num_max_boxes = num_max_boxes
super(PadBox, self).__init__()
def __call__(self, sample, context=None):
def __call__(self, sample):
assert 'gt_bbox' in sample
bbox = sample['gt_bbox']
gt_num = min(self.num_max_boxes, len(bbox))
num_max = self.num_max_boxes
fields = context['fields'] if context else []
pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
if gt_num > 0:
pad_bbox[:gt_num, :] = bbox[:gt_num, :]
sample['gt_bbox'] = pad_bbox
if 'gt_class' in fields:
if 'gt_class' in sample.keys():
pad_class = np.zeros((num_max), dtype=np.int32)
if gt_num > 0:
pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
sample['gt_class'] = pad_class
if 'gt_score' in fields:
if 'gt_score' in sample.keys():
pad_score = np.zeros((num_max), dtype=np.float32)
if gt_num > 0:
pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
......@@ -1971,7 +1970,7 @@ class PadBox(BaseOperator):
# in training, for example in op ExpandImage,
# the bbox and gt_class is expandded, but the difficult is not,
# so, judging by it's length
if 'is_difficult' in fields:
if 'is_difficult' in sample.keys():
pad_diff = np.zeros((num_max), dtype=np.int32)
if gt_num > 0:
pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
......@@ -1988,7 +1987,7 @@ class BboxXYXY2XYWH(BaseOperator):
def __init__(self):
super(BboxXYXY2XYWH, self).__init__()
def __call__(self, sample, context=None):
def __call__(self, sample):
assert 'gt_bbox' in sample
bbox = sample['gt_bbox']
bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
......@@ -2012,7 +2011,7 @@ class Lighting(BaseOperator):
self.eigval = np.array(eigval).astype('float32')
self.eigvec = np.array(eigvec).astype('float32')
def __call__(self, sample, context=None):
def __call__(self, sample):
alpha = np.random.normal(scale=self.alphastd, size=(3, ))
sample['image'] += np.dot(self.eigvec, self.eigval * alpha)
return sample
......@@ -2050,7 +2049,7 @@ class CornerTarget(BaseOperator):
self.gaussian_iou = gaussian_iou
self.max_tag_len = max_tag_len
def __call__(self, sample, context=None):
def __call__(self, sample):
tl_heatmaps = np.zeros(
(self.num_classes, self.output_size[0], self.output_size[1]),
dtype=np.float32)
......@@ -2147,7 +2146,7 @@ class CornerCrop(BaseOperator):
self.is_train = is_train
self.input_size = input_size
def __call__(self, sample, context=None):
def __call__(self, sample):
im_h, im_w = int(sample['h']), int(sample['w'])
if self.is_train:
scale = np.random.choice(self.random_scales)
......@@ -2221,7 +2220,7 @@ class CornerRatio(BaseOperator):
self.input_size = input_size
self.output_size = output_size
def __call__(self, sample, context=None):
def __call__(self, sample):
scale = (self.input_size + 1) // self.output_size
out_height, out_width = (sample['h'] + 1) // scale, (
sample['w'] + 1) // scale
......@@ -2251,7 +2250,7 @@ class RandomScaledCrop(BaseOperator):
self.scale_range = scale_range
self.interp = interp
def __call__(self, sample, context=None):
def __call__(self, sample):
w = sample['w']
h = sample['h']
random_scale = np.random.uniform(*self.scale_range)
......@@ -2300,7 +2299,7 @@ class ResizeAndPad(BaseOperator):
self.target_dim = target_dim
self.interp = interp
def __call__(self, sample, context=None):
def __call__(self, sample):
w = sample['w']
h = sample['h']
interp = self.interp
......@@ -2404,7 +2403,7 @@ class TargetAssign(BaseOperator):
offsets[..., 2:] = np.log(whb / wha)
return offsets
def __call__(self, sample, context=None):
def __call__(self, sample):
gt_boxes = sample['gt_bbox']
gt_labels = sample['gt_class']
labels = np.full((self.anchors.shape[0], 1), 0, dtype=np.int32)
......@@ -2444,7 +2443,7 @@ class DebugVisibleImage(BaseOperator):
if not isinstance(self.is_normalized, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
def __call__(self, sample):
image = Image.open(sample['im_file']).convert('RGB')
out_file_name = sample['im_file'].split('/')[-1]
width = sample['w']
......
......@@ -7,6 +7,7 @@ from . import head
from . import loss
from . import architecture
from . import post_process
from . import layers
from .ops import *
from .bbox import *
......@@ -17,3 +18,4 @@ from .head import *
from .loss import *
from .architecture import *
from .post_process import *
from .layers import *
......@@ -31,15 +31,9 @@ class BaseArch(nn.Layer):
def build_inputs(self, data, input_def):
inputs = {}
for name in input_def:
inputs[name] = []
batch_size = len(data)
for bs in range(batch_size):
for name, input in zip(input_def, data[bs]):
input_v = np.array(input)[np.newaxis, ...]
inputs[name].append(input_v)
for name in input_def:
inputs[name] = paddle.to_tensor(np.concatenate(inputs[name]))
for i, k in enumerate(input_def):
v = paddle.to_tensor(data[i])
inputs[k] = v
return inputs
def model_arch(self):
......
......@@ -23,7 +23,7 @@ class Mask(object):
im_info=inputs['im_info'],
gt_classes=inputs['gt_class'],
is_crowd=inputs['is_crowd'],
gt_segms=inputs['gt_mask'],
gt_segms=inputs['gt_poly'],
rois=proposals,
rois_num=proposals_num,
labels_int32=labels_int32)
......
......@@ -43,7 +43,7 @@ class PiecewiseDecay(object):
milestones (list): steps at which to decay learning rate
"""
def __init__(self, gamma=[0.1, 0.01], milestones=[60000, 80000]):
def __init__(self, gamma=[0.1, 0.01], milestones=[8, 11]):
super(PiecewiseDecay, self).__init__()
if type(gamma) is not list:
self.gamma = []
......@@ -53,9 +53,13 @@ class PiecewiseDecay(object):
self.gamma = gamma
self.milestones = milestones
def __call__(self, base_lr=None, boundary=None, value=None):
def __call__(self,
base_lr=None,
boundary=None,
value=None,
step_per_epoch=None):
if boundary is not None:
boundary.extend(self.milestones)
boundary.extend([int(step_per_epoch) * i for i in self.milestones])
if value is not None:
for i in self.gamma:
......@@ -110,12 +114,13 @@ class LearningRate(object):
self.base_lr = base_lr
self.schedulers = schedulers
def __call__(self):
def __call__(self, step_per_epoch):
# TODO: split warmup & decay
# warmup
boundary, value = self.schedulers[1](self.base_lr)
# decay
decay_lr = self.schedulers[0](self.base_lr, boundary, value)
decay_lr = self.schedulers[0](self.base_lr, boundary, value,
step_per_epoch)
return decay_lr
......
......@@ -19,7 +19,6 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import get_infer_results, eval_results
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
......@@ -45,7 +44,7 @@ def parse_args():
return args
def run(FLAGS, cfg):
def run(FLAGS, cfg, place):
# Model
main_arch = cfg.architecture
......@@ -55,18 +54,14 @@ def run(FLAGS, cfg):
model = load_dygraph_ckpt(model, ckpt=cfg.weights)
# Data Reader
if FLAGS.use_gpu:
devices_num = 1
else:
devices_num = int(os.environ.get('CPU_NUM', 1))
eval_reader = create_reader(
cfg.EvalDataset, cfg.EvalReader, devices_num=devices_num)
dataset = cfg.EvalDataset
eval_loader, _ = create('EvalReader')(dataset, cfg['worker_num'], place)
# Run Eval
outs_res = []
start_time = time.time()
sample_num = 0
for iter_id, data in enumerate(eval_reader()):
for iter_id, data in enumerate(eval_loader):
# forward
model.eval()
outs = model(data, cfg['EvalReader']['inputs_def']['fields'], 'infer')
......@@ -86,10 +81,9 @@ def run(FLAGS, cfg):
eval_type.append('mask')
# Metric
# TODO: support other metric
dataset = cfg.EvalReader['dataset']
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno()
with_background = dataset.with_background
with_background = cfg.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
......@@ -107,10 +101,9 @@ def main():
check_gpu(cfg.use_gpu)
check_version()
place = paddle.CUDAPlace(ParallelEnv()
.dev_id) if cfg.use_gpu else paddle.CPUPlace()
paddle.disable_static(place)
run(FLAGS, cfg)
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg, place)
if __name__ == '__main__':
......
......@@ -18,12 +18,11 @@ from collections import deque
import paddle
from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.reader import create_reader
from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
import paddle.distributed as dist
from paddle.distributed import ParallelEnv
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
......@@ -87,9 +86,8 @@ def parse_args():
return args
def run(FLAGS, cfg):
def run(FLAGS, cfg, place):
env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
......@@ -101,15 +99,20 @@ def run(FLAGS, cfg):
random.seed(0)
np.random.seed(0)
if dist.ParallelEnv().nranks > 1:
if ParallelEnv().nranks > 1:
paddle.distributed.init_parallel_env()
# Data
dataset = cfg.TrainDataset
train_loader, step_per_epoch = create('TrainReader')(
dataset, cfg['worker_num'], place)
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
# Optimizer
lr = create('LearningRate')()
lr = create('LearningRate')(step_per_epoch / int(ParallelEnv().nranks))
optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer
......@@ -121,73 +124,60 @@ def run(FLAGS, cfg):
load_static_weights=cfg.get('load_static_weights', False))
# Parallel Model
if dist.ParallelEnv().nranks > 1:
if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model)
# Data Reader
# Run Train
start_iter = 0
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(os.environ.get('CPU_NUM', 1))
train_reader = create_reader(
cfg.TrainDataset,
cfg.TrainReader, (cfg.max_iters - start_iter),
cfg,
devices_num=devices_num)
time_stat = deque(maxlen=cfg.log_iter)
start_time = time.time()
end_time = time.time()
# Run Train
for iter_id, data in enumerate(train_reader()):
start_time = end_time
end_time = time.time()
time_stat.append(end_time - start_time)
time_cost = np.mean(time_stat)
eta_sec = (cfg.max_iters - iter_id) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec)))
# Model Forward
model.train()
outputs = model(data, cfg['TrainReader']['inputs_def']['fields'],
'train')
# Model Backward
loss = outputs['loss']
if dist.ParallelEnv().nranks > 1:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
else:
loss.backward()
optimizer.minimize(loss)
optimizer.step()
curr_lr = optimizer.get_lr()
lr.step()
optimizer.clear_grad()
if dist.ParallelEnv().nranks < 2 or dist.ParallelEnv().local_rank == 0:
# Log state
if iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs)
logs = train_stats.log()
if iter_id % cfg.log_iter == 0:
ips = float(cfg['TrainReader']['batch_size']) / time_cost
strs = 'iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'.format(
iter_id, curr_lr, logs, eta, time_cost, ips)
logger.info(strs)
# Save Stage
if iter_id > 0 and iter_id % int(
cfg.snapshot_iter) == 0 or iter_id == cfg.max_iters - 1:
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(
iter_id) if iter_id != cfg.max_iters - 1 else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name)
save_dygraph_ckpt(model, optimizer, save_dir, save_name)
for e_id in range(int(cfg.epoch)):
for iter_id, data in enumerate(train_loader):
start_time = end_time
end_time = time.time()
time_stat.append(end_time - start_time)
time_cost = np.mean(time_stat)
eta_sec = (cfg.epoch * step_per_epoch - iter_id) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec)))
# Model Forward
model.train()
outputs = model(data, cfg['TrainReader']['inputs_def']['fields'],
'train')
# Model Backward
loss = outputs['loss']
if ParallelEnv().nranks > 1:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
else:
loss.backward()
optimizer.minimize(loss)
optimizer.step()
curr_lr = optimizer.get_lr()
lr.step()
optimizer.clear_grad()
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state
if iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs)
logs = train_stats.log()
if iter_id % cfg.log_iter == 0:
strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
e_id, iter_id, curr_lr, logs, time_cost, eta)
logger.info(strs)
# Save Stage
if ParallelEnv().local_rank == 0 and e_id % cfg.snapshot_epoch == 0:
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(e_id + 1) if e_id + 1 != int(
cfg.epoch) else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name)
save_dygraph_ckpt(model, optimizer, save_dir, save_name)
def main():
......@@ -199,7 +189,10 @@ def main():
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg, place)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册