提交 23b83460 编写于 作者: 武毅 提交者: GitHub

Fault tolerant distributed training, just work version, with etcd (#2849)

* using etcd as fault tolerant training

* update

* workable version, ft not tested

* small fix

* update

* remove TODO
上级 c40707b6
......@@ -40,7 +40,7 @@ func main() {
idx = *index
} else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
idx, err = e.Register()
idx, err = e.Register(*port)
candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
......
......@@ -2,6 +2,7 @@ package master
import (
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
......@@ -36,9 +37,9 @@ func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// TODO(helin): wait before move on with next
// getTask call.
log.Errorln(err)
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
}
......
......@@ -215,6 +215,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
}
count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count)
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
......
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379"
def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"])
while 1:
r, e = master_client.next_record()
if not r:
break
yield pickle.loads(r)
def main():
......@@ -22,13 +40,13 @@ def main():
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec="localhost:3000")
pserver_spec=etcd_endpoint,
use_etcd=True)
# event_handler to print training and testing info
def event_handler(event):
......@@ -47,11 +65,11 @@ def main():
print "Test %d, %.2f" % (event.pass_id, result.cost)
# training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
batch_size=2),
cloud_reader, buf_size=500), batch_size=2),
feeding={'x': 0,
'y': 1},
event_handler=event_handler,
......
......@@ -12,6 +12,7 @@ import (
)
const (
// DefaultEtcdTimeout is the default etcd timeout
DefaultEtcdTimeout time.Duration = 5 * time.Second
)
......@@ -66,12 +67,12 @@ func (p *EtcdClient) List() []Server {
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
cancel()
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
log.Infof("Get psKey=%s error, %v", psKey, err)
time.Sleep(p.timeout)
continue
}
......
......@@ -49,7 +49,7 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {
func (e *EtcdClient) Register(port int) (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
......@@ -116,7 +116,7 @@ func (e *EtcdClient) Register() (int, error) {
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx)
pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel()
if err != nil {
log.Warn(err)
......@@ -140,7 +140,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
......@@ -156,8 +156,9 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
......
......@@ -843,7 +843,8 @@ public:
bool useSparseUpdater);
static ParameterUpdater* createNewRemoteUpdater(
OptimizationConfig* config,
const std::string pserverSpec) throw(UnsupportError);
const std::string pserverSpec,
const bool useEtcd) throw(UnsupportError);
~ParameterUpdater();
/**
......
......@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
ParameterUpdater *ParameterUpdater::createNewRemoteUpdater(
OptimizationConfig *config,
const std::string pserverSpec) throw(UnsupportError) {
const std::string pserverSpec,
const bool useEtcd) throw(UnsupportError) {
#ifndef PADDLE_WITHOUT_GOLANG
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::NewRemoteParameterUpdater(
config->m->getConfig(), pserverSpec));
config->m->getConfig(), pserverSpec, useEtcd));
return updater;
#else
throw UnsupportError();
......
......@@ -155,7 +155,8 @@ RUN apt-get update &&\
paddle version
${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_GPU_ENV}
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
# default command shows the paddle version and exit
CMD ["paddle", "version"]
EOF
......@@ -28,6 +28,17 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater(
newGradients_(nullptr),
pserverSpec_(pserverSpec) {}
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config,
const std::string pserverSpec,
const bool useEtcd)
: trainerConfig_(config),
parameterClient_(-1),
newParameters_(nullptr),
newGradients_(nullptr),
pserverSpec_(pserverSpec),
useEtcd_(useEtcd) {}
void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) {
ParameterUpdater::init(parameters);
......@@ -38,8 +49,13 @@ void NewRemoteParameterUpdater::init(
}
// create parameter server client.
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0);
if (useEtcd_) {
parameterClient_ = paddle_new_etcd_pserver_client(
(char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0);
} else {
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0);
}
// init new parameter and gradient.
newParameters_ = initNewParameter(PARAMETER_VALUE);
......
......@@ -32,6 +32,9 @@ class NewRemoteParameterUpdater : public ParameterUpdater {
public:
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec);
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec,
const bool useEtcd);
~NewRemoteParameterUpdater() {
releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_);
......@@ -111,6 +114,8 @@ protected:
paddle_parameter** newGradients_;
/// the specification of parameter server "host1:port,host1:port"
std::string pserverSpec_;
/// true if pserverSpec_ is etcd endpoint, else pserverSpec_ is pserver addr
bool useEtcd_;
};
} // namespace paddle
......@@ -22,6 +22,8 @@ import importlib
import paddle.v2.dataset
import cPickle
import glob
import cPickle as pickle
import random
__all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
......@@ -170,8 +172,6 @@ def convert(output_path,
name_prefix,
max_lines_to_shuffle=1000):
import recordio
import cPickle as pickle
import random
"""
Convert data from reader to recordio format files.
......@@ -201,7 +201,7 @@ def convert(output_path,
def write_data(w, lines):
random.shuffle(lines)
for i, d in enumerate(lines):
d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL)
d = cPickle.dumps(d)
w[i % num_shards].write(d)
w = open_writers()
......
......@@ -10,8 +10,9 @@ class client(object):
client is a client to the master server.
"""
def __init__(self, addr, buf_size):
self.c = lib.paddle_new_master_client(addr, buf_size)
def __init__(self, etcd_endpoints, timeout, buf_size):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout,
buf_size)
def close(self):
lib.paddle_release_master_client(self.c)
......
......@@ -46,12 +46,12 @@ class Optimizer(object):
return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater)
def __create_new_remote_updater__(self, pserver_spec):
def __create_new_remote_updater__(self, pserver_spec, use_etcd):
return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec)
self.__opt_conf__, pserver_spec, use_etcd)
def create_updater(self, is_local, num_passes, use_sparse_updater,
pserver_spec):
pserver_spec, use_etcd):
"""
create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater
......@@ -77,7 +77,7 @@ class Optimizer(object):
num_passes, use_sparse_updater)
else:
parameter_updater = self.__create_new_remote_updater__(
pserver_spec)
pserver_spec, use_etcd)
return parameter_updater
......
......@@ -45,7 +45,8 @@ class SGD(object):
update_equation,
extra_layers=None,
is_local=True,
pserver_spec=None):
pserver_spec=None,
use_etcd=True):
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
......@@ -61,6 +62,7 @@ class SGD(object):
self.__topology_in_proto__ = topology.proto()
self.__is_local__ = is_local
self.__pserver_spec__ = pserver_spec
self.__use_etcd__ = use_etcd
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
# # In local mode, disable sparse_remote_update.
......@@ -127,7 +129,7 @@ class SGD(object):
self.__parameter_updater__ = self.__optimizer__.create_updater(
self.__is_local__, num_passes, self.__use_sparse_updater__,
self.__pserver_spec__)
self.__pserver_spec__, self.__use_etcd__)
self.__parameter_updater__.init(self.__gradient_machine__)
self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册