test_dist_save_load.py 5.4 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

import os
import shutil
import unittest
import tempfile

import numpy as np

from test_dist_base import TestDistBase, RUN_STEP

25 26 27
import os
flag_name = os.path.splitext(__file__)[0]

T
tangwei12 已提交
28 29 30 31 32 33 34 35 36 37

class TestDistSaveLoadDense2x2(TestDistBase):
    def _setup_config(self):
        self._sync_mode = True
        self._enforce_place = "CPU"

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
38 39
                         need_envs={},
                         log_name=""):
T
tangwei12 已提交
40 41 42 43 44 45 46 47 48 49
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "http_proxy": ""
        }

        required_envs.update(need_envs)

        if check_error_log:
50 51
            required_envs["GLOG_vmodule"] = \
                "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10"
T
tangwei12 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            required_envs["GLOG_logtostderr"] = "1"

        model_dir = tempfile.mkdtemp()

        local_env = {}
        local_env["SAVE"] = "1"
        local_env["MODEL_DIR"] = model_dir
        local_env.update(required_envs)

        cluster_env = {}
        cluster_env["LOAD"] = "1"
        cluster_env["MODEL_DIR"] = model_dir
        cluster_env.update(required_envs)

        local_var = self._run_local(model_file, local_env, check_error_log)
67 68
        tr0_var, tr1_var = self._run_cluster(
            model_file, cluster_env, check_error_log, log_name=flag_name)
T
tangwei12 已提交
69 70 71

        shutil.rmtree(model_dir)

72 73 74 75
        local_np = np.array(local_var)
        train0_np = np.array(tr0_var)
        train1_np = np.array(tr1_var)

G
gongweibao 已提交
76 77 78
        np.testing.assert_almost_equal(local_np, train0_np, decimal=2)
        np.testing.assert_almost_equal(local_np, train1_np, decimal=2)
        np.testing.assert_almost_equal(train0_np, train1_np, decimal=2)
T
tangwei12 已提交
79 80 81 82 83

    def test_dist(self):
        need_envs = {
            "IS_DISTRIBUTED": '0',
            "IS_SPARSE": '0',
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
            'IS_SELF_CONTAINED_LR': '1',
            'SAVE_MODE': 'LOCAL',
        }
        self.check_with_place(
            "dist_save_load.py",
            delta=0,
            check_error_log=False,
            need_envs=need_envs)


class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase):
    def _setup_config(self):
        self._sync_mode = True
        self._enforce_place = "CPU"

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
103 104
                         need_envs={},
                         log_name=""):
105 106 107 108 109 110 111 112 113 114
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "http_proxy": ""
        }

        required_envs.update(need_envs)

        if check_error_log:
115 116
            required_envs["GLOG_vmodule"] = \
                "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10"
117 118 119 120 121 122 123 124 125 126
            required_envs["GLOG_logtostderr"] = "1"

        model_dir = tempfile.mkdtemp()

        save_env = {}
        save_env["SAVE_MODE"] = "DIST"
        save_env["SAVE"] = "1"
        save_env["MODEL_DIR"] = model_dir
        save_env.update(required_envs)

127 128
        tr0_var_1, tr1_var_1 = self._run_cluster(
            model_file, save_env, check_error_log, log_name=flag_name)
129 130 131 132 133

        load_env = {}
        load_env["LOAD"] = "1"
        load_env["MODEL_DIR"] = model_dir
        load_env.update(required_envs)
134 135
        tr0_var_2, tr1_var_2 = self._run_cluster(
            model_file, load_env, check_error_log, log_name=flag_name)
136 137 138 139 140 141 142 143

        shutil.rmtree(model_dir)

        train0_1_np = np.array(tr0_var_1)
        train1_1_np = np.array(tr1_var_1)
        train0_2_np = np.array(tr0_var_2)
        train1_2_np = np.array(tr1_var_2)

G
gongweibao 已提交
144 145
        np.testing.assert_almost_equal(train0_1_np, train0_2_np, decimal=2)
        np.testing.assert_almost_equal(train1_1_np, train1_2_np, decimal=2)
146 147 148 149 150 151 152 153 154

    def test_dist(self):
        need_envs = {
            "IS_DISTRIBUTED": '0',
            "IS_SPARSE": '0',
            'IS_SELF_CONTAINED_LR': '1',
            'SAVE_MODE': 'DIST',
            'OPTIMIZER': 'ADAM',
            'SKIP_STEPS': str(np.random.randint(2, 6))
T
tangwei12 已提交
155 156 157 158
        }
        self.check_with_place(
            "dist_save_load.py",
            delta=0,
159 160 161
            check_error_log=True,
            need_envs=need_envs,
            log_name=flag_name)
T
tangwei12 已提交
162 163 164 165


if __name__ == "__main__":
    unittest.main()