diff --git a/dygraph/datasets/__init__.py b/dygraph/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77dae995267f45ed2551b38ee6b1daf396fc0b8a --- /dev/null +++ b/dygraph/datasets/__init__.py @@ -0,0 +1,16 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dataset import Dataset diff --git a/dygraph/datasets/dataset.py b/dygraph/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..07bfdebd4ac106c0c8e28726590d3ec3aeb148a8 --- /dev/null +++ b/dygraph/datasets/dataset.py @@ -0,0 +1,275 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +from threading import Thread +import multiprocessing +import collections +import numpy as np +import six +import sys +import copy +import random +import platform +import chardet +import utils.logging as logging + + +class EndSignal(): + pass + + +def is_pic(img_name): + valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'] + suffix = img_name.split('.')[-1] + if suffix not in valid_suffix: + return False + return True + + +def is_valid(sample): + if sample is None: + return False + if isinstance(sample, tuple): + for s in sample: + if s is None: + return False + elif isinstance(s, np.ndarray) and s.size == 0: + return False + elif isinstance(s, collections.Sequence) and len(s) == 0: + return False + return True + + +def get_encoding(path): + f = open(path, 'rb') + data = f.read() + file_encoding = chardet.detect(data).get('encoding') + return file_encoding + + +def multithread_reader(mapper, + reader, + num_workers=4, + buffer_size=1024, + batch_size=8, + drop_last=True): + from queue import Queue + end = EndSignal() + + # define a worker to read samples from reader to in_queue + def read_worker(reader, in_queue): + for i in reader(): + in_queue.put(i) + in_queue.put(end) + + # define a worker to handle samples from in_queue by mapper + # and put mapped samples into out_queue + def handle_worker(in_queue, out_queue, mapper): + sample = in_queue.get() + while not isinstance(sample, EndSignal): + if len(sample) == 2: + r = mapper(sample[0], sample[1]) + elif len(sample) == 3: + r = mapper(sample[0], sample[1], sample[2]) + else: + raise Exception('The sample\'s length must be 2 or 3.') + if is_valid(r): + out_queue.put(r) + sample = in_queue.get() + in_queue.put(end) + out_queue.put(end) + + def xreader(): + in_queue = Queue(buffer_size) + out_queue = Queue(buffer_size) + # start a read worker in a thread + target = read_worker + t = Thread(target=target, args=(reader, in_queue)) + t.daemon = True + t.start() + # start several handle_workers + target = handle_worker + args = (in_queue, out_queue, mapper) + workers = [] + for i in range(num_workers): + worker = Thread(target=target, args=args) + worker.daemon = True + workers.append(worker) + for w in workers: + w.start() + + batch_data = [] + sample = out_queue.get() + while not isinstance(sample, EndSignal): + batch_data.append(sample) + if len(batch_data) == batch_size: + yield batch_data + batch_data = [] + sample = out_queue.get() + finish = 1 + while finish < num_workers: + sample = out_queue.get() + if isinstance(sample, EndSignal): + finish += 1 + else: + batch_data.append(sample) + if len(batch_data) == batch_size: + yield batch_data + batch_data = [] + if not drop_last and len(batch_data) != 0: + yield batch_data + batch_data = [] + + return xreader + + +def multiprocess_reader(mapper, + reader, + num_workers=4, + buffer_size=1024, + batch_size=8, + drop_last=True): + from .shared_queue import SharedQueue as Queue + + def _read_into_queue(samples, mapper, queue): + end = EndSignal() + try: + for sample in samples: + if sample is None: + raise ValueError("sample has None") + if len(sample) == 2: + result = mapper(sample[0], sample[1]) + elif len(sample) == 3: + result = mapper(sample[0], sample[1], sample[2]) + else: + raise Exception('The sample\'s length must be 2 or 3.') + if is_valid(result): + queue.put(result) + queue.put(end) + except: + queue.put("") + six.reraise(*sys.exc_info()) + + def queue_reader(): + queue = Queue(buffer_size, memsize=3 * 1024**3) + total_samples = [[] for i in range(num_workers)] + for i, sample in enumerate(reader()): + index = i % num_workers + total_samples[index].append(sample) + for i in range(num_workers): + p = multiprocessing.Process( + target=_read_into_queue, args=(total_samples[i], mapper, queue)) + p.start() + + finish_num = 0 + batch_data = list() + while finish_num < num_workers: + sample = queue.get() + if isinstance(sample, EndSignal): + finish_num += 1 + elif sample == "": + raise ValueError("multiprocess reader raises an exception") + else: + batch_data.append(sample) + if len(batch_data) == batch_size: + yield batch_data + batch_data = [] + if len(batch_data) != 0 and not drop_last: + yield batch_data + batch_data = [] + + return queue_reader + + +class Dataset: + def __init__(self, + data_dir, + file_list, + label_list=None, + transforms=None, + num_workers='auto', + buffer_size=100, + parallel_method='thread', + shuffle=False): + if num_workers == 'auto': + import multiprocessing as mp + num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8 + if transforms is None: + raise Exception("transform should be defined.") + self.transforms = transforms + self.num_workers = num_workers + self.buffer_size = buffer_size + self.parallel_method = parallel_method + self.shuffle = shuffle + + self.file_list = list() + self.labels = list() + self._epoch = 0 + + if label_list is not None: + with open(label_list, encoding=get_encoding(label_list)) as f: + for line in f: + item = line.strip() + self.labels.append(item) + + with open(file_list, encoding=get_encoding(file_list)) as f: + for line in f: + items = line.strip().split() + if not is_pic(items[0]): + continue + full_path_im = osp.join(data_dir, items[0]) + full_path_label = osp.join(data_dir, items[1]) + if not osp.exists(full_path_im): + raise IOError( + 'The image file {} is not exist!'.format(full_path_im)) + if not osp.exists(full_path_label): + raise IOError('The image file {} is not exist!'.format( + full_path_label)) + self.file_list.append([full_path_im, full_path_label]) + self.num_samples = len(self.file_list) + logging.info("{} samples in file {}".format( + len(self.file_list), file_list)) + + def iterator(self): + self._epoch += 1 + self._pos = 0 + files = copy.deepcopy(self.file_list) + if self.shuffle: + random.shuffle(files) + files = files[:self.num_samples] + self.num_samples = len(files) + for f in files: + label_path = f[1] + sample = [f[0], None, label_path] + yield sample + + def generator(self, batch_size=1, drop_last=True): + self.batch_size = batch_size + parallel_reader = multithread_reader + if self.parallel_method == "process": + if platform.platform().startswith("Windows"): + logging.debug( + "multiprocess_reader is not supported in Windows platform, force to use multithread_reader." + ) + else: + parallel_reader = multiprocess_reader + return parallel_reader( + self.transforms, + self.iterator, + num_workers=self.num_workers, + buffer_size=self.buffer_size, + batch_size=batch_size, + drop_last=drop_last) diff --git a/dygraph/datasets/shared_queue/__init__.py b/dygraph/datasets/shared_queue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1662f739d491243a56d1cc70c13b7a91dbb46a7e --- /dev/null +++ b/dygraph/datasets/shared_queue/__init__.py @@ -0,0 +1,26 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +__all__ = ['SharedBuffer', 'SharedMemoryMgr', 'SharedQueue'] + +from .sharedmemory import SharedBuffer +from .sharedmemory import SharedMemoryMgr +from .sharedmemory import SharedMemoryError +from .queue import SharedQueue diff --git a/dygraph/datasets/shared_queue/queue.py b/dygraph/datasets/shared_queue/queue.py new file mode 100644 index 0000000000000000000000000000000000000000..7a67f98de700e74e0361ae80cfe983c594f88024 --- /dev/null +++ b/dygraph/datasets/shared_queue/queue.py @@ -0,0 +1,103 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +if six.PY3: + import pickle + from io import BytesIO as StringIO +else: + import cPickle as pickle + from cStringIO import StringIO + +import logging +import traceback +import multiprocessing as mp +from multiprocessing.queues import Queue +from .sharedmemory import SharedMemoryMgr + +logger = logging.getLogger(__name__) + + +class SharedQueueError(ValueError): + """ SharedQueueError + """ + pass + + +class SharedQueue(Queue): + """ a Queue based on shared memory to communicate data between Process, + and it's interface is compatible with 'multiprocessing.queues.Queue' + """ + + def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None): + """ init + """ + if six.PY3: + super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context()) + else: + super(SharedQueue, self).__init__(maxsize) + + if mem_mgr is not None: + self._shared_mem = mem_mgr + else: + self._shared_mem = SharedMemoryMgr( + capacity=memsize, pagesize=pagesize) + + def put(self, obj, **kwargs): + """ put an object to this queue + """ + obj = pickle.dumps(obj, -1) + buff = None + try: + buff = self._shared_mem.malloc(len(obj)) + buff.put(obj) + super(SharedQueue, self).put(buff, **kwargs) + except Exception as e: + stack_info = traceback.format_exc() + err_msg = 'failed to put a element to SharedQueue '\ + 'with stack info[%s]' % (stack_info) + logger.warn(err_msg) + + if buff is not None: + buff.free() + raise e + + def get(self, **kwargs): + """ get an object from this queue + """ + buff = None + try: + buff = super(SharedQueue, self).get(**kwargs) + data = buff.get() + return pickle.load(StringIO(data)) + except Exception as e: + stack_info = traceback.format_exc() + err_msg = 'failed to get element from SharedQueue '\ + 'with stack info[%s]' % (stack_info) + logger.warn(err_msg) + raise e + finally: + if buff is not None: + buff.free() + + def release(self): + self._shared_mem.release() + self._shared_mem = None diff --git a/dygraph/datasets/shared_queue/sharedmemory.py b/dygraph/datasets/shared_queue/sharedmemory.py new file mode 100644 index 0000000000000000000000000000000000000000..8df13752194575628d173ac6c96d40ae404590b5 --- /dev/null +++ b/dygraph/datasets/shared_queue/sharedmemory.py @@ -0,0 +1,532 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import time +import math +import struct +import sys +import six + +if six.PY3: + import pickle +else: + import cPickle as pickle + +import json +import uuid +import random +import numpy as np +import weakref +import logging +from multiprocessing import Lock +from multiprocessing import RawArray + +logger = logging.getLogger(__name__) + + +class SharedMemoryError(ValueError): + """ SharedMemoryError + """ + pass + + +class SharedBufferError(SharedMemoryError): + """ SharedBufferError + """ + pass + + +class MemoryFullError(SharedMemoryError): + """ MemoryFullError + """ + + def __init__(self, errmsg=''): + super(MemoryFullError, self).__init__() + self.errmsg = errmsg + + +def memcopy(dst, src, offset=0, length=None): + """ copy data from 'src' to 'dst' in bytes + """ + length = length if length is not None else len(src) + assert type(dst) == np.ndarray, 'invalid type for "dst" in memcopy' + if type(src) is not np.ndarray: + if type(src) is str and six.PY3: + src = src.encode() + src = np.frombuffer(src, dtype='uint8', count=len(src)) + + dst[:] = src[offset:offset + length] + + +class SharedBuffer(object): + """ Buffer allocated from SharedMemoryMgr, and it stores data on shared memory + + note that: + every instance of this should be freed explicitely by calling 'self.free' + """ + + def __init__(self, owner, capacity, pos, size=0, alloc_status=''): + """ Init + + Args: + owner (str): manager to own this buffer + capacity (int): capacity in bytes for this buffer + pos (int): page position in shared memory + size (int): bytes already used + alloc_status (str): debug info about allocator when allocate this + """ + self._owner = owner + self._cap = capacity + self._pos = pos + self._size = size + self._alloc_status = alloc_status + assert self._pos >= 0 and self._cap > 0, \ + "invalid params[%d:%d] to construct SharedBuffer" \ + % (self._pos, self._cap) + + def owner(self): + """ get owner + """ + return SharedMemoryMgr.get_mgr(self._owner) + + def put(self, data, override=False): + """ put data to this buffer + + Args: + data (str): data to be stored in this buffer + + Returns: + None + + Raises: + SharedMemoryError when not enough space in this buffer + """ + assert type(data) in [str, bytes], \ + 'invalid type[%s] for SharedBuffer::put' % (str(type(data))) + if self._size > 0 and not override: + raise SharedBufferError('already has already been setted before') + + if self.capacity() < len(data): + raise SharedBufferError('data[%d] is larger than size of buffer[%s]'\ + % (len(data), str(self))) + + self.owner().put_data(self, data) + self._size = len(data) + + def get(self, offset=0, size=None, no_copy=True): + """ get the data stored this buffer + + Args: + offset (int): position for the start point to 'get' + size (int): size to get + + Returns: + data (np.ndarray('uint8')): user's data in numpy + which is passed in by 'put' + None: if no data stored in + """ + offset = offset if offset >= 0 else self._size + offset + if self._size <= 0: + return None + + size = self._size if size is None else size + assert offset + size <= self._cap, 'invalid offset[%d] '\ + 'or size[%d] for capacity[%d]' % (offset, size, self._cap) + return self.owner().get_data(self, offset, size, no_copy=no_copy) + + def size(self): + """ bytes of used memory + """ + return self._size + + def resize(self, size): + """ resize the used memory to 'size', should not be greater than capacity + """ + assert size >= 0 and size <= self._cap, \ + "invalid size[%d] for resize" % (size) + + self._size = size + + def capacity(self): + """ size of allocated memory + """ + return self._cap + + def __str__(self): + """ human readable format + """ + return "SharedBuffer(owner:%s, pos:%d, size:%d, "\ + "capacity:%d, alloc_status:[%s], pid:%d)" \ + % (str(self._owner), self._pos, self._size, \ + self._cap, self._alloc_status, os.getpid()) + + def free(self): + """ free this buffer to it's owner + """ + if self._owner is not None: + self.owner().free(self) + self._owner = None + self._cap = 0 + self._pos = -1 + self._size = 0 + return True + else: + return False + + +class PageAllocator(object): + """ allocator used to malloc and free shared memory which + is split into pages + """ + s_allocator_header = 12 + + def __init__(self, base, total_pages, page_size): + """ init + """ + self._magic_num = 1234321000 + random.randint(100, 999) + self._base = base + self._total_pages = total_pages + self._page_size = page_size + + header_pages = int( + math.ceil((total_pages + self.s_allocator_header) / page_size)) + + self._header_pages = header_pages + self._free_pages = total_pages - header_pages + self._header_size = self._header_pages * page_size + self._reset() + + def _dump_alloc_info(self, fname): + hpages, tpages, pos, used = self.header() + + start = self.s_allocator_header + end = start + self._page_size * hpages + alloc_flags = self._base[start:end].tostring() + info = { + 'magic_num': self._magic_num, + 'header_pages': hpages, + 'total_pages': tpages, + 'pos': pos, + 'used': used + } + info['alloc_flags'] = alloc_flags + fname = fname + '.' + str(uuid.uuid4())[:6] + with open(fname, 'wb') as f: + f.write(pickle.dumps(info, -1)) + logger.warn('dump alloc info to file[%s]' % (fname)) + + def _reset(self): + alloc_page_pos = self._header_pages + used_pages = self._header_pages + header_info = struct.pack( + str('III'), self._magic_num, alloc_page_pos, used_pages) + assert len(header_info) == self.s_allocator_header, \ + 'invalid size of header_info' + + memcopy(self._base[0:self.s_allocator_header], header_info) + self.set_page_status(0, self._header_pages, '1') + self.set_page_status(self._header_pages, self._free_pages, '0') + + def header(self): + """ get header info of this allocator + """ + header_str = self._base[0:self.s_allocator_header].tostring() + magic, pos, used = struct.unpack(str('III'), header_str) + + assert magic == self._magic_num, \ + 'invalid header magic[%d] in shared memory' % (magic) + return self._header_pages, self._total_pages, pos, used + + def empty(self): + """ are all allocatable pages available + """ + header_pages, pages, pos, used = self.header() + return header_pages == used + + def full(self): + """ are all allocatable pages used + """ + header_pages, pages, pos, used = self.header() + return header_pages + used == pages + + def __str__(self): + header_pages, pages, pos, used = self.header() + desc = '{page_info[magic:%d,total:%d,used:%d,header:%d,alloc_pos:%d,pagesize:%d]}' \ + % (self._magic_num, pages, used, header_pages, pos, self._page_size) + return 'PageAllocator:%s' % (desc) + + def set_alloc_info(self, alloc_pos, used_pages): + """ set allocating position to new value + """ + memcopy(self._base[4:12], struct.pack(str('II'), alloc_pos, used_pages)) + + def set_page_status(self, start, page_num, status): + """ set pages from 'start' to 'end' with new same status 'status' + """ + assert status in ['0', '1'], 'invalid status[%s] for page status '\ + 'in allocator[%s]' % (status, str(self)) + start += self.s_allocator_header + end = start + page_num + assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\ + 'in allocator[%s]' % (end, str(self)) + memcopy(self._base[start:end], str(status * page_num)) + + def get_page_status(self, start, page_num, ret_flag=False): + start += self.s_allocator_header + end = start + page_num + assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\ + 'in allocator[%s]' % (end, str(self)) + status = self._base[start:end].tostring().decode() + if ret_flag: + return status + + zero_num = status.count('0') + if zero_num == 0: + return (page_num, 1) + else: + return (zero_num, 0) + + def malloc_page(self, page_num): + header_pages, pages, pos, used = self.header() + end = pos + page_num + if end > pages: + pos = self._header_pages + end = pos + page_num + + start_pos = pos + flags = '' + while True: + # maybe flags already has some '0' pages, + # so just check 'page_num - len(flags)' pages + flags = self.get_page_status(pos, page_num, ret_flag=True) + + if flags.count('0') == page_num: + break + + # not found enough pages, so shift to next few pages + free_pos = flags.rfind('1') + 1 + pos += free_pos + end = pos + page_num + if end > pages: + pos = self._header_pages + end = pos + page_num + flags = '' + + # not found available pages after scan all pages + if pos <= start_pos and end >= start_pos: + logger.debug('not found available pages after scan all pages') + break + + page_status = (flags.count('0'), 0) + if page_status != (page_num, 0): + free_pages = self._total_pages - used + if free_pages == 0: + err_msg = 'all pages have been used:%s' % (str(self)) + else: + err_msg = 'not found available pages with page_status[%s] '\ + 'and %d free pages' % (str(page_status), free_pages) + err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] and allocator status[%s]' \ + % (page_num, pos, err_msg, str(self)) + raise MemoryFullError(err_msg) + + self.set_page_status(pos, page_num, '1') + used += page_num + self.set_alloc_info(end, used) + return pos + + def free_page(self, start, page_num): + """ free 'page_num' pages start from 'start' + """ + page_status = self.get_page_status(start, page_num) + assert page_status == (page_num, 1), \ + 'invalid status[%s] when free [%d, %d]' \ + % (str(page_status), start, page_num) + self.set_page_status(start, page_num, '0') + _, _, pos, used = self.header() + used -= page_num + self.set_alloc_info(pos, used) + + +DEFAULT_SHARED_MEMORY_SIZE = 1024 * 1024 * 1024 + + +class SharedMemoryMgr(object): + """ manage a continouse block of memory, provide + 'malloc' to allocate new buffer, and 'free' to free buffer + """ + s_memory_mgrs = weakref.WeakValueDictionary() + s_mgr_num = 0 + s_log_statis = False + + @classmethod + def get_mgr(cls, id): + """ get a SharedMemoryMgr with size of 'capacity' + """ + assert id in cls.s_memory_mgrs, 'invalid id[%s] for memory managers' % ( + id) + return cls.s_memory_mgrs[id] + + def __init__(self, capacity=None, pagesize=None): + """ init + """ + logger.debug('create SharedMemoryMgr') + + pagesize = 64 * 1024 if pagesize is None else pagesize + assert type(pagesize) is int, "invalid type of pagesize[%s]" \ + % (str(pagesize)) + + capacity = DEFAULT_SHARED_MEMORY_SIZE if capacity is None else capacity + assert type(capacity) is int, "invalid type of capacity[%s]" \ + % (str(capacity)) + + assert capacity > 0, '"size of shared memory should be greater than 0' + self._released = False + self._cap = capacity + self._page_size = pagesize + + assert self._cap % self._page_size == 0, \ + "capacity[%d] and pagesize[%d] are not consistent" \ + % (self._cap, self._page_size) + self._total_pages = self._cap // self._page_size + + self._pid = os.getpid() + SharedMemoryMgr.s_mgr_num += 1 + self._id = self._pid * 100 + SharedMemoryMgr.s_mgr_num + SharedMemoryMgr.s_memory_mgrs[self._id] = self + self._locker = Lock() + self._setup() + + def _setup(self): + self._shared_mem = RawArray('c', self._cap) + self._base = np.frombuffer( + self._shared_mem, dtype='uint8', count=self._cap) + self._locker.acquire() + try: + self._allocator = PageAllocator(self._base, self._total_pages, + self._page_size) + finally: + self._locker.release() + + def malloc(self, size, wait=True): + """ malloc a new SharedBuffer + + Args: + size (int): buffer size to be malloc + wait (bool): whether to wait when no enough memory + + Returns: + SharedBuffer + + Raises: + SharedMemoryError when not found available memory + """ + page_num = int(math.ceil(size / self._page_size)) + size = page_num * self._page_size + + start = None + ct = 0 + errmsg = '' + while True: + self._locker.acquire() + try: + start = self._allocator.malloc_page(page_num) + alloc_status = str(self._allocator) + except MemoryFullError as e: + start = None + errmsg = e.errmsg + if not wait: + raise e + finally: + self._locker.release() + + if start is None: + time.sleep(0.1) + if ct % 100 == 0: + logger.warn('not enough space for reason[%s]' % (errmsg)) + + ct += 1 + else: + break + + return SharedBuffer(self._id, size, start, alloc_status=alloc_status) + + def free(self, shared_buf): + """ free a SharedBuffer + + Args: + shared_buf (SharedBuffer): buffer to be freed + + Returns: + None + + Raises: + SharedMemoryError when failed to release this buffer + """ + assert shared_buf._owner == self._id, "invalid shared_buf[%s] "\ + "for it's not allocated from me[%s]" % (str(shared_buf), str(self)) + cap = shared_buf.capacity() + start_page = shared_buf._pos + page_num = cap // self._page_size + + #maybe we don't need this lock here + self._locker.acquire() + try: + self._allocator.free_page(start_page, page_num) + finally: + self._locker.release() + + def put_data(self, shared_buf, data): + """ fill 'data' into 'shared_buf' + """ + assert len(data) <= shared_buf.capacity(), 'too large data[%d] '\ + 'for this buffer[%s]' % (len(data), str(shared_buf)) + start = shared_buf._pos * self._page_size + end = start + len(data) + assert start >= 0 and end <= self._cap, "invalid start "\ + "position[%d] when put data to buff:%s" % (start, str(shared_buf)) + self._base[start:end] = np.frombuffer(data, 'uint8', len(data)) + + def get_data(self, shared_buf, offset, size, no_copy=True): + """ extract 'data' from 'shared_buf' in range [offset, offset + size) + """ + start = shared_buf._pos * self._page_size + start += offset + if no_copy: + return self._base[start:start + size] + else: + return self._base[start:start + size].tostring() + + def __str__(self): + return 'SharedMemoryMgr:{id:%d, %s}' % (self._id, str(self._allocator)) + + def __del__(self): + if SharedMemoryMgr.s_log_statis: + logger.info('destroy [%s]' % (self)) + + if not self._released and not self._allocator.empty(): + logger.debug( + 'not empty when delete this SharedMemoryMgr[%s]' % (self)) + else: + self._released = True + + if self._id in SharedMemoryMgr.s_memory_mgrs: + del SharedMemoryMgr.s_memory_mgrs[self._id] + SharedMemoryMgr.s_mgr_num -= 1 diff --git a/dygraph/main.py b/dygraph/main.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb861fb8ed10427657f2c3aee6fb04ef18321fe --- /dev/null +++ b/dygraph/main.py @@ -0,0 +1,46 @@ +from datasets import Dataset +import transforms +import paddle.fluid as fluid +from models import UNet + +data_dir = '/ssd1/chenguowei01/dataset/optic_disc_seg' +train_list = '/ssd1/chenguowei01/dataset/optic_disc_seg/train_list.txt' +val_list = '/ssd1/chenguowei01/dataset/optic_disc_seg/val_list.txt' +img_file = data_dir + '/JPEGImages/H0005.jpg' + +train_transforms = transforms.Compose([ + transforms.Resize((192, 192)), + transforms.RandomHorizontalFlip(), + transforms.Normalize() +]) + +train_dataset = Dataset( + data_dir=data_dir, + file_list=train_list, + transforms=train_transforms, + num_workers='auto', + buffer_size=100, + parallel_method='thread', + shuffle=True) + +eval_transforms = transforms.Compose( + [transforms.Resize((192, 192)), + transforms.Normalize()]) + +eval_dataset = Dataset( + data_dir=data_dir, + file_list=val_list, + transforms=eval_transforms, + num_workers='auto', + buffer_size=100, + parallel_method='thread', + shuffle=True) + +model = UNet(num_classes=2) +with fluid.dygraph.guard(model.places): + model.build_model() + #model.load_model('output/epoch_10/') + model.train( + num_epochs=10, train_dataset=train_dataset, eval_dataset=eval_dataset) + model.evaluate(eval_dataset) + model.predict(img_file, eval_transforms) diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5bb8150e5af38deeee37716e345aff174c9064e --- /dev/null +++ b/dygraph/models/__init__.py @@ -0,0 +1,16 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet import UNet diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f4989e699aa18af18662e02dbb2611a86b7dfa86 --- /dev/null +++ b/dygraph/models/unet.py @@ -0,0 +1,307 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import paddle.fluid as fluid +import os +from os import path as osp +import numpy as np +from collections import OrderedDict +import copy +import math +import time +import tqdm +import cv2 +import yaml +import shutil + +from paddle.fluid.dygraph.base import to_variable + +import utils +import utils.logging as logging +from utils import seconds_to_hms +from utils import ConfusionMatrix +from utils import get_environ_info +import nets +import transforms as T + + +def dict2str(dict_input): + out = '' + for k, v in dict_input.items(): + try: + v = round(float(v), 6) + except: + pass + out = out + '{}={}, '.format(k, v) + return out.strip(', ') + + +class UNet(object): + # DeepLab mobilenet + def __init__(self, + num_classes=2, + upsample_mode='bilinear', + ignore_index=255): + + self.num_classes = num_classes + self.upsample_mode = upsample_mode + self.ignore_index = ignore_index + + self.labels = None + self.env_info = get_environ_info() + if self.env_info['place'] == 'cpu': + self.places = fluid.CPUPlace() + else: + self.places = fluid.CUDAPlace(0) + + def build_model(self): + self.model = nets.UNet(self.num_classes, self.upsample_mode) + + def arrange_transform(self, transforms, mode='train'): + arrange_transform = T.ArrangeSegmenter + if type(transforms.transforms[-1]).__name__.startswith('Arrange'): + transforms.transforms[-1] = arrange_transform(mode=mode) + else: + transforms.transforms.append(arrange_transform(mode=mode)) + + def load_model(self, model_dir): + ckpt_path = osp.join(model_dir, 'model') + para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path) + self.model.set_dict(para_state_dict) + + def save_model(self, state_dict, save_dir): + if not osp.isdir(save_dir): + if osp.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + fluid.save_dygraph(state_dict, osp.join(save_dir, 'model')) + + def default_optimizer(self, + learning_rate, + num_epochs, + num_steps_each_epoch, + parameter_list=None, + lr_decay_power=0.9, + regularization_coeff=4e-5): + decay_step = num_epochs * num_steps_each_epoch + lr_decay = fluid.layers.polynomial_decay( + learning_rate, + decay_step, + end_learning_rate=0, + power=lr_decay_power) + optimizer = fluid.optimizer.Momentum( + lr_decay, + momentum=0.9, + parameter_list=parameter_list, + regularization=fluid.regularizer.L2Decay( + regularization_coeff=regularization_coeff)) + return optimizer + + def _get_loss(self, logit, label): + mask = label != self.ignore_index + mask = fluid.layers.cast(mask, 'float32') + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + ignore_index=self.ignore_index, + return_softmax=True, + axis=1) + + loss = loss * mask + avg_loss = fluid.layers.mean(loss) / (fluid.layers.mean(mask) + 0.00001) + + label.stop_gradient = True + mask.stop_gradient = True + return avg_loss + + def train(self, + num_epochs, + train_dataset, + train_batch_size=2, + eval_dataset=None, + save_interval_epochs=1, + log_interval_steps=2, + save_dir='output', + pretrained_weights=None, + resume_weights=None, + optimizer=None, + learning_rate=0.01, + lr_decay_power=0.9, + regularization_coeff=4e-5, + use_vdl=False): + self.labels = train_dataset.labels + self.train_transforms = train_dataset.transforms + self.train_init = locals() + self.begin_epoch = 0 + if optimizer is None: + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + optimizer = self.default_optimizer( + learning_rate=learning_rate, + num_epochs=num_epochs, + num_steps_each_epoch=num_steps_each_epoch, + parameter_list=self.model.parameters(), + lr_decay_power=lr_decay_power, + regularization_coeff=regularization_coeff) + + # to do: 预训练模型加载, resume + + if self.begin_epoch >= num_epochs: + raise ValueError( + ("begin epoch[{}] is larger than num_epochs[{}]").format( + self.begin_epoch, num_epochs)) + + if not osp.isdir(save_dir): + if osp.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + + # add arrange op to transforms + self.arrange_transform( + transforms=train_dataset.transforms, mode='train') + + if eval_dataset is not None: + self.eval_transforms = eval_dataset.transforms + self.test_transforms = copy.deepcopy(eval_dataset.transforms) + + data_generator = train_dataset.generator( + batch_size=train_batch_size, drop_last=True) + total_num_steps = math.floor( + train_dataset.num_samples / train_batch_size) + + for i in range(self.begin_epoch, num_epochs): + for step, data in enumerate(data_generator()): + images = np.array([d[0] for d in data]) + labels = np.array([d[1] for d in data]).astype('int64') + images = to_variable(images) + labels = to_variable(labels) + logit = self.model(images) + loss = self._get_loss(logit, labels) + loss.backward() + optimizer.minimize(loss) + print("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( + i + 1, num_epochs, step + 1, total_num_steps, loss.numpy())) + + if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1: + current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1)) + if not osp.isdir(current_save_dir): + os.makedirs(current_save_dir) + self.save_model(self.model.state_dict(), current_save_dir) + if eval_dataset is not None: + self.model.eval() + self.evaluate(eval_dataset, batch_size=train_batch_size) + self.model.train() + + def evaluate(self, eval_dataset, batch_size=1, epoch_id=None): + """评估。 + + Args: + eval_dataset (paddlex.datasets): 评估数据读取器。 + batch_size (int): 评估时的batch大小。默认1。 + epoch_id (int): 当前评估模型所在的训练轮数。 + return_details (bool): 是否返回详细信息。默认False。 + + Returns: + dict: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、 + 'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。 + tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details), + 包含关键字:'confusion_matrix',表示评估的混淆矩阵。 + """ + self.model.eval() + self.arrange_transform(transforms=eval_dataset.transforms, mode='train') + total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size) + conf_mat = ConfusionMatrix(self.num_classes, streaming=True) + data_generator = eval_dataset.generator( + batch_size=batch_size, drop_last=False) + logging.info( + "Start to evaluating(total_samples={}, total_steps={})...".format( + eval_dataset.num_samples, total_steps)) + for step, data in tqdm.tqdm( + enumerate(data_generator()), total=total_steps): + images = np.array([d[0] for d in data]) + labels = np.array([d[1] for d in data]) + images = to_variable(images) + + logit = self.model(images) + pred = fluid.layers.argmax(logit, axis=1) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + pred = pred.numpy() + + mask = labels != self.ignore_index + conf_mat.calculate(pred=pred, label=labels, ignore=mask) + _, iou = conf_mat.mean_iou() + + logging.debug("[EVAL] Epoch={}, Step={}/{}, iou={}".format( + epoch_id, step + 1, total_steps, iou)) + + category_iou, miou = conf_mat.mean_iou() + category_acc, macc = conf_mat.accuracy() + + metrics = OrderedDict( + zip(['miou', 'category_iou', 'macc', 'category_acc', 'kappa'], + [miou, category_iou, macc, category_acc, + conf_mat.kappa()])) + + logging.info('[EVAL] Finished, Epoch={}, {} .'.format( + epoch_id, dict2str(metrics))) + return metrics + + def predict(self, im_file, transforms=None): + """预测。 + Args: + img_file(str|np.ndarray): 预测图像。 + transforms(paddlex.cv.transforms): 数据预处理操作。 + + Returns: + dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图, + 像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes) + """ + if isinstance(im_file, str): + if not osp.exists(im_file): + raise ValueError( + 'The Image file does not exist: {}'.format(im_file)) + + if transforms is None and not hasattr(self, 'test_transforms'): + raise Exception("transforms need to be defined, now is None.") + if transforms is not None: + self.arrange_transform(transforms=transforms, mode='test') + im, im_info = transforms(im_file) + else: + self.arrange_transform(transforms=self.test_transforms, mode='test') + im, im_info = self.test_transforms(im_file) + im = np.expand_dims(im, axis=0) + im = to_variable(im) + logit = self.model(im) + logit = fluid.layers.softmax(logit) + pred = fluid.layers.argmax(logit, axis=1) + logit = logit.numpy() + pred = pred.numpy() + + logit = np.squeeze(logit) + logit = np.transpose(logit, (1, 2, 0)) + pred = np.squeeze(pred).astype('uint8') + keys = list(im_info.keys()) + print(pred.shape, logit.shape) + for k in keys[::-1]: + if k == 'shape_before_resize': + h, w = im_info[k][0], im_info[k][1] + pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) + logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR) + elif k == 'shape_before_padding': + h, w = im_info[k][0], im_info[k][1] + pred = pred[0:h, 0:w] + logit = logit[0:h, 0:w, :] + + return {'label_map': pred, 'score_map': logit} diff --git a/dygraph/nets/__init__.py b/dygraph/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26d9fb22e1a39ed75217e0d82cd1b0f4eccb6e4c --- /dev/null +++ b/dygraph/nets/__init__.py @@ -0,0 +1,16 @@ +# coding: utf8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet import UNet diff --git a/dygraph/nets/unet.py b/dygraph/nets/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5ed3b895dceb6ba44afff999280a4a6ccbf9cd --- /dev/null +++ b/dygraph/nets/unet.py @@ -0,0 +1,220 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +import paddle.fluid as fluid +from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D +import contextlib + +regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0) +name_scope = "" + + +@contextlib.contextmanager +def scope(name): + global name_scope + bk = name_scope + name_scope = name_scope + name + '/' + yield + name_scope = bk + + +class UNet(fluid.dygraph.Layer): + def __init__( + self, + num_classes, + upsample_mode='bilinear', + ): + super().__init__() + self.encode = Encoder() + self.decode = Decode(upsample_mode=upsample_mode) + self.get_logit = GetLogit(64, num_classes) + + def forward(self, x): + encode_data, short_cuts = self.encode(x) + decode_data = self.decode(encode_data, short_cuts) + logit = self.get_logit(decode_data) + return logit + + +class Encoder(fluid.dygraph.Layer): + def __init__(self): + super().__init__() + with scope('encode'): + with scope('block1'): + self.double_conv = DoubleConv(3, 64) + with scope('block1'): + self.down1 = Down(64, 128) + with scope('block2'): + self.down2 = Down(128, 256) + with scope('block3'): + self.down3 = Down(256, 512) + with scope('block4'): + self.down4 = Down(512, 512) + + def forward(self, x): + short_cuts = [] + x = self.double_conv(x) + short_cuts.append(x) + x = self.down1(x) + short_cuts.append(x) + x = self.down2(x) + short_cuts.append(x) + x = self.down3(x) + short_cuts.append(x) + x = self.down4(x) + return x, short_cuts + + +class Decode(fluid.dygraph.Layer): + def __init__(self, upsample_mode='bilinear'): + super().__init__() + with scope('decode'): + with scope('decode1'): + self.up1 = Up(512, 256, upsample_mode) + with scope('decode2'): + self.up2 = Up(256, 128, upsample_mode) + with scope('decode3'): + self.up3 = Up(128, 64, upsample_mode) + with scope('decode4'): + self.up4 = Up(64, 64, upsample_mode) + + def forward(self, x, short_cuts): + x = self.up1(x, short_cuts[3]) + x = self.up2(x, short_cuts[2]) + x = self.up3(x, short_cuts[1]) + x = self.up4(x, short_cuts[0]) + return x + + +class GetLogit(fluid.dygraph.Layer): + def __init__(self): + super().__init__() + + +class DoubleConv(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + with scope('conv0'): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.TruncatedNormal( + loc=0.0, scale=0.33)) + self.conv0 = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + param_attr=param_attr) + self.bn0 = BatchNorm( + num_channels=num_filters, + param_attr=fluid.ParamAttr( + name=name_scope + 'gamma', regularizer=regularizer), + bias_attr=fluid.ParamAttr( + name=name_scope + 'beta', regularizer=regularizer), + moving_mean_name=name_scope + 'moving_mean', + moving_variance_name=name_scope + 'moving_variance') + with scope('conv1'): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.TruncatedNormal( + loc=0.0, scale=0.33)) + self.conv1 = Conv2D( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + param_attr=param_attr) + self.bn1 = BatchNorm( + num_channels=num_filters, + param_attr=fluid.ParamAttr( + name=name_scope + 'gamma', regularizer=regularizer), + bias_attr=fluid.ParamAttr( + name=name_scope + 'beta', regularizer=regularizer), + moving_mean_name=name_scope + 'moving_mean', + moving_variance_name=name_scope + 'moving_variance') + + def forward(self, x): + x = self.conv0(x) + x = self.bn0(x) + x = fluid.layers.relu(x) + x = self.conv1(x) + x = self.bn1(x) + x = fluid.layers.relu(x) + return x + + +class Down(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + with scope("down"): + self.max_pool = Pool2D( + pool_size=2, pool_type='max', pool_stride=2, pool_padding=0) + self.double_conv = DoubleConv(num_channels, num_filters) + + def forward(self, x): + x = self.max_pool(x) + x = self.double_conv(x) + return x + + +class Up(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters, upsample_mode): + super().__init__() + self.upsample_mode = upsample_mode + with scope('up'): + if upsample_mode == 'bilinear': + self.double_conv = DoubleConv(2 * num_channels, num_filters) + if not upsample_mode == 'bilinear': + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.XavierInitializer(), + ) + self.deconv = fluid.dygraph.Conv2DTranspose( + num_channels=num_channels, + num_filters=num_filters // 2, + filter_size=2, + stride=2, + padding=0, + param_attr=param_attr) + self.double_conv = DoubleConv(num_channels + num_filters // 2, + num_filters) + + def forward(self, x, short_cut): + if self.upsample_mode == 'bilinear': + short_cut_shape = fluid.layers.shape(short_cut) + x = fluid.layers.resize_bilinear(x, short_cut_shape[2:]) + else: + x = self.deconv(x) + x = fluid.layers.concat([x, short_cut], axis=1) + x = self.double_conv(x) + return x + + +class GetLogit(fluid.dygraph.Layer): + def __init__(self, num_channels, num_classes): + super().__init__() + with scope('logit'): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.TruncatedNormal( + loc=0.0, scale=0.01)) + self.conv = Conv2D( + num_channels=num_channels, + num_filters=num_classes, + filter_size=3, + stride=1, + padding=1, + param_attr=param_attr) + + def forward(self, x): + x = self.conv(x) + return x diff --git a/dygraph/transforms/__init__.py b/dygraph/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b537d3216c5b72c75bf3e201304d20638f3ea706 --- /dev/null +++ b/dygraph/transforms/__init__.py @@ -0,0 +1,17 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transforms import * +from . import functional diff --git a/dygraph/transforms/functional.py b/dygraph/transforms/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..c3f0265b0e1f4a250b41d5b86717b3d87797351d --- /dev/null +++ b/dygraph/transforms/functional.py @@ -0,0 +1,100 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +from PIL import Image, ImageEnhance + + +def normalize(im, mean, std): + im = im.astype(np.float32, copy=False) / 255.0 + im -= mean + im /= std + return im + + +def permute(im): + im = np.transpose(im, (2, 0, 1)) + return im + + +def resize(im, target_size=608, interp=cv2.INTER_LINEAR): + if isinstance(target_size, list) or isinstance(target_size, tuple): + w = target_size[0] + h = target_size[1] + else: + w = target_size + h = target_size + im = cv2.resize(im, (w, h), interpolation=interp) + return im + + +def resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR): + value = max(im.shape[0], im.shape[1]) + scale = float(long_size) / float(value) + resized_width = int(round(im.shape[1] * scale)) + resized_height = int(round(im.shape[0] * scale)) + + im = cv2.resize( + im, (resized_width, resized_height), interpolation=interpolation) + return im + + +def horizontal_flip(im): + if len(im.shape) == 3: + im = im[:, ::-1, :] + elif len(im.shape) == 2: + im = im[:, ::-1] + return im + + +def vertical_flip(im): + if len(im.shape) == 3: + im = im[::-1, :, :] + elif len(im.shape) == 2: + im = im[::-1, :] + return im + + +def brightness(im, brightness_lower, brightness_upper): + brightness_delta = np.random.uniform(brightness_lower, brightness_upper) + im = ImageEnhance.Brightness(im).enhance(brightness_delta) + return im + + +def contrast(im, contrast_lower, contrast_upper): + contrast_delta = np.random.uniform(contrast_lower, contrast_upper) + im = ImageEnhance.Contrast(im).enhance(contrast_delta) + return im + + +def saturation(im, saturation_lower, saturation_upper): + saturation_delta = np.random.uniform(saturation_lower, saturation_upper) + im = ImageEnhance.Color(im).enhance(saturation_delta) + return im + + +def hue(im, hue_lower, hue_upper): + hue_delta = np.random.uniform(hue_lower, hue_upper) + im = np.array(im.convert('HSV')) + im[:, :, 0] = im[:, :, 0] + hue_delta + im = Image.fromarray(im, mode='HSV').convert('RGB') + return im + + +def rotate(im, rotate_lower, rotate_upper): + rotate_delta = np.random.uniform(rotate_lower, rotate_upper) + im = im.rotate(int(rotate_delta)) + return im diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..261b29781f59ca211f66b240937952b51af5949f --- /dev/null +++ b/dygraph/transforms/transforms.py @@ -0,0 +1,915 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .functional import * +import random +import numpy as np +from PIL import Image +import cv2 +from collections import OrderedDict + + +class Compose: + """根据数据预处理/增强算子对输入数据进行操作。 + 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。 + + Args: + transforms (list): 数据预处理/增强算子。 + to_rgb (bool): 是否转化为rgb通道格式 + + Raises: + TypeError: transforms不是list对象 + ValueError: transforms元素个数小于1。 + + """ + + def __init__(self, transforms, to_rgb=False): + if not isinstance(transforms, list): + raise TypeError('The transforms must be a list!') + if len(transforms) < 1: + raise ValueError('The length of transforms ' + \ + 'must be equal or larger than 1!') + self.transforms = transforms + self.to_rgb = to_rgb + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (str/np.ndarray): 图像路径/图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息,dict中的字段如下: + - shape_before_resize (tuple): 图像resize之前的大小(h, w)。 + - shape_before_padding (tuple): 图像padding之前的大小(h, w)。 + label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。 + + Returns: + tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 + """ + + if im_info is None: + im_info = dict() + if isinstance(im, str): + im = cv2.imread(im).astype('float32') + if isinstance(label, str): + label = np.asarray(Image.open(label)) + if im is None: + raise ValueError('Can\'t read The image file {}!'.format(im)) + if self.to_rgb: + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + for op in self.transforms: + outputs = op(im, im_info, label) + im = outputs[0] + if len(outputs) >= 2: + im_info = outputs[1] + if len(outputs) == 3: + label = outputs[2] + return outputs + + +class RandomHorizontalFlip: + """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。 + + Args: + prob (float): 随机水平翻转的概率。默认值为0.5。 + + """ + + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if random.random() < self.prob: + im = horizontal_flip(im) + if label is not None: + label = horizontal_flip(label) + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomVerticalFlip: + """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。 + + Args: + prob (float): 随机垂直翻转的概率。默认值为0.1。 + """ + + def __init__(self, prob=0.1): + self.prob = prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if random.random() < self.prob: + im = vertical_flip(im) + if label is not None: + label = vertical_flip(label) + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class Resize: + """调整图像大小(resize)。 + + - 当目标大小(target_size)类型为int时,根据插值方式, + 将图像resize为[target_size, target_size]。 + - 当目标大小(target_size)类型为list或tuple时,根据插值方式, + 将图像resize为target_size。 + 注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。 + + Args: + target_size (int/list/tuple): 短边目标长度。默认为608。 + interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为 + ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"LINEAR"。 + + Raises: + TypeError: 形参数据类型不满足需求。 + ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC', + 'AREA', 'LANCZOS4', 'RANDOM']中。 + """ + + # The interpolation mode + interp_dict = { + 'NEAREST': cv2.INTER_NEAREST, + 'LINEAR': cv2.INTER_LINEAR, + 'CUBIC': cv2.INTER_CUBIC, + 'AREA': cv2.INTER_AREA, + 'LANCZOS4': cv2.INTER_LANCZOS4 + } + + def __init__(self, target_size=512, interp='LINEAR'): + self.interp = interp + if not (interp == "RANDOM" or interp in self.interp_dict): + raise ValueError("interp should be one of {}".format( + self.interp_dict.keys())) + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise TypeError( + 'when target is list or tuple, it should include 2 elements, but it is {}' + .format(target_size)) + elif not isinstance(target_size, int): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List or tuple, now is {}" + .format(type(target_size))) + + self.target_size = target_size + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict, 可选): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + 其中,im_info跟新字段为: + -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 + + Raises: + TypeError: 形参数据类型不满足需求。 + ValueError: 数据长度不匹配。 + """ + if im_info is None: + im_info = OrderedDict() + im_info['shape_before_resize'] = im.shape[:2] + if not isinstance(im, np.ndarray): + raise TypeError("Resize: image type is not numpy.") + if len(im.shape) != 3: + raise ValueError('Resize: image is not 3-dimensional.') + if self.interp == "RANDOM": + interp = random.choice(list(self.interp_dict.keys())) + else: + interp = self.interp + im = resize(im, self.target_size, self.interp_dict[interp]) + if label is not None: + label = resize(label, self.target_size, cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ResizeByLong: + """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 + + Args: + long_size (int): resize后图像的长边大小。 + """ + + def __init__(self, long_size): + self.long_size = long_size + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + 其中,im_info新增字段为: + -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 + """ + if im_info is None: + im_info = OrderedDict() + + im_info['shape_before_resize'] = im.shape[:2] + im = resize_long(im, self.long_size) + if label is not None: + label = resize_long(label, self.long_size, cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ResizeRangeScaling: + """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 + + Args: + min_value (int): 图像长边resize后的最小值。默认值400。 + max_value (int): 图像长边resize后的最大值。默认值600。 + + Raises: + ValueError: min_value大于max_value + """ + + def __init__(self, min_value=400, max_value=600): + if min_value > max_value: + raise ValueError('min_value must be less than max_value, ' + 'but they are {} and {}.'.format( + min_value, max_value)) + self.min_value = min_value + self.max_value = max_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.min_value == self.max_value: + random_size = self.max_value + else: + random_size = int( + np.random.uniform(self.min_value, self.max_value) + 0.5) + im = resize_long(im, random_size, cv2.INTER_LINEAR) + if label is not None: + label = resize_long(label, random_size, cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ResizeStepScaling: + """对图像按照某一个比例resize,这个比例以scale_step_size为步长 + 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。 + + Args: + min_scale_factor(float), resize最小尺度。默认值0.75。 + max_scale_factor (float), resize最大尺度。默认值1.25。 + scale_step_size (float), resize尺度范围间隔。默认值0.25。 + + Raises: + ValueError: min_scale_factor大于max_scale_factor + """ + + def __init__(self, + min_scale_factor=0.75, + max_scale_factor=1.25, + scale_step_size=0.25): + if min_scale_factor > max_scale_factor: + raise ValueError( + 'min_scale_factor must be less than max_scale_factor, ' + 'but they are {} and {}.'.format(min_scale_factor, + max_scale_factor)) + self.min_scale_factor = min_scale_factor + self.max_scale_factor = max_scale_factor + self.scale_step_size = scale_step_size + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.min_scale_factor == self.max_scale_factor: + scale_factor = self.min_scale_factor + + elif self.scale_step_size == 0: + scale_factor = np.random.uniform(self.min_scale_factor, + self.max_scale_factor) + + else: + num_steps = int((self.max_scale_factor - self.min_scale_factor) / + self.scale_step_size + 1) + scale_factors = np.linspace(self.min_scale_factor, + self.max_scale_factor, + num_steps).tolist() + np.random.shuffle(scale_factors) + scale_factor = scale_factors[0] + w = int(round(scale_factor * im.shape[1])) + h = int(round(scale_factor * im.shape[0])) + + im = resize(im, (w, h), cv2.INTER_LINEAR) + if label is not None: + label = resize(label, (w, h), cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class Normalize: + """对图像进行标准化。 + 1.尺度缩放到 [0,1]。 + 2.对图像进行减均值除以标准差操作。 + + Args: + mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。 + std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。 + + Raises: + ValueError: mean或std不是list对象。std包含0。 + """ + + def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + self.mean = mean + self.std = std + if not (isinstance(self.mean, list) and isinstance(self.std, list)): + raise ValueError("{}: input type is invalid.".format(self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im = normalize(im, mean, std) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class Padding: + """对图像或标注图像进行padding,padding方向为右和下。 + 根据提供的值对图像或标注图像进行padding操作。 + + Args: + target_size (int|list|tuple): padding后图像的大小。 + im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + label_padding_value (int): 标注图像padding的值。默认值为255。 + + Raises: + TypeError: target_size不是int|list|tuple。 + ValueError: target_size为list|tuple时元素个数不等于2。 + """ + + def __init__(self, + target_size, + im_padding_value=[127.5, 127.5, 127.5], + label_padding_value=255): + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise ValueError( + 'when target is list or tuple, it should include 2 elements, but it is {}' + .format(target_size)) + elif not isinstance(target_size, int): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List or tuple, now is {}" + .format(type(target_size))) + self.target_size = target_size + self.im_padding_value = im_padding_value + self.label_padding_value = label_padding_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + 其中,im_info新增字段为: + -shape_before_padding (tuple): 保存padding之前图像的形状(h, w)。 + + Raises: + ValueError: 输入图像im或label的形状大于目标值 + """ + if im_info is None: + im_info = OrderedDict() + im_info['shape_before_padding'] = im.shape[:2] + + im_height, im_width = im.shape[0], im.shape[1] + if isinstance(self.target_size, int): + target_height = self.target_size + target_width = self.target_size + else: + target_height = self.target_size[1] + target_width = self.target_size[0] + pad_height = target_height - im_height + pad_width = target_width - im_width + if pad_height < 0 or pad_width < 0: + raise ValueError( + 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' + .format(im_width, im_height, target_width, target_height)) + else: + im = cv2.copyMakeBorder( + im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) + if label is not None: + label = cv2.copyMakeBorder( + label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomPaddingCrop: + """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。 + + Args: + crop_size (int|list|tuple): 裁剪图像大小。默认为512。 + im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + label_padding_value (int): 标注图像padding的值。默认值为255。 + + Raises: + TypeError: crop_size不是int/list/tuple。 + ValueError: target_size为list/tuple时元素个数不等于2。 + """ + + def __init__(self, + crop_size=512, + im_padding_value=[127.5, 127.5, 127.5], + label_padding_value=255): + if isinstance(crop_size, list) or isinstance(crop_size, tuple): + if len(crop_size) != 2: + raise ValueError( + 'when crop_size is list or tuple, it should include 2 elements, but it is {}' + .format(crop_size)) + elif not isinstance(crop_size, int): + raise TypeError( + "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}" + .format(type(crop_size))) + self.crop_size = crop_size + self.im_padding_value = im_padding_value + self.label_padding_value = label_padding_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if isinstance(self.crop_size, int): + crop_width = self.crop_size + crop_height = self.crop_size + else: + crop_width = self.crop_size[0] + crop_height = self.crop_size[1] + + img_height = im.shape[0] + img_width = im.shape[1] + + if img_height == crop_height and img_width == crop_width: + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + else: + pad_height = max(crop_height - img_height, 0) + pad_width = max(crop_width - img_width, 0) + if (pad_height > 0 or pad_width > 0): + im = cv2.copyMakeBorder( + im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) + if label is not None: + label = cv2.copyMakeBorder( + label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) + img_height = im.shape[0] + img_width = im.shape[1] + + if crop_height > 0 and crop_width > 0: + h_off = np.random.randint(img_height - crop_height + 1) + w_off = np.random.randint(img_width - crop_width + 1) + + im = im[h_off:(crop_height + h_off), w_off:( + w_off + crop_width), :] + if label is not None: + label = label[h_off:(crop_height + h_off), w_off:( + w_off + crop_width)] + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomBlur: + """以一定的概率对图像进行高斯模糊。 + + Args: + prob (float): 图像模糊概率。默认为0.1。 + """ + + def __init__(self, prob=0.1): + self.prob = prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.prob <= 0: + n = 0 + elif self.prob >= 1: + n = 1 + else: + n = int(1.0 / self.prob) + if n > 0: + if np.random.randint(0, n) == 0: + radius = np.random.randint(3, 10) + if radius % 2 != 1: + radius = radius + 1 + if radius > 9: + radius = 9 + im = cv2.GaussianBlur(im, (radius, radius), 0, 0) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomRotation: + """对图像进行随机旋转。 + 在不超过最大旋转角度的情况下,图像进行随机旋转,当存在标注图像时,同步进行, + 并对旋转后的图像和标注图像进行相应的padding。 + + Args: + max_rotation (float): 最大旋转角度。默认为15度。 + im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + label_padding_value (int): 标注图像padding的值。默认为255。 + + """ + + def __init__(self, + max_rotation=15, + im_padding_value=[127.5, 127.5, 127.5], + label_padding_value=255): + self.max_rotation = max_rotation + self.im_padding_value = im_padding_value + self.label_padding_value = label_padding_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.max_rotation > 0: + (h, w) = im.shape[:2] + do_rotation = np.random.uniform(-self.max_rotation, + self.max_rotation) + pc = (w // 2, h // 2) + r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0) + cos = np.abs(r[0, 0]) + sin = np.abs(r[0, 1]) + + nw = int((h * sin) + (w * cos)) + nh = int((h * cos) + (w * sin)) + + (cx, cy) = pc + r[0, 2] += (nw / 2) - cx + r[1, 2] += (nh / 2) - cy + dsize = (nw, nh) + im = cv2.warpAffine( + im, + r, + dsize=dsize, + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.im_padding_value) + label = cv2.warpAffine( + label, + r, + dsize=dsize, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.label_padding_value) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomScaleAspect: + """裁剪并resize回原始尺寸的图像和标注图像。 + 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。 + + Args: + min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 + aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 + """ + + def __init__(self, min_scale=0.5, aspect_ratio=0.33): + self.min_scale = min_scale + self.aspect_ratio = aspect_ratio + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.min_scale != 0 and self.aspect_ratio != 0: + img_height = im.shape[0] + img_width = im.shape[1] + for i in range(0, 10): + area = img_height * img_width + target_area = area * np.random.uniform(self.min_scale, 1.0) + aspectRatio = np.random.uniform(self.aspect_ratio, + 1.0 / self.aspect_ratio) + + dw = int(np.sqrt(target_area * 1.0 * aspectRatio)) + dh = int(np.sqrt(target_area * 1.0 / aspectRatio)) + if (np.random.randint(10) < 5): + tmp = dw + dw = dh + dh = tmp + + if (dh < img_height and dw < img_width): + h1 = np.random.randint(0, img_height - dh) + w1 = np.random.randint(0, img_width - dw) + + im = im[h1:(h1 + dh), w1:(w1 + dw), :] + label = label[h1:(h1 + dh), w1:(w1 + dw)] + im = cv2.resize( + im, (img_width, img_height), + interpolation=cv2.INTER_LINEAR) + label = cv2.resize( + label, (img_width, img_height), + interpolation=cv2.INTER_NEAREST) + break + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomDistort: + """对图像进行随机失真。 + + 1. 对变换的操作顺序进行随机化操作。 + 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。 + + Args: + brightness_range (float): 明亮度因子的范围。默认为0.5。 + brightness_prob (float): 随机调整明亮度的概率。默认为0.5。 + contrast_range (float): 对比度因子的范围。默认为0.5。 + contrast_prob (float): 随机调整对比度的概率。默认为0.5。 + saturation_range (float): 饱和度因子的范围。默认为0.5。 + saturation_prob (float): 随机调整饱和度的概率。默认为0.5。 + hue_range (int): 色调因子的范围。默认为18。 + hue_prob (float): 随机调整色调的概率。默认为0.5。 + """ + + def __init__(self, + brightness_range=0.5, + brightness_prob=0.5, + contrast_range=0.5, + contrast_prob=0.5, + saturation_range=0.5, + saturation_prob=0.5, + hue_range=18, + hue_prob=0.5): + self.brightness_range = brightness_range + self.brightness_prob = brightness_prob + self.contrast_range = contrast_range + self.contrast_prob = contrast_prob + self.saturation_range = saturation_range + self.saturation_prob = saturation_prob + self.hue_range = hue_range + self.hue_prob = hue_prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + brightness_lower = 1 - self.brightness_range + brightness_upper = 1 + self.brightness_range + contrast_lower = 1 - self.contrast_range + contrast_upper = 1 + self.contrast_range + saturation_lower = 1 - self.saturation_range + saturation_upper = 1 + self.saturation_range + hue_lower = -self.hue_range + hue_upper = self.hue_range + ops = [brightness, contrast, saturation, hue] + random.shuffle(ops) + params_dict = { + 'brightness': { + 'brightness_lower': brightness_lower, + 'brightness_upper': brightness_upper + }, + 'contrast': { + 'contrast_lower': contrast_lower, + 'contrast_upper': contrast_upper + }, + 'saturation': { + 'saturation_lower': saturation_lower, + 'saturation_upper': saturation_upper + }, + 'hue': { + 'hue_lower': hue_lower, + 'hue_upper': hue_upper + } + } + prob_dict = { + 'brightness': self.brightness_prob, + 'contrast': self.contrast_prob, + 'saturation': self.saturation_prob, + 'hue': self.hue_prob + } + im = im.astype('uint8') + im = Image.fromarray(im) + for id in range(4): + params = params_dict[ops[id].__name__] + prob = prob_dict[ops[id].__name__] + params['im'] = im + if np.random.uniform(0, 1) < prob: + im = ops[id](**params) + im = np.asarray(im).astype('float32') + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ArrangeSegmenter: + """获取训练/验证/预测所需的信息。 + + Args: + mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。 + + Raises: + ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内 + """ + + def __init__(self, mode): + if mode not in ['train', 'eval', 'test', 'quant']: + raise ValueError( + "mode should be defined as one of ['train', 'eval', 'test', 'quant']!" + ) + self.mode = mode + + def __call__(self, im, im_info, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为 + 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。 + """ + im = permute(im) + if self.mode == 'train' or self.mode == 'eval': + label = label[np.newaxis, :, :] + return (im, label) + elif self.mode == 'test': + return (im, im_info) + else: + return (im, ) diff --git a/dygraph/utils/__init__.py b/dygraph/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6978c8a43e18cfe172f480b08092095752ccdb --- /dev/null +++ b/dygraph/utils/__init__.py @@ -0,0 +1,18 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import logging +from .metrics import ConfusionMatrix +from .utils import * diff --git a/dygraph/utils/logging.py b/dygraph/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0c25ff554cb80b9435e72ed47432d4bb6e6813 --- /dev/null +++ b/dygraph/utils/logging.py @@ -0,0 +1,47 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import sys + +levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} +log_level = 2 + + +def log(level=2, message=""): + current_time = time.time() + time_array = time.localtime(current_time) + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) + if log_level >= level: + print("{} [{}]\t{}".format(current_time, levels[level], + message).encode("utf-8").decode("latin1")) + sys.stdout.flush() + + +def debug(message=""): + log(level=3, message=message) + + +def info(message=""): + log(level=2, message=message) + + +def warning(message=""): + log(level=1, message=message) + + +def error(message=""): + log(level=0, message=message) diff --git a/dygraph/utils/metrics.py b/dygraph/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bab020ab44131f5d46643cf0522dd6acf483ee --- /dev/null +++ b/dygraph/utils/metrics.py @@ -0,0 +1,145 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +from scipy.sparse import csr_matrix + + +class ConfusionMatrix(object): + """ + Confusion Matrix for segmentation evaluation + """ + + def __init__(self, num_classes=2, streaming=False): + self.confusion_matrix = np.zeros([num_classes, num_classes], + dtype='int64') + self.num_classes = num_classes + self.streaming = streaming + + def calculate(self, pred, label, ignore=None): + # If not in streaming mode, clear matrix everytime when call `calculate` + if not self.streaming: + self.zero_matrix() + + label = np.transpose(label, (0, 2, 3, 1)) + ignore = np.transpose(ignore, (0, 2, 3, 1)) + mask = np.array(ignore) == 1 + + label = np.asarray(label)[mask] + pred = np.asarray(pred)[mask] + one = np.ones_like(pred) + # Accumuate ([row=label, col=pred], 1) into sparse matrix + spm = csr_matrix((one, (label, pred)), + shape=(self.num_classes, self.num_classes)) + spm = spm.todense() + self.confusion_matrix += spm + + def zero_matrix(self): + """ Clear confusion matrix """ + self.confusion_matrix = np.zeros([self.num_classes, self.num_classes], + dtype='int64') + + def mean_iou(self): + iou_list = [] + avg_iou = 0 + # TODO: use numpy sum axis api to simpliy + vji = np.zeros(self.num_classes, dtype=int) + vij = np.zeros(self.num_classes, dtype=int) + for j in range(self.num_classes): + v_j = 0 + for i in range(self.num_classes): + v_j += self.confusion_matrix[j][i] + vji[j] = v_j + + for i in range(self.num_classes): + v_i = 0 + for j in range(self.num_classes): + v_i += self.confusion_matrix[j][i] + vij[i] = v_i + + for c in range(self.num_classes): + total = vji[c] + vij[c] - self.confusion_matrix[c][c] + if total == 0: + iou = 0 + else: + iou = float(self.confusion_matrix[c][c]) / total + avg_iou += iou + iou_list.append(iou) + avg_iou = float(avg_iou) / float(self.num_classes) + return np.array(iou_list), avg_iou + + def accuracy(self): + total = self.confusion_matrix.sum() + total_right = 0 + for c in range(self.num_classes): + total_right += self.confusion_matrix[c][c] + if total == 0: + avg_acc = 0 + else: + avg_acc = float(total_right) / total + + vij = np.zeros(self.num_classes, dtype=int) + for i in range(self.num_classes): + v_i = 0 + for j in range(self.num_classes): + v_i += self.confusion_matrix[j][i] + vij[i] = v_i + + acc_list = [] + for c in range(self.num_classes): + if vij[c] == 0: + acc = 0 + else: + acc = self.confusion_matrix[c][c] / float(vij[c]) + acc_list.append(acc) + return np.array(acc_list), avg_acc + + def kappa(self): + vji = np.zeros(self.num_classes) + vij = np.zeros(self.num_classes) + for j in range(self.num_classes): + v_j = 0 + for i in range(self.num_classes): + v_j += self.confusion_matrix[j][i] + vji[j] = v_j + + for i in range(self.num_classes): + v_i = 0 + for j in range(self.num_classes): + v_i += self.confusion_matrix[j][i] + vij[i] = v_i + + total = self.confusion_matrix.sum() + + # avoid spillovers + # TODO: is it reasonable to hard code 10000.0? + total = float(total) / 10000.0 + vji = vji / 10000.0 + vij = vij / 10000.0 + + tp = 0 + tc = 0 + for c in range(self.num_classes): + tp += vji[c] * vij[c] + tc += self.confusion_matrix[c][c] + + tc = tc / 10000.0 + pe = tp / (total * total) + po = tc / total + + kappa = (po - pe) / (1 - pe) + return kappa diff --git a/dygraph/utils/utils.py b/dygraph/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..435374aea27e57db16fee13522cb09d68d40fbec --- /dev/null +++ b/dygraph/utils/utils.py @@ -0,0 +1,275 @@ +# coding: utf8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import os +import os.path as osp +import numpy as np +import six +import yaml +import math +import cv2 +from . import logging + + +def seconds_to_hms(seconds): + h = math.floor(seconds / 3600) + m = math.floor((seconds - h * 3600) / 60) + s = int(seconds - h * 3600 - m * 60) + hms_str = "{}:{}:{}".format(h, m, s) + return hms_str + + +def setting_environ_flags(): + if 'FLAGS_eager_delete_tensor_gb' not in os.environ: + os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0' + if 'FLAGS_allocator_strategy' not in os.environ: + os.environ['FLAGS_allocator_strategy'] = 'auto_growth' + if "CUDA_VISIBLE_DEVICES" in os.environ: + if os.environ["CUDA_VISIBLE_DEVICES"].count("-1") > 0: + os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +def get_environ_info(): + setting_environ_flags() + import paddle.fluid as fluid + info = dict() + info['place'] = 'cpu' + info['num'] = int(os.environ.get('CPU_NUM', 1)) + if os.environ.get('CUDA_VISIBLE_DEVICES', None) != "": + if hasattr(fluid.core, 'get_cuda_device_count'): + gpu_num = 0 + try: + gpu_num = fluid.core.get_cuda_device_count() + except: + os.environ['CUDA_VISIBLE_DEVICES'] = '' + pass + if gpu_num > 0: + info['place'] = 'cuda' + info['num'] = fluid.core.get_cuda_device_count() + return info + + +def parse_param_file(param_file, return_shape=True): + from paddle.fluid.proto.framework_pb2 import VarType + f = open(param_file, 'rb') + version = np.fromstring(f.read(4), dtype='int32') + lod_level = np.fromstring(f.read(8), dtype='int64') + for i in range(int(lod_level)): + _size = np.fromstring(f.read(8), dtype='int64') + _ = f.read(_size) + version = np.fromstring(f.read(4), dtype='int32') + tensor_desc = VarType.TensorDesc() + tensor_desc_size = np.fromstring(f.read(4), dtype='int32') + tensor_desc.ParseFromString(f.read(int(tensor_desc_size))) + tensor_shape = tuple(tensor_desc.dims) + if return_shape: + f.close() + return tuple(tensor_desc.dims) + if tensor_desc.data_type != 5: + raise Exception( + "Unexpected data type while parse {}".format(param_file)) + data_size = 4 + for i in range(len(tensor_shape)): + data_size *= tensor_shape[i] + weight = np.fromstring(f.read(data_size), dtype='float32') + f.close() + return np.reshape(weight, tensor_shape) + + +def fuse_bn_weights(exe, main_prog, weights_dir): + import paddle.fluid as fluid + logging.info("Try to fuse weights of batch_norm...") + bn_vars = list() + for block in main_prog.blocks: + ops = list(block.ops) + for op in ops: + if op.type == 'affine_channel': + scale_name = op.input('Scale')[0] + bias_name = op.input('Bias')[0] + prefix = scale_name[:-5] + mean_name = prefix + 'mean' + variance_name = prefix + 'variance' + if not osp.exists(osp.join( + weights_dir, mean_name)) or not osp.exists( + osp.join(weights_dir, variance_name)): + logging.info( + "There's no batch_norm weight found to fuse, skip fuse_bn." + ) + return + + bias = block.var(bias_name) + pretrained_shape = parse_param_file( + osp.join(weights_dir, bias_name)) + actual_shape = tuple(bias.shape) + if pretrained_shape != actual_shape: + continue + bn_vars.append( + [scale_name, bias_name, mean_name, variance_name]) + eps = 1e-5 + for names in bn_vars: + scale_name, bias_name, mean_name, variance_name = names + scale = parse_param_file( + osp.join(weights_dir, scale_name), return_shape=False) + bias = parse_param_file( + osp.join(weights_dir, bias_name), return_shape=False) + mean = parse_param_file( + osp.join(weights_dir, mean_name), return_shape=False) + variance = parse_param_file( + osp.join(weights_dir, variance_name), return_shape=False) + bn_std = np.sqrt(np.add(variance, eps)) + new_scale = np.float32(np.divide(scale, bn_std)) + new_bias = bias - mean * new_scale + scale_tensor = fluid.global_scope().find_var(scale_name).get_tensor() + bias_tensor = fluid.global_scope().find_var(bias_name).get_tensor() + scale_tensor.set(new_scale, exe.place) + bias_tensor.set(new_bias, exe.place) + if len(bn_vars) == 0: + logging.info( + "There's no batch_norm weight found to fuse, skip fuse_bn.") + else: + logging.info("There's {} batch_norm ops been fused.".format( + len(bn_vars))) + + +def load_pdparams(exe, main_prog, model_dir): + import paddle.fluid as fluid + from paddle.fluid.proto.framework_pb2 import VarType + from paddle.fluid.framework import Program + + vars_to_load = list() + import pickle + with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f: + params_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + unused_vars = list() + for var in main_prog.list_vars(): + if not isinstance(var, fluid.framework.Parameter): + continue + if var.name not in params_dict: + raise Exception("{} is not in saved model".format(var.name)) + if var.shape != params_dict[var.name].shape: + unused_vars.append(var.name) + logging.warning( + "[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})" + .format(var.name, params_dict[var.name].shape, var.shape)) + continue + vars_to_load.append(var) + logging.debug("Weight {} will be load".format(var.name)) + for var_name in unused_vars: + del params_dict[var_name] + fluid.io.set_program_state(main_prog, params_dict) + + if len(vars_to_load) == 0: + logging.warning( + "There is no pretrain weights loaded, maybe you should check you pretrain model!" + ) + else: + logging.info("There are {} varaibles in {} are loaded.".format( + len(vars_to_load), model_dir)) + + +def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False): + if not osp.exists(weights_dir): + raise Exception("Path {} not exists.".format(weights_dir)) + if osp.exists(osp.join(weights_dir, "model.pdparams")): + return load_pdparams(exe, main_prog, weights_dir) + import paddle.fluid as fluid + vars_to_load = list() + for var in main_prog.list_vars(): + if not isinstance(var, fluid.framework.Parameter): + continue + if not osp.exists(osp.join(weights_dir, var.name)): + logging.debug("[SKIP] Pretrained weight {}/{} doesn't exist".format( + weights_dir, var.name)) + continue + pretrained_shape = parse_param_file(osp.join(weights_dir, var.name)) + actual_shape = tuple(var.shape) + if pretrained_shape != actual_shape: + logging.warning( + "[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})" + .format(weights_dir, var.name, pretrained_shape, actual_shape)) + continue + vars_to_load.append(var) + logging.debug("Weight {} will be load".format(var.name)) + + params_dict = fluid.io.load_program_state( + weights_dir, var_list=vars_to_load) + fluid.io.set_program_state(main_prog, params_dict) + if len(vars_to_load) == 0: + logging.warning( + "There is no pretrain weights loaded, maybe you should check you pretrain model!" + ) + else: + logging.info("There are {} varaibles in {} are loaded.".format( + len(vars_to_load), weights_dir)) + if fuse_bn: + fuse_bn_weights(exe, main_prog, weights_dir) + + +def visualize(image, result, save_dir=None, weight=0.6): + """ + Convert segment result to color image, and save added image. + Args: + image: the path of origin image + result: the predict result of image + save_dir: the directory for saving visual image + weight: the image weight of visual image, and the result weight is (1 - weight) + """ + label_map = result['label_map'] + color_map = get_color_map_list(256) + color_map = np.array(color_map).astype("uint8") + # Use OpenCV LUT for color mapping + c1 = cv2.LUT(label_map, color_map[:, 0]) + c2 = cv2.LUT(label_map, color_map[:, 1]) + c3 = cv2.LUT(label_map, color_map[:, 2]) + pseudo_img = np.dstack((c1, c2, c3)) + + im = cv2.imread(image) + vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0) + + if save_dir is not None: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + image_name = os.path.split(image)[-1] + out_path = os.path.join(save_dir, image_name) + cv2.imwrite(out_path, vis_result) + else: + return vis_result + + +def get_color_map_list(num_classes): + """ Returns the color map for visualizing the segmentation mask, + which can support arbitrary number of classes. + Args: + num_classes: Number of classes + Returns: + The color map + """ + num_classes += 1 + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + color_map = color_map[1:] + return color_map