data_structure.py 2.1 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import numpy as np


class BufferDict(dict):
    def __init__(self, **kwargs):
        super(BufferDict, self).__init__(**kwargs)

    def __getitem__(self, key):
        if key in self.keys():
            return super(BufferDict, self).__getitem__(key)
        else:
            raise Exception("The %s is not in global inputs dict" % key)

    def __setitem__(self, key, value):
        if key not in self.keys():
            super(BufferDict, self).__setitem__(key, value)
        else:
            raise Exception("The %s is already in global inputs dict" % key)

    def update(self, *args, **kwargs):
        for k, v in dict(*args, **kwargs).items():
            self[k] = v

24 25 26 27 28 29
    def update_v(self, key, value):
        if key in self.keys():
            super(BufferDict, self).__setitem__(key, value)
        else:
            raise Exception("The %s is not in global inputs dict" % key)

F
FDInSky 已提交
30 31 32 33
    def get(self, key):
        return self.__getitem__(key)

    def set(self, key, value):
34
        return self.__setitem__(key, value)
F
FDInSky 已提交
35

36 37
    def debug(self, dshape=True, dvalue=True, dtype=False):
        if self['open_debug']:
F
FDInSky 已提交
38
            if 'debug_names' not in self.keys():
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
                ditems = self.keys()
            else:
                ditems = self['debug_names']

            infos = {}
            for k in ditems:
                if type(k) is dict:
                    i_d = {}
                    for i, j in k.items():
                        if type(j) is list:
                            for jj in j:
                                i_d[jj] = self.get_debug_info(self[i][jj])
                        infos[i] = i_d
                else:
                    infos[k] = self.get_debug_info(self[k])
            print(infos)

    def get_debug_info(self, v, dshape=True, dvalue=True, dtype=False):
        info = []
        if dshape == True and hasattr(v, 'shape'):
            info.append(v.shape)
        if dvalue == True and hasattr(v, 'numpy'):
            info.append(np.mean(np.abs(v.numpy())))
        if dtype == True:
            info.append(type(v))
        return info