Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
3ff0a9fb
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3ff0a9fb
编写于
7月 19, 2017
作者:
H
Helin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement distributed training save model, improve master.NewClient interface
上级
bea40565
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
308 addition
and
105 deletion
+308
-105
doc/design/cluster_train/save_model.md
doc/design/cluster_train/save_model.md
+5
-4
go/master/c/client.go
go/master/c/client.go
+42
-19
go/master/client.go
go/master/client.go
+85
-9
go/master/client_test.go
go/master/client_test.go
+5
-3
go/master/etcd_client.go
go/master/etcd_client.go
+4
-4
go/master/service.go
go/master/service.go
+47
-7
go/pserver/client/c/cclient.go
go/pserver/client/c/cclient.go
+7
-14
go/pserver/client/c/test/test_cclient.c
go/pserver/client/c/test/test_cclient.c
+0
-4
go/pserver/client/client.go
go/pserver/client/client.go
+0
-26
go/pserver/service.go
go/pserver/service.go
+3
-3
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+2
-0
python/paddle/v2/master/client.py
python/paddle/v2/master/client.py
+29
-6
python/paddle/v2/model.py
python/paddle/v2/model.py
+73
-0
python/paddle/v2/reader/creator.py
python/paddle/v2/reader/creator.py
+6
-6
未找到文件。
doc/design/cluster_train/save_model.md
浏览文件 @
3ff0a9fb
...
...
@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
etcd, trainer ID is a randomly generated UUID, the trainer will
contact the master server requesting to save the model, and find out
if itself is elected. When the master server is not used, unique
trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path
...
...
go/master/c/client.go
浏览文件 @
3ff0a9fb
...
...
@@ -33,7 +33,6 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log
"github.com/sirupsen/logrus"
)
...
...
@@ -65,32 +64,32 @@ func remove(client C.paddle_master_client) *master.Client {
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func
paddle_new_etcd_master_client
(
etcdEndpoints
*
C
.
char
,
timeout
int
,
bufSize
int
)
C
.
paddle_master_client
{
p
:=
C
.
GoString
(
etcdEndpoints
)
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
strings
.
Split
(
p
,
","
),
DialTimeout
:
time
.
Second
*
time
.
Duration
(
timeout
),
})
endpoints
:=
strings
.
Split
(
p
,
","
)
c
,
err
:=
master
.
NewClient
(
master
.
WithEtcd
(
endpoints
,
time
.
Duration
(
timeout
)
*
time
.
Second
),
master
.
WithBuffer
(
bufSize
),
)
if
err
!=
nil
{
panic
(
err
)
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
master
.
GetKey
(
cli
,
master
.
DefaultAddrPath
,
timeout
)
if
err
!=
nil
{
panic
(
err
)
}
ch
<-
a
go
master
.
WatchKey
(
cli
,
master
.
DefaultAddrPath
,
ch
)
c
:=
master
.
NewClient
(
ch
,
bufSize
)
return
add
(
c
)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func
paddle_new_master_client
(
addr
*
C
.
char
,
bufSize
int
)
C
.
paddle_master_client
{
a
:=
C
.
GoString
(
addr
)
ch
:=
make
(
chan
string
,
1
)
ch
<-
a
c
:=
master
.
NewClient
(
ch
,
bufSize
)
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
a
),
master
.
WithBuffer
(
bufSize
))
if
err
!=
nil
{
panic
(
err
)
}
return
add
(
c
)
}
...
...
@@ -117,9 +116,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return
C
.
PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:error
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
c
:=
get
(
client
)
...
...
@@ -143,6 +143,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return
C
.
int
(
size
)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if does the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func
paddle_request_save_model
(
client
C
.
paddle_master_client
,
trainerID
string
,
blockMS
int
)
C
.
int
{
c
:=
get
(
client
)
need
,
err
:=
c
.
RequestSaveModel
(
trainerID
,
time
.
Duration
(
blockMS
)
*
time
.
Millisecond
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
-
1
}
if
need
{
return
1
}
return
0
}
//export mem_free
func
mem_free
(
p
unsafe
.
Pointer
)
{
// "free" may be a better name for this function, but doing so
...
...
go/master/client.go
浏览文件 @
3ff0a9fb
...
...
@@ -16,17 +16,20 @@ package master
import
(
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log
"github.com/sirupsen/logrus"
)
// Client is the client of the master server.
type
Client
struct
{
conn
*
connection
.
Conn
ch
chan
record
conn
*
connection
.
Conn
ch
chan
record
initChOnce
sync
.
Once
}
type
record
struct
{
...
...
@@ -34,24 +37,83 @@ type record struct {
err
error
}
//
NewClient creates a new Client
.
//
WithBuffer sets the client to buffer the training record
.
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func
NewClient
(
addrCh
<-
chan
string
,
bufSize
int
)
*
Client
{
func
WithBuffer
(
bufSize
int
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
if
bufSize
<=
0
{
return
nil
}
c
.
initChOnce
.
Do
(
func
()
{
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
getRecords
()
})
return
nil
}
}
// WithAddr sets the client to use fixed master address.
func
WithAddr
(
addr
string
)
func
(
c
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
ch
:=
make
(
chan
string
,
1
)
ch
<-
addr
go
c
.
monitorMaster
(
ch
)
return
nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func
WithEtcd
(
endpoints
[]
string
,
timeout
time
.
Duration
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
DialTimeout
:
timeout
,
})
if
err
!=
nil
{
return
err
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
GetKey
(
cli
,
DefaultAddrPath
,
timeout
)
if
err
!=
nil
{
return
err
}
if
a
!=
""
{
// Master is registered, send to the master address
// channel.
ch
<-
a
}
go
watchKey
(
cli
,
DefaultAddrPath
,
ch
)
go
c
.
monitorMaster
(
ch
)
return
nil
}
}
// NewClient creates a new Client.
func
NewClient
(
opts
...
func
(
*
Client
)
error
)
(
*
Client
,
error
)
{
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
monitorMaster
(
addrCh
)
go
c
.
getRecords
()
return
c
for
_
,
opt
:=
range
opts
{
err
:=
opt
(
c
)
if
err
!=
nil
{
return
nil
,
err
}
}
return
c
,
nil
}
func
(
c
*
Client
)
getRecords
()
{
for
{
t
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
// getTask call.
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
time
.
Sleep
(
3
*
time
.
Second
)
continue
...
...
@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
c
.
initChOnce
.
Do
(
func
()
{
// initialize with in case WithBuffer is not used.
c
.
ch
=
make
(
chan
record
,
0
)
go
c
.
getRecords
()
})
r
:=
<-
c
.
ch
return
r
.
r
,
r
.
err
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func
(
c
*
Client
)
RequestSaveModel
(
trainerID
string
,
blockDur
time
.
Duration
)
(
bool
,
error
)
{
var
need
bool
err
:=
c
.
conn
.
Call
(
"Service.RequestSaveModel"
,
SaveModelRequest
{
TrainerID
:
trainerID
,
BlockDur
:
blockDur
},
&
need
)
return
need
,
err
}
go/master/client_test.go
浏览文件 @
3ff0a9fb
...
...
@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
}
curAddr
:=
make
(
chan
string
,
1
)
curAddr
<-
fmt
.
Sprintf
(
":%d"
,
p
)
c
:=
master
.
NewClient
(
curAddr
,
10
)
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
10
))
if
err
!=
nil
{
panic
(
err
)
}
err
=
c
.
SetDataset
([]
string
{
path
})
if
err
!=
nil
{
panic
(
err
)
...
...
go/master/etcd_client.go
浏览文件 @
3ff0a9fb
...
...
@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
}
// GetKey gets the value by the specify key.
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
int
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
)
)
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
time
.
Duration
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
out
)
resp
,
err
:=
c
.
Get
(
ctx
,
key
)
cancel
()
if
err
!=
nil
{
...
...
@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return
string
(
v
),
nil
}
//
W
atchKey watches the specify key and send to valChan if there is some event.
func
W
atchKey
(
c
*
clientv3
.
Client
,
key
string
,
valChan
chan
<-
string
)
{
//
w
atchKey watches the specify key and send to valChan if there is some event.
func
w
atchKey
(
c
*
clientv3
.
Client
,
key
string
,
valChan
chan
<-
string
)
{
rch
:=
c
.
Watch
(
context
.
Background
(),
key
)
for
wresp
:=
range
rch
{
for
_
,
ev
:=
range
wresp
.
Events
{
...
...
go/master/service.go
浏览文件 @
3ff0a9fb
...
...
@@ -78,9 +78,10 @@ type Service struct {
ready
chan
struct
{}
store
Store
mu
sync
.
Mutex
initDone
bool
taskQueues
taskQueues
mu
sync
.
Mutex
initDone
bool
taskQueues
taskQueues
savingTrainer
string
}
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
...
...
@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
_
*
int
)
error
{
if
len
(
globPaths
)
==
0
{
return
errors
.
New
(
"no dataset specified"
)
}
...
...
@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
func
(
s
*
Service
)
GetTask
(
_
int
,
task
*
Task
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
_
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
// TaskFailed tells the service that a task is failed.
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
_
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s
.
processFailedTask
(
t
,
meta
.
Epoch
)
return
nil
}
// SaveModelRequest is the request for saving model
type
SaveModelRequest
struct
{
TrainerID
string
BlockDur
time
.
Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func
(
s
*
Service
)
RequestSaveModel
(
req
SaveModelRequest
,
need
*
bool
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
req
.
TrainerID
==
""
{
return
errors
.
New
(
"trainer id is empty"
)
}
if
s
.
savingTrainer
==
""
{
*
need
=
true
}
else
{
if
req
.
TrainerID
==
s
.
savingTrainer
{
// save trainer asked to save model again
*
need
=
true
}
else
{
*
need
=
false
}
}
if
*
need
{
s
.
savingTrainer
=
req
.
TrainerID
time
.
AfterFunc
(
req
.
BlockDur
,
func
()
{
s
.
mu
.
Lock
()
s
.
savingTrainer
=
""
s
.
mu
.
Unlock
()
})
}
return
nil
}
go/pserver/client/c/cclient.go
浏览文件 @
3ff0a9fb
...
...
@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove
(
client
)
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
func
paddle_begin_init_params
(
client
C
.
paddle_pserver_client
)
C
.
int
{
c
:=
get
(
client
)
if
selected
:=
c
.
BeginInitParams
();
selected
{
return
1
}
return
C
.
PSERVER_OK
return
0
}
//export paddle_init_param
...
...
@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return
C
.
PSERVER_OK
}
//export paddle_save_model
func
paddle_save_model
(
client
C
.
paddle_pserver_client
,
path
*
C
.
char
)
C
.
int
{
p
:=
C
.
GoString
(
path
)
c
:=
get
(
client
)
err
:=
c
.
Save
(
p
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
C
.
PSERVER_ERROR
}
return
C
.
PSERVER_OK
}
func
main
()
{}
// Required but ignored
go/pserver/client/c/test/test_cclient.c
浏览文件 @
3ff0a9fb
...
...
@@ -111,9 +111,5 @@ retry:
getParams
(
c
);
}
if
(
paddle_save_model
(
c
,
"/tmp/"
))
{
fail
();
}
return
0
;
}
go/pserver/client/client.go
浏览文件 @
3ff0a9fb
...
...
@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return
ps
,
nil
}
// Save indicates parameters to save the parameter to the given path.
func
(
c
*
Client
)
Save
(
path
string
)
error
{
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
for
_
,
p
:=
range
c
.
pservers
{
err
:=
p
.
Call
(
"Service.Save"
,
path
,
nil
)
errCh
<-
err
}
recv
:=
0
for
err
:=
range
errCh
{
if
err
!=
nil
{
return
err
}
recv
++
if
recv
==
len
(
c
.
pservers
)
{
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return
nil
}
func
strHash
(
s
string
)
uint32
{
h
:=
fnv
.
New32a
()
_
,
_
=
h
.
Write
([]
byte
(
s
))
...
...
go/pserver/service.go
浏览文件 @
3ff0a9fb
...
...
@@ -164,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
// InitParam initializes a parameter.
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
_
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
return
errors
.
New
(
AlreadyInitialized
)
...
...
@@ -185,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func
(
s
*
Service
)
FinishInitParams
(
dummy0
int
,
dummy1
*
int
)
error
{
func
(
s
*
Service
)
FinishInitParams
(
_
int
,
_
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
return
errors
.
New
(
AlreadyInitialized
)
...
...
@@ -198,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
_
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
default
:
...
...
python/paddle/v2/__init__.py
浏览文件 @
3ff0a9fb
...
...
@@ -33,6 +33,7 @@ import networks
import
minibatch
import
plot
import
image
import
model
__all__
=
[
'optimizer'
,
...
...
@@ -54,6 +55,7 @@ __all__ = [
'evaluator'
,
'image'
,
'master'
,
'model'
,
]
...
...
python/paddle/v2/master/client.py
浏览文件 @
3ff0a9fb
...
...
@@ -10,11 +10,31 @@ class client(object):
client is a client to the master server.
"""
def
__init__
(
self
,
etcd_endpoints
,
timeout
,
buf_size
):
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
,
def
__init__
(
self
,
etcd_endpoints
,
timeout
_sec
,
buf_size
=
0
):
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
_sec
,
buf_size
)
def
close
(
self
):
def
request_save_model
(
self
,
trainer_id
,
block_ms
):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return
lib
.
paddle_request_save_model
(
self
.
c
,
trainer_id
,
block_ms
)
def
release
(
self
):
lib
.
paddle_release_master_client
(
self
.
c
)
self
.
c
=
None
...
...
@@ -27,10 +47,13 @@ class client(object):
holder
[
idx
]
=
c_ptr
lib
.
paddle_set_dataset
(
self
.
c
,
holder
,
len
(
paths
))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def
next_record
(
self
):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p
=
ctypes
.
c_char_p
()
ret
=
ctypes
.
pointer
(
p
)
size
=
lib
.
paddle_next_record
(
self
.
c
,
ret
)
...
...
python/paddle/v2/model.py
0 → 100644
浏览文件 @
3ff0a9fb
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
errno
import
uuid
import
paddle.v2.master
__all__
=
[
"save_model"
,
"load_model"
]
trainer_id
=
str
(
uuid
.
uuid4
())
def
mkdir_p
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
exc
:
if
exc
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
pass
else
:
raise
def
save_model
(
parameters
,
path
):
need_request
=
"KUBERNETES_SERVICE_HOST"
in
os
.
environ
.
keys
()
if
need_request
:
# TODO(helin): figure out how MPI trains, since MPI only save
# model when trainer_id == "0", we can consolidate the logic
# here.
# TODO(helin): change this environment variable name from
# MASTER_IP to ETCD_IP
etcd_name
=
"MASTER_IP"
if
etcd_name
not
in
os
.
environ
.
keys
():
raise
Exception
(
'not find '
+
etcd_name
+
' in environment variable.'
)
etcd_ip
=
os
.
environ
.
get
(
etcd_name
)
client
=
master
.
client
(
"http://"
+
etcd_ip
+
":2379"
,
5
,
0
)
r
=
client
.
request_save_model
(
trainer_id
,
5000
)
if
r
==
0
:
# do not need to save
return
elif
r
<
0
:
# error
return
else
:
# save model
path
=
os
.
path
.
join
(
path
,
trainer_id
)
path
=
os
.
path
.
join
(
path
,
"model.tar"
)
mkdir_p
(
path
)
with
open
(
path
,
'wb'
)
as
f
:
parameters
.
to_tar
(
f
)
def
load_model
(
parameters
,
path
):
with
open
(
path
,
'rb'
)
as
f
:
parameters
.
from_tar
(
f
)
python/paddle/v2/reader/creator.py
浏览文件 @
3ff0a9fb
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Creator package contains some simple reader creator, which could
be used in user
program.
Creator package contains some simple reader creator, which could
be used in user
program.
"""
__all__
=
[
'np_array'
,
'text_file'
,
"recordio"
]
...
...
@@ -59,7 +59,7 @@ def text_file(path):
def
recordio_local
(
paths
,
buf_size
=
100
):
"""
Creates a data reader from given RecordIO file paths separated by ",",
Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported.
:path: path of recordio files.
:returns: data reader of recordio files.
...
...
@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100):
def
recordio
(
paths
,
buf_size
=
100
):
"""
Creates a data reader that outputs record one one by one
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
:path: path of recordio files.
:returns: data reader of recordio files.
...
...
@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
host_name
=
"MASTER_SERVICE_HOST"
if
host_name
not
in
os
.
environ
.
keys
():
raise
Exception
(
'not find '
+
host_name
+
' in environ.'
)
raise
Exception
(
'not find '
+
host_name
+
' in environ
ment variable
.'
)
addr
=
os
.
environ
(
host
)
...
...
@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
break
yield
r
c
.
clo
se
()
c
.
relea
se
()
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录