提交 91689b6b 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2973 from helinwang/save_model_1

Implement distributed training save model, improve master.NewClient i…
......@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
etcd, trainer ID is a randomly generated UUID, the trainer will
contact the master server requesting to save the model, and find out
if itself is elected. When the master server is not used, unique
trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path
......
......@@ -22,6 +22,9 @@ package main
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
*/
import "C"
......@@ -33,7 +36,6 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
......@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
endpoints := strings.Split(p, ",")
c, err := master.NewClient(
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
master.WithBuffer(bufSize),
)
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
if err != nil {
panic(err)
}
return add(c)
}
......@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:error
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
......@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
......
......@@ -16,17 +16,20 @@ package master
import (
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan record
conn *connection.Conn
ch chan record
initChOnce sync.Once
}
type record struct {
......@@ -34,24 +37,83 @@ type record struct {
err error
}
// NewClient creates a new Client.
// WithBuffer sets the client to buffer the training record.
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client {
func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh)
go c.getRecords()
return c
for _, opt := range opts {
err := opt(c)
if err != nil {
return nil, err
}
}
return c, nil
}
func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// getTask call.
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
......@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch
return r.r, r.err
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
......@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10))
if err != nil {
panic(err)
}
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
......
......@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
......@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
// watchKey watches the specify key and send to valChan if there is some event.
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
......
......@@ -78,9 +78,10 @@ type Service struct {
ready chan struct{}
store Store
mu sync.Mutex
initDone bool
taskQueues taskQueues
mu sync.Mutex
initDone bool
taskQueues taskQueues
savingTrainer string
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
......@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
......@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
func (s *Service) GetTask(_ int, task *Task) error {
select {
case <-s.ready:
}
......@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
func (s *Service) TaskFinished(taskID int, _ *int) error {
select {
case <-s.ready:
}
......@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
func (s *Service) TaskFailed(meta TaskMeta, _ *int) error {
select {
case <-s.ready:
}
......@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.processFailedTask(t, meta.Epoch)
return nil
}
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
......@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client)
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
if selected := c.BeginInitParams(); selected {
return 1
}
return C.PSERVER_OK
return 0
}
//export paddle_init_param
......@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK
}
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored
......@@ -111,9 +111,5 @@ retry:
getParams(c);
}
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0;
}
......@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil
}
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 {
h := fnv.New32a()
_, _ = h.Write([]byte(s))
......
......@@ -164,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
......@@ -185,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
func (s *Service) FinishInitParams(_ int, _ *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
......@@ -198,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func (s *Service) SendGrad(g Gradient, dummy *int) error {
func (s *Service) SendGrad(g Gradient, _ *int) error {
select {
case <-s.initialized:
default:
......
......@@ -33,6 +33,7 @@ import networks
import minibatch
import plot
import image
import model
__all__ = [
'optimizer',
......@@ -54,6 +55,7 @@ __all__ = [
'evaluator',
'image',
'master',
'model',
]
......
......@@ -10,11 +10,31 @@ class client(object):
client is a client to the master server.
"""
def __init__(self, etcd_endpoints, timeout, buf_size):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout,
def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
buf_size)
def close(self):
def request_save_model(self, trainer_id, block_ms):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
def release(self):
lib.paddle_release_master_client(self.c)
self.c = None
......@@ -27,10 +47,13 @@ class client(object):
holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def next_record(self):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import errno
import uuid
import paddle.v2.master
__all__ = ["save_model", "load_model"]
trainer_id = str(uuid.uuid4())
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def save_model(parameters, path):
need_request = "KUBERNETES_SERVICE_HOST" in os.environ.keys()
if need_request:
# TODO(helin): figure out how MPI trains, since MPI only save
# model when trainer_id == "0", we can consolidate the logic
# here.
# TODO(helin): change this environment variable name from
# MASTER_IP to ETCD_IP
etcd_name = "MASTER_IP"
if etcd_name not in os.environ.keys():
raise Exception('not find ' + etcd_name +
' in environment variable.')
etcd_ip = os.environ.get(etcd_name)
client = master.client("http://" + etcd_ip + ":2379", 5, 0)
r = client.request_save_model(trainer_id, 5000)
if r == 0:
# do not need to save
return
elif r < 0:
# error
return
else:
# save model
path = os.path.join(path, trainer_id)
path = os.path.join(path, "model.tar")
mkdir_p(path)
with open(path, 'wb') as f:
parameters.to_tar(f)
def load_model(parameters, path):
with open(path, 'rb') as f:
parameters.from_tar(f)
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Creator package contains some simple reader creator, which could be used in user
program.
Creator package contains some simple reader creator, which could
be used in user program.
"""
__all__ = ['np_array', 'text_file', "recordio"]
......@@ -59,7 +59,7 @@ def text_file(path):
def recordio_local(paths, buf_size=100):
"""
Creates a data reader from given RecordIO file paths separated by ",",
Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported.
:path: path of recordio files.
:returns: data reader of recordio files.
......@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100):
def recordio(paths, buf_size=100):
"""
Creates a data reader that outputs record one one by one
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
:path: path of recordio files.
:returns: data reader of recordio files.
......@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environ.')
raise Exception('not find ' + host_name + ' in environment variable.')
addr = os.environ(host)
......@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
break
yield r
c.close()
c.release()
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册