Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
d263ccef
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d263ccef
编写于
7月 25, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into random_op
上级
6f80b5f1
0d26a158
变更
38
隐藏空白更改
内联
并排
Showing
38 changed file
with
802 addition
and
223 deletion
+802
-223
doc/design/cluster_train/save_model.md
doc/design/cluster_train/save_model.md
+5
-4
doc/design/simple_op_design.md
doc/design/simple_op_design.md
+1
-0
doc/faq/index_cn.rst
doc/faq/index_cn.rst
+10
-0
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+5
-1
go/master/c/client.go
go/master/c/client.go
+45
-19
go/master/client.go
go/master/client.go
+85
-9
go/master/client_test.go
go/master/client_test.go
+5
-3
go/master/etcd_client.go
go/master/etcd_client.go
+4
-4
go/master/service.go
go/master/service.go
+47
-7
go/pserver/client/c/cclient.go
go/pserver/client/c/cclient.go
+7
-14
go/pserver/client/c/test/test_cclient.c
go/pserver/client/c/test/test_cclient.c
+0
-4
go/pserver/client/client.go
go/pserver/client/client.go
+0
-26
go/pserver/service.go
go/pserver/service.go
+11
-3
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+6
-3
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+116
-0
paddle/framework/grad_op_builder.h
paddle/framework/grad_op_builder.h
+48
-0
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+26
-0
paddle/framework/net.cc
paddle/framework/net.cc
+11
-1
paddle/framework/net.h
paddle/framework/net.h
+7
-5
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+39
-8
paddle/framework/net_test.cc
paddle/framework/net_test.cc
+0
-24
paddle/framework/op_proto.proto
paddle/framework/op_proto.proto
+6
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+91
-30
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+10
-14
paddle/framework/operator.h
paddle/framework/operator.h
+8
-4
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+5
-7
paddle/framework/scope.h
paddle/framework/scope.h
+2
-3
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+13
-0
paddle/operators/add_op_test.cc
paddle/operators/add_op_test.cc
+6
-1
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+13
-0
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+13
-0
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+12
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+3
-2
python/paddle/trainer_config_helpers/networks.py
python/paddle/trainer_config_helpers/networks.py
+32
-15
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+2
-0
python/paddle/v2/master/client.py
python/paddle/v2/master/client.py
+29
-6
python/paddle/v2/model.py
python/paddle/v2/model.py
+73
-0
python/paddle/v2/reader/creator.py
python/paddle/v2/reader/creator.py
+6
-6
未找到文件。
doc/design/cluster_train/save_model.md
浏览文件 @
d263ccef
...
...
@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
etcd, trainer ID is a randomly generated UUID, the trainer will
contact the master server requesting to save the model, and find out
if itself is elected. When the master server is not used, unique
trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path
...
...
doc/design/simple_op_design.md
浏览文件 @
d263ccef
...
...
@@ -49,6 +49,7 @@ message AttrProto {
message
VarProto
{
required
string
name
=
1
;
required
string
comment
=
2
;
required
bool
is_tensor
=
3
;
};
message
OpProto
{
...
...
doc/faq/index_cn.rst
浏览文件 @
d263ccef
...
...
@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
go/cmd/pserver/pserver.go
浏览文件 @
d263ccef
...
...
@@ -59,7 +59,11 @@ func main() {
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
if
err
!=
nil
{
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Infof
(
"Could not find the pserver checkpoint."
)
}
else
{
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
}
}
}
...
...
go/master/c/client.go
浏览文件 @
d263ccef
...
...
@@ -22,6 +22,9 @@ package main
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
*/
import
"C"
...
...
@@ -33,7 +36,6 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log
"github.com/sirupsen/logrus"
)
...
...
@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func
paddle_new_etcd_master_client
(
etcdEndpoints
*
C
.
char
,
timeout
int
,
bufSize
int
)
C
.
paddle_master_client
{
p
:=
C
.
GoString
(
etcdEndpoints
)
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
strings
.
Split
(
p
,
","
),
DialTimeout
:
time
.
Second
*
time
.
Duration
(
timeout
),
})
endpoints
:=
strings
.
Split
(
p
,
","
)
c
,
err
:=
master
.
NewClient
(
master
.
WithEtcd
(
endpoints
,
time
.
Duration
(
timeout
)
*
time
.
Second
),
master
.
WithBuffer
(
bufSize
),
)
if
err
!=
nil
{
panic
(
err
)
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
master
.
GetKey
(
cli
,
master
.
DefaultAddrPath
,
timeout
)
if
err
!=
nil
{
panic
(
err
)
}
ch
<-
a
go
master
.
WatchKey
(
cli
,
master
.
DefaultAddrPath
,
ch
)
c
:=
master
.
NewClient
(
ch
,
bufSize
)
return
add
(
c
)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func
paddle_new_master_client
(
addr
*
C
.
char
,
bufSize
int
)
C
.
paddle_master_client
{
a
:=
C
.
GoString
(
addr
)
ch
:=
make
(
chan
string
,
1
)
ch
<-
a
c
:=
master
.
NewClient
(
ch
,
bufSize
)
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
a
),
master
.
WithBuffer
(
bufSize
))
if
err
!=
nil
{
panic
(
err
)
}
return
add
(
c
)
}
...
...
@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return
C
.
PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:error
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
c
:=
get
(
client
)
...
...
@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return
C
.
int
(
size
)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func
paddle_request_save_model
(
client
C
.
paddle_master_client
,
trainerID
string
,
blockMS
int
)
C
.
int
{
c
:=
get
(
client
)
need
,
err
:=
c
.
RequestSaveModel
(
trainerID
,
time
.
Duration
(
blockMS
)
*
time
.
Millisecond
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
C
.
PADDLE_MASTER_ERROR
}
if
need
{
return
C
.
PADDLE_SAVE_MODEL_OK
}
return
C
.
PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
func
mem_free
(
p
unsafe
.
Pointer
)
{
// "free" may be a better name for this function, but doing so
...
...
go/master/client.go
浏览文件 @
d263ccef
...
...
@@ -16,17 +16,20 @@ package master
import
(
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log
"github.com/sirupsen/logrus"
)
// Client is the client of the master server.
type
Client
struct
{
conn
*
connection
.
Conn
ch
chan
record
conn
*
connection
.
Conn
ch
chan
record
initChOnce
sync
.
Once
}
type
record
struct
{
...
...
@@ -34,24 +37,83 @@ type record struct {
err
error
}
//
NewClient creates a new Client
.
//
WithBuffer sets the client to buffer the training record
.
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func
NewClient
(
addrCh
<-
chan
string
,
bufSize
int
)
*
Client
{
func
WithBuffer
(
bufSize
int
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
if
bufSize
<=
0
{
return
nil
}
c
.
initChOnce
.
Do
(
func
()
{
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
getRecords
()
})
return
nil
}
}
// WithAddr sets the client to use fixed master address.
func
WithAddr
(
addr
string
)
func
(
c
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
ch
:=
make
(
chan
string
,
1
)
ch
<-
addr
go
c
.
monitorMaster
(
ch
)
return
nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func
WithEtcd
(
endpoints
[]
string
,
timeout
time
.
Duration
)
func
(
*
Client
)
error
{
return
func
(
c
*
Client
)
error
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
DialTimeout
:
timeout
,
})
if
err
!=
nil
{
return
err
}
ch
:=
make
(
chan
string
,
1
)
a
,
err
:=
GetKey
(
cli
,
DefaultAddrPath
,
timeout
)
if
err
!=
nil
{
return
err
}
if
a
!=
""
{
// Master is registered, send to the master address
// channel.
ch
<-
a
}
go
watchKey
(
cli
,
DefaultAddrPath
,
ch
)
go
c
.
monitorMaster
(
ch
)
return
nil
}
}
// NewClient creates a new Client.
func
NewClient
(
opts
...
func
(
*
Client
)
error
)
(
*
Client
,
error
)
{
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
monitorMaster
(
addrCh
)
go
c
.
getRecords
()
return
c
for
_
,
opt
:=
range
opts
{
err
:=
opt
(
c
)
if
err
!=
nil
{
return
nil
,
err
}
}
return
c
,
nil
}
func
(
c
*
Client
)
getRecords
()
{
for
{
t
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
// getTask call.
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
time
.
Sleep
(
3
*
time
.
Second
)
continue
...
...
@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
c
.
initChOnce
.
Do
(
func
()
{
// initialize with in case WithBuffer is not used.
c
.
ch
=
make
(
chan
record
,
0
)
go
c
.
getRecords
()
})
r
:=
<-
c
.
ch
return
r
.
r
,
r
.
err
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func
(
c
*
Client
)
RequestSaveModel
(
trainerID
string
,
blockDur
time
.
Duration
)
(
bool
,
error
)
{
var
need
bool
err
:=
c
.
conn
.
Call
(
"Service.RequestSaveModel"
,
SaveModelRequest
{
TrainerID
:
trainerID
,
BlockDur
:
blockDur
},
&
need
)
return
need
,
err
}
go/master/client_test.go
浏览文件 @
d263ccef
...
...
@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
}
curAddr
:=
make
(
chan
string
,
1
)
curAddr
<-
fmt
.
Sprintf
(
":%d"
,
p
)
c
:=
master
.
NewClient
(
curAddr
,
10
)
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
10
))
if
err
!=
nil
{
panic
(
err
)
}
err
=
c
.
SetDataset
([]
string
{
path
})
if
err
!=
nil
{
panic
(
err
)
...
...
go/master/etcd_client.go
浏览文件 @
d263ccef
...
...
@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
}
// GetKey gets the value by the specify key.
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
int
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
)
)
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
time
.
Duration
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
out
)
resp
,
err
:=
c
.
Get
(
ctx
,
key
)
cancel
()
if
err
!=
nil
{
...
...
@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return
string
(
v
),
nil
}
//
W
atchKey watches the specify key and send to valChan if there is some event.
func
W
atchKey
(
c
*
clientv3
.
Client
,
key
string
,
valChan
chan
<-
string
)
{
//
w
atchKey watches the specify key and send to valChan if there is some event.
func
w
atchKey
(
c
*
clientv3
.
Client
,
key
string
,
valChan
chan
<-
string
)
{
rch
:=
c
.
Watch
(
context
.
Background
(),
key
)
for
wresp
:=
range
rch
{
for
_
,
ev
:=
range
wresp
.
Events
{
...
...
go/master/service.go
浏览文件 @
d263ccef
...
...
@@ -78,9 +78,10 @@ type Service struct {
ready
chan
struct
{}
store
Store
mu
sync
.
Mutex
initDone
bool
taskQueues
taskQueues
mu
sync
.
Mutex
initDone
bool
taskQueues
taskQueues
savingTrainer
string
}
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
...
...
@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
_
*
int
)
error
{
if
len
(
globPaths
)
==
0
{
return
errors
.
New
(
"no dataset specified"
)
}
...
...
@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
func
(
s
*
Service
)
GetTask
(
_
int
,
task
*
Task
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
_
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
// TaskFailed tells the service that a task is failed.
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
_
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s
.
processFailedTask
(
t
,
meta
.
Epoch
)
return
nil
}
// SaveModelRequest is the request for saving model
type
SaveModelRequest
struct
{
TrainerID
string
BlockDur
time
.
Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func
(
s
*
Service
)
RequestSaveModel
(
req
SaveModelRequest
,
need
*
bool
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
req
.
TrainerID
==
""
{
return
errors
.
New
(
"trainer id is empty"
)
}
if
s
.
savingTrainer
==
""
{
*
need
=
true
}
else
{
if
req
.
TrainerID
==
s
.
savingTrainer
{
// save trainer asked to save model again
*
need
=
true
}
else
{
*
need
=
false
}
}
if
*
need
{
s
.
savingTrainer
=
req
.
TrainerID
time
.
AfterFunc
(
req
.
BlockDur
,
func
()
{
s
.
mu
.
Lock
()
s
.
savingTrainer
=
""
s
.
mu
.
Unlock
()
})
}
return
nil
}
go/pserver/client/c/cclient.go
浏览文件 @
d263ccef
...
...
@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove
(
client
)
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
func
paddle_begin_init_params
(
client
C
.
paddle_pserver_client
)
C
.
int
{
c
:=
get
(
client
)
if
selected
:=
c
.
BeginInitParams
();
selected
{
return
1
}
return
C
.
PSERVER_OK
return
0
}
//export paddle_init_param
...
...
@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return
C
.
PSERVER_OK
}
//export paddle_save_model
func
paddle_save_model
(
client
C
.
paddle_pserver_client
,
path
*
C
.
char
)
C
.
int
{
p
:=
C
.
GoString
(
path
)
c
:=
get
(
client
)
err
:=
c
.
Save
(
p
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
C
.
PSERVER_ERROR
}
return
C
.
PSERVER_OK
}
func
main
()
{}
// Required but ignored
go/pserver/client/c/test/test_cclient.c
浏览文件 @
d263ccef
...
...
@@ -111,9 +111,5 @@ retry:
getParams
(
c
);
}
if
(
paddle_save_model
(
c
,
"/tmp/"
))
{
fail
();
}
return
0
;
}
go/pserver/client/client.go
浏览文件 @
d263ccef
...
...
@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return
ps
,
nil
}
// Save indicates parameters to save the parameter to the given path.
func
(
c
*
Client
)
Save
(
path
string
)
error
{
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
for
_
,
p
:=
range
c
.
pservers
{
err
:=
p
.
Call
(
"Service.Save"
,
path
,
nil
)
errCh
<-
err
}
recv
:=
0
for
err
:=
range
errCh
{
if
err
!=
nil
{
return
err
}
recv
++
if
recv
==
len
(
c
.
pservers
)
{
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return
nil
}
func
strHash
(
s
string
)
uint32
{
h
:=
fnv
.
New32a
()
_
,
_
=
h
.
Write
([]
byte
(
s
))
...
...
go/pserver/service.go
浏览文件 @
d263ccef
...
...
@@ -36,6 +36,10 @@ import (
// ElementType is the type of elements of a Parameter.
type
ElementType
int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found"
)
// RPC error message.
const
(
AlreadyInitialized
=
"pserver already initialized"
...
...
@@ -103,6 +107,10 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, e
return
nil
,
err
}
if
len
(
v
)
==
0
{
return
nil
,
ErrCheckpointNotFound
}
var
cpMeta
checkpointMeta
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -156,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
// InitParam initializes a parameter.
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
_
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
return
errors
.
New
(
AlreadyInitialized
)
...
...
@@ -177,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func
(
s
*
Service
)
FinishInitParams
(
dummy0
int
,
dummy1
*
int
)
error
{
func
(
s
*
Service
)
FinishInitParams
(
_
int
,
_
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
return
errors
.
New
(
AlreadyInitialized
)
...
...
@@ -190,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
_
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
default
:
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
d263ccef
...
...
@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context tensor
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
...
...
@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies
(
framework_py_proto framework_py_proto_init
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
add_op mul_op sigmoid_op softmax_op fc_op
)
paddle/framework/grad_op_builder.cc
0 → 100644
浏览文件 @
d263ccef
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
OperatorBase
*
GradOpBuilder
::
Build
()
{
BuildOpInOutArgList
();
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op_
->
type_
);
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
grad_op
->
type_
=
grad_op_type
;
CompleteGradOp
(
grad_op
);
return
grad_op
;
}
OpInOutArg
*
GradOpBuilder
::
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
)
{
int
idx
=
var_map
.
at
(
var
.
name
());
int
begin_idx
=
format
.
empty
()
?
idx
:
format
.
at
(
idx
);
int
end_idx
=
format
.
empty
()
?
idx
+
1
:
format
.
at
(
idx
+
1
);
return
new
OpInOutArg
(
var
.
name
(),
type
,
!
var
.
ignore_gradient
(),
begin_idx
,
end_idx
);
}
void
GradOpBuilder
::
BuildOpInOutArgList
()
{
const
OpProto
&
op_proto
=
OpRegistry
::
protos
().
at
(
op_
->
type_
);
const
auto
&
var_map
=
*
(
OpRegistry
::
VarIndexMaps
().
at
(
op_
->
type_
));
const
std
::
vector
<
int
>&
in_format
=
op_
->
attrs_
.
count
(
"input_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
)
:
std
::
vector
<
int
>
();
const
std
::
vector
<
int
>&
out_format
=
op_
->
attrs_
.
count
(
"output_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
)
:
std
::
vector
<
int
>
();
for
(
const
auto
&
var
:
op_proto
.
inputs
())
{
arg_list_
.
emplace_back
(
std
::
shared_ptr
<
OpInOutArg
>
(
BuildArg
(
var
,
var_map
,
in_format
,
IN
)));
}
for
(
const
auto
&
var
:
op_proto
.
outputs
())
{
arg_list_
.
emplace_back
(
std
::
shared_ptr
<
OpInOutArg
>
(
BuildArg
(
var
,
var_map
,
out_format
,
OUT
)));
}
}
void
GradOpBuilder
::
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
{
std
::
string
var_name
=
arg
->
proto_name_
;
if
(
is_grad
)
{
var_name
+=
OperatorBase
::
GRAD_VAR_SUFFIX
();
}
(
*
varmap
)[
var_name
]
=
idx
++
;
size_t
pre_sz
=
in_out
.
size
();
auto
base_it
=
arg
->
type_
==
IN
?
op_
->
inputs_
.
begin
()
:
op_
->
outputs_
.
begin
();
std
::
copy
(
base_it
+
arg
->
begin_idx_
,
base_it
+
arg
->
end_idx_
,
std
::
back_inserter
(
in_out
));
if
(
is_grad
)
{
for
(
size_t
i
=
pre_sz
;
i
<
in_out
.
size
();
++
i
)
{
in_out
[
i
]
+=
OperatorBase
::
GRAD_VAR_SUFFIX
();
}
}
format
.
push_back
(
in_out
.
size
());
}
void
GradOpBuilder
::
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
{
grad_op
->
attrs_
=
op_
->
attrs_
;
grad_op
->
attrs_
.
erase
(
"input_format"
);
grad_op
->
attrs_
.
erase
(
"output_format"
);
VarIndexMap
*
grad_varmap
=
new
VarIndexMap
();
int
in_idx
=
0
;
int
out_idx
=
0
;
std
::
vector
<
int
>
in_format
({
0
});
std
::
vector
<
int
>
out_format
({
0
});
for
(
const
auto
&
arg
:
arg_list_
)
{
// op_'s inputs_ and outputs_
if
(
arg
->
needed_in_grad_
)
{
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
inputs_
,
in_format
,
grad_varmap
,
in_idx
,
false
);
}
if
(
arg
->
type_
==
IN
)
{
// gradients of op_'s inputs_
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
outputs_
,
out_format
,
grad_varmap
,
out_idx
,
true
);
}
else
{
// gradients of op_'s outputs_
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
inputs_
,
in_format
,
grad_varmap
,
in_idx
,
true
);
}
}
grad_op
->
attrs_
[
"input_format"
]
=
in_format
;
grad_op
->
attrs_
[
"output_format"
]
=
out_format
;
grad_op
->
in_out_idxs_
.
reset
(
grad_varmap
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_builder.h
0 → 100644
浏览文件 @
d263ccef
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
framework
{
class
OpRegistry
;
enum
InOutType
{
IN
,
OUT
};
struct
OpInOutArg
{
OpInOutArg
(
const
std
::
string
&
proto_name
,
const
InOutType
&
type
,
bool
needed_in_grad
,
size_t
begin_idx
,
size_t
end_idx
)
:
proto_name_
(
proto_name
),
type_
(
type
),
needed_in_grad_
(
needed_in_grad
),
begin_idx_
(
begin_idx
),
end_idx_
(
end_idx
)
{}
std
::
string
proto_name_
;
InOutType
type_
;
bool
needed_in_grad_
;
size_t
begin_idx_
;
size_t
end_idx_
;
};
class
GradOpBuilder
{
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
public:
GradOpBuilder
(
const
OperatorBase
*
op
)
:
op_
(
op
)
{}
OperatorBase
*
Build
();
private:
OpInOutArg
*
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
);
void
BuildOpInOutArgList
();
void
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
;
void
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
;
const
OperatorBase
*
op_
;
std
::
vector
<
std
::
shared_ptr
<
OpInOutArg
>>
arg_list_
;
};
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_builder_test.cc
0 → 100644
浏览文件 @
d263ccef
#include "paddle/framework/grad_op_builder.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP
(
add_two
);
namespace
paddle
{
namespace
framework
{
TEST
(
GradOpBuilder
,
AddTwo
)
{
std
::
shared_ptr
<
OperatorBase
>
add_op
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
std
::
shared_ptr
<
OperatorBase
>
grad_add_op
=
OpRegistry
::
CreateGradOp
(
add_op
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
outputs_
.
size
()),
2
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"X"
),
"x"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Y"
),
"y"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out"
),
"out"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out@GRAD"
),
"out@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"X@GRAD"
),
"x@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"Y@GRAD"
),
"y@GRAD"
);
}
}
// namespace framework
}
// namespace paddle
\ No newline at end of file
paddle/framework/net.cc
浏览文件 @
d263ccef
...
...
@@ -15,14 +15,24 @@
*/
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
)
{
auto
grad_ops
=
std
::
make_shared
<
PlainNet
>
();
for
(
auto
&
op
:
ForwardOps
->
ops_
)
{
auto
op_grad
=
OpRegistry
::
CreateGradOp
(
op
);
grad_ops
->
AddOp
(
op_grad
);
}
grad_ops
->
CompleteAddOp
();
return
grad_ops
;
}
void
PlainNet
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
if
(
!
calc
)
return
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
std
::
unordered_set
<
std
::
string
>
temp_output
;
...
...
paddle/framework/net.h
浏览文件 @
d263ccef
...
...
@@ -39,7 +39,7 @@ namespace framework {
*/
class
Net
:
public
OperatorBase
{
public:
virtual
void
AddOp
(
const
OperatorPtr
&
op
)
=
0
;
virtual
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>
&
op
)
=
0
;
virtual
void
CompleteAddOp
(
bool
calc
)
=
0
;
};
...
...
@@ -57,7 +57,7 @@ class PlainNet : public Net {
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
InferShape
(
scope
);
}
...
...
@@ -70,7 +70,7 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
dev_ctx
);
...
...
@@ -80,7 +80,7 @@ class PlainNet : public Net {
/**
* @brief Add an operator by ptr
*/
void
AddOp
(
const
OperatorPtr
&
op
)
override
{
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>
&
op
)
override
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
ops_
.
push_back
(
op
);
}
...
...
@@ -89,7 +89,7 @@ class PlainNet : public Net {
std
::
string
DebugString
()
const
override
;
std
::
vector
<
OperatorPtr
>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>
>
ops_
;
private:
bool
add_op_done_
{
false
};
...
...
@@ -100,5 +100,7 @@ class PlainNet : public Net {
}
};
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
);
}
// namespace framework
}
// namespace paddle
paddle/framework/net_op_test.cc
浏览文件 @
d263ccef
...
...
@@ -3,17 +3,24 @@
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
namespace
pd
=
paddle
::
framework
;
USE_OP
(
add_two
);
USE_OP
(
mul
);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
namespace
paddle
{
namespace
framework
{
static
int
infer_shape_cnt
=
0
;
static
int
run_cnt
=
0
;
class
TestOp
:
public
pd
::
OperatorBase
{
class
TestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
paddle
::
framework
::
ScopePtr
&
scope
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
Run
(
const
paddle
::
framework
::
ScopePtr
&
scope
,
void
Run
(
const
std
::
shared_ptr
<
framework
::
Scope
>
&
scope
,
const
paddle
::
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
++
run_cnt
;
}
...
...
@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
TEST
(
OpKernel
,
all
)
{
auto
net
=
std
::
make_shared
<
paddle
::
framework
::
PlainNet
>
();
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
...
...
@@ -55,13 +62,37 @@ TEST(OpKernel, all) {
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
auto
scope
=
std
::
make_shared
<
pd
::
Scope
>
();
p
addle
::
p
latform
::
CPUDeviceContext
dev_ctx
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
platform
::
CPUDeviceContext
dev_ctx
;
net
->
InferShape
(
scope
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
std
::
runtime_error
);
}
TEST
(
AddBackwardOp
,
TestGradOp
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
""
},
{}));
auto
grad_ops
=
AddBackwardOp
(
net
);
for
(
auto
&
op
:
grad_ops
->
ops_
)
{
op
->
DebugString
();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
}
// namespace framework
}
// namespace paddle
paddle/framework/net_test.cc
已删除
100644 → 0
浏览文件 @
6f80b5f1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
class
FakeFC
:
public
Operator
{}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_proto.proto
浏览文件 @
d263ccef
...
...
@@ -84,6 +84,11 @@ message VarProto {
// "temporary_index": [1]
// }
optional
bool
temporary
=
4
[
default
=
false
];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional
bool
ignore_gradient
=
6
;
}
// Op protocol message for 3rd-party language binding.
...
...
@@ -105,4 +110,5 @@ message OpProto {
// The type of that Op.
required
string
type
=
5
;
}
paddle/framework/op_registry.h
浏览文件 @
d263ccef
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
...
...
@@ -6,9 +20,9 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
multiple
=
false
)
{
bool
multiple
=
false
,
bool
ignore_gradient
=
false
)
{
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
*
input
->
mutable_name
()
=
name
;
*
input
->
mutable_comment
()
=
comment
;
input
->
set_ignore_gradient
(
ignore_gradient
);
input
->
set_multiple
(
multiple
);
if
(
multiple
)
{
SetHasMultipleInput
();
}
}
void
AddInputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
AddInput
(
name
,
comment
,
true
);
void
AddInputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
ignore_gradient
=
false
)
{
AddInput
(
name
,
comment
,
true
,
ignore_gradient
);
}
void
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
temporary
=
false
,
bool
multiple
=
false
)
{
bool
temporary
=
false
,
bool
multiple
=
false
,
bool
ignore_gradient
=
false
)
{
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
*
output
->
mutable_name
()
=
name
;
*
output
->
mutable_comment
()
=
comment
;
output
->
set_ignore_gradient
(
ignore_gradient
);
output
->
set_multiple
(
multiple
);
if
(
multiple
)
{
SetHasMultipleOutput
();
...
...
@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
}
void
AddOutputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
temporary
=
false
)
{
AddOutput
(
name
,
comment
,
temporary
,
true
);
bool
temporary
=
false
,
bool
ignore_gradient
=
false
)
{
AddOutput
(
name
,
comment
,
temporary
,
true
,
ignore_gradient
);
}
template
<
typename
T
>
...
...
@@ -204,9 +222,9 @@ class OpRegistry {
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos
()[
op_type
];
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
...
...
@@ -227,18 +245,26 @@ class OpRegistry {
}
}
static
OperatorPtr
CreateOp
(
const
std
::
string
&
type
,
const
VarNameList
&
inputs
,
const
VarNameList
&
outputs
,
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
creators
().
end
(),
"Operator %s cannot be found"
,
type
);
template
<
typename
GradOpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
op_creators
()[
grad_op_type
]
=
[]
{
return
new
GradOpType
;
};
grad_ops
()[
op_type
]
=
grad_op_type
;
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameList
&
inputs
,
const
VarNameList
&
outputs
,
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
op_creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
"Operator %s cannot be found."
,
type
);
auto
op
=
op_create_it
->
second
();
op
->
type_
=
type
;
op
->
inputs_
=
inputs
;
op
->
outputs_
=
outputs
;
op
->
attrs_
=
attrs
;
op_checkers
().
at
(
type
).
Check
(
op
->
attrs_
);
...
...
@@ -252,10 +278,10 @@ class OpRegistry {
}
op
->
Init
();
return
OperatorPtr
(
op
);
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
}
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
vector
<
std
::
string
>
inputs
;
inputs
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
...
...
@@ -274,18 +300,41 @@ class OpRegistry {
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateGradOp
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
GradOpBuilder
builder
(
op
.
get
());
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
builder
.
Build
());
grad_op
->
Init
();
return
grad_op
;
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
};
private:
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_ops
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_ops_
;
return
grad_ops_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
return
op_creators_
;
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
};
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
...
...
@@ -296,16 +345,6 @@ class OpRegistry {
}
}
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
};
};
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
...
@@ -316,6 +355,14 @@ class OpRegisterHelper {
}
};
template
<
typename
GradOpType
>
class
GradOpRegisterHelper
{
public:
GradOpRegisterHelper
(
const
char
*
op_type
,
const
char
*
grad_op_type
)
{
OpRegistry
::
RegisterGradOp
<
GradOpType
>
(
op_type
,
grad_op_type
);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
...
...
@@ -335,6 +382,20 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
#__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/**
* Macro to Register OperatorKernel.
*/
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
d263ccef
...
...
@@ -7,9 +7,9 @@ namespace paddle {
namespace
framework
{
class
CosineOp
:
public
OperatorBase
{
public:
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
@@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
Run
(
const
ScopePtr
&
scope
,
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
};
...
...
@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
scale
);
paddle
::
framework
::
OperatorPtr
op
=
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
...
...
@@ -89,8 +89,7 @@ TEST(OpRegistry, IllegalAttr) {
bool
caught
=
false
;
try
{
paddle
::
framework
::
OperatorPtr
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
caught
=
true
;
std
::
string
msg
=
"larger_than check fail"
;
...
...
@@ -110,7 +109,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
paddle
::
framework
::
OperatorPtr
op
=
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
...
...
@@ -136,8 +135,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool
caught
=
false
;
try
{
paddle
::
framework
::
OperatorPtr
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
caught
=
true
;
std
::
string
msg
=
"Attribute 'test_attr' is required!"
;
...
...
@@ -155,8 +153,7 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_i
(
3
);
caught
=
false
;
try
{
paddle
::
framework
::
OperatorPtr
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
caught
=
true
;
std
::
string
msg
=
"'test_attr' must be even!"
;
...
...
@@ -174,8 +171,7 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
4
);
SetInputFormat
(
&
op_desc
);
paddle
::
framework
::
OperatorPtr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
op
->
Run
(
scope
,
dev_ctx
);
...
...
paddle/framework/operator.h
浏览文件 @
d263ccef
...
...
@@ -47,7 +47,6 @@ struct EigenDeviceConverter<platform::GPUPlace> {
#endif
class
OperatorBase
;
using
OperatorPtr
=
std
::
shared_ptr
<
OperatorBase
>
;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
...
...
@@ -63,6 +62,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static
std
::
string
GRAD_VAR_SUFFIX
()
{
return
"@GRAD"
;
}
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
...
...
@@ -80,10 +84,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual
void
InferShape
(
const
ScopePtr
&
scope
)
const
=
0
;
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
=
0
;
/// Net will call this function to Run an op.
virtual
void
Run
(
const
ScopePtr
&
scope
,
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
...
...
@@ -208,7 +212,7 @@ class OperatorWithKernel : public OperatorBase {
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
KernelContext
(
this
,
scope
,
dev_ctx
));
...
...
paddle/framework/operator_test.cc
浏览文件 @
d263ccef
...
...
@@ -24,8 +24,8 @@ static int op_run_num = 0;
class
OpWithoutKernelTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
Run
(
const
ScopePtr
&
scope
,
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
op_run_num
++
;
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
...
...
@@ -70,8 +70,7 @@ TEST(OperatorBase, all) {
paddle
::
platform
::
CPUDeviceContext
device_context
;
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
framework
::
OperatorPtr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
scope
->
CreateVariable
(
"OUT1"
);
ASSERT_EQ
(
paddle
::
framework
::
op_run_num
,
0
);
op
->
Run
(
scope
,
device_context
);
...
...
@@ -189,8 +188,7 @@ TEST(OpKernel, all) {
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
framework
::
OperatorPtr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
0
);
op
->
Run
(
scope
,
cpu_device_context
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
...
...
@@ -236,6 +234,6 @@ TEST(OpKernel, multi_inputs) {
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorPtr
op
(
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
)
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_device_context
);
}
paddle/framework/scope.h
浏览文件 @
d263ccef
...
...
@@ -24,7 +24,6 @@ namespace paddle {
namespace
framework
{
class
Scope
;
using
ScopePtr
=
std
::
shared_ptr
<
Scope
>
;
/**
* @brief Scope that manage all variables.
...
...
@@ -44,7 +43,7 @@ class Scope {
/**
* @brief Initialize a Scope with parent.
*/
explicit
Scope
(
const
ScopePtr
&
parent
)
:
parent_
(
parent
)
{}
explicit
Scope
(
const
std
::
shared_ptr
<
Scope
>
&
parent
)
:
parent_
(
parent
)
{}
/**
* @brief Create Variable
...
...
@@ -91,7 +90,7 @@ class Scope {
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars_
;
ScopePtr
parent_
{
nullptr
};
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
};
}
// namespace framework
...
...
paddle/operators/add_op.cc
浏览文件 @
d263ccef
...
...
@@ -49,9 +49,22 @@ The equation is: Out = X + Y
)DOC"
);
}
};
class
AddOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"AddOpGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_GRADIENT_OP
(
add_two
,
add_two_grad
,
paddle
::
operators
::
AddOpGrad
);
REGISTER_OP_CPU_KERNEL
(
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/add_op_test.cc
浏览文件 @
d263ccef
...
...
@@ -16,8 +16,13 @@ limitations under the License. */
#define private public
#include <paddle/framework/op_registry.h>
USE_OP
(
add_two
);
// USE_OP(add_two_grad);
TEST
(
AddOp
,
GetOpProto
)
{
auto
&
protos
=
paddle
::
framework
::
OpRegistry
::
protos
();
auto
it
=
protos
.
find
(
"add_two"
);
ASSERT_NE
(
it
,
protos
.
end
());
}
\ No newline at end of file
auto
&
op_creators
=
paddle
::
framework
::
OpRegistry
::
op_creators
();
auto
it1
=
op_creators
.
find
(
"add_two_grad"
);
ASSERT_NE
(
it1
,
op_creators
.
end
());
}
paddle/operators/mul_op.cc
浏览文件 @
d263ccef
...
...
@@ -52,9 +52,22 @@ The equation is: Out = X * Y
}
};
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
mul
,
paddle
::
operators
::
MulOp
,
paddle
::
operators
::
MulOpMaker
);
REGISTER_GRADIENT_OP
(
mul
,
mul_grad
,
paddle
::
operators
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
mul
,
paddle
::
operators
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/sigmoid_op.cc
浏览文件 @
d263ccef
...
...
@@ -39,12 +39,25 @@ public:
}
};
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOp
,
paddle
::
operators
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
paddle
::
operators
::
SigmoidOpGrad
);
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
paddle
::
operators
::
SigmoidKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_op.cc
浏览文件 @
d263ccef
...
...
@@ -42,11 +42,23 @@ public:
}
};
class
SoftmaxOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SoftmaxOpGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_GRADIENT_OP
(
softmax
,
softmax_grad
,
paddle
::
operators
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/pybind/pybind.cc
浏览文件 @
d263ccef
...
...
@@ -126,9 +126,10 @@ All parameter, weight, gradient are variables in Paddle.
return
new
paddle
::
platform
::
CPUDeviceContext
();
});
py
::
class_
<
pd
::
OperatorBase
,
pd
::
OperatorPtr
>
operator_base
(
m
,
"Operator"
);
py
::
class_
<
pd
::
OperatorBase
,
std
::
shared_ptr
<
pd
::
OperatorBase
>>
operator_base
(
m
,
"Operator"
);
operator_base
.
def_static
(
"create"
,
[](
py
::
bytes
protobin
)
->
pd
::
OperatorPtr
{
operator_base
.
def_static
(
"create"
,
[](
py
::
bytes
protobin
)
{
pd
::
OpDesc
desc
;
PADDLE_ENFORCE
(
desc
.
ParsePartialFromString
(
protobin
),
"Cannot parse user input to OpDesc"
);
...
...
python/paddle/trainer_config_helpers/networks.py
浏览文件 @
d263ccef
...
...
@@ -340,24 +340,40 @@ def img_conv_group(input,
conv_with_batchnorm
=
False
,
conv_batchnorm_drop_rate
=
0
,
pool_stride
=
1
,
pool_type
=
None
):
pool_type
=
None
,
param_attr
=
None
):
"""
Image Convolution Group, Used for vgg net.
TODO(yuyang18): Complete docs
:param conv_batchnorm_drop_rate:
:param input:
:param conv_num_filter:
:param pool_size:
:param num_channels:
:param conv_padding:
:param conv_filter_size:
:param conv_act:
:param conv_with_batchnorm:
:param pool_stride:
:param pool_type:
:return:
:param conv_batchnorm_drop_rate: if conv_with_batchnorm[i] is true,
conv_batchnorm_drop_rate[i] represents the drop rate of each batch norm.
:type conv_batchnorm_drop_rate: list
:param input: layer's input.
:type input: LayerOutput
:param conv_num_filter: output channels num.
:type conv_num_filter: int
:param pool_size: pooling filter size.
:type pool_size: int
:param num_channels: input channels num.
:type num_channels: int
:param conv_padding: convolution padding size.
:type conv_padding: int
:param conv_filter_size: convolution filter size.
:type conv_filter_size: int
:param conv_act: activation funciton after convolution.
:type conv_act: BaseActivation
:param conv_with_batchnorm: conv_with_batchnorm[i] represents
if there is a batch normalization after each convolution.
:type conv_with_batchnorm: list
:param pool_stride: pooling stride size.
:type pool_stride: int
:param pool_type: pooling type.
:type pool_type: BasePoolingType
:param param_attr: Convolution param attribute.
None means default attribute.
:type param_attr: ParameterAttribute
:return: Layer's output
:type: LayerOutput
"""
tmp
=
input
...
...
@@ -397,6 +413,7 @@ def img_conv_group(input,
padding
=
conv_padding
[
i
],
filter_size
=
conv_filter_size
[
i
],
num_filters
=
conv_num_filter
[
i
],
param_attr
=
param_attr
,
**
extra_kwargs
)
# logger.debug("tmp.num_filters = %d" % tmp.num_filters)
...
...
python/paddle/v2/__init__.py
浏览文件 @
d263ccef
...
...
@@ -33,6 +33,7 @@ import networks
import
minibatch
import
plot
import
image
import
model
__all__
=
[
'optimizer'
,
...
...
@@ -54,6 +55,7 @@ __all__ = [
'evaluator'
,
'image'
,
'master'
,
'model'
,
]
...
...
python/paddle/v2/master/client.py
浏览文件 @
d263ccef
...
...
@@ -10,11 +10,31 @@ class client(object):
client is a client to the master server.
"""
def
__init__
(
self
,
etcd_endpoints
,
timeout
,
buf_size
):
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
,
def
__init__
(
self
,
etcd_endpoints
,
timeout
_sec
,
buf_size
=
0
):
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
_sec
,
buf_size
)
def
close
(
self
):
def
request_save_model
(
self
,
trainer_id
,
block_ms
):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return
lib
.
paddle_request_save_model
(
self
.
c
,
trainer_id
,
block_ms
)
def
release
(
self
):
lib
.
paddle_release_master_client
(
self
.
c
)
self
.
c
=
None
...
...
@@ -27,10 +47,13 @@ class client(object):
holder
[
idx
]
=
c_ptr
lib
.
paddle_set_dataset
(
self
.
c
,
holder
,
len
(
paths
))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def
next_record
(
self
):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p
=
ctypes
.
c_char_p
()
ret
=
ctypes
.
pointer
(
p
)
size
=
lib
.
paddle_next_record
(
self
.
c
,
ret
)
...
...
python/paddle/v2/model.py
0 → 100644
浏览文件 @
d263ccef
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
errno
import
uuid
import
paddle.v2.master
__all__
=
[
"save_model"
,
"load_model"
]
trainer_id
=
str
(
uuid
.
uuid4
())
def
mkdir_p
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
exc
:
if
exc
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
pass
else
:
raise
def
save_model
(
parameters
,
path
):
need_request
=
"KUBERNETES_SERVICE_HOST"
in
os
.
environ
.
keys
()
if
need_request
:
# TODO(helin): figure out how MPI trains, since MPI only save
# model when trainer_id == "0", we can consolidate the logic
# here.
# TODO(helin): change this environment variable name from
# MASTER_IP to ETCD_IP
etcd_name
=
"MASTER_IP"
if
etcd_name
not
in
os
.
environ
.
keys
():
raise
Exception
(
'not find '
+
etcd_name
+
' in environment variable.'
)
etcd_ip
=
os
.
environ
.
get
(
etcd_name
)
client
=
master
.
client
(
"http://"
+
etcd_ip
+
":2379"
,
5
,
0
)
r
=
client
.
request_save_model
(
trainer_id
,
5000
)
if
r
==
0
:
# do not need to save
return
elif
r
<
0
:
# error
return
else
:
# save model
path
=
os
.
path
.
join
(
path
,
trainer_id
)
path
=
os
.
path
.
join
(
path
,
"model.tar"
)
mkdir_p
(
path
)
with
open
(
path
,
'wb'
)
as
f
:
parameters
.
to_tar
(
f
)
def
load_model
(
parameters
,
path
):
with
open
(
path
,
'rb'
)
as
f
:
parameters
.
from_tar
(
f
)
python/paddle/v2/reader/creator.py
浏览文件 @
d263ccef
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Creator package contains some simple reader creator, which could
be used in user
program.
Creator package contains some simple reader creator, which could
be used in user
program.
"""
__all__
=
[
'np_array'
,
'text_file'
,
"recordio"
]
...
...
@@ -59,7 +59,7 @@ def text_file(path):
def
recordio_local
(
paths
,
buf_size
=
100
):
"""
Creates a data reader from given RecordIO file paths separated by ",",
Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported.
:path: path of recordio files.
:returns: data reader of recordio files.
...
...
@@ -83,7 +83,7 @@ def recordio_local(paths, buf_size=100):
def
recordio
(
paths
,
buf_size
=
100
):
"""
Creates a data reader that outputs record one one by one
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
:path: path of recordio files.
:returns: data reader of recordio files.
...
...
@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
host_name
=
"MASTER_SERVICE_HOST"
if
host_name
not
in
os
.
environ
.
keys
():
raise
Exception
(
'not find '
+
host_name
+
' in environ.'
)
raise
Exception
(
'not find '
+
host_name
+
' in environ
ment variable
.'
)
addr
=
os
.
environ
(
host
)
...
...
@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
break
yield
r
c
.
clo
se
()
c
.
relea
se
()
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录