Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0467cd2d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0467cd2d
编写于
7月 25, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into feature/middle_level_net_api
上级
0ceeacbe
91689b6b
变更
33
隐藏空白更改
内联
并排
Showing
33 changed file
with
727 addition
and
164 deletion
+727
-164
doc/design/cluster_train/save_model.md
doc/design/cluster_train/save_model.md
+5
-4
doc/design/simple_op_design.md
doc/design/simple_op_design.md
+1
-0
doc/faq/index_cn.rst
doc/faq/index_cn.rst
+10
-0
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+5
-1
go/master/c/client.go
go/master/c/client.go
+45
-19
go/master/client.go
go/master/client.go
+85
-9
go/master/client_test.go
go/master/client_test.go
+5
-3
go/master/etcd_client.go
go/master/etcd_client.go
+4
-4
go/master/service.go
go/master/service.go
+47
-7
go/pserver/client/c/cclient.go
go/pserver/client/c/cclient.go
+7
-14
go/pserver/client/c/test/test_cclient.c
go/pserver/client/c/test/test_cclient.c
+0
-4
go/pserver/client/client.go
go/pserver/client/client.go
+0
-26
go/pserver/service.go
go/pserver/service.go
+11
-3
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+6
-3
paddle/framework/grad_op_creator.cc
paddle/framework/grad_op_creator.cc
+115
-0
paddle/framework/grad_op_creator.h
paddle/framework/grad_op_creator.h
+48
-0
paddle/framework/grad_op_creator_test.cc
paddle/framework/grad_op_creator_test.cc
+26
-0
paddle/framework/net.cc
paddle/framework/net.cc
+11
-1
paddle/framework/net.h
paddle/framework/net.h
+2
-0
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+39
-8
paddle/framework/net_test.cc
paddle/framework/net_test.cc
+0
-24
paddle/framework/op_proto.proto
paddle/framework/op_proto.proto
+6
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+77
-21
paddle/framework/operator.h
paddle/framework/operator.h
+5
-0
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+13
-0
paddle/operators/add_op_test.cc
paddle/operators/add_op_test.cc
+6
-1
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+13
-0
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+13
-0
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+12
-0
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+2
-0
python/paddle/v2/master/client.py
python/paddle/v2/master/client.py
+29
-6
python/paddle/v2/model.py
python/paddle/v2/model.py
+73
-0
python/paddle/v2/reader/creator.py
python/paddle/v2/reader/creator.py
+6
-6
未找到文件。
doc/design/cluster_train/save_model.md
浏览文件 @
0467cd2d
...
@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
...
@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election
### Trainer Election
One trainer will be elected as the one to save the model. When using
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
etcd, trainer ID is a randomly generated UUID, the trainer will
elect one trainer. When not using etcd, unique trainer IDs will be
contact the master server requesting to save the model, and find out
given by the administrator, the trainer whose ID is "0" is elected to
if itself is elected. When the master server is not used, unique
save the model.
trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path
### Model Save Path
...
...
doc/design/simple_op_design.md
浏览文件 @
0467cd2d
...
@@ -49,6 +49,7 @@ message AttrProto {
...
@@ -49,6 +49,7 @@ message AttrProto {
message
VarProto
{
message
VarProto
{
required
string
name
=
1
;
required
string
name
=
1
;
required
string
comment
=
2
;
required
string
comment
=
2
;
required
bool
is_tensor
=
3
;
};
};
message
OpProto
{
message
OpProto
{
...
...
doc/faq/index_cn.rst
浏览文件 @
0467cd2d
...
@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
...
@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。
主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
go/cmd/pserver/pserver.go
浏览文件 @
0467cd2d
...
@@ -59,7 +59,11 @@ func main() {
...
@@ -59,7 +59,11 @@ func main() {
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Infof
(
"Could not find the pserver checkpoint."
)
}
else
{
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
}
}
}
}
}
...
...
go/master/c/client.go
浏览文件 @
0467cd2d
...
@@ -22,6 +22,9 @@ package main
...
@@ -22,6 +22,9 @@ package main
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
typedef int paddle_master_client;
*/
*/
import
"C"
import
"C"
...
@@ -33,7 +36,6 @@ import (
...
@@ -33,7 +36,6 @@ import (
"unsafe"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log
"github.com/sirupsen/logrus"
log
"github.com/sirupsen/logrus"
)
)
...
@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
...
@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
}
}
//export paddle_new_etcd_master_client
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func
paddle_new_etcd_master_client
(
etcdEndpoints
*
C
.
char
,
timeout
int
,
bufSize
int
)
C
.
paddle_master_client
{
func
paddle_new_etcd_master_client
(
etcdEndpoints
*
C
.
char
,
timeout
int
,
bufSize
int
)
C
.
paddle_master_client
{
p
:=
C
.
GoString
(
etcdEndpoints
)
p
:=
C
.
GoString
(
etcdEndpoints
)
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
endpoints
:=
strings
.
Split
(
p
,
","
)
Endpoints
:
strings
.
Split
(
p
,
","
),
c
,
err
:=
master
.
NewClient
(
DialTimeout
:
time
.
Second
*
time
.
Duration
(
timeout
),
master
.
WithEtcd
(
endpoints
,
time
.
Duration
(
timeout
)
*
time
.
Second
),
})
master
.
WithBuffer
(
bufSize
),
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
master
.
GetKey
(
cli
,
master
.
DefaultAddrPath
,
timeout
)
if
err
!=
nil
{
panic
(
err
)
}
ch
<-
a
go
master
.
WatchKey
(
cli
,
master
.
DefaultAddrPath
,
ch
)
c
:=
master
.
NewClient
(
ch
,
bufSize
)
return
add
(
c
)
return
add
(
c
)
}
}
//export paddle_new_master_client
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func
paddle_new_master_client
(
addr
*
C
.
char
,
bufSize
int
)
C
.
paddle_master_client
{
func
paddle_new_master_client
(
addr
*
C
.
char
,
bufSize
int
)
C
.
paddle_master_client
{
a
:=
C
.
GoString
(
addr
)
a
:=
C
.
GoString
(
addr
)
ch
:=
make
(
chan
string
,
1
)
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
a
),
master
.
WithBuffer
(
bufSize
))
ch
<-
a
if
err
!=
nil
{
c
:=
master
.
NewClient
(
ch
,
bufSize
)
panic
(
err
)
}
return
add
(
c
)
return
add
(
c
)
}
}
...
@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
...
@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return
C
.
PADDLE_MASTER_OK
return
C
.
PADDLE_MASTER_OK
}
}
// return value:
// paddle_next_record gets the nexts training record.
// 0:ok
//
// -1:error
// returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record
//export paddle_next_record
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
c
:=
get
(
client
)
c
:=
get
(
client
)
...
@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
...
@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return
C
.
int
(
size
)
return
C
.
int
(
size
)
}
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func
paddle_request_save_model
(
client
C
.
paddle_master_client
,
trainerID
string
,
blockMS
int
)
C
.
int
{
c
:=
get
(
client
)
need
,
err
:=
c
.
RequestSaveModel
(
trainerID
,
time
.
Duration
(
blockMS
)
*
time
.
Millisecond
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
C
.
PADDLE_MASTER_ERROR
}
if
need
{
return
C
.
PADDLE_SAVE_MODEL_OK
}
return
C
.
PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
//export mem_free
func
mem_free
(
p
unsafe
.
Pointer
)
{
func
mem_free
(
p
unsafe
.
Pointer
)
{
// "free" may be a better name for this function, but doing so
// "free" may be a better name for this function, but doing so
...
...
go/master/client.go
浏览文件 @
0467cd2d
...
@@ -16,17 +16,20 @@ package master
...
@@ -16,17 +16,20 @@ package master
import
(
import
(
"os"
"os"
"sync"
"time"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log
"github.com/sirupsen/logrus"
log
"github.com/sirupsen/logrus"
)
)
// Client is the client of the master server.
// Client is the client of the master server.
type
Client
struct
{
type
Client
struct
{
conn
*
connection
.
Conn
conn
*
connection
.
Conn
ch
chan
record
ch
chan
record
initChOnce
sync
.
Once
}
}
type
record
struct
{
type
record
struct
{
...
@@ -34,24 +37,83 @@ type record struct {
...
@@ -34,24 +37,83 @@ type record struct {
err
error
err
error
}
}
//
NewClient creates a new Client
.
//
WithBuffer sets the client to buffer the training record
.
//
//
// bufSize is the record buffer size. NextRecord will read from this
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
// buffer.
func
NewClient
(
addrCh
<-
chan
string
,
bufSize
int
)
*
Client
{
func
WithBuffer
(
bufSize
int
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
if
bufSize
<=
0
{
return
nil
}
c
.
initChOnce
.
Do
(
func
()
{
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
getRecords
()
})
return
nil
}
}
// WithAddr sets the client to use fixed master address.
func
WithAddr
(
addr
string
)
func
(
c
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
ch
:=
make
(
chan
string
,
1
)
ch
<-
addr
go
c
.
monitorMaster
(
ch
)
return
nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func
WithEtcd
(
endpoints
[]
string
,
timeout
time
.
Duration
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
DialTimeout
:
timeout
,
})
if
err
!=
nil
{
return
err
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
GetKey
(
cli
,
DefaultAddrPath
,
timeout
)
if
err
!=
nil
{
return
err
}
if
a
!=
""
{
// Master is registered, send to the master address
// channel.
ch
<-
a
}
go
watchKey
(
cli
,
DefaultAddrPath
,
ch
)
go
c
.
monitorMaster
(
ch
)
return
nil
}
}
// NewClient creates a new Client.
func
NewClient
(
opts
...
func
(
*
Client
)
error
)
(
*
Client
,
error
)
{
c
:=
&
Client
{}
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
c
.
conn
=
connection
.
New
()
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
monitorMaster
(
addrCh
)
for
_
,
opt
:=
range
opts
{
go
c
.
getRecords
()
err
:=
opt
(
c
)
return
c
if
err
!=
nil
{
return
nil
,
err
}
}
return
c
,
nil
}
}
func
(
c
*
Client
)
getRecords
()
{
func
(
c
*
Client
)
getRecords
()
{
for
{
for
{
t
,
err
:=
c
.
getTask
()
t
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
if
err
!=
nil
{
// getTask call.
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
time
.
Sleep
(
3
*
time
.
Second
)
time
.
Sleep
(
3
*
time
.
Second
)
continue
continue
...
@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
...
@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// NextRecord will block until the next record is available. It is
// thread-safe.
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
c
.
initChOnce
.
Do
(
func
()
{
// initialize with in case WithBuffer is not used.
c
.
ch
=
make
(
chan
record
,
0
)
go
c
.
getRecords
()
})
r
:=
<-
c
.
ch
r
:=
<-
c
.
ch
return
r
.
r
,
r
.
err
return
r
.
r
,
r
.
err
}
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func
(
c
*
Client
)
RequestSaveModel
(
trainerID
string
,
blockDur
time
.
Duration
)
(
bool
,
error
)
{
var
need
bool
err
:=
c
.
conn
.
Call
(
"Service.RequestSaveModel"
,
SaveModelRequest
{
TrainerID
:
trainerID
,
BlockDur
:
blockDur
},
&
need
)
return
need
,
err
}
go/master/client_test.go
浏览文件 @
0467cd2d
...
@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
...
@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
panic
(
err
)
}
}
curAddr
:=
make
(
chan
string
,
1
)
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
10
))
curAddr
<-
fmt
.
Sprintf
(
":%d"
,
p
)
if
err
!=
nil
{
c
:=
master
.
NewClient
(
curAddr
,
10
)
panic
(
err
)
}
err
=
c
.
SetDataset
([]
string
{
path
})
err
=
c
.
SetDataset
([]
string
{
path
})
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
...
...
go/master/etcd_client.go
浏览文件 @
0467cd2d
...
@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
...
@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
}
}
// GetKey gets the value by the specify key.
// GetKey gets the value by the specify key.
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
int
)
(
string
,
error
)
{
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
time
.
Duration
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
)
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
out
)
resp
,
err
:=
c
.
Get
(
ctx
,
key
)
resp
,
err
:=
c
.
Get
(
ctx
,
key
)
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
...
@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return
string
(
v
),
nil
return
string
(
v
),
nil
}
}
//
W
atchKey watches the specify key and send to valChan if there is some event.
//
w
atchKey watches the specify key and send to valChan if there is some event.
func
W
atchKey
(
c
*
clientv3
.
Client
,
key
string
,
valChan
chan
<-
string
)
{
func
w
atchKey
(
c
*
clientv3
.
Client
,
key
string
,
valChan
chan
<-
string
)
{
rch
:=
c
.
Watch
(
context
.
Background
(),
key
)
rch
:=
c
.
Watch
(
context
.
Background
(),
key
)
for
wresp
:=
range
rch
{
for
wresp
:=
range
rch
{
for
_
,
ev
:=
range
wresp
.
Events
{
for
_
,
ev
:=
range
wresp
.
Events
{
...
...
go/master/service.go
浏览文件 @
0467cd2d
...
@@ -78,9 +78,10 @@ type Service struct {
...
@@ -78,9 +78,10 @@ type Service struct {
ready
chan
struct
{}
ready
chan
struct
{}
store
Store
store
Store
mu
sync
.
Mutex
mu
sync
.
Mutex
initDone
bool
initDone
bool
taskQueues
taskQueues
taskQueues
taskQueues
savingTrainer
string
}
}
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
...
@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
...
@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
//
//
// SetDataset can be call multiple times. But only the first call will
// SetDataset can be call multiple times. But only the first call will
// be honored.
// be honored.
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
_
*
int
)
error
{
if
len
(
globPaths
)
==
0
{
if
len
(
globPaths
)
==
0
{
return
errors
.
New
(
"no dataset specified"
)
return
errors
.
New
(
"no dataset specified"
)
}
}
...
@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
...
@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
}
}
// GetTask gets a new task from the service.
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
func
(
s
*
Service
)
GetTask
(
_
int
,
task
*
Task
)
error
{
select
{
select
{
case
<-
s
.
ready
:
case
<-
s
.
ready
:
}
}
...
@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
...
@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
}
// TaskFinished tell the service that a task is finished.
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
_
*
int
)
error
{
select
{
select
{
case
<-
s
.
ready
:
case
<-
s
.
ready
:
}
}
...
@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
...
@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
}
// TaskFailed tells the service that a task is failed.
// TaskFailed tells the service that a task is failed.
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
_
*
int
)
error
{
select
{
select
{
case
<-
s
.
ready
:
case
<-
s
.
ready
:
}
}
...
@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
...
@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s
.
processFailedTask
(
t
,
meta
.
Epoch
)
s
.
processFailedTask
(
t
,
meta
.
Epoch
)
return
nil
return
nil
}
}
// SaveModelRequest is the request for saving model
type
SaveModelRequest
struct
{
TrainerID
string
BlockDur
time
.
Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func
(
s
*
Service
)
RequestSaveModel
(
req
SaveModelRequest
,
need
*
bool
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
req
.
TrainerID
==
""
{
return
errors
.
New
(
"trainer id is empty"
)
}
if
s
.
savingTrainer
==
""
{
*
need
=
true
}
else
{
if
req
.
TrainerID
==
s
.
savingTrainer
{
// save trainer asked to save model again
*
need
=
true
}
else
{
*
need
=
false
}
}
if
*
need
{
s
.
savingTrainer
=
req
.
TrainerID
time
.
AfterFunc
(
req
.
BlockDur
,
func
()
{
s
.
mu
.
Lock
()
s
.
savingTrainer
=
""
s
.
mu
.
Unlock
()
})
}
return
nil
}
go/pserver/client/c/cclient.go
浏览文件 @
0467cd2d
...
@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
...
@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove
(
client
)
remove
(
client
)
}
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
//export paddle_begin_init_params
func
paddle_begin_init_params
(
client
C
.
paddle_pserver_client
)
C
.
int
{
func
paddle_begin_init_params
(
client
C
.
paddle_pserver_client
)
C
.
int
{
c
:=
get
(
client
)
c
:=
get
(
client
)
if
selected
:=
c
.
BeginInitParams
();
selected
{
if
selected
:=
c
.
BeginInitParams
();
selected
{
return
1
return
1
}
}
return
C
.
PSERVER_OK
return
0
}
}
//export paddle_init_param
//export paddle_init_param
...
@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
...
@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return
C
.
PSERVER_OK
return
C
.
PSERVER_OK
}
}
//export paddle_save_model
func
paddle_save_model
(
client
C
.
paddle_pserver_client
,
path
*
C
.
char
)
C
.
int
{
p
:=
C
.
GoString
(
path
)
c
:=
get
(
client
)
err
:=
c
.
Save
(
p
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
C
.
PSERVER_ERROR
}
return
C
.
PSERVER_OK
}
func
main
()
{}
// Required but ignored
func
main
()
{}
// Required but ignored
go/pserver/client/c/test/test_cclient.c
浏览文件 @
0467cd2d
...
@@ -111,9 +111,5 @@ retry:
...
@@ -111,9 +111,5 @@ retry:
getParams
(
c
);
getParams
(
c
);
}
}
if
(
paddle_save_model
(
c
,
"/tmp/"
))
{
fail
();
}
return
0
;
return
0
;
}
}
go/pserver/client/client.go
浏览文件 @
0467cd2d
...
@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
...
@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return
ps
,
nil
return
ps
,
nil
}
}
// Save indicates parameters to save the parameter to the given path.
func
(
c
*
Client
)
Save
(
path
string
)
error
{
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
for
_
,
p
:=
range
c
.
pservers
{
err
:=
p
.
Call
(
"Service.Save"
,
path
,
nil
)
errCh
<-
err
}
recv
:=
0
for
err
:=
range
errCh
{
if
err
!=
nil
{
return
err
}
recv
++
if
recv
==
len
(
c
.
pservers
)
{
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return
nil
}
func
strHash
(
s
string
)
uint32
{
func
strHash
(
s
string
)
uint32
{
h
:=
fnv
.
New32a
()
h
:=
fnv
.
New32a
()
_
,
_
=
h
.
Write
([]
byte
(
s
))
_
,
_
=
h
.
Write
([]
byte
(
s
))
...
...
go/pserver/service.go
浏览文件 @
0467cd2d
...
@@ -36,6 +36,10 @@ import (
...
@@ -36,6 +36,10 @@ import (
// ElementType is the type of elements of a Parameter.
// ElementType is the type of elements of a Parameter.
type
ElementType
int
type
ElementType
int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found"
)
// RPC error message.
// RPC error message.
const
(
const
(
AlreadyInitialized
=
"pserver already initialized"
AlreadyInitialized
=
"pserver already initialized"
...
@@ -103,6 +107,10 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, e
...
@@ -103,6 +107,10 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, e
return
nil
,
err
return
nil
,
err
}
}
if
len
(
v
)
==
0
{
return
nil
,
ErrCheckpointNotFound
}
var
cpMeta
checkpointMeta
var
cpMeta
checkpointMeta
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -156,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
...
@@ -156,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
}
// InitParam initializes a parameter.
// InitParam initializes a parameter.
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
_
*
int
)
error
{
select
{
select
{
case
<-
s
.
initialized
:
case
<-
s
.
initialized
:
return
errors
.
New
(
AlreadyInitialized
)
return
errors
.
New
(
AlreadyInitialized
)
...
@@ -177,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
...
@@ -177,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
// initialization has finished.
func
(
s
*
Service
)
FinishInitParams
(
dummy0
int
,
dummy1
*
int
)
error
{
func
(
s
*
Service
)
FinishInitParams
(
_
int
,
_
*
int
)
error
{
select
{
select
{
case
<-
s
.
initialized
:
case
<-
s
.
initialized
:
return
errors
.
New
(
AlreadyInitialized
)
return
errors
.
New
(
AlreadyInitialized
)
...
@@ -190,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
...
@@ -190,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter
// SendGrad sends gradient to parameter servers for parameter
// optimization.
// optimization.
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
_
*
int
)
error
{
select
{
select
{
case
<-
s
.
initialized
:
case
<-
s
.
initialized
:
default
:
default
:
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
0467cd2d
...
@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
...
@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context tensor
)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context tensor
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_library
(
grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
...
@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
...
@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies
(
framework_py_proto framework_py_proto_init
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
add_op mul_op sigmoid_op softmax_op fc_op
)
paddle/framework/grad_op_creator.cc
0 → 100644
浏览文件 @
0467cd2d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
OperatorBase
*
GradOpCreator
::
Create
()
{
BuildOpInOutArgList
();
OperatorBase
*
grad_op
=
OpRegistry
::
grad_creators
().
at
(
op_
->
type_
)();
CompleteGradOp
(
grad_op
);
return
grad_op
;
}
OpInOutArg
*
GradOpCreator
::
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
)
{
int
idx
=
var_map
.
at
(
var
.
name
());
int
begin_idx
=
format
.
empty
()
?
idx
:
format
.
at
(
idx
);
int
end_idx
=
format
.
empty
()
?
idx
+
1
:
format
.
at
(
idx
+
1
);
return
new
OpInOutArg
(
var
.
name
(),
type
,
!
var
.
ignore_gradient
(),
begin_idx
,
end_idx
);
}
void
GradOpCreator
::
BuildOpInOutArgList
()
{
const
OpProto
&
op_proto
=
OpRegistry
::
protos
().
at
(
op_
->
type_
);
const
auto
&
var_map
=
*
(
OpRegistry
::
VarIndexMaps
().
at
(
op_
->
type_
));
const
std
::
vector
<
int
>&
in_format
=
op_
->
attrs_
.
count
(
"input_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
)
:
std
::
vector
<
int
>
();
const
std
::
vector
<
int
>&
out_format
=
op_
->
attrs_
.
count
(
"output_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
)
:
std
::
vector
<
int
>
();
for
(
const
auto
&
var
:
op_proto
.
inputs
())
{
arg_list_
.
emplace_back
(
std
::
shared_ptr
<
OpInOutArg
>
(
BuildArg
(
var
,
var_map
,
in_format
,
IN
)));
}
for
(
const
auto
&
var
:
op_proto
.
outputs
())
{
arg_list_
.
emplace_back
(
std
::
shared_ptr
<
OpInOutArg
>
(
BuildArg
(
var
,
var_map
,
out_format
,
OUT
)));
}
}
void
GradOpCreator
::
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
{
std
::
string
var_name
=
arg
->
proto_name_
;
if
(
is_grad
)
{
var_name
+=
OperatorBase
::
GRAD_VAR_SUFFIX
();
}
(
*
varmap
)[
var_name
]
=
idx
++
;
size_t
pre_sz
=
in_out
.
size
();
auto
base_it
=
arg
->
type_
==
IN
?
op_
->
inputs_
.
begin
()
:
op_
->
outputs_
.
begin
();
std
::
copy
(
base_it
+
arg
->
begin_idx_
,
base_it
+
arg
->
end_idx_
,
std
::
back_inserter
(
in_out
));
if
(
is_grad
)
{
for
(
size_t
i
=
pre_sz
;
i
<
in_out
.
size
();
++
i
)
{
in_out
[
i
]
+=
OperatorBase
::
GRAD_VAR_SUFFIX
();
}
}
format
.
push_back
(
in_out
.
size
());
}
void
GradOpCreator
::
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
{
grad_op
->
type_
=
op_
->
type_
+
"@GRAD"
;
// not necessary
grad_op
->
attrs_
=
op_
->
attrs_
;
grad_op
->
attrs_
.
erase
(
"input_format"
);
grad_op
->
attrs_
.
erase
(
"output_format"
);
VarIndexMap
*
grad_varmap
=
new
VarIndexMap
();
int
in_idx
=
0
;
int
out_idx
=
0
;
std
::
vector
<
int
>
in_format
({
0
});
std
::
vector
<
int
>
out_format
({
0
});
for
(
const
auto
&
arg
:
arg_list_
)
{
// op_'s inputs_ and outputs_
if
(
arg
->
needed_in_grad_
)
{
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
inputs_
,
in_format
,
grad_varmap
,
in_idx
,
false
);
}
if
(
arg
->
type_
==
IN
)
{
// gradients of op_'s inputs_
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
outputs_
,
out_format
,
grad_varmap
,
out_idx
,
true
);
}
else
{
// gradients of op_'s outputs_
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
inputs_
,
in_format
,
grad_varmap
,
in_idx
,
true
);
}
}
grad_op
->
attrs_
[
"input_format"
]
=
in_format
;
grad_op
->
attrs_
[
"output_format"
]
=
out_format
;
grad_op
->
in_out_idxs_
.
reset
(
grad_varmap
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_creator.h
0 → 100644
浏览文件 @
0467cd2d
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
framework
{
class
OpRegistry
;
enum
InOutType
{
IN
,
OUT
};
struct
OpInOutArg
{
OpInOutArg
(
const
std
::
string
&
proto_name
,
const
InOutType
&
type
,
bool
needed_in_grad
,
size_t
begin_idx
,
size_t
end_idx
)
:
proto_name_
(
proto_name
),
type_
(
type
),
needed_in_grad_
(
needed_in_grad
),
begin_idx_
(
begin_idx
),
end_idx_
(
end_idx
)
{}
std
::
string
proto_name_
;
InOutType
type_
;
bool
needed_in_grad_
;
size_t
begin_idx_
;
size_t
end_idx_
;
};
class
GradOpCreator
{
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
public:
GradOpCreator
(
const
OperatorBase
*
op
)
:
op_
(
op
)
{}
OperatorBase
*
Create
();
private:
OpInOutArg
*
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
);
void
BuildOpInOutArgList
();
void
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
;
void
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
;
const
OperatorBase
*
op_
;
std
::
vector
<
std
::
shared_ptr
<
OpInOutArg
>>
arg_list_
;
};
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_creator_test.cc
0 → 100644
浏览文件 @
0467cd2d
#include "paddle/framework/grad_op_creator.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP
(
add_two
);
namespace
paddle
{
namespace
framework
{
TEST
(
GradOpCreator
,
AddTwo
)
{
std
::
shared_ptr
<
OperatorBase
>
add_op
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
std
::
shared_ptr
<
OperatorBase
>
grad_add_op
=
OpRegistry
::
CreateGradOp
(
add_op
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
outputs_
.
size
()),
2
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"X"
),
"x"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Y"
),
"y"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out"
),
"out"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out@GRAD"
),
"out@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"X@GRAD"
),
"x@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"Y@GRAD"
),
"y@GRAD"
);
}
}
// namespace framework
}
// namespace paddle
\ No newline at end of file
paddle/framework/net.cc
浏览文件 @
0467cd2d
...
@@ -15,14 +15,24 @@
...
@@ -15,14 +15,24 @@
*/
*/
#include "paddle/framework/net.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
)
{
auto
grad_ops
=
std
::
make_shared
<
PlainNet
>
();
for
(
auto
&
op
:
ForwardOps
->
ops_
)
{
auto
op_grad
=
OpRegistry
::
CreateGradOp
(
op
);
grad_ops
->
AddOp
(
op_grad
);
}
grad_ops
->
CompleteAddOp
();
return
grad_ops
;
}
void
PlainNet
::
CompleteAddOp
(
bool
calc
)
{
void
PlainNet
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
add_op_done_
=
true
;
if
(
!
calc
)
return
;
if
(
!
calc
)
return
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
std
::
unordered_set
<
std
::
string
>
temp_output
;
std
::
unordered_set
<
std
::
string
>
temp_output
;
...
...
paddle/framework/net.h
浏览文件 @
0467cd2d
...
@@ -100,5 +100,7 @@ class PlainNet : public Net {
...
@@ -100,5 +100,7 @@ class PlainNet : public Net {
}
}
};
};
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net_op_test.cc
浏览文件 @
0467cd2d
...
@@ -3,17 +3,24 @@
...
@@ -3,17 +3,24 @@
#include <paddle/framework/op_registry.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/operator.h>
namespace
pd
=
paddle
::
framework
;
USE_OP
(
add_two
);
USE_OP
(
mul
);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
namespace
paddle
{
namespace
framework
{
static
int
infer_shape_cnt
=
0
;
static
int
infer_shape_cnt
=
0
;
static
int
run_cnt
=
0
;
static
int
run_cnt
=
0
;
class
TestOp
:
public
pd
::
OperatorBase
{
class
TestOp
:
public
OperatorBase
{
public:
public:
void
InferShape
(
const
std
::
shared_ptr
<
pd
::
Scope
>&
scope
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
)
const
override
{
++
infer_shape_cnt
;
++
infer_shape_cnt
;
}
}
void
Run
(
const
std
::
shared_ptr
<
pd
::
Scope
>&
scope
,
void
Run
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
,
const
paddle
::
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
paddle
::
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
++
run_cnt
;
++
run_cnt
;
}
}
...
@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
...
@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
}
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
auto
net
=
std
::
make_shared
<
paddle
::
framework
::
PlainNet
>
();
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
make_shared
<
TestOp
>
();
...
@@ -55,13 +62,37 @@ TEST(OpKernel, all) {
...
@@ -55,13 +62,37 @@ TEST(OpKernel, all) {
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
auto
scope
=
std
::
make_shared
<
pd
::
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
p
addle
::
p
latform
::
CPUDeviceContext
dev_ctx
;
platform
::
CPUDeviceContext
dev_ctx
;
net
->
InferShape
(
scope
);
net
->
InferShape
(
scope
);
net
->
Run
(
scope
,
dev_ctx
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
std
::
runtime_error
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
std
::
runtime_error
);
}
}
TEST
(
AddBackwardOp
,
TestGradOp
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
""
},
{}));
auto
grad_ops
=
AddBackwardOp
(
net
);
for
(
auto
&
op
:
grad_ops
->
ops_
)
{
op
->
DebugString
();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
}
// namespace framework
}
// namespace paddle
paddle/framework/net_test.cc
已删除
100644 → 0
浏览文件 @
0ceeacbe
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
class
FakeFC
:
public
Operator
{}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_proto.proto
浏览文件 @
0467cd2d
...
@@ -84,6 +84,11 @@ message VarProto {
...
@@ -84,6 +84,11 @@ message VarProto {
// "temporary_index": [1]
// "temporary_index": [1]
// }
// }
optional
bool
temporary
=
4
[
default
=
false
];
optional
bool
temporary
=
4
[
default
=
false
];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional
bool
ignore_gradient
=
6
;
}
}
// Op protocol message for 3rd-party language binding.
// Op protocol message for 3rd-party language binding.
...
@@ -105,4 +110,5 @@ message OpProto {
...
@@ -105,4 +110,5 @@ message OpProto {
// The type of that Op.
// The type of that Op.
required
string
type
=
5
;
required
string
type
=
5
;
}
}
paddle/framework/op_registry.h
浏览文件 @
0467cd2d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
...
@@ -6,9 +20,9 @@
...
@@ -6,9 +20,9 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
...
@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
protected:
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
multiple
=
false
)
{
bool
multiple
=
false
,
bool
ignore_gradient
=
false
)
{
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
*
input
->
mutable_name
()
=
name
;
*
input
->
mutable_name
()
=
name
;
*
input
->
mutable_comment
()
=
comment
;
*
input
->
mutable_comment
()
=
comment
;
input
->
set_ignore_gradient
(
ignore_gradient
);
input
->
set_multiple
(
multiple
);
input
->
set_multiple
(
multiple
);
if
(
multiple
)
{
if
(
multiple
)
{
SetHasMultipleInput
();
SetHasMultipleInput
();
}
}
}
}
void
AddInputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
void
AddInputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
AddInput
(
name
,
comment
,
true
);
bool
ignore_gradient
=
false
)
{
AddInput
(
name
,
comment
,
true
,
ignore_gradient
);
}
}
void
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
void
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
temporary
=
false
,
bool
multiple
=
false
)
{
bool
temporary
=
false
,
bool
multiple
=
false
,
bool
ignore_gradient
=
false
)
{
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
*
output
->
mutable_name
()
=
name
;
*
output
->
mutable_name
()
=
name
;
*
output
->
mutable_comment
()
=
comment
;
*
output
->
mutable_comment
()
=
comment
;
output
->
set_ignore_gradient
(
ignore_gradient
);
output
->
set_multiple
(
multiple
);
output
->
set_multiple
(
multiple
);
if
(
multiple
)
{
if
(
multiple
)
{
SetHasMultipleOutput
();
SetHasMultipleOutput
();
...
@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
...
@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
}
}
void
AddOutputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
void
AddOutputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
temporary
=
false
)
{
bool
temporary
=
false
,
bool
ignore_gradient
=
false
)
{
AddOutput
(
name
,
comment
,
temporary
,
true
);
AddOutput
(
name
,
comment
,
temporary
,
true
,
ignore_gradient
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -205,8 +223,8 @@ class OpRegistry {
...
@@ -205,8 +223,8 @@ class OpRegistry {
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
*
op_proto
.
mutable_type
()
=
op_type
;
...
@@ -227,18 +245,24 @@ class OpRegistry {
...
@@ -227,18 +245,24 @@ class OpRegistry {
}
}
}
}
template
<
typename
OpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
)
{
grad_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameList
&
inputs
,
const
VarNameList
&
inputs
,
const
VarNameList
&
outputs
,
const
VarNameList
&
outputs
,
const
AttributeMap
&
attrs
)
{
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
creators
().
find
(
type
);
auto
op_create_it
=
creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
creators
().
end
(),
PADDLE_ENFORCE
(
op_create_it
!=
creators
().
end
(),
"Operator %s cannot be found"
,
type
);
"Operator %s cannot be found
.
"
,
type
);
auto
op
=
op_create_it
->
second
();
auto
op
=
op_create_it
->
second
();
op
->
type_
=
type
;
op
->
type_
=
type
;
op
->
inputs_
=
inputs
;
op
->
inputs_
=
inputs
;
op
->
outputs_
=
outputs
;
op
->
outputs_
=
outputs
;
op
->
attrs_
=
attrs
;
op
->
attrs_
=
attrs
;
op_checkers
().
at
(
type
).
Check
(
op
->
attrs_
);
op_checkers
().
at
(
type
).
Check
(
op
->
attrs_
);
...
@@ -274,18 +298,41 @@ class OpRegistry {
...
@@ -274,18 +298,41 @@ class OpRegistry {
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateGradOp
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
GradOpCreator
creator
(
op
.
get
());
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
creator
.
Create
());
grad_op
->
Init
();
return
grad_op
;
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
return
protos_
;
};
};
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
grad_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
grad_creators_
;
return
grad_creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
return
maps_
;
}
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
};
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
for
(
auto
&
outname
:
op
->
outputs_
)
{
...
@@ -296,16 +343,6 @@ class OpRegistry {
...
@@ -296,16 +343,6 @@ class OpRegistry {
}
}
}
}
}
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
};
};
};
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
@@ -316,6 +353,14 @@ class OpRegisterHelper {
...
@@ -316,6 +353,14 @@ class OpRegisterHelper {
}
}
};
};
template
<
typename
OpType
>
class
GradOpRegisterHelper
{
public:
GradOpRegisterHelper
(
const
char
*
op_type
)
{
OpRegistry
::
RegisterGradOp
<
OpType
>
(
op_type
);
}
};
/**
/**
* check if MACRO is used in GLOBAL NAMESPACE.
* check if MACRO is used in GLOBAL NAMESPACE.
*/
*/
...
@@ -335,6 +380,17 @@ class OpRegisterHelper {
...
@@ -335,6 +380,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
/**
/**
* Macro to Register OperatorKernel.
* Macro to Register OperatorKernel.
*/
*/
...
...
paddle/framework/operator.h
浏览文件 @
0467cd2d
...
@@ -62,6 +62,11 @@ class OperatorBase {
...
@@ -62,6 +62,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
/// but it will be convert to a unique name in scope after OpCreator.
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static
std
::
string
GRAD_VAR_SUFFIX
()
{
return
"@GRAD"
;
}
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/operators/add_op.cc
浏览文件 @
0467cd2d
...
@@ -49,9 +49,22 @@ The equation is: Out = X + Y
...
@@ -49,9 +49,22 @@ The equation is: Out = X + Y
)DOC"
);
)DOC"
);
}
}
};
};
class
AddOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"AddOpGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_GRADIENT_OP
(
add_two
,
paddle
::
operators
::
AddOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/add_op_test.cc
浏览文件 @
0467cd2d
...
@@ -16,8 +16,13 @@ limitations under the License. */
...
@@ -16,8 +16,13 @@ limitations under the License. */
#define private public
#define private public
#include <paddle/framework/op_registry.h>
#include <paddle/framework/op_registry.h>
USE_OP
(
add_two
);
USE_OP
(
add_two
);
// USE_OP(add_two_grad);
TEST
(
AddOp
,
GetOpProto
)
{
TEST
(
AddOp
,
GetOpProto
)
{
auto
&
protos
=
paddle
::
framework
::
OpRegistry
::
protos
();
auto
&
protos
=
paddle
::
framework
::
OpRegistry
::
protos
();
auto
it
=
protos
.
find
(
"add_two"
);
auto
it
=
protos
.
find
(
"add_two"
);
ASSERT_NE
(
it
,
protos
.
end
());
ASSERT_NE
(
it
,
protos
.
end
());
}
auto
&
grad_creators
=
paddle
::
framework
::
OpRegistry
::
grad_creators
();
\ No newline at end of file
auto
it1
=
grad_creators
.
find
(
"add_two"
);
ASSERT_NE
(
it1
,
grad_creators
.
end
());
}
paddle/operators/mul_op.cc
浏览文件 @
0467cd2d
...
@@ -52,9 +52,22 @@ The equation is: Out = X * Y
...
@@ -52,9 +52,22 @@ The equation is: Out = X * Y
}
}
};
};
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
mul
,
paddle
::
operators
::
MulOp
,
paddle
::
operators
::
MulOpMaker
);
REGISTER_OP
(
mul
,
paddle
::
operators
::
MulOp
,
paddle
::
operators
::
MulOpMaker
);
REGISTER_GRADIENT_OP
(
mul
,
paddle
::
operators
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mul
,
paddle
::
operators
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
mul
,
paddle
::
operators
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/sigmoid_op.cc
浏览文件 @
0467cd2d
...
@@ -39,12 +39,25 @@ public:
...
@@ -39,12 +39,25 @@ public:
}
}
};
};
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
sigmoid
,
REGISTER_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOp
,
paddle
::
operators
::
SigmoidOp
,
paddle
::
operators
::
SigmoidOpMaker
);
paddle
::
operators
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
sigmoid
,
paddle
::
operators
::
SigmoidKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle
::
operators
::
SigmoidKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_op.cc
浏览文件 @
0467cd2d
...
@@ -42,11 +42,23 @@ public:
...
@@ -42,11 +42,23 @@ public:
}
}
};
};
class
SoftmaxOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SoftmaxOpGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_GRADIENT_OP
(
softmax
,
paddle
::
operators
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
python/paddle/v2/__init__.py
浏览文件 @
0467cd2d
...
@@ -33,6 +33,7 @@ import networks
...
@@ -33,6 +33,7 @@ import networks
import
minibatch
import
minibatch
import
plot
import
plot
import
image
import
image
import
model
__all__
=
[
__all__
=
[
'optimizer'
,
'optimizer'
,
...
@@ -54,6 +55,7 @@ __all__ = [
...
@@ -54,6 +55,7 @@ __all__ = [
'evaluator'
,
'evaluator'
,
'image'
,
'image'
,
'master'
,
'master'
,
'model'
,
]
]
...
...
python/paddle/v2/master/client.py
浏览文件 @
0467cd2d
...
@@ -10,11 +10,31 @@ class client(object):
...
@@ -10,11 +10,31 @@ class client(object):
client is a client to the master server.
client is a client to the master server.
"""
"""
def
__init__
(
self
,
etcd_endpoints
,
timeout
,
buf_size
):
def
__init__
(
self
,
etcd_endpoints
,
timeout
_sec
,
buf_size
=
0
):
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
,
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
_sec
,
buf_size
)
buf_size
)
def
close
(
self
):
def
request_save_model
(
self
,
trainer_id
,
block_ms
):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return
lib
.
paddle_request_save_model
(
self
.
c
,
trainer_id
,
block_ms
)
def
release
(
self
):
lib
.
paddle_release_master_client
(
self
.
c
)
lib
.
paddle_release_master_client
(
self
.
c
)
self
.
c
=
None
self
.
c
=
None
...
@@ -27,10 +47,13 @@ class client(object):
...
@@ -27,10 +47,13 @@ class client(object):
holder
[
idx
]
=
c_ptr
holder
[
idx
]
=
c_ptr
lib
.
paddle_set_dataset
(
self
.
c
,
holder
,
len
(
paths
))
lib
.
paddle_set_dataset
(
self
.
c
,
holder
,
len
(
paths
))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def
next_record
(
self
):
def
next_record
(
self
):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p
=
ctypes
.
c_char_p
()
p
=
ctypes
.
c_char_p
()
ret
=
ctypes
.
pointer
(
p
)
ret
=
ctypes
.
pointer
(
p
)
size
=
lib
.
paddle_next_record
(
self
.
c
,
ret
)
size
=
lib
.
paddle_next_record
(
self
.
c
,
ret
)
...
...
python/paddle/v2/model.py
0 → 100644
浏览文件 @
0467cd2d
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
errno
import
uuid
import
paddle.v2.master
__all__
=
[
"save_model"
,
"load_model"
]
trainer_id
=
str
(
uuid
.
uuid4
())
def
mkdir_p
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
exc
:
if
exc
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
pass
else
:
raise
def
save_model
(
parameters
,
path
):
need_request
=
"KUBERNETES_SERVICE_HOST"
in
os
.
environ
.
keys
()
if
need_request
:
# TODO(helin): figure out how MPI trains, since MPI only save
# model when trainer_id == "0", we can consolidate the logic
# here.
# TODO(helin): change this environment variable name from
# MASTER_IP to ETCD_IP
etcd_name
=
"MASTER_IP"
if
etcd_name
not
in
os
.
environ
.
keys
():
raise
Exception
(
'not find '
+
etcd_name
+
' in environment variable.'
)
etcd_ip
=
os
.
environ
.
get
(
etcd_name
)
client
=
master
.
client
(
"http://"
+
etcd_ip
+
":2379"
,
5
,
0
)
r
=
client
.
request_save_model
(
trainer_id
,
5000
)
if
r
==
0
:
# do not need to save
return
elif
r
<
0
:
# error
return
else
:
# save model
path
=
os
.
path
.
join
(
path
,
trainer_id
)
path
=
os
.
path
.
join
(
path
,
"model.tar"
)
mkdir_p
(
path
)
with
open
(
path
,
'wb'
)
as
f
:
parameters
.
to_tar
(
f
)
def
load_model
(
parameters
,
path
):
with
open
(
path
,
'rb'
)
as
f
:
parameters
.
from_tar
(
f
)
python/paddle/v2/reader/creator.py
浏览文件 @
0467cd2d
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
"""
Creator package contains some simple reader creator, which could
be used in user
Creator package contains some simple reader creator, which could
program.
be used in user
program.
"""
"""
__all__
=
[
'np_array'
,
'text_file'
,
"recordio"
]
__all__
=
[
'np_array'
,
'text_file'
,
"recordio"
]
...
@@ -59,7 +59,7 @@ def text_file(path):
...
@@ -59,7 +59,7 @@ def text_file(path):
def
recordio_local
(
paths
,
buf_size
=
100
):
def
recordio_local
(
paths
,
buf_size
=
100
):
"""
"""
Creates a data reader from given RecordIO file paths separated by ",",
Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported.
glob pattern is supported.
:path: path of recordio files.
:path: path of recordio files.
:returns: data reader of recordio files.
:returns: data reader of recordio files.
...
@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100):
...
@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100):
def
recordio
(
paths
,
buf_size
=
100
):
def
recordio
(
paths
,
buf_size
=
100
):
"""
"""
Creates a data reader that outputs record one one by one
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
from given local or cloud recordio path.
:path: path of recordio files.
:path: path of recordio files.
:returns: data reader of recordio files.
:returns: data reader of recordio files.
...
@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
...
@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
host_name
=
"MASTER_SERVICE_HOST"
host_name
=
"MASTER_SERVICE_HOST"
if
host_name
not
in
os
.
environ
.
keys
():
if
host_name
not
in
os
.
environ
.
keys
():
raise
Exception
(
'not find '
+
host_name
+
' in environ.'
)
raise
Exception
(
'not find '
+
host_name
+
' in environ
ment variable
.'
)
addr
=
os
.
environ
(
host
)
addr
=
os
.
environ
(
host
)
...
@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
...
@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
break
break
yield
r
yield
r
c
.
clo
se
()
c
.
relea
se
()
return
reader
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录