filesystem.py 1.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19
import os
import six
import pickle
import paddle

L
LielinJiang 已提交
20

L
LielinJiang 已提交
21 22
def makedirs(dir):
    if not os.path.exists(dir):
23 24 25 26 27
        # avoid error when train with multiple gpus
        try:
            os.makedirs(dir)
        except:
            pass
L
LielinJiang 已提交
28 29


L
LielinJiang 已提交
30
def save(state_dicts, file_name):
L
LielinJiang 已提交
31 32
    def convert(state_dict):
        model_dict = {}
L
LielinJiang 已提交
33

L
LielinJiang 已提交
34
        for k, v in state_dict.items():
L
LielinJiang 已提交
35
            if isinstance(
L
LielinJiang 已提交
36 37
                    v,
                (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
L
LielinJiang 已提交
38 39 40
                model_dict[k] = v.numpy()
            else:
                model_dict[k] = v
41

L
LielinJiang 已提交
42 43 44 45
        return model_dict

    final_dict = {}
    for k, v in state_dicts.items():
L
LielinJiang 已提交
46 47 48
        if isinstance(
                v,
            (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
L
LielinJiang 已提交
49 50 51 52 53 54
            final_dict = convert(state_dicts)
            break
        elif isinstance(v, dict):
            final_dict[k] = convert(v)
        else:
            final_dict[k] = v
L
LielinJiang 已提交
55

L
LielinJiang 已提交
56 57 58 59 60 61 62 63 64
    with open(file_name, 'wb') as f:
        pickle.dump(final_dict, f, protocol=2)


def load(file_name):
    with open(file_name, 'rb') as f:
        state_dicts = pickle.load(f) if six.PY2 else pickle.load(
            f, encoding='latin1')
    return state_dicts