data_structure.py 2.0 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import numpy as np


class BufferDict(dict):
    def __init__(self, **kwargs):
        super(BufferDict, self).__init__(**kwargs)

    def __getitem__(self, key):
        if key in self.keys():
            return super(BufferDict, self).__getitem__(key)
        else:
            raise Exception("The %s is not in global inputs dict" % key)

    def __setitem__(self, key, value):
        if key not in self.keys():
            super(BufferDict, self).__setitem__(key, value)
        else:
            raise Exception("The %s is already in global inputs dict" % key)

    def update(self, *args, **kwargs):
        for k, v in dict(*args, **kwargs).items():
            self[k] = v

24 25 26 27 28 29
    def update_v(self, key, value):
        if key in self.keys():
            super(BufferDict, self).__setitem__(key, value)
        else:
            raise Exception("The %s is not in global inputs dict" % key)

F
FDInSky 已提交
30 31 32 33
    def get(self, key):
        return self.__getitem__(key)

    def set(self, key, value):
34
        return self.__setitem__(key, value)
F
FDInSky 已提交
35

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    def debug(self, dshape=True, dvalue=True, dtype=False):
        if self['open_debug']:
            if self['debug_names'] is None:
                ditems = self.keys()
            else:
                ditems = self['debug_names']

            infos = {}
            for k in ditems:
                if type(k) is dict:
                    i_d = {}
                    for i, j in k.items():
                        if type(j) is list:
                            for jj in j:
                                i_d[jj] = self.get_debug_info(self[i][jj])
                        infos[i] = i_d
                else:
                    infos[k] = self.get_debug_info(self[k])
            print(infos)

    def get_debug_info(self, v, dshape=True, dvalue=True, dtype=False):
        info = []
        if dshape == True and hasattr(v, 'shape'):
            info.append(v.shape)
        if dvalue == True and hasattr(v, 'numpy'):
            info.append(np.mean(np.abs(v.numpy())))
        if dtype == True:
            info.append(type(v))
        return info