Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5a4f33df
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5a4f33df
编写于
7月 11, 2017
作者:
Y
yi.wu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into fix_newupdater
上级
26d95a6b
15f021a9
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
440 addition
and
91 deletion
+440
-91
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/generic.cmake
cmake/generic.cmake
+10
-10
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+6
-2
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+13
-0
go/pserver/optimizer.go
go/pserver/optimizer.go
+20
-6
go/pserver/optimizer_test.go
go/pserver/optimizer_test.go
+1
-1
go/pserver/service.go
go/pserver/service.go
+99
-11
go/pserver/service_test.go
go/pserver/service_test.go
+12
-14
paddle/optimizer/adadelta_optimizer.cc
paddle/optimizer/adadelta_optimizer.cc
+5
-3
paddle/optimizer/adagrad_optimizer.cc
paddle/optimizer/adagrad_optimizer.cc
+6
-3
paddle/optimizer/adam_optimizer.cc
paddle/optimizer/adam_optimizer.cc
+6
-3
paddle/optimizer/lr_policy.h
paddle/optimizer/lr_policy.h
+34
-14
paddle/optimizer/sgd_optimizer.cc
paddle/optimizer/sgd_optimizer.cc
+5
-1
paddle/platform/CMakeLists.txt
paddle/platform/CMakeLists.txt
+2
-0
paddle/platform/device_context.h
paddle/platform/device_context.h
+159
-0
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+33
-0
proto/OptimizerConfig.proto
proto/OptimizerConfig.proto
+12
-15
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+11
-3
python/paddle/trainer_config_helpers/networks.py
python/paddle/trainer_config_helpers/networks.py
+2
-2
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+2
-2
python/setup.py.in
python/setup.py.in
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
5a4f33df
...
@@ -16,6 +16,7 @@ cmake_minimum_required(VERSION 3.0)
...
@@ -16,6 +16,7 @@ cmake_minimum_required(VERSION 3.0)
set
(
CMAKE_MODULE_PATH
${
CMAKE_MODULE_PATH
}
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake"
)
set
(
CMAKE_MODULE_PATH
${
CMAKE_MODULE_PATH
}
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake"
)
set
(
PROJ_ROOT
${
CMAKE_CURRENT_SOURCE_DIR
}
)
set
(
PROJ_ROOT
${
CMAKE_CURRENT_SOURCE_DIR
}
)
set
(
PROJ_BINARY_ROOT
${
CMAKE_CURRENT_BINARY_DIR
}
)
include
(
system
)
include
(
system
)
...
...
cmake/generic.cmake
浏览文件 @
5a4f33df
...
@@ -88,7 +88,7 @@
...
@@ -88,7 +88,7 @@
#
#
# including binary directory for generated headers.
# including binary directory for generated headers.
include_directories
(
${
CMAKE_BINARY_DIR
}
)
include_directories
(
${
CMAKE_
CURRENT_
BINARY_DIR
}
)
if
(
NOT APPLE
)
if
(
NOT APPLE
)
find_package
(
Threads REQUIRED
)
find_package
(
Threads REQUIRED
)
...
@@ -106,7 +106,7 @@ function(merge_static_libs TARGET_NAME)
...
@@ -106,7 +106,7 @@ function(merge_static_libs TARGET_NAME)
if
(
APPLE
)
# Use OSX's libtool to merge archives
if
(
APPLE
)
# Use OSX's libtool to merge archives
# To produce a library we need at least one source file.
# To produce a library we need at least one source file.
# It is created by add_custom_command below and will helps
# It is created by add_custom_command below and will helps
# also help to track dependencies.
# also help to track dependencies.
set
(
dummyfile
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
_dummy.c
)
set
(
dummyfile
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
_dummy.c
)
...
@@ -144,24 +144,24 @@ function(merge_static_libs TARGET_NAME)
...
@@ -144,24 +144,24 @@ function(merge_static_libs TARGET_NAME)
DEPENDS
${
lib
}
${
objdir
}
DEPENDS
${
lib
}
${
objdir
}
WORKING_DIRECTORY
${
objdir
}
)
WORKING_DIRECTORY
${
objdir
}
)
# Empty dummy source file that goes into merged library
# Empty dummy source file that goes into merged library
set
(
mergebase
${
lib
}
.mergebase.c
)
set
(
mergebase
${
lib
}
.mergebase.c
)
add_custom_command
(
OUTPUT
${
mergebase
}
add_custom_command
(
OUTPUT
${
mergebase
}
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
mergebase
}
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
mergebase
}
DEPENDS
${
objlistfile
}
)
DEPENDS
${
objlistfile
}
)
list
(
APPEND mergebases
"
${
mergebase
}
"
)
list
(
APPEND mergebases
"
${
mergebase
}
"
)
endforeach
()
endforeach
()
add_library
(
${
TARGET_NAME
}
STATIC
${
mergebases
}
)
add_library
(
${
TARGET_NAME
}
STATIC
${
mergebases
}
)
target_link_libraries
(
${
TARGET_NAME
}
${
libs_deps
}
)
target_link_libraries
(
${
TARGET_NAME
}
${
libs_deps
}
)
# Get the file name of the generated library
# Get the file name of the generated library
set
(
outlibfile
"$<TARGET_FILE:
${
TARGET_NAME
}
>"
)
set
(
outlibfile
"$<TARGET_FILE:
${
TARGET_NAME
}
>"
)
foreach
(
lib
${
libs
}
)
foreach
(
lib
${
libs
}
)
add_custom_command
(
TARGET
${
TARGET_NAME
}
POST_BUILD
add_custom_command
(
TARGET
${
TARGET_NAME
}
POST_BUILD
COMMAND
${
CMAKE_AR
}
cr
${
outlibfile
}
*.o
COMMAND
${
CMAKE_AR
}
cr
${
outlibfile
}
*.o
COMMAND
${
CMAKE_RANLIB
}
${
outlibfile
}
COMMAND
${
CMAKE_RANLIB
}
${
outlibfile
}
WORKING_DIRECTORY
${
lib
}
.objdir
)
WORKING_DIRECTORY
${
lib
}
.objdir
)
endforeach
()
endforeach
()
...
@@ -362,4 +362,4 @@ function(py_proto_compile TARGET_NAME)
...
@@ -362,4 +362,4 @@ function(py_proto_compile TARGET_NAME)
set
(
py_srcs
)
set
(
py_srcs
)
protobuf_generate_python
(
py_srcs
${
py_proto_compile_SRCS
}
)
protobuf_generate_python
(
py_srcs
${
py_proto_compile_SRCS
}
)
add_custom_target
(
${
TARGET_NAME
}
ALL DEPENDS
${
py_srcs
}
)
add_custom_target
(
${
TARGET_NAME
}
ALL DEPENDS
${
py_srcs
}
)
endfunction
()
endfunction
()
\ No newline at end of file
go/cmd/pserver/pserver.go
浏览文件 @
5a4f33df
...
@@ -20,6 +20,8 @@ func main() {
...
@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd"
)
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
600
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
flag
.
Parse
()
...
@@ -31,18 +33,20 @@ func main() {
...
@@ -31,18 +33,20 @@ func main() {
log
.
SetLevel
(
level
)
log
.
SetLevel
(
level
)
var
idx
int
var
idx
int
var
cp
pserver
.
Checkpoint
var
e
*
pserver
.
EtcdClient
if
*
index
>=
0
{
if
*
index
>=
0
{
idx
=
*
index
idx
=
*
index
}
else
{
}
else
{
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
e
:
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
idx
,
err
=
e
.
Register
()
idx
,
err
=
e
.
Register
()
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
}
}
s
,
err
:=
pserver
.
NewService
(
idx
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
...
go/pserver/etcd_client.go
浏览文件 @
5a4f33df
...
@@ -18,6 +18,8 @@ const (
...
@@ -18,6 +18,8 @@ const (
PsDesired
=
"/ps_desired"
PsDesired
=
"/ps_desired"
// PsAddr is the base dir for pserver to store their addr
// PsAddr is the base dir for pserver to store their addr
PsPath
=
"/ps/"
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
)
)
// EtcdClient is the etcd client that the pserver uses for fault
// EtcdClient is the etcd client that the pserver uses for fault
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
return
idx
,
nil
}
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
))
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
cancel
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/optimizer.go
浏览文件 @
5a4f33df
...
@@ -35,22 +35,30 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
...
@@ -35,22 +35,30 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
}
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
)
*
optimizer
{
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
,
State
[]
byte
)
*
optimizer
{
o
:=
&
optimizer
{}
o
:=
&
optimizer
{}
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
p
:=
paramWithConfigs
.
Param
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
c
:=
paramWithConfigs
.
Config
s
:=
State
paramBufferSize
:=
C
.
size_t
(
len
(
p
.
Content
)
/
C
.
sizeof_float
)
log
.
WithFields
(
log
.
Fields
{
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
)
/
C
.
sizeof_float
,
"ParamSize"
:
paramBufferSize
,
"ConfigSize"
:
len
(
c
),
"ConfigSize"
:
len
(
c
),
"StateSize"
:
len
(
s
),
})
.
Info
(
"New Optimizer Created with config:"
)
})
.
Info
(
"New Optimizer Created with config:"
)
var
cbuffer
unsafe
.
Pointer
var
cbuffer
unsafe
.
Pointer
cbuffer
=
C
.
malloc
(
C
.
size_t
(
len
(
p
.
Content
)))
cbuffer
=
C
.
malloc
(
paramBufferSize
)
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
size_t
(
len
(
p
.
Content
)
/
C
.
sizeof_float
))
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
paramBufferSize
)
var
cstate
unsafe
.
Pointer
if
len
(
s
)
!=
0
{
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
paramBufferSize
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
(
*
C
.
char
)(
nullPtr
),
0
)
return
o
return
o
}
}
...
@@ -60,6 +68,12 @@ func (o *optimizer) GetWeights() []byte {
...
@@ -60,6 +68,12 @@ func (o *optimizer) GetWeights() []byte {
return
cArrayToSlice
(
buffer
,
int
(
bufferLen
)
*
C
.
sizeof_float
)
return
cArrayToSlice
(
buffer
,
int
(
bufferLen
)
*
C
.
sizeof_float
)
}
}
func
(
o
*
optimizer
)
GetStates
()
[]
byte
{
var
cbuffer
*
C
.
char
cbufferLen
:=
C
.
paddle_optimizer_get_state
(
o
.
opt
,
&
cbuffer
)
return
cArrayToSlice
(
unsafe
.
Pointer
(
cbuffer
),
int
(
cbufferLen
))
}
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
if
o
.
elementType
!=
g
.
ElementType
{
if
o
.
elementType
!=
g
.
ElementType
{
return
fmt
.
Errorf
(
"Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
g
.
Name
,
o
.
elementType
,
g
.
ElementType
)
return
fmt
.
Errorf
(
"Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
g
.
Name
,
o
.
elementType
,
g
.
ElementType
)
...
...
go/pserver/optimizer_test.go
浏览文件 @
5a4f33df
...
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
...
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param
:
p
,
Param
:
p
,
Config
:
config
,
Config
:
config
,
}
}
o
:=
newOptimizer
(
param
)
o
:=
newOptimizer
(
param
,
nil
)
o
.
Cleanup
()
o
.
Cleanup
()
}
}
go/pserver/service.go
浏览文件 @
5a4f33df
package
pserver
package
pserver
import
(
import
(
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"sync"
"time"
log
"github.com/sirupsen/logrus"
)
)
// ElementType is the type of elements of a Parameter.
// ElementType is the type of elements of a Parameter.
...
@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
...
@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
Config
[]
byte
// parameter configuration in Proto Buffer format
Config
[]
byte
// parameter configuration in Proto Buffer format
}
}
// ParameterCheckpoint is Parameter and State checkpoint
type
ParameterCheckpoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Md5sum
string
`json:"md5sum"`
Timestamp
string
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
ParameterCheckpoint
// Gradient is the gradient of the parameter.
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
type
Gradient
Parameter
// Service is the RPC service for pserver.
// Service is the RPC service for pserver.
type
Service
struct
{
type
Service
struct
{
initialized
chan
struct
{}
initialized
chan
struct
{}
idx
int
idx
int
checkpointInterval
time
.
Duration
mu
sync
.
Mutex
checkpointPath
string
optMap
map
[
string
]
*
optimizer
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
}
// NewService creates a new service, will bypass etcd registration if no
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
s
:=
&
Service
{
idx
:
idx
,
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
),
checkpointPath
:
path
,
client
:
client
,
}
}
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
initialized
=
make
(
chan
struct
{})
s
.
initialized
=
make
(
chan
struct
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
}
}
return
s
,
nil
return
s
,
nil
}
}
...
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
...
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// properly memory aligned, if not, make copy to a memory
// aligned region.
// aligned region.
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
)
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
,
nil
)
return
nil
return
nil
}
}
...
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
...
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
return
nil
}
}
//
Save tells the parameter server to save parameters.
//
pserver save checkpoint
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
doCheckpoint
(
)
error
{
<-
s
.
initialized
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cp
:=
make
([]
ParameterCheckpoint
,
0
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
ParameterCheckpoint
pc
.
ParamConfig
.
Param
.
Name
=
name
pc
.
ParamConfig
.
Param
.
ElementType
=
opt
.
elementType
pc
.
ParamConfig
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
cp
)
if
err
!=
nil
{
return
err
}
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
h
:=
md5
.
New
()
cpMeta
.
Md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
// TODO
cpMetajson
,
_
:=
json
.
Marshal
(
cpMeta
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
if
err
!=
nil
{
return
err
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
return
err
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
buf
.
Bytes
())
writer
.
Flush
()
if
err
!=
nil
{
return
err
}
return
nil
return
nil
}
}
go/pserver/service_test.go
浏览文件 @
5a4f33df
...
@@ -15,7 +15,8 @@ const (
...
@@ -15,7 +15,8 @@ const (
)
)
func
TestServiceFull
(
t
*
testing
.
T
)
{
func
TestServiceFull
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
...
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
...
@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
}
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
FailNow
()
t
.
FailNow
()
...
@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
...
@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
}
}
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch
<-
struct
{}{}
ch
<-
struct
{}{}
}()
}()
wg
.
Add
(
1
)
go
func
()
{
err
:=
s
.
Save
(
""
,
nil
)
if
err
!=
nil
{
errCh
<-
err
}
wg
.
Done
()
ch
<-
struct
{}{}
}()
time
.
Sleep
(
50
*
time
.
Millisecond
)
time
.
Sleep
(
50
*
time
.
Millisecond
)
select
{
select
{
...
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
wg
.
Wait
()
}
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
paddle/optimizer/adadelta_optimizer.cc
浏览文件 @
5a4f33df
...
@@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
...
@@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const
char
*
AdadeltaOptimizer
::
SerializeState
(
int
*
state_len
)
{
const
char
*
AdadeltaOptimizer
::
SerializeState
(
int
*
state_len
)
{
AdadeltaOptimizerState
state
;
AdadeltaOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
state
.
set_num_sample_passed
(
num_sample_passed_
);
state
.
set_num_sample_passed
(
num_sample_passed_
);
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
accum_gradient_
,
state
.
mutable_accum_gradient
());
TensorToProto
(
*
accum_gradient_
,
state
.
mutable_accum_gradient
());
TensorToProto
(
*
accum_delta_
,
state
.
mutable_accum_delta
());
TensorToProto
(
*
accum_delta_
,
state
.
mutable_accum_delta
());
TensorToProto
(
*
update_delta_
,
state
.
mutable_update_delta
());
TensorToProto
(
*
update_delta_
,
state
.
mutable_update_delta
());
auto
str
=
state
.
SerializeAsString
();
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
return
str
.
c_str
();
}
}
void
AdadeltaOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
void
AdadeltaOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
AdadeltaOptimizerState
state
;
AdadeltaOptimizerState
state
;
state
.
ParseFromString
(
str
);
state
.
ParseFromString
(
str
);
// TODO(zhihong) : add lr_policy DeserializeState
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
...
...
paddle/optimizer/adagrad_optimizer.cc
浏览文件 @
5a4f33df
...
@@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
...
@@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
}
}
const
char
*
AdagradOptimizer
::
SerializeState
(
int
*
state_len
)
{
const
char
*
AdagradOptimizer
::
SerializeState
(
int
*
state_len
)
{
AdagradOptimizerState
state
;
AdagradOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
state
.
set_num_sample_passed
(
num_sample_passed_
);
state
.
set_num_sample_passed
(
num_sample_passed_
);
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
accum_gradient_
,
state
.
mutable_accum_gradient
());
TensorToProto
(
*
accum_gradient_
,
state
.
mutable_accum_gradient
());
auto
str
=
state
.
SerializeAsString
();
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
return
str
.
c_str
();
}
}
void
AdagradOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
void
AdagradOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
AdagradOptimizerState
state
;
AdagradOptimizerState
state
;
state
.
ParseFromString
(
str
);
state
.
ParseFromString
(
str
);
// TODO(zhihong) : add lr_policy DeserializeState
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
accum_gradient
(),
accum_gradient_
);
ProtoToTensor
(
state
.
accum_gradient
(),
accum_gradient_
);
...
...
paddle/optimizer/adam_optimizer.cc
浏览文件 @
5a4f33df
...
@@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
...
@@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const
char
*
AdamOptimizer
::
SerializeState
(
int
*
state_len
)
{
const
char
*
AdamOptimizer
::
SerializeState
(
int
*
state_len
)
{
AdamOptimizerState
state
;
AdamOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
state
.
set_num_sample_passed
(
num_sample_passed_
);
state
.
set_num_sample_passed
(
num_sample_passed_
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
TensorToProto
(
*
velocitys_
,
state
.
mutable_velocitys
());
TensorToProto
(
*
velocitys_
,
state
.
mutable_velocitys
());
auto
str
=
state
.
SerializeAsString
();
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
return
str
.
c_str
();
}
}
void
AdamOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
void
AdamOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
AdamOptimizerState
state
;
AdamOptimizerState
state
;
state
.
ParseFromString
(
str
);
state
.
ParseFromString
(
str
);
// TODO(zhihong) : add lr_policy DeserializeState
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
...
...
paddle/optimizer/lr_policy.h
浏览文件 @
5a4f33df
...
@@ -17,36 +17,56 @@ public:
...
@@ -17,36 +17,56 @@ public:
// constant learning rate policy
// constant learning rate policy
class
ConstLr
final
:
public
LrPolicy
{
class
ConstLr
final
:
public
LrPolicy
{
public:
public:
ConstLr
(
double
lr
)
:
learning_rate
(
lr
){};
ConstLr
(
double
lr
)
:
learning_rate
_
(
lr
){};
double
LearningRate
(
const
uint64_t
num_sample_passed
)
{
double
LearningRate
(
const
uint64_t
num_sample_passed
)
{
return
learning_rate
;
return
learning_rate_
;
}
const
char
*
SerializeState
(
int
*
state_len
)
{
LrPolicyState
state
;
state
.
set_learning_rate
(
learning_rate_
);
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
return
str
.
c_str
();
}
void
DeserializeState
(
const
std
::
string
&
str
)
{
LrPolicyState
state
;
state
.
ParseFromString
(
str
);
learning_rate_
=
state
.
learning_rate
();
}
}
const
char
*
SerializeState
(
int
*
state_len
)
{
return
nullptr
;
}
void
DeserializeState
(
const
std
::
string
&
state
)
{}
private:
private:
double
learning_rate
;
double
learning_rate
_
;
};
};
class
LinearLr
final
:
public
LrPolicy
{
class
LinearLr
final
:
public
LrPolicy
{
public:
public:
LinearLr
(
double
lr
,
double
lr_decay_a
,
double
lr_decay_b
)
LinearLr
(
double
lr
,
double
lr_decay_a
,
double
lr_decay_b
)
:
learning_rate
(
lr
),
lr_decay_a
(
lr_decay_a
),
lr_decay_b
(
lr_decay_b
)
{}
:
learning_rate
_
(
lr
),
lr_decay_a_
(
lr_decay_a
),
lr_decay_b_
(
lr_decay_b
)
{}
double
LearningRate
(
const
uint64_t
num_sample_passed
)
{
double
LearningRate
(
const
uint64_t
num_sample_passed
)
{
return
std
::
max
(
learning_rate
-
lr_decay_a
*
num_sample_passed
,
lr_decay_b
);
return
std
::
max
(
learning_rate_
-
lr_decay_a_
*
num_sample_passed
,
lr_decay_b_
);
}
}
const
char
*
SerializeState
(
int
*
state_len
)
{
const
char
*
SerializeState
(
int
*
state_len
)
{
// TODO(zhihong) : add lr_policy serialization
LrPolicyState
state
;
return
nullptr
;
state
.
set_learning_rate
(
learning_rate_
);
state
.
set_lr_decay_a
(
lr_decay_a_
);
state
.
set_lr_decay_b
(
lr_decay_b_
);
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
return
str
.
c_str
();
}
}
void
DeserializeState
(
const
std
::
string
&
state
)
{
void
DeserializeState
(
const
std
::
string
&
str
)
{
// TODO(zhihong) : add lr_policy serialization
LrPolicyState
state
;
state
.
ParseFromString
(
str
);
learning_rate_
=
state
.
learning_rate
();
lr_decay_a_
=
state
.
lr_decay_a
();
lr_decay_b_
=
state
.
lr_decay_b
();
}
}
private:
private:
double
learning_rate
;
double
learning_rate
_
;
double
lr_decay_a
;
double
lr_decay_a
_
;
double
lr_decay_b
;
double
lr_decay_b
_
;
};
};
}
// namespace optimizer
}
// namespace optimizer
...
...
paddle/optimizer/sgd_optimizer.cc
浏览文件 @
5a4f33df
...
@@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
...
@@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const
char
*
SGDOptimizer
::
SerializeState
(
int
*
state_len
)
{
const
char
*
SGDOptimizer
::
SerializeState
(
int
*
state_len
)
{
SGDOptimizerState
state
;
SGDOptimizerState
state
;
state
.
set_num_sample_passed
(
num_sample_passed_
);
state
.
set_num_sample_passed
(
num_sample_passed_
);
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
if
(
momentum_
!=
0.0
)
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
if
(
momentum_
!=
0.0
)
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
auto
str
=
state
.
SerializeAsString
();
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
return
str
.
c_str
();
}
}
void
SGDOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
void
SGDOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
SGDOptimizerState
state
;
SGDOptimizerState
state
;
state
.
ParseFromString
(
str
);
state
.
ParseFromString
(
str
);
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
if
(
momentum_
!=
0.0
)
ProtoToTensor
(
state
.
parameter
(),
momentums_
);
if
(
momentum_
!=
0.0
)
ProtoToTensor
(
state
.
parameter
(),
momentums_
);
...
...
paddle/platform/CMakeLists.txt
浏览文件 @
5a4f33df
...
@@ -4,3 +4,5 @@ nv_test(cuda_test SRCS cuda_test.cu)
...
@@ -4,3 +4,5 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library
(
place SRCS place.cc
)
cc_library
(
place SRCS place.cc
)
cc_test
(
place_test SRCS place_test.cc DEPS place glog gflags
)
cc_test
(
place_test SRCS place_test.cc DEPS place glog gflags
)
nv_test
(
device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags
)
paddle/platform/device_context.h
0 → 100644
浏览文件 @
5a4f33df
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/enforce.h"
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
namespace
platform
{
class
DeviceContext
{
public:
virtual
~
DeviceContext
()
{}
};
class
CPUDeviceContext
:
public
DeviceContext
{};
#ifndef PADDLE_ONLY_CPU
class
GPUPlaceGuard
{
public:
explicit
GPUPlaceGuard
(
GPUPlace
new_place
)
:
previous_
(
GetCurrentDeviceId
())
{
if
(
previous_
!=
new_place
)
{
paddle
::
platform
::
SetDeviceId
(
new_place
.
device
);
}
}
~
GPUPlaceGuard
()
{
paddle
::
platform
::
SetDeviceId
(
previous_
.
device
);
}
private:
GPUPlace
previous_
;
};
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
paddle
::
platform
::
throw_on_error
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
eigen_stream_
=
new
Eigen
::
CudaStreamDevice
(
&
stream_
);
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
}
void
Wait
()
{
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
{
return
stream_
;
}
Eigen
::
GpuDevice
eigen_device
()
{
return
*
eigen_device_
;
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasCreate
(
&
blas_handle_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasSetStream failed"
);
}
return
blas_handle_
;
}
cudnnHandle_t
cudnn_handle
()
{
if
(
!
dnn_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnCreate
(
&
dnn_handle_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnSetStream failed"
);
}
return
dnn_handle_
;
}
curandGenerator_t
curand_generator
()
{
if
(
!
rand_generator_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
)
==
CURAND_STATUS_SUCCESS
,
"curandCreateGenerator failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
rand_generator_
,
random_seed_
)
==
CURAND_STATUS_SUCCESS
,
"curandSetPseudoRandomGeneratorSeed failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetStream
(
rand_generator_
,
stream_
)
==
CURAND_STATUS_SUCCESS
,
"curandSetStream failed"
);
}
return
rand_generator_
;
}
~
CUDADeviceContext
()
{
Wait
();
if
(
blas_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasDestroy
(
blas_handle_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasDestroy failed"
);
}
if
(
dnn_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnDestroy
(
dnn_handle_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnDestroy failed"
);
}
if
(
rand_generator_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandDestroyGenerator
(
rand_generator_
)
==
CURAND_STATUS_SUCCESS
,
"curandDestroyGenerator failed"
);
}
delete
eigen_stream_
;
delete
eigen_device_
;
paddle
::
platform
::
throw_on_error
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
private:
GPUPlace
gpu_place_
;
cudaStream_t
stream_
;
Eigen
::
CudaStreamDevice
*
eigen_stream_
;
Eigen
::
GpuDevice
*
eigen_device_
;
cublasHandle_t
blas_handle_
{
nullptr
};
cudnnHandle_t
dnn_handle_
{
nullptr
};
int
random_seed_
;
curandGenerator_t
rand_generator_
{
nullptr
};
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/platform/device_context_test.cc
0 → 100644
浏览文件 @
5a4f33df
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"
TEST
(
CUDADeviceContext
,
Init
)
{
int
count
=
paddle
::
platform
::
GetDeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
paddle
::
platform
::
CUDADeviceContext
*
device_context
=
new
paddle
::
platform
::
CUDADeviceContext
(
i
);
Eigen
::
GpuDevice
gpu_device
=
device_context
->
eigen_device
();
ASSERT_NE
(
nullptr
,
gpu_device
.
stream
());
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
delete
device_context
;
}
}
proto/OptimizerConfig.proto
浏览文件 @
5a4f33df
...
@@ -78,11 +78,15 @@ enum DataType {
...
@@ -78,11 +78,15 @@ enum DataType {
repeated
bytes
content
=
2
;
repeated
bytes
content
=
2
;
}
}
message
LrPolicyState
{
// learninRate Policy
optional
double
learning_rate
=
1
[
default
=
1.0
];
optional
double
lr_decay_a
=
2
;
optional
double
lr_decay_b
=
3
;
}
message
SGDOptimizerState
{
message
SGDOptimizerState
{
// learning rate policy
optional
LrPolicyState
lr_state
=
101
;
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
double
num_sample_passed
=
104
;
optional
double
num_sample_passed
=
104
;
// state
// state
optional
TensorProto
parameter
=
1
;
optional
TensorProto
parameter
=
1
;
...
@@ -91,9 +95,7 @@ message SGDOptimizerState {
...
@@ -91,9 +95,7 @@ message SGDOptimizerState {
message
AdadeltaOptimizerState
{
message
AdadeltaOptimizerState
{
// learning rate policy
// learning rate policy
optional
double
learning_rate
=
101
;
optional
LrPolicyState
lr_state
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
double
num_sample_passed
=
104
;
optional
double
num_sample_passed
=
104
;
// state
// state
optional
TensorProto
parameter
=
1
;
optional
TensorProto
parameter
=
1
;
...
@@ -102,11 +104,9 @@ message AdadeltaOptimizerState {
...
@@ -102,11 +104,9 @@ message AdadeltaOptimizerState {
optional
TensorProto
update_delta
=
4
;
optional
TensorProto
update_delta
=
4
;
}
}
message
AdagradOptimizerState
{
message
AdagradOptimizerState
{
// learning rate policy
optional
LrPolicyState
lr_state
=
101
;
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
double
num_sample_passed
=
104
;
optional
double
num_sample_passed
=
104
;
// state
// state
optional
TensorProto
parameter
=
1
;
optional
TensorProto
parameter
=
1
;
...
@@ -114,10 +114,7 @@ message AdagradOptimizerState {
...
@@ -114,10 +114,7 @@ message AdagradOptimizerState {
}
}
message
AdamOptimizerState
{
message
AdamOptimizerState
{
// learning rate policy
optional
LrPolicyState
lr_state
=
101
;
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
double
num_sample_passed
=
104
;
optional
double
num_sample_passed
=
104
;
// state
// state
optional
TensorProto
parameter
=
1
;
optional
TensorProto
parameter
=
1
;
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
5a4f33df
...
@@ -1253,9 +1253,9 @@ def pooling_layer(input,
...
@@ -1253,9 +1253,9 @@ def pooling_layer(input,
If stride > 0, this layer slides a window whose size is determined by stride,
If stride > 0, this layer slides a window whose size is determined by stride,
and return the pooling value of the window as the output. Thus, a long sequence
and return the pooling value of the window as the output. Thus, a long sequence
will be shorten.
will be shorten.
The parameter stride specifies the intervals at which to apply the pooling
The parameter stride specifies the intervals at which to apply the pooling
operation. Note that for sequence with sub-sequence, the default value
operation. Note that for sequence with sub-sequence, the default value
of stride is -1.
of stride is -1.
...
@@ -4805,6 +4805,14 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
...
@@ -4805,6 +4805,14 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
So groups should be larger than 1, and the num of channels should be able
So groups should be larger than 1, and the num of channels should be able
to devided by groups.
to devided by groups.
.. math::
y_{si+j} = \max_k x_{gsi + sk + j}
g = groups
s = input.size / num_channels
0 \le i < num_channels / groups
0 \le j < s
0 \le k < groups
Please refer to Paper:
Please refer to Paper:
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Multi-digit Number Recognition from Street View
\
- Multi-digit Number Recognition from Street View
\
...
...
python/paddle/trainer_config_helpers/networks.py
浏览文件 @
5a4f33df
...
@@ -1395,7 +1395,7 @@ def inputs(layers, *args):
...
@@ -1395,7 +1395,7 @@ def inputs(layers, *args):
if
len
(
args
)
!=
0
:
if
len
(
args
)
!=
0
:
layers
.
extend
(
args
)
layers
.
extend
(
args
)
Inputs
(
*
[
l
.
name
for
l
in
layers
])
Inputs
(
*
[
l
.
name
for
l
in
layers
])
def
outputs
(
layers
,
*
args
):
def
outputs
(
layers
,
*
args
):
...
@@ -1438,7 +1438,7 @@ def outputs(layers, *args):
...
@@ -1438,7 +1438,7 @@ def outputs(layers, *args):
assert
len
(
layers
)
>
0
assert
len
(
layers
)
>
0
if
HasInputsSet
():
# input already set
if
HasInputsSet
():
# input already set
Outputs
(
*
[
l
.
name
for
l
in
layers
])
Outputs
(
*
[
l
.
name
for
l
in
layers
])
return
# just return outputs.
return
# just return outputs.
if
len
(
layers
)
!=
1
:
if
len
(
layers
)
!=
1
:
...
...
python/paddle/v2/dataset/wmt14.py
浏览文件 @
5a4f33df
...
@@ -32,9 +32,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
...
@@ -32,9 +32,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN
=
'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
URL_TRAIN
=
'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
#
this is the pretrained model, whose bleu =
26.92
#
BLEU of this trained model is
26.92
URL_MODEL
=
'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
URL_MODEL
=
'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL
=
'
4ce14a26607fb8a1cc23bcdedb1895e4
'
MD5_MODEL
=
'
0cb4a5366189b6acba876491c8724fa3
'
START
=
"<s>"
START
=
"<s>"
END
=
"<e>"
END
=
"<e>"
...
...
python/setup.py.in
浏览文件 @
5a4f33df
...
@@ -34,6 +34,6 @@ setup(name='paddle',
...
@@ -34,6 +34,6 @@ setup(name='paddle',
'': '${CMAKE_CURRENT_SOURCE_DIR}',
'': '${CMAKE_CURRENT_SOURCE_DIR}',
# The paddle.v2.framework.proto will be generated while compiling.
# The paddle.v2.framework.proto will be generated while compiling.
# So that package points to other directory.
# So that package points to other directory.
'paddle.v2.framework.proto': '${
CMAKE_BINARY_DIR
}/paddle/framework'
'paddle.v2.framework.proto': '${
PROJ_BINARY_ROOT
}/paddle/framework'
},
},
)
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录