Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f1330e21
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1330e21
编写于
7月 03, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"saving checkpoint"
上级
5ef1425a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
80 addition
and
5 deletion
+80
-5
go/pserver/service.go
go/pserver/service.go
+74
-5
go/pserver/service_test.go
go/pserver/service_test.go
+6
-0
未找到文件。
go/pserver/service.go
浏览文件 @
f1330e21
package
pserver
package
pserver
import
(
import
(
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"errors"
"errors"
"fmt"
"fmt"
"os"
"strconv"
"sync"
"sync"
"time"
log
"github.com/sirupsen/logrus"
)
)
// ElementType is the type of elements of a Parameter.
// ElementType is the type of elements of a Parameter.
...
@@ -14,6 +24,10 @@ const (
...
@@ -14,6 +24,10 @@ const (
Uninitialized
=
"pserver not fully initialized"
Uninitialized
=
"pserver not fully initialized"
)
)
const
(
checkpoint_path
=
"/checkpoints/"
)
// Supported element types
// Supported element types
const
(
const
(
Int32
ElementType
=
iota
Int32
ElementType
=
iota
...
@@ -53,6 +67,24 @@ type Service struct {
...
@@ -53,6 +67,24 @@ type Service struct {
optMap
map
[
string
]
*
optimizer
optMap
map
[
string
]
*
optimizer
}
}
type
Checkpoint
struct
{
uuid
string
md5sum
string
timestamp
string
}
//serialize ParameterWithConfig to byte stream
func
GetBytes
(
content
...
interface
{})
([]
byte
,
error
)
{
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
content
)
if
err
!=
nil
{
return
nil
,
err
}
return
buf
.
Bytes
(),
nil
}
// NewService creates a new service, will bypass etcd registration if no
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
...
@@ -143,13 +175,50 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
...
@@ -143,13 +175,50 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// Save tells the parameter server to save parameters.
// Save tells the parameter server to save parameters.
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
<-
s
.
initialized
<-
s
.
initialized
for
opt
,
ok
:=
range
s
.
optMap
{
s
.
mu
.
Lock
()
if
ok
!=
nil
{
defer
s
.
mu
.
Unlock
()
return
fmt
.
Errorf
(
"parameter optimizerMap error: "
,
ok
)
var
paramWithConfig
ParameterWithConfig
for
name
,
opt
:=
range
s
.
optMap
{
paramWithConfig
.
Param
.
Name
=
name
paramWithConfig
.
Param
.
ElementType
=
opt
.
elementType
paramWithConfig
.
Param
.
Content
=
opt
.
GetWeights
()
paramWithConfig
.
State
=
opt
.
GetStates
()
content
,
err
:=
GetBytes
(
paramWithConfig
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
ck
:=
Checkpoint
{}
h
:=
md5
.
New
()
ck
.
md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
ck
.
timestamp
=
time
.
Now
()
.
String
()
ck
.
uuid
=
checkpoint_path
+
strconv
.
Itoa
(
s
.
idx
)
ckbytes
,
err
:=
GetBytes
(
ck
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
// TODO: according design doc, need to save uuid to etcd in json format
// {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx}
log
.
Infof
(
"parameter checkpoint %s"
,
ckbytes
)
if
_
,
err
=
os
.
Stat
(
ck
.
uuid
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint not exists."
)
}
else
{
err
=
os
.
Remove
(
ck
.
uuid
)
log
.
Infof
(
"remove %s"
,
ck
.
uuid
)
}
f
,
err
:=
os
.
Create
(
ck
.
uuid
)
defer
f
.
Close
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
content
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
}
state
:=
opt
.
GetStates
()
weights
:=
opt
.
GetWeights
()
}
}
return
nil
return
nil
}
}
go/pserver/service_test.go
浏览文件 @
f1330e21
...
@@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) {
...
@@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) {
if
!
reflect
.
DeepEqual
(
param1
,
p
)
{
if
!
reflect
.
DeepEqual
(
param1
,
p
)
{
t
.
FailNow
()
t
.
FailNow
()
}
}
var
dummy
int
s
.
Save
(
""
,
&
dummy
)
}
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
func
TestMultipleInit
(
t
*
testing
.
T
)
{
...
@@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
wg
.
Wait
()
}
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO: test speed
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录