filesystem.py 1.1 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5
import os
import six
import pickle
import paddle

L
LielinJiang 已提交
6

L
LielinJiang 已提交
7 8 9 10 11
def makedirs(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


L
LielinJiang 已提交
12
def save(state_dicts, file_name):
L
LielinJiang 已提交
13 14
    def convert(state_dict):
        model_dict = {}
L
LielinJiang 已提交
15

L
LielinJiang 已提交
16
        for k, v in state_dict.items():
L
LielinJiang 已提交
17
            if isinstance(
L
LielinJiang 已提交
18 19
                    v,
                (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
L
LielinJiang 已提交
20 21 22
                model_dict[k] = v.numpy()
            else:
                model_dict[k] = v
23

L
LielinJiang 已提交
24 25 26 27
        return model_dict

    final_dict = {}
    for k, v in state_dicts.items():
L
LielinJiang 已提交
28 29 30
        if isinstance(
                v,
            (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
L
LielinJiang 已提交
31 32 33 34 35 36
            final_dict = convert(state_dicts)
            break
        elif isinstance(v, dict):
            final_dict[k] = convert(v)
        else:
            final_dict[k] = v
L
LielinJiang 已提交
37

L
LielinJiang 已提交
38 39 40 41 42 43 44 45 46
    with open(file_name, 'wb') as f:
        pickle.dump(final_dict, f, protocol=2)


def load(file_name):
    with open(file_name, 'rb') as f:
        state_dicts = pickle.load(f) if six.PY2 else pickle.load(
            f, encoding='latin1')
    return state_dicts