未验证 提交 101f74cb 编写于 作者: T tangwei12 提交者: GitHub

fix save/load in fleet (#17675)

* fix save/load in Fleet
* add UT framework of Fleet
上级 f1d458da
......@@ -53,7 +53,7 @@ paddle.fluid.io.save_persistables (ArgSpec(args=['executor', 'dirname', 'main_pr
paddle.fluid.io.load_vars (ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '1bb9454cf09d71f190bb51550c5a3ac9'))
paddle.fluid.io.load_params (ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)), ('document', '944291120d37bdb037a689d2c86d0a6e'))
paddle.fluid.io.load_persistables (ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)), ('document', '28df5bfe26ca7a077f91156abb0fe6d2'))
paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)), ('document', '89539e459eb959145f15c9c3e38fa97c'))
paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment', 'program_only'], varargs=None, keywords=None, defaults=(None, None, None, True, False)), ('document', 'fc82bfd137a9b1ab8ebd1651bd35b6e5'))
paddle.fluid.io.load_inference_model (ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '2f54d7c206b62f8c10f4f9d78c731cfd'))
paddle.fluid.io.PyReader.__init__ (ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable', 'return_list'], varargs=None, keywords=None, defaults=(None, None, True, True, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.io.PyReader.decorate_batch_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', '4a072de39998ee4e0de33fcec11325a6'))
......
......@@ -104,7 +104,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distributed_transpiler.py
// NOTE: the format is determined by distribute_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
......
......@@ -15,23 +15,22 @@
from __future__ import print_function
import abc
from enum import Enum
import paddle.fluid as fluid
from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import SGD
from role_maker import MPISymetricRoleMaker
from role_maker import RoleMakerBase
from role_maker import UserDefinedRoleMaker
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase
from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedRoleMaker
class Mode(Enum):
class Mode:
"""
There are various mode for fleet, each of them is designed for different model.
"""
TRANSPILER = 1,
PSLIB = 2,
TRANSPILER = 1
PSLIB = 2
COLLECTIVE = 3
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import print_function
from enum import Enum
__all__ = [
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
......@@ -21,8 +20,8 @@ __all__ = [
]
class Role(Enum):
WORKER = 1,
class Role:
WORKER = 1
SERVER = 2
......@@ -313,7 +312,7 @@ class UserDefinedRoleMaker(RoleMakerBase):
raise ValueError("current_id must be gather or equal 0")
self._current_id = current_id
if not isinstance(role, Role):
if role != Role.WORKER and role != Role.SERVER:
raise TypeError("role must be as Role")
else:
self._role = role
......
......@@ -17,9 +17,9 @@ import paddle.fluid as fluid
import paddle.fluid.io as io
import paddle.fluid.transpiler.distribute_transpiler as dist_transpiler
from ..base.fleet_base import Fleet
from ..base.fleet_base import Mode
from ..base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
class Collective(Fleet):
......
......@@ -15,14 +15,16 @@ import os
import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator
from paddle.fluid.framework import default_main_program
from paddle.fluid.framework import default_startup_program
from paddle.fluid.framework import Program
from paddle.fluid.optimizer import Optimizer
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from ...base.fleet_base import DistributedOptimizer
from ...base.fleet_base import Fleet
from ...base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
class DistributedTranspiler(Fleet):
......@@ -34,6 +36,7 @@ class DistributedTranspiler(Fleet):
super(DistributedTranspiler, self).__init__(Mode.TRANSPILER)
self._transpile_config = None
self._transpiler = None
self._origin_program = None
self.startup_program = None
self.main_program = None
self._communicator = None
......@@ -75,8 +78,7 @@ class DistributedTranspiler(Fleet):
if not os.path.isdir(model_dir):
raise ValueError("There is no directory named '%s'", model_dir)
io.load_persistables(self._executor, model_dir,
self.startup_program)
io.load_persistables(self._executor, model_dir, self.main_program)
def run_server(self):
"""
......@@ -137,9 +139,31 @@ class DistributedTranspiler(Fleet):
Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` by the `executor`.
"""
if main_program is not None:
io.save_inference_model(dirname, feeded_var_names, target_vars,
executor, main_program, None, None,
export_for_deployment)
else:
io.save_inference_model(
dirname,
feeded_var_names,
target_vars,
executor,
self._origin_program,
None,
None,
export_for_deployment,
model_only=True)
model_basename = "__model__"
model_filename = os.path.join(dirname, model_basename)
with open(model_filename, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(self.main_program)
self.save_persistables(executor, dirname, program)
def save_persistables(self, executor, dirname, main_program=None):
"""
......@@ -152,6 +176,14 @@ class DistributedTranspiler(Fleet):
files, set `filename` None; if you would like to save all variables in a
single file, use `filename` to specify the file name.
"""
if main_program is None:
main_program = self.main_program
if not main_program._is_distributed:
raise ValueError(
"main_program is for local, may not use fleet.save_persistables")
io.save_persistables(executor, dirname, main_program, None)
def _transpile(self, config):
......@@ -162,18 +194,27 @@ class DistributedTranspiler(Fleet):
if not config.sync_mode:
config.runtime_split_send_recv = True
# _origin_program is a deep copy for default_main_program, for inference
self._origin_program = default_main_program().clone(for_test=False)
self._transpile_config = config
self._transpiler = OriginTranspiler(config)
if self.is_worker():
self._transpiler.transpile(
trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(),
sync_mode=config.sync_mode)
if self.is_worker():
self.main_program = self._transpiler.get_trainer_program()
self.startup_program = default_startup_program()
else:
self._transpiler.transpile(
trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(),
sync_mode=config.sync_mode,
current_endpoint=self.server_endpoints()[self.server_index()])
self.main_program, self.startup_program = \
self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()])
......
......@@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
import sys
from .optimizer_factory import *
from optimizer_factory import *
from google.protobuf import text_format
import paddle.fluid as fluid
from paddle.fluid.framework import Program
from ...base.fleet_base import Fleet
from ...base.fleet_base import Mode
from ...base.role_maker import MPISymetricRoleMaker
from ...base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
class PSLib(Fleet):
......
......@@ -18,7 +18,7 @@ import time
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distributed_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.log_helper import get_logger
......
......@@ -907,7 +907,8 @@ def save_inference_model(dirname,
main_program=None,
model_filename=None,
params_filename=None,
export_for_deployment=True):
export_for_deployment=True,
program_only=False):
"""
Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` by the `executor`.
......@@ -938,6 +939,7 @@ def save_inference_model(dirname,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
program_only(bool): If True, It will save inference program only, and do not save params of Program.
Returns:
target_var_name_list(list): The fetch variables' name list
......@@ -1071,6 +1073,12 @@ def save_inference_model(dirname,
with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string())
if program_only:
warnings.warn(
"save_inference_model specified the param `program_only` to True, It will not save params of Program."
)
return target_var_name_list
main_program._copy_dist_param_info_from(origin_program)
if params_filename is not None:
......
......@@ -17,6 +17,7 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
LIST(REMOVE_ITEM TEST_OPS test_nce_remote_table_op)
LIST(REMOVE_ITEM TEST_OPS test_hsigmoid_remote_table_op)
LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_ctr)
endif(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_launch)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
import tarfile
import os
import paddle
import paddle.fluid.incubate.data_generator as data_generator
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
DATA_URL = "http://paddle-ctr-data.bj.bcebos.com/avazu_ctr_data.tgz"
DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e"
"""
avazu_ctr_data/train.txt
avazu_ctr_data/infer.txt
avazu_ctr_data/test.txt
avazu_ctr_data/data.meta.txt
"""
def download_file():
file_name = "avazu_ctr_data"
path = paddle.dataset.common.download(DATA_URL, file_name, DATA_MD5)
dir_name = os.path.dirname(path)
text_file_dir_name = os.path.join(dir_name, file_name)
if not os.path.exists(text_file_dir_name):
tar = tarfile.open(path, "r:gz")
tar.extractall(dir_name)
return text_file_dir_name
def load_dnn_input_record(sent):
return list(map(int, sent.split()))
def load_lr_input_record(sent):
res = []
for _ in [x.split(':') for x in sent.split()]:
res.append(int(_[0]))
return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
def generate_sample(self, line):
def iter():
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
click = [int(fs[2])]
yield ("dnn_data", dnn_input), \
("lr_data", lr_input), \
("click", click)
return iter
def prepare_data():
"""
load data meta info from path, return (dnn_input_dim, lr_input_dim)
"""
file_dir_name = download_file()
meta_file_path = os.path.join(file_dir_name, 'data.meta.txt')
train_file_path = os.path.join(file_dir_name, 'train.txt')
with open(meta_file_path, "r") as f:
lines = f.readlines()
err_info = "wrong meta format"
assert len(lines) == 2, err_info
assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
1], err_info
res = map(int, [_.split(':')[1] for _ in lines])
res = list(res)
dnn_input_dim = res[0]
lr_input_dim = res[1]
logger.info('dnn input dim: %d' % dnn_input_dim)
logger.info('lr input dim: %d' % lr_input_dim)
return dnn_input_dim, lr_input_dim, train_file_path
if __name__ == "__main__":
pairwise_reader = DatasetCtrReader()
pairwise_reader.run_from_stdin()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import shutil
import tempfile
import time
import paddle.fluid as fluid
import os
import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
class TestDistCTR2x2(FleetDistRunnerBase):
def net(self, batch_size=4, lr=0.01):
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
""" network definition """
dnn_data = fluid.layers.data(
name="dnn_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
lr_data = fluid.layers.data(
name="lr_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
label = fluid.layers.data(
name="click",
shape=[-1, 1],
dtype="int64",
lod_level=0,
append_batch_size=False)
datas = [dnn_data, lr_data, label]
# build dnn model
dnn_layer_dims = [128, 64, 32, 1]
dnn_embedding = fluid.layers.embedding(
is_distributed=False,
input=dnn_data,
size=[dnn_input_dim, dnn_layer_dims[0]],
param_attr=fluid.ParamAttr(
name="deep_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True)
dnn_pool = fluid.layers.sequence_pool(
input=dnn_embedding, pool_type="sum")
dnn_out = dnn_pool
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = fluid.layers.fc(
input=dnn_out,
size=dim,
act="relu",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)),
name='dnn-fc-%d' % i)
dnn_out = fc
# build lr model
lr_embbding = fluid.layers.embedding(
is_distributed=False,
input=lr_data,
size=[lr_input_dim, 1],
param_attr=fluid.ParamAttr(
name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True)
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
acc = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
label=label)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
self.feeds = datas
self.train_file_path = train_file_path
self.avg_cost = avg_cost
self.predict = predict
return avg_cost
def check_model_right(self, dirname):
model_filename = os.path.join(dirname, "__model__")
with open(model_filename, "rb") as f:
program_desc_str = f.read()
program = fluid.Program.parse_from_string(program_desc_str)
with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
wn.write(str(program))
def do_training(self, fleet):
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fleet.startup_program)
thread_num = 2
filelist = []
for _ in range(thread_num):
filelist.append(train_file_path)
# config dataset
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(128)
dataset.set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
for epoch_id in range(2):
pass_start = time.time()
dataset.set_filelist(filelist)
exe.train_from_dataset(
program=fleet.main_program,
dataset=dataset,
fetch_list=[self.avg_cost],
fetch_info=["cost"],
print_period=100,
debug=False)
pass_time = time.time() - pass_start
model_dir = tempfile.mkdtemp()
fleet.save_inference_model(
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
fleet.stop_worker()
if __name__ == "__main__":
runtime_main(TestDistCTR2x2)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import os
import pickle
import subprocess
import sys
import time
import traceback
import math
import collections
import socket
from contextlib import closing
import six
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
RUN_STEP = 5
LEARNING_RATE = 0.01
class FleetDistRunnerBase(object):
def run_pserver(self, args):
if args.role.upper() != "PSERVER":
raise ValueError("args role must be PSERVER")
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.SERVER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
fleet.init(role)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = args.sync_mode
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_server()
fleet.run_server()
def run_trainer(self, args):
if args.role.upper() != "TRAINER":
raise ValueError("args role must be TRAINER")
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.WORKER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
fleet.init(role)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = args.sync_mode
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
self.do_training(fleet)
out = self.do_training(fleet)
def net(self, batch_size=4, lr=0.01):
raise NotImplementedError(
"get_model should be implemented by child classes.")
def do_training(self, fleet):
raise NotImplementedError(
"do_training should be implemented by child classes.")
class TestFleetBase(unittest.TestCase):
def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented")
def setUp(self):
self._sync_mode = True
self._trainers = 2
self._pservers = 2
self._port_set = set()
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self._setup_config()
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def _start_pserver(self, cmd, required_envs):
ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)
ps0_pipe = open("/tmp/ps0_err.log", "wb+")
ps1_pipe = open("/tmp/ps1_err.log", "wb+")
ps0_proc = subprocess.Popen(
ps0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps0_pipe,
env=required_envs)
ps1_proc = subprocess.Popen(
ps1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps1_pipe,
env=required_envs)
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
def _start_trainer(self, cmd, required_envs):
tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)
tr0_pipe = open("/tmp/tr0_err.log", "wb+")
tr1_pipe = open("/tmp/tr1_err.log", "wb+")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=required_envs)
tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=required_envs)
return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe
def _run_cluster(self, model, envs):
env = {'CPU_NUM': '1'}
env.update(envs)
tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}".format(
self._python_interp, model, self._ps_endpoints, self._trainers)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}".format(
self._python_interp, model, self._ps_endpoints, self._trainers)
if self._sync_mode:
tr_cmd += " --sync_mode"
ps_cmd += " --sync_mode"
# Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env)
# Wait until trainer process terminate
while True:
stat0 = tr0.poll()
time.sleep(0.1)
if stat0 is not None:
break
while True:
stat1 = tr1.poll()
time.sleep(0.1)
if stat1 is not None:
break
tr0_out, tr0_err = tr0.communicate()
tr1_out, tr1_err = tr1.communicate()
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
ps0_pipe.close()
ps1_pipe.close()
ps0.terminate()
ps1.terminate()
with open("/tmp/tr0_out.log", "wb+") as wn:
wn.write(tr0_out)
with open("/tmp/tr1_out.log", "wb+") as wn:
wn.write(tr1_out)
# print server log
with open("/tmp/ps0_err.log", "r") as fn:
sys.stderr.write("ps0 stderr: %s\n" % fn.read())
with open("/tmp/ps1_err.log", "r") as fn:
sys.stderr.write("ps1 stderr: %s\n" % fn.read())
# print log
with open("/tmp/tr0_err.log", "r") as fn:
sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
with open("/tmp/tr1_err.log", "r") as fn:
sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
return 0, 0
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run Fleet test.')
parser.add_argument(
'--role', type=str, required=True, choices=['pserver', 'trainer'])
parser.add_argument('--endpoints', type=str, required=False, default="")
parser.add_argument('--current_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--sync_mode', action='store_true')
args = parser.parse_args()
model = test_class()
if args.role == "pserver":
model.run_pserver(args)
else:
model.run_trainer(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
from test_dist_fleet_base import TestFleetBase
class TestDistMnist2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=False)
if __name__ == "__main__":
unittest.main()
......@@ -131,7 +131,7 @@ packages=['paddle',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.parameter_server.distributed_transpiler',
'paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler',
'paddle.fluid.incubate.fleet.parameter_server.pslib',
'paddle.fluid.incubate.fleet.collective']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册