filesystem.py 1.2 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
import os
import six
import pickle
import paddle

def makedirs(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def save(state_dicts, file_name):

    def convert(state_dict):
        model_dict = {}
L
LielinJiang 已提交
14
        # name_table = {}
L
LielinJiang 已提交
15 16 17 18 19 20
        for k, v in state_dict.items():
            if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
                model_dict[k] = v.numpy()
            else:
                model_dict[k] = v
                return state_dict
L
LielinJiang 已提交
21 22
        #     name_table[k] = v.name
        # model_dict["StructuredToParameterName@@"] = name_table
L
LielinJiang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        return model_dict

    final_dict = {}
    for k, v in state_dicts.items():
        if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
            final_dict = convert(state_dicts)
            break
        elif isinstance(v, dict):
            final_dict[k] = convert(v)
        else:
            final_dict[k] = v
    
    with open(file_name, 'wb') as f:
        pickle.dump(final_dict, f, protocol=2)


def load(file_name):
    with open(file_name, 'rb') as f:
        state_dicts = pickle.load(f) if six.PY2 else pickle.load(
            f, encoding='latin1')
    return state_dicts