提交 23b83460 编写于 作者: 武毅 提交者: GitHub

Fault tolerant distributed training, just work version, with etcd (#2849)

* using etcd as fault tolerant training

* update

* workable version, ft not tested

* small fix

* update

* remove TODO
上级 c40707b6
...@@ -40,7 +40,7 @@ func main() { ...@@ -40,7 +40,7 @@ func main() {
idx = *index idx = *index
} else { } else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
idx, err = e.Register() idx, err = e.Register(*port)
candy.Must(err) candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
......
...@@ -2,6 +2,7 @@ package master ...@@ -2,6 +2,7 @@ package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
...@@ -36,9 +37,9 @@ func (c *Client) getRecords() { ...@@ -36,9 +37,9 @@ func (c *Client) getRecords() {
for { for {
t, err := c.getTask() t, err := c.getTask()
if err != nil { if err != nil {
// TODO(helin): wait before move on with next
// getTask call. // getTask call.
log.Errorln(err) log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue continue
} }
......
...@@ -215,6 +215,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -215,6 +215,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
count := index.NumChunks() count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := Chunk{ chunk := Chunk{
Path: path, Path: path,
......
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379"
def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"])
while 1:
r, e = master_client.next_record()
if not r:
break
yield pickle.loads(r)
def main(): def main():
...@@ -22,13 +40,13 @@ def main(): ...@@ -22,13 +40,13 @@ def main():
# create optimizer of new remote updater to pserver # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
is_local=False, is_local=False,
pserver_spec="localhost:3000") pserver_spec=etcd_endpoint,
use_etcd=True)
# event_handler to print training and testing info # event_handler to print training and testing info
def event_handler(event): def event_handler(event):
...@@ -47,11 +65,11 @@ def main(): ...@@ -47,11 +65,11 @@ def main():
print "Test %d, %.2f" % (event.pass_id, result.cost) print "Test %d, %.2f" % (event.pass_id, result.cost)
# training # training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
uci_housing.train(), buf_size=500), cloud_reader, buf_size=500), batch_size=2),
batch_size=2),
feeding={'x': 0, feeding={'x': 0,
'y': 1}, 'y': 1},
event_handler=event_handler, event_handler=event_handler,
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
) )
const ( const (
// DefaultEtcdTimeout is the default etcd timeout
DefaultEtcdTimeout time.Duration = 5 * time.Second DefaultEtcdTimeout time.Duration = 5 * time.Second
) )
...@@ -66,12 +67,12 @@ func (p *EtcdClient) List() []Server { ...@@ -66,12 +67,12 @@ func (p *EtcdClient) List() []Server {
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := p.client.Get(ctx, psKey)
cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey=%s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(p.timeout)
continue continue
} }
......
...@@ -49,7 +49,7 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et ...@@ -49,7 +49,7 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
// Register registers the pserver on etcd // Register registers the pserver on etcd
// //
// Register returns the index of the current pserver. // Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) { func (e *EtcdClient) Register(port int) (int, error) {
var err error var err error
e.externalIP, err = networkhelper.GetExternalIP() e.externalIP, err = networkhelper.GetExternalIP()
...@@ -116,7 +116,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -116,7 +116,7 @@ func (e *EtcdClient) Register() (int, error) {
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error var err error
pserverIdx, err = e.registerPserverEtcd(ctx) pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
...@@ -140,7 +140,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) ( ...@@ -140,7 +140,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
} }
// registerPserverEtcd registers pserver node on etcd using transaction. // registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false registered := false
...@@ -156,8 +156,9 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -156,8 +156,9 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
log.Fatal(err) log.Fatal(err)
} }
// find the first id and write info // find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil { if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr) log.Errorf("keepalive etcd node error: %v", kaerr)
......
...@@ -843,7 +843,8 @@ public: ...@@ -843,7 +843,8 @@ public:
bool useSparseUpdater); bool useSparseUpdater);
static ParameterUpdater* createNewRemoteUpdater( static ParameterUpdater* createNewRemoteUpdater(
OptimizationConfig* config, OptimizationConfig* config,
const std::string pserverSpec) throw(UnsupportError); const std::string pserverSpec,
const bool useEtcd) throw(UnsupportError);
~ParameterUpdater(); ~ParameterUpdater();
/** /**
......
...@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ...@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
ParameterUpdater *ParameterUpdater::createNewRemoteUpdater( ParameterUpdater *ParameterUpdater::createNewRemoteUpdater(
OptimizationConfig *config, OptimizationConfig *config,
const std::string pserverSpec) throw(UnsupportError) { const std::string pserverSpec,
const bool useEtcd) throw(UnsupportError) {
#ifndef PADDLE_WITHOUT_GOLANG #ifndef PADDLE_WITHOUT_GOLANG
auto updater = new ParameterUpdater(); auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::NewRemoteParameterUpdater( updater->m->updater.reset(new paddle::NewRemoteParameterUpdater(
config->m->getConfig(), pserverSpec)); config->m->getConfig(), pserverSpec, useEtcd));
return updater; return updater;
#else #else
throw UnsupportError(); throw UnsupportError();
......
...@@ -155,7 +155,8 @@ RUN apt-get update &&\ ...@@ -155,7 +155,8 @@ RUN apt-get update &&\
paddle version paddle version
${DOCKERFILE_CUDNN_DSO} ${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_GPU_ENV} ${DOCKERFILE_GPU_ENV}
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
# default command shows the paddle version and exit # default command shows the paddle version and exit
CMD ["paddle", "version"] CMD ["paddle", "version"]
EOF EOF
...@@ -28,6 +28,17 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater( ...@@ -28,6 +28,17 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater(
newGradients_(nullptr), newGradients_(nullptr),
pserverSpec_(pserverSpec) {} pserverSpec_(pserverSpec) {}
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config,
const std::string pserverSpec,
const bool useEtcd)
: trainerConfig_(config),
parameterClient_(-1),
newParameters_(nullptr),
newGradients_(nullptr),
pserverSpec_(pserverSpec),
useEtcd_(useEtcd) {}
void NewRemoteParameterUpdater::init( void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) { const std::vector<ParameterPtr> &parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
...@@ -38,8 +49,13 @@ void NewRemoteParameterUpdater::init( ...@@ -38,8 +49,13 @@ void NewRemoteParameterUpdater::init(
} }
// create parameter server client. // create parameter server client.
if (useEtcd_) {
parameterClient_ = paddle_new_etcd_pserver_client(
(char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0);
} else {
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0); FLAGS_trainer_id == 0);
}
// init new parameter and gradient. // init new parameter and gradient.
newParameters_ = initNewParameter(PARAMETER_VALUE); newParameters_ = initNewParameter(PARAMETER_VALUE);
......
...@@ -32,6 +32,9 @@ class NewRemoteParameterUpdater : public ParameterUpdater { ...@@ -32,6 +32,9 @@ class NewRemoteParameterUpdater : public ParameterUpdater {
public: public:
NewRemoteParameterUpdater(const OptimizationConfig& config, NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec); const std::string pserverSpec);
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec,
const bool useEtcd);
~NewRemoteParameterUpdater() { ~NewRemoteParameterUpdater() {
releaseNewParameter(newParameters_); releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_); releaseNewParameter(newGradients_);
...@@ -111,6 +114,8 @@ protected: ...@@ -111,6 +114,8 @@ protected:
paddle_parameter** newGradients_; paddle_parameter** newGradients_;
/// the specification of parameter server "host1:port,host1:port" /// the specification of parameter server "host1:port,host1:port"
std::string pserverSpec_; std::string pserverSpec_;
/// true if pserverSpec_ is etcd endpoint, else pserverSpec_ is pserver addr
bool useEtcd_;
}; };
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,8 @@ import importlib ...@@ -22,6 +22,8 @@ import importlib
import paddle.v2.dataset import paddle.v2.dataset
import cPickle import cPickle
import glob import glob
import cPickle as pickle
import random
__all__ = [ __all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader', 'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
...@@ -170,8 +172,6 @@ def convert(output_path, ...@@ -170,8 +172,6 @@ def convert(output_path,
name_prefix, name_prefix,
max_lines_to_shuffle=1000): max_lines_to_shuffle=1000):
import recordio import recordio
import cPickle as pickle
import random
""" """
Convert data from reader to recordio format files. Convert data from reader to recordio format files.
...@@ -201,7 +201,7 @@ def convert(output_path, ...@@ -201,7 +201,7 @@ def convert(output_path,
def write_data(w, lines): def write_data(w, lines):
random.shuffle(lines) random.shuffle(lines)
for i, d in enumerate(lines): for i, d in enumerate(lines):
d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL) d = cPickle.dumps(d)
w[i % num_shards].write(d) w[i % num_shards].write(d)
w = open_writers() w = open_writers()
......
...@@ -10,8 +10,9 @@ class client(object): ...@@ -10,8 +10,9 @@ class client(object):
client is a client to the master server. client is a client to the master server.
""" """
def __init__(self, addr, buf_size): def __init__(self, etcd_endpoints, timeout, buf_size):
self.c = lib.paddle_new_master_client(addr, buf_size) self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout,
buf_size)
def close(self): def close(self):
lib.paddle_release_master_client(self.c) lib.paddle_release_master_client(self.c)
......
...@@ -46,12 +46,12 @@ class Optimizer(object): ...@@ -46,12 +46,12 @@ class Optimizer(object):
return swig_api.ParameterUpdater.createRemoteUpdater( return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater) self.__opt_conf__, pass_num, use_sparse_updater)
def __create_new_remote_updater__(self, pserver_spec): def __create_new_remote_updater__(self, pserver_spec, use_etcd):
return swig_api.ParameterUpdater.createNewRemoteUpdater( return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec) self.__opt_conf__, pserver_spec, use_etcd)
def create_updater(self, is_local, num_passes, use_sparse_updater, def create_updater(self, is_local, num_passes, use_sparse_updater,
pserver_spec): pserver_spec, use_etcd):
""" """
create proper parameter_updater by configuration. create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater :param is_local: create local or remote parameter updater
...@@ -77,7 +77,7 @@ class Optimizer(object): ...@@ -77,7 +77,7 @@ class Optimizer(object):
num_passes, use_sparse_updater) num_passes, use_sparse_updater)
else: else:
parameter_updater = self.__create_new_remote_updater__( parameter_updater = self.__create_new_remote_updater__(
pserver_spec) pserver_spec, use_etcd)
return parameter_updater return parameter_updater
......
...@@ -45,7 +45,8 @@ class SGD(object): ...@@ -45,7 +45,8 @@ class SGD(object):
update_equation, update_equation,
extra_layers=None, extra_layers=None,
is_local=True, is_local=True,
pserver_spec=None): pserver_spec=None,
use_etcd=True):
if not isinstance(parameters, v2_parameters.Parameters): if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters') raise TypeError('parameters should be parameters')
...@@ -61,6 +62,7 @@ class SGD(object): ...@@ -61,6 +62,7 @@ class SGD(object):
self.__topology_in_proto__ = topology.proto() self.__topology_in_proto__ = topology.proto()
self.__is_local__ = is_local self.__is_local__ = is_local
self.__pserver_spec__ = pserver_spec self.__pserver_spec__ = pserver_spec
self.__use_etcd__ = use_etcd
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
# # In local mode, disable sparse_remote_update. # # In local mode, disable sparse_remote_update.
...@@ -127,7 +129,7 @@ class SGD(object): ...@@ -127,7 +129,7 @@ class SGD(object):
self.__parameter_updater__ = self.__optimizer__.create_updater( self.__parameter_updater__ = self.__optimizer__.create_updater(
self.__is_local__, num_passes, self.__use_sparse_updater__, self.__is_local__, num_passes, self.__use_sparse_updater__,
self.__pserver_spec__) self.__pserver_spec__, self.__use_etcd__)
self.__parameter_updater__.init(self.__gradient_machine__) self.__parameter_updater__.init(self.__gradient_machine__)
self.__gradient_machine__.start() self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册