提交 91689b6b 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2973 from helinwang/save_model_1

Implement distributed training save model, improve master.NewClient i…
...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future. ...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election ### Trainer Election
One trainer will be elected as the one to save the model. When using One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to etcd, trainer ID is a randomly generated UUID, the trainer will
elect one trainer. When not using etcd, unique trainer IDs will be contact the master server requesting to save the model, and find out
given by the administrator, the trainer whose ID is "0" is elected to if itself is elected. When the master server is not used, unique
save the model. trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path ### Model Save Path
......
...@@ -22,6 +22,9 @@ package main ...@@ -22,6 +22,9 @@ package main
#define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1 #define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client; typedef int paddle_master_client;
*/ */
import "C" import "C"
...@@ -33,7 +36,6 @@ import ( ...@@ -33,7 +36,6 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
} }
//export paddle_new_etcd_master_client //export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints) p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{ endpoints := strings.Split(p, ",")
Endpoints: strings.Split(p, ","), c, err := master.NewClient(
DialTimeout: time.Second * time.Duration(timeout), master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
}) master.WithBuffer(bufSize),
)
if err != nil { if err != nil {
panic(err) panic(err)
} }
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
ch := make(chan string, 1) c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
ch <- a if err != nil {
c := master.NewClient(ch, bufSize) panic(err)
}
return add(c) return add(c)
} }
...@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value: // paddle_next_record gets the nexts training record.
// 0:ok //
// -1:error // returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
...@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { ...@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size) return C.int(size)
} }
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free //export mem_free
func mem_free(p unsafe.Pointer) { func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so // "free" may be a better name for this function, but doing so
......
...@@ -16,17 +16,20 @@ package master ...@@ -16,17 +16,20 @@ package master
import ( import (
"os" "os"
"sync"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan record ch chan record
initChOnce sync.Once
} }
type record struct { type record struct {
...@@ -34,24 +37,83 @@ type record struct { ...@@ -34,24 +37,83 @@ type record struct {
err error err error
} }
// NewClient creates a new Client. // WithBuffer sets the client to buffer the training record.
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client { func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh) for _, opt := range opts {
go c.getRecords() err := opt(c)
return c if err != nil {
return nil, err
}
}
return c, nil
} }
func (c *Client) getRecords() { func (c *Client) getRecords() {
for { for {
t, err := c.getTask() t, err := c.getTask()
if err != nil { if err != nil {
// getTask call.
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err) log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
continue continue
...@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error { ...@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() ([]byte, error) { func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch r := <-c.ch
return r.r, r.err return r.r, r.err
} }
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
...@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) { ...@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
curAddr := make(chan string, 1) c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10))
curAddr <- fmt.Sprintf(":%d", p) if err != nil {
c := master.NewClient(curAddr, 10) panic(err)
}
err = c.SetDataset([]string{path}) err = c.SetDataset([]string{path})
if err != nil { if err != nil {
panic(err) panic(err)
......
...@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
} }
// GetKey gets the value by the specify key. // GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key) resp, err := c.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
...@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { ...@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil return string(v), nil
} }
// WatchKey watches the specify key and send to valChan if there is some event. // watchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) { func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key) rch := c.Watch(context.Background(), key)
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
......
...@@ -78,9 +78,10 @@ type Service struct { ...@@ -78,9 +78,10 @@ type Service struct {
ready chan struct{} ready chan struct{}
store Store store Store
mu sync.Mutex mu sync.Mutex
initDone bool initDone bool
taskQueues taskQueues taskQueues taskQueues
savingTrainer string
} }
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
...@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
// //
// SetDataset can be call multiple times. But only the first call will // SetDataset can be call multiple times. But only the first call will
// be honored. // be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error { func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 { if len(globPaths) == 0 {
return errors.New("no dataset specified") return errors.New("no dataset specified")
} }
...@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields { ...@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
} }
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error { func (s *Service) GetTask(_ int, task *Task) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
// TaskFinished tell the service that a task is finished. // TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error { func (s *Service) TaskFinished(taskID int, _ *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
} }
// TaskFailed tells the service that a task is failed. // TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { func (s *Service) TaskFailed(meta TaskMeta, _ *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.processFailedTask(t, meta.Epoch) s.processFailedTask(t, meta.Epoch)
return nil return nil
} }
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
...@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client) remove(client)
} }
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { if selected := c.BeginInitParams(); selected {
return 1 return 1
} }
return C.PSERVER_OK return 0
} }
//export paddle_init_param //export paddle_init_param
...@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK return C.PSERVER_OK
} }
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored func main() {} // Required but ignored
...@@ -111,9 +111,5 @@ retry: ...@@ -111,9 +111,5 @@ retry:
getParams(c); getParams(c);
} }
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0; return 0;
} }
...@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { ...@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil return ps, nil
} }
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 { func strHash(s string) uint32 {
h := fnv.New32a() h := fnv.New32a()
_, _ = h.Write([]byte(s)) _, _ = h.Write([]byte(s))
......
...@@ -164,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -164,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
} }
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
...@@ -185,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -185,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter // FinishInitParams tells the parameter server that the parameter
// initialization has finished. // initialization has finished.
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { func (s *Service) FinishInitParams(_ int, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
...@@ -198,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { ...@@ -198,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter // SendGrad sends gradient to parameter servers for parameter
// optimization. // optimization.
func (s *Service) SendGrad(g Gradient, dummy *int) error { func (s *Service) SendGrad(g Gradient, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
......
...@@ -33,6 +33,7 @@ import networks ...@@ -33,6 +33,7 @@ import networks
import minibatch import minibatch
import plot import plot
import image import image
import model
__all__ = [ __all__ = [
'optimizer', 'optimizer',
...@@ -54,6 +55,7 @@ __all__ = [ ...@@ -54,6 +55,7 @@ __all__ = [
'evaluator', 'evaluator',
'image', 'image',
'master', 'master',
'model',
] ]
......
...@@ -10,11 +10,31 @@ class client(object): ...@@ -10,11 +10,31 @@ class client(object):
client is a client to the master server. client is a client to the master server.
""" """
def __init__(self, etcd_endpoints, timeout, buf_size): def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout, self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
buf_size) buf_size)
def close(self): def request_save_model(self, trainer_id, block_ms):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
def release(self):
lib.paddle_release_master_client(self.c) lib.paddle_release_master_client(self.c)
self.c = None self.c = None
...@@ -27,10 +47,13 @@ class client(object): ...@@ -27,10 +47,13 @@ class client(object):
holder[idx] = c_ptr holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths)) lib.paddle_set_dataset(self.c, holder, len(paths))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def next_record(self): def next_record(self):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p = ctypes.c_char_p() p = ctypes.c_char_p()
ret = ctypes.pointer(p) ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret) size = lib.paddle_next_record(self.c, ret)
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import errno
import uuid
import paddle.v2.master
__all__ = ["save_model", "load_model"]
trainer_id = str(uuid.uuid4())
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def save_model(parameters, path):
need_request = "KUBERNETES_SERVICE_HOST" in os.environ.keys()
if need_request:
# TODO(helin): figure out how MPI trains, since MPI only save
# model when trainer_id == "0", we can consolidate the logic
# here.
# TODO(helin): change this environment variable name from
# MASTER_IP to ETCD_IP
etcd_name = "MASTER_IP"
if etcd_name not in os.environ.keys():
raise Exception('not find ' + etcd_name +
' in environment variable.')
etcd_ip = os.environ.get(etcd_name)
client = master.client("http://" + etcd_ip + ":2379", 5, 0)
r = client.request_save_model(trainer_id, 5000)
if r == 0:
# do not need to save
return
elif r < 0:
# error
return
else:
# save model
path = os.path.join(path, trainer_id)
path = os.path.join(path, "model.tar")
mkdir_p(path)
with open(path, 'wb') as f:
parameters.to_tar(f)
def load_model(parameters, path):
with open(path, 'rb') as f:
parameters.from_tar(f)
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Creator package contains some simple reader creator, which could be used in user Creator package contains some simple reader creator, which could
program. be used in user program.
""" """
__all__ = ['np_array', 'text_file', "recordio"] __all__ = ['np_array', 'text_file', "recordio"]
...@@ -59,7 +59,7 @@ def text_file(path): ...@@ -59,7 +59,7 @@ def text_file(path):
def recordio_local(paths, buf_size=100): def recordio_local(paths, buf_size=100):
""" """
Creates a data reader from given RecordIO file paths separated by ",", Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported. glob pattern is supported.
:path: path of recordio files. :path: path of recordio files.
:returns: data reader of recordio files. :returns: data reader of recordio files.
...@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100): ...@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100):
def recordio(paths, buf_size=100): def recordio(paths, buf_size=100):
""" """
Creates a data reader that outputs record one one by one Creates a data reader that outputs record one one by one
from given local or cloud recordio path. from given local or cloud recordio path.
:path: path of recordio files. :path: path of recordio files.
:returns: data reader of recordio files. :returns: data reader of recordio files.
...@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100): ...@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
host_name = "MASTER_SERVICE_HOST" host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys(): if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environ.') raise Exception('not find ' + host_name + ' in environment variable.')
addr = os.environ(host) addr = os.environ(host)
...@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100): ...@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
break break
yield r yield r
c.close() c.release()
return reader return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册