Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
46d766e2
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
46d766e2
编写于
7月 28, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/unittest_for_inputs' into feature/backward
上级
7088654a
8e7c3253
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
642 addition
and
382 deletion
+642
-382
.pre-commit-config.yaml
.pre-commit-config.yaml
+7
-5
go/master/c/client.go
go/master/c/client.go
+13
-4
go/master/client.go
go/master/client.go
+39
-31
go/master/client_internal_test.go
go/master/client_internal_test.go
+32
-28
go/master/client_test.go
go/master/client_test.go
+57
-26
go/master/service.go
go/master/service.go
+60
-38
go/master/service_internal_test.go
go/master/service_internal_test.go
+2
-1
go/pserver/client/c/test/test_train.py
go/pserver/client/c/test/test_train.py
+13
-7
paddle/api/Evaluator.cpp
paddle/api/Evaluator.cpp
+1
-1
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/detail/tensor-inl.h
paddle/framework/detail/tensor-inl.h
+160
-0
paddle/framework/net.h
paddle/framework/net.h
+0
-4
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+4
-3
paddle/framework/operator.cc
paddle/framework/operator.cc
+13
-7
paddle/framework/operator.h
paddle/framework/operator.h
+3
-1
paddle/framework/tensor.cc
paddle/framework/tensor.cc
+1
-1
paddle/framework/tensor.h
paddle/framework/tensor.h
+88
-102
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+70
-20
paddle/gserver/activations/ActivationFunction.cpp
paddle/gserver/activations/ActivationFunction.cpp
+2
-2
paddle/memory/detail/buddy_allocator.cc
paddle/memory/detail/buddy_allocator.cc
+27
-29
paddle/memory/memory.h
paddle/memory/memory.h
+3
-3
paddle/platform/device_context.h
paddle/platform/device_context.h
+1
-1
paddle/trainer/NewRemoteParameterUpdater.cpp
paddle/trainer/NewRemoteParameterUpdater.cpp
+5
-1
paddle/utils/Error.h
paddle/utils/Error.h
+4
-9
paddle/utils/tests/test_Error.cpp
paddle/utils/tests/test_Error.cpp
+4
-4
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+1
-1
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+11
-20
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+13
-31
python/paddle/v2/inference.py
python/paddle/v2/inference.py
+7
-0
python/paddle/v2/master/client.py
python/paddle/v2/master/client.py
+0
-1
未找到文件。
.pre-commit-config.yaml
浏览文件 @
46d766e2
...
...
@@ -22,9 +22,11 @@
hooks
:
-
id
:
clang-formater
-
repo
:
https://github.com/PaddlePaddle/pre-commit-golang
sha
:
16398aeccf263adaf53b2495eed0406347d76281
sha
:
8337620115c25ff8333f1b1a493bd031049bd7c0
hooks
:
-
id
:
go-fmt
types
:
[
go
]
-
id
:
gometalinter
types
:
[
go
]
-
id
:
go-fmt
types
:
-
go
-
id
:
gometalinter
types
:
-
go
go/master/c/client.go
浏览文件 @
46d766e2
...
...
@@ -18,7 +18,6 @@ package main
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
...
...
@@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove
(
client
)
}
//export paddle_start_get_records
func
paddle_start_get_records
(
client
C
.
paddle_master_client
,
pass
C
.
int
)
{
c
:=
get
(
client
)
c
.
StartGetRecords
(
int
(
pass
))
}
//export paddle_set_dataset
func
paddle_set_dataset
(
client
C
.
paddle_master_client
,
path
**
C
.
char
,
size
C
.
int
)
C
.
int
{
c
:=
get
(
client
)
...
...
@@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
// returns number of bytes of the records if success, -1 if failed
, -2 if pass end
.
//
//export paddle_next_record
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
c
:=
get
(
client
)
r
,
err
:=
c
.
NextRecord
()
if
err
!=
nil
{
// Error
// TODO: return the type of error?
// NOTE: use errors to indicate pass ends
if
err
.
Error
()
==
master
.
ErrAllTaskFailed
.
Error
()
||
err
.
Error
()
==
master
.
ErrNoMoreAvailable
.
Error
()
||
err
.
Error
()
==
master
.
ErrPassBefore
.
Error
()
{
return
-
2
}
*
record
=
(
*
C
.
uchar
)(
nil
)
return
-
1
}
...
...
go/master/client.go
浏览文件 @
46d766e2
...
...
@@ -16,7 +16,6 @@ package master
import
(
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
...
...
@@ -27,9 +26,9 @@ import (
// Client is the client of the master server.
type
Client
struct
{
conn
*
connection
.
Conn
ch
chan
record
initChOnce
sync
.
Once
conn
*
connection
.
Conn
ch
chan
record
bufSize
int
}
type
record
struct
{
...
...
@@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if
bufSize
<=
0
{
return
nil
}
c
.
initChOnce
.
Do
(
func
()
{
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
getRecords
()
})
c
.
bufSize
=
bufSize
return
nil
}
}
...
...
@@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if
err
!=
nil
{
return
nil
,
err
}
}
c
.
ch
=
make
(
chan
record
,
c
.
bufSize
)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time
.
Sleep
(
time
.
Second
)
return
c
,
nil
}
func
(
c
*
Client
)
getRecords
()
{
// StartGetRecords must be called at beginning of each pass
func
(
c
*
Client
)
StartGetRecords
(
passID
int
)
{
go
c
.
getRecords
(
passID
)
}
func
(
c
*
Client
)
getRecords
(
passID
int
)
{
for
{
t
,
err
:=
c
.
getTask
()
t
,
err
:=
c
.
getTask
(
passID
)
if
err
!=
nil
{
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
time
.
Sleep
(
3
*
time
.
Second
)
continue
if
err
.
Error
()
==
ErrPassBefore
.
Error
()
||
err
.
Error
()
==
ErrNoMoreAvailable
.
Error
()
||
err
.
Error
()
==
ErrAllTaskFailed
.
Error
()
{
c
.
ch
<-
record
{
nil
,
err
}
break
}
if
err
.
Error
()
==
ErrPassAfter
.
Error
()
{
// wait util last pass finishes
time
.
Sleep
(
time
.
Second
*
3
)
continue
}
log
.
Errorf
(
"getTask error: %s"
,
err
)
}
for
_
,
chunk
:=
range
t
.
Chunks
{
f
,
e
rr
:=
os
.
Open
(
chunk
.
Path
)
if
e
rr
!=
nil
{
log
.
Errorln
(
e
rr
)
f
,
e
:=
os
.
Open
(
chunk
.
Path
)
if
e
!=
nil
{
log
.
Errorln
(
e
)
continue
}
...
...
@@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}
// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
// After all tasks are done, another call of SetDataset will start another pass.
func
(
c
*
Client
)
SetDataset
(
globPaths
[]
string
)
error
{
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
err
:=
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
return
err
}
// getTask gets a new task from the master server.
func
(
c
*
Client
)
getTask
()
(
Task
,
error
)
{
func
(
c
*
Client
)
getTask
(
passID
int
)
(
Task
,
error
)
{
var
t
Task
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
passID
,
&
t
)
return
t
,
err
}
...
...
@@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
c
.
initChOnce
.
Do
(
func
()
{
// initialize with in case WithBuffer is not used.
c
.
ch
=
make
(
chan
record
,
0
)
go
c
.
getRecords
()
})
r
:=
<-
c
.
ch
return
r
.
r
,
r
.
err
}
...
...
go/master/client_internal_test.go
浏览文件 @
46d766e2
...
...
@@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic
(
err
)
}
go
func
(
l
net
.
Listener
)
{
s
,
e
rr
:=
NewService
(
&
InMemStore
{},
chunkPerTask
,
time
.
Second
,
1
)
if
e
rr
!=
nil
{
panic
(
e
rr
)
s
,
sE
rr
:=
NewService
(
&
InMemStore
{},
chunkPerTask
,
time
.
Second
,
1
)
if
sE
rr
!=
nil
{
panic
(
sE
rr
)
}
server
:=
rpc
.
NewServer
()
e
rr
=
server
.
Register
(
s
)
if
e
rr
!=
nil
{
panic
(
e
rr
)
sE
rr
=
server
.
Register
(
s
)
if
sE
rr
!=
nil
{
panic
(
sE
rr
)
}
mux
:=
http
.
NewServeMux
()
mux
.
Handle
(
rpc
.
DefaultRPCPath
,
server
)
e
rr
=
http
.
Serve
(
l
,
mux
)
if
e
rr
!=
nil
{
panic
(
e
rr
)
sE
rr
=
http
.
Serve
(
l
,
mux
)
if
sE
rr
!=
nil
{
panic
(
sE
rr
)
}
}(
l
)
...
...
@@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch
:=
make
(
chan
string
,
1
)
ch
<-
addr
go
c
.
monitorMaster
(
ch
)
err
=
c
.
SetDataset
([]
string
{
path
})
if
err
!=
nil
{
panic
(
err
)
...
...
@@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
Task
for
idx
:=
0
;
idx
<
totalTask
;
idx
++
{
task
,
err
:=
c
.
getTask
(
)
if
err
!=
nil
{
t
.
Fatalf
(
"
Error: %v, pass: %d
\n
"
,
e
rr
,
i
)
task
,
cErr
:=
c
.
getTask
(
i
)
if
cErr
!=
nil
&&
cErr
.
Error
()
!=
ErrNoMoreAvailable
.
Error
()
&&
cErr
.
Error
()
!=
ErrPassAfter
.
Error
()
{
t
.
Fatalf
(
"
error: %v, pass: %d
\n
"
,
cE
rr
,
i
)
}
tasks
=
append
(
tasks
,
task
)
}
_
,
err
=
c
.
getTask
()
if
err
==
nil
{
// getting task before task finishes should return error
_
,
cErr
:=
c
.
getTask
(
i
)
if
cErr
==
nil
{
t
.
Fatalf
(
"Should get error, pass: %d
\n
"
,
i
)
}
e
rr
=
c
.
taskFinished
(
tasks
[
0
]
.
Meta
.
ID
)
if
e
rr
!=
nil
{
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
e
rr
,
i
)
cE
rr
=
c
.
taskFinished
(
tasks
[
0
]
.
Meta
.
ID
)
if
cE
rr
!=
nil
{
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
cE
rr
,
i
)
}
err
=
c
.
taskFailed
(
tasks
[
0
]
.
Meta
)
if
err
!=
nil
{
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
err
,
i
)
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr
=
c
.
taskFailed
(
tasks
[
0
]
.
Meta
)
if
cErr
!=
nil
{
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
cErr
,
i
)
}
tasks
=
tasks
[
1
:
]
task
,
err
:=
c
.
getTask
(
)
if
err
!=
nil
{
t
.
Fatal
(
e
rr
)
_
,
cErr
=
c
.
getTask
(
i
)
if
cErr
!=
nil
&&
cErr
.
Error
()
!=
ErrNoMoreAvailable
.
Error
()
&&
cErr
.
Error
()
!=
ErrPassAfter
.
Error
()
{
t
.
Fatal
f
(
"Should be ErrNoMoreAvailable or ErrPassAfter: %s"
,
cE
rr
)
}
tasks
=
append
(
tasks
,
task
)
for
_
,
task
:=
range
tasks
{
e
rr
=
c
.
taskFinished
(
task
.
Meta
.
ID
)
if
e
rr
!=
nil
{
t
.
Fatal
f
(
"Error: %v, pass: %d
\n
"
,
err
,
i
)
cE
rr
=
c
.
taskFinished
(
task
.
Meta
.
ID
)
if
cE
rr
!=
nil
{
t
.
Fatal
(
cErr
)
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
// init pass data
c
.
StartGetRecords
(
i
)
checkOnePass
(
i
)
}
}
go/master/client_test.go
浏览文件 @
46d766e2
...
...
@@ -20,8 +20,10 @@ import (
"net/http"
"net/rpc"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
...
...
@@ -29,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio"
)
// tool function for testing output goroutine ids
func
goid
()
int
{
var
buf
[
64
]
byte
n
:=
runtime
.
Stack
(
buf
[
:
],
false
)
idField
:=
strings
.
Fields
(
strings
.
TrimPrefix
(
string
(
buf
[
:
n
]),
"goroutine "
))[
0
]
id
,
err
:=
strconv
.
Atoi
(
idField
)
if
err
!=
nil
{
panic
(
fmt
.
Sprintf
(
"cannot get goroutine id: %v"
,
err
))
}
return
id
}
func
TestNextRecord
(
t
*
testing
.
T
)
{
const
(
path
=
"/tmp/master_client_TestFull"
...
...
@@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
}
go
func
(
l
net
.
Listener
)
{
s
,
err
:=
master
.
NewService
(
&
master
.
InMemStore
{},
1
0
,
time
.
Second
,
1
)
s
,
err
:=
master
.
NewService
(
&
master
.
InMemStore
{},
1
,
time
.
Second
*
60
,
1
)
if
err
!=
nil
{
panic
(
err
)
}
...
...
@@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
}
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
:=
recordio
.
NewWriter
(
f
,
1
,
-
1
)
for
i
:=
0
;
i
<
total
;
i
++
{
_
,
err
=
w
.
Write
([]
byte
{
byte
(
i
)})
if
err
!=
nil
{
...
...
@@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
}
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
10
))
if
err
!=
nil
{
panic
(
err
)
}
err
=
c
.
SetDataset
([]
string
{
path
})
if
err
!=
nil
{
panic
(
err
)
}
for
pass
:=
0
;
pass
<
50
;
pass
++
{
received
:=
make
(
map
[
byte
]
bool
)
for
i
:=
0
;
i
<
total
;
i
++
{
r
,
err
:=
c
.
NextRecord
()
if
err
!=
nil
{
t
.
Fatal
(
pass
,
i
,
"Read error:"
,
err
)
// start several client to test task fetching
var
wg
sync
.
WaitGroup
for
i
:=
0
;
i
<
4
;
i
++
{
wg
.
Add
(
1
)
// test for multiple concurrent clients
go
func
()
{
defer
wg
.
Done
()
// each go-routine needs a single client connection instance
c
,
e
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
1
))
if
e
!=
nil
{
t
.
Fatal
(
e
)
}
if
len
(
r
)
!=
1
{
t
.
Fatal
(
pass
,
i
,
"Length should be 1."
,
r
)
e
=
c
.
SetDataset
([]
string
{
path
})
if
e
!=
nil
{
panic
(
e
)
}
if
received
[
r
[
0
]]
{
t
.
Fatal
(
pass
,
i
,
"Received duplicate."
,
received
,
r
)
// test for n passes
for
pass
:=
0
;
pass
<
10
;
pass
++
{
c
.
StartGetRecords
(
pass
)
received
:=
make
(
map
[
byte
]
bool
)
taskid
:=
0
for
{
r
,
e
:=
c
.
NextRecord
()
if
e
!=
nil
{
// ErrorPassAfter will wait, else break for next pass
if
e
.
Error
()
==
master
.
ErrPassBefore
.
Error
()
||
e
.
Error
()
==
master
.
ErrNoMoreAvailable
.
Error
()
{
break
}
t
.
Fatal
(
pass
,
taskid
,
"Read error:"
,
e
)
}
if
len
(
r
)
!=
1
{
t
.
Fatal
(
pass
,
taskid
,
"Length should be 1."
,
r
)
}
if
received
[
r
[
0
]]
{
t
.
Fatal
(
pass
,
taskid
,
"Received duplicate."
,
received
,
r
)
}
taskid
++
received
[
r
[
0
]]
=
true
}
}
received
[
r
[
0
]]
=
true
}
}()
}
wg
.
Wait
()
}
go/master/service.go
浏览文件 @
46d766e2
...
...
@@ -19,6 +19,7 @@ import (
"compress/gzip"
"encoding/gob"
"errors"
"math/rand"
"os"
"path/filepath"
"sync"
...
...
@@ -33,6 +34,18 @@ const (
dialTimeout
=
5
*
time
.
Second
)
// ErrAllTaskFailed occur when tasks are in done or failed state.
var
ErrAllTaskFailed
=
errors
.
New
(
"all task finished"
)
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var
ErrNoMoreAvailable
=
errors
.
New
(
"no more available task"
)
// ErrPassBefore client side pass number does not match with master counter.
var
ErrPassBefore
=
errors
.
New
(
"pass number smaller than master"
)
// ErrPassAfter client side pass number does not match with master counter.
var
ErrPassAfter
=
errors
.
New
(
"pass number larger than master"
)
// Store is the interface for save and load the master state.
type
Store
interface
{
Save
([]
byte
)
error
...
...
@@ -75,17 +88,26 @@ type Service struct {
chunksPerTask
int
timeoutDur
time
.
Duration
failureMax
int
ready
chan
struct
{}
store
Store
mu
sync
.
Mutex
initDone
bool
taskQueues
taskQueues
ready
chan
struct
{}
initDone
bool
mu
sync
.
Mutex
taskQueues
taskQueues
currPass
int
jobTasks
[]
taskEntry
savingTrainer
string
}
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
id
:=
0
// generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart
:=
rand
.
Int
()
counter
:=
0
timestamp
:=
time
.
Now
()
.
Nanosecond
()
id
:=
timestamp
+
randStart
+
counter
if
chunksPerTask
<=
0
{
chunksPerTask
=
1
}
...
...
@@ -95,7 +117,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for
i
,
c
:=
range
chunks
{
if
i
%
chunksPerTask
==
0
&&
len
(
cur
.
Task
.
Chunks
)
>
0
{
cur
.
Task
.
Meta
.
ID
=
id
id
++
counter
++
id
=
timestamp
+
randStart
+
counter
result
=
append
(
result
,
cur
)
cur
.
Task
.
Chunks
=
nil
}
...
...
@@ -266,19 +289,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return
err
}
s
.
taskQueues
.
Todo
=
partition
(
chunks
,
s
.
chunksPerTask
)
s
.
jobTasks
=
partition
(
chunks
,
s
.
chunksPerTask
)
s
.
taskQueues
.
Todo
=
s
.
jobTasks
err
=
s
.
snapshot
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
err
}
close
(
s
.
ready
)
s
.
initDone
=
true
return
nil
}
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func
(
s
*
Service
)
processFailedTask
(
t
taskEntry
,
epoch
int
)
{
if
t
.
Task
.
Meta
.
Epoch
!=
epoch
{
// new epoch, task launched after the
...
...
@@ -302,8 +327,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
return
}
log
.
Warningf
(
"Task %v failed %d times,
discard
."
,
t
.
Task
,
t
.
NumFailure
)
log
.
Warningf
(
"Task %v failed %d times,
re-dispatch
."
,
t
.
Task
,
t
.
NumFailure
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
t
)
return
}
func
(
s
*
Service
)
checkTimeoutFunc
(
taskID
int
,
epoch
int
)
func
()
{
...
...
@@ -331,37 +357,30 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
_
int
,
task
*
Task
)
error
{
// passID is the client side pass count
func
(
s
*
Service
)
GetTask
(
passID
int
,
task
*
Task
)
error
{
select
{
case
<-
s
.
ready
:
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
passID
<
s
.
currPass
{
return
ErrPassBefore
}
if
passID
>
s
.
currPass
{
// Client may get run to pass after master when one client faster than the
// other
return
ErrPassAfter
}
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
err
:=
errors
.
New
(
"all task failed"
)
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"All tasks failed."
)
return
err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err
:=
errors
.
New
(
"no more available task"
)
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"No more available task."
)
return
err
if
len
(
s
.
taskQueues
.
Done
)
==
0
&&
len
(
s
.
taskQueues
.
Pending
)
==
0
{
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"All tasks failed, may start next pass"
)
return
ErrAllTaskFailed
}
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Done
=
nil
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo task, but trainer is requesting task to do. Move all done task to todo."
)
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"No more available task."
)
return
ErrNoMoreAvailable
}
t
:=
s
.
taskQueues
.
Todo
[
0
]
...
...
@@ -381,7 +400,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
_
*
int
)
error
{
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
@@ -401,11 +420,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d finished."
,
taskID
)
if
len
(
s
.
taskQueues
.
Pending
)
==
0
&&
len
(
s
.
taskQueues
.
Todo
)
==
0
{
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo and pending task, start a new pass."
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Done
=
nil
if
len
(
s
.
taskQueues
.
Todo
)
==
0
&&
len
(
s
.
taskQueues
.
Pending
)
==
0
{
// increase master side pass count if all tasks finished
s
.
currPass
++
s
.
taskQueues
.
Todo
=
s
.
jobTasks
s
.
taskQueues
.
Done
=
[]
taskEntry
{}
// TODO(typhoonzero): deal with failed tasks
s
.
taskQueues
.
Failed
=
[]
taskEntry
{}
log
.
WithFields
(
s
.
logFields
())
.
Warningf
(
"all task finished, add new pass data, newpass: %d."
,
s
.
currPass
)
}
err
:=
s
.
snapshot
()
...
...
@@ -416,7 +438,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
}
// TaskFailed tells the service that a task is failed.
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
_
*
int
)
error
{
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
dummy
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
...
...
go/master/service_internal_test.go
浏览文件 @
46d766e2
...
...
@@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs
:=
make
([]
Chunk
,
100
)
ts
:=
partition
(
cs
,
20
)
for
i
:=
range
ts
{
if
ts
[
i
]
.
Task
.
Meta
.
ID
!=
i
{
// test auto increament ids
if
i
>
0
&&
ts
[
i
]
.
Task
.
Meta
.
ID
!=
ts
[
i
-
1
]
.
Task
.
Meta
.
ID
+
1
{
t
.
Error
(
ts
[
i
],
i
)
}
}
...
...
go/pserver/client/c/test/test_train.py
浏览文件 @
46d766e2
...
...
@@ -6,16 +6,19 @@ import cPickle as pickle
etcd_ip
=
os
.
getenv
(
"MASTER_IP"
,
"127.0.0.1"
)
etcd_endpoint
=
"http://"
+
etcd_ip
+
":2379"
print
"connecting to master, etcd endpoints: "
,
etcd_endpoint
master_client
=
master
.
client
(
etcd_endpoint
,
5
,
64
)
def
cloud_reader
():
print
"connecting to master, etcd endpoints: "
,
etcd_endpoint
master_client
=
master
.
client
(
etcd_endpoint
,
5
,
64
)
global
master_client
master_client
.
set_dataset
(
[
"/pfs/dlnel/public/dataset/uci_housing/uci_housing-*
-of-*"
]
)
[
"/pfs/dlnel/public/dataset/uci_housing/uci_housing-*
"
],
passes
=
30
)
while
1
:
r
,
e
=
master_client
.
next_record
()
if
not
r
:
if
e
!=
-
2
:
# other errors
print
"get record error:"
,
e
break
yield
pickle
.
loads
(
r
)
...
...
@@ -27,10 +30,12 @@ def main():
# network config
x
=
paddle
.
layer
.
data
(
name
=
'x'
,
type
=
paddle
.
data_type
.
dense_vector
(
13
))
y_predict
=
paddle
.
layer
.
fc
(
input
=
x
,
param_attr
=
paddle
.
attr
.
Param
(
name
=
'w'
),
param_attr
=
paddle
.
attr
.
Param
(
name
=
'w'
,
learning_rate
=
1e-3
),
size
=
1
,
act
=
paddle
.
activation
.
Linear
(),
bias_attr
=
paddle
.
attr
.
Param
(
name
=
'b'
))
bias_attr
=
paddle
.
attr
.
Param
(
name
=
'b'
,
learning_rate
=
1e-3
))
y
=
paddle
.
layer
.
data
(
name
=
'y'
,
type
=
paddle
.
data_type
.
dense_vector
(
1
))
cost
=
paddle
.
layer
.
mse_cost
(
input
=
y_predict
,
label
=
y
)
...
...
@@ -38,9 +43,8 @@ def main():
parameters
=
paddle
.
parameters
.
create
(
cost
)
# create optimizer of new remote updater to pserver
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
,
learning_rate
=
1e-3
)
print
"etcd endoint: "
,
etcd_endpoint
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
,
...
...
@@ -51,6 +55,8 @@ def main():
# event_handler to print training and testing info
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if
event
.
batch_id
%
100
==
0
:
print
"Pass %d, Batch %d, Cost %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
...
...
paddle/api/Evaluator.cpp
浏览文件 @
46d766e2
...
...
@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double
Evaluator
::
getValue
(
const
std
::
string
name
)
const
{
paddle
::
Error
err
;
double
v
=
m
->
rawPtr
->
getValue
(
name
,
&
err
);
if
(
err
)
{
if
(
!
err
.
isOK
()
)
{
throw
std
::
runtime_error
(
err
.
msg
());
}
return
v
;
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
46d766e2
...
...
@@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place paddle_memory
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place paddle_memory
device_context
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
...
...
paddle/framework/detail/tensor-inl.h
0 → 100644
浏览文件 @
46d766e2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
inline
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"Tenosr holds no memory. Call Tensor::mutable_data first."
);
PADDLE_ENFORCE
(
holder_
->
size
()
>=
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."
);
}
template
<
typename
T
>
inline
const
T
*
Tensor
::
data
()
const
{
check_memory_size
<
T
>
();
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
data
()
{
check_memory_size
<
T
>
();
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
mutable_data
(
DDim
dims
,
platform
::
Place
place
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
Resize
(
dims
);
return
mutable_data
<
T
>
(
place
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
mutable_data
(
platform
::
Place
place
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
PADDLE_ENFORCE
(
product
(
dims_
)
>
0
,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."
);
/* some versions of boost::variant don't have operator!= */
size_t
size
=
product
(
dims_
)
*
sizeof
(
T
);
if
(
holder_
==
nullptr
||
!
(
holder_
->
place
()
==
place
)
||
holder_
->
size
()
<
size
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
));
}
#ifndef PADDLE_ONLY_CPU
else
if
(
platform
::
is_gpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
GPUPlace
>
(
boost
::
get
<
platform
::
GPUPlace
>
(
place
),
size
));
}
#endif
offset_
=
0
;
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
inline
void
Tensor
::
ShareDataWith
(
const
Tensor
&
src
)
{
src
.
check_memory_size
<
T
>
();
*
this
=
src
;
}
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CPUDeviceContext
&
ctx
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_place
=
ctx
.
GetPlace
();
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
src_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
);
}
#ifndef PADDLE_ONLY_CPU
else
if
(
platform
::
is_gpu_place
(
src_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
#endif
}
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_place
=
ctx
.
GetPlace
();
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
src_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
,
ctx
.
stream
());
}
else
if
(
platform
::
is_gpu_place
(
src_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
ctx
.
stream
());
}
}
#endif
template
<
typename
T
>
inline
Tensor
Tensor
::
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
check_memory_size
<
T
>
();
PADDLE_ENFORCE
(
begin_idx
>=
0
,
"Slice begin index is less than zero."
);
PADDLE_ENFORCE
(
end_idx
<=
dims_
[
0
],
"Slice end index is out of bound."
);
PADDLE_ENFORCE
(
begin_idx
<
end_idx
,
"Begin index must be less than end index."
);
PADDLE_ENFORCE
(
dims_
[
0
]
!=
1
,
"Can not slice a tensor with dims_[0] = 1."
);
int
base
=
product
(
dims_
)
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
sizeof
(
T
);
return
dst
;
}
inline
void
Tensor
::
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
}
inline
const
DDim
&
Tensor
::
dims
()
const
{
return
dims_
;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/net.h
浏览文件 @
46d766e2
...
...
@@ -97,9 +97,5 @@ class NetOp : public OperatorBase {
}
};
/**
* @brief Identify operator in local Net. used in backward
*/
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.h
浏览文件 @
46d766e2
...
...
@@ -407,15 +407,16 @@ class GradOpRegisterHelper {
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__
{
\
__op_kernel_register__##type##__
() {
\
struct __op_kernel_register__##type##__
##DEVICE_TYPE##__ {
\
__op_kernel_register__##type##__
##DEVICE_TYPE##__() {
\
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new __VA_ARGS__()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType)
...
...
paddle/framework/operator.cc
浏览文件 @
46d766e2
...
...
@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"Input Output Indices could not be nullptr"
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
[
it
->
second
]
;
return
inputs_
.
at
((
size_t
)
it
->
second
)
;
}
else
{
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
(
idx
);
return
inputs_
.
at
(
(
size_t
)
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"IO Idx could not be nullptr"
);
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
input_format
.
at
((
size_t
)
offset
+
1
)
<=
inputs_
.
size
(),
"Input Out Of Range"
);
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
...
...
@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
outputs_
[
it
->
second
]
;
return
outputs_
.
at
((
size_t
)
it
->
second
)
;
}
else
{
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
int
idx
=
output_format
[
it
->
second
];
return
outputs_
.
at
(
idx
);
return
outputs_
.
at
(
(
size_t
)
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
output_format
.
at
((
size_t
)
offset
+
1
)
<=
outputs_
.
size
(),
"Output Out of Range"
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
...
...
paddle/framework/operator.h
浏览文件 @
46d766e2
...
...
@@ -214,7 +214,9 @@ class OperatorWithKernel : public OperatorBase {
place_
=
dev_ctx
.
GetPlace
();
}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
place_
==
o
.
place_
;
}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
);
}
};
struct
OpKernelHash
{
...
...
paddle/framework/tensor.cc
浏览文件 @
46d766e2
...
...
@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include
<paddle/framework/tensor.h>
#include
"paddle/framework/tensor.h"
namespace
paddle
{
namespace
framework
{}
...
...
paddle/framework/tensor.h
浏览文件 @
46d766e2
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <typeindex>
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
...
...
@@ -31,9 +32,11 @@ template <bool less, size_t i, typename... args>
struct
CastToPyBufferImpl
;
}
// namespace details
}
// namespace pybind
namespace
framework
{
class
Tensor
{
public:
template
<
bool
less
,
size_t
i
,
typename
...
args
>
friend
struct
paddle
::
pybind
::
details
::
CastToPyBufferImpl
;
...
...
@@ -46,106 +49,84 @@ class Tensor {
public:
Tensor
()
:
offset_
(
0
)
{}
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
const
T
*
data
()
const
{
EnforceSufficientMemory
<
T
>
();
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
inline
T
*
data
();
/*! Return a pointer to constant memory block. */
template
<
typename
T
>
T
*
data
()
{
EnforceSufficientMemory
<
T
>
();
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
,
// must be POD types
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
)
{
Resize
(
dims
);
return
mutable_data
<
T
>
(
place
);
}
template
<
typename
T
,
// must be POD types
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
T
*
mutable_data
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
product
(
dims_
)
>
0
,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."
);
if
(
holder_
==
nullptr
||
!
(
holder_
->
place
()
==
place
)
/* some versions of boost::variant don't have operator!= */
||
holder_
->
size
()
<
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
product
(
dims_
)
*
sizeof
(
T
)));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW
(
"'GPUPlace' is not supported in CPU only device."
);
#else
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
GPUPlace
>
(
boost
::
get
<
platform
::
GPUPlace
>
(
place
),
product
(
dims_
)
*
sizeof
(
T
)));
#endif
}
else
{
PADDLE_THROW
(
"Unknown 'place'."
);
}
offset_
=
0
;
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
inline
const
T
*
data
()
const
;
/**
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
platform
::
Place
place
);
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
*
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
);
/*! Return the dimensions of the memory block. */
inline
const
DDim
&
dims
()
const
;
/*! Resize the dimensions of the memory block. */
inline
void
Resize
(
const
DDim
&
dims
);
/*! The internal of two tensors share the same memory block. */
template
<
typename
T
>
inline
void
ShareDataWith
(
const
Tensor
&
src
);
/**
* @brief Copy the content of external tensor to a new place.
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains place where to store.
*
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
template
<
typename
T
>
void
ShareDataWith
(
const
Tensor
&
src
)
{
src
.
EnforceSufficientMemory
<
T
>
();
*
this
=
src
;
}
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CPUDeviceContext
&
ctx
);
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
void
CopyFrom
(
const
Tensor
&
src
,
platform
::
Place
dst_place
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
src
.
holder_
->
place
())
&&
platform
::
is_cpu_place
(
dst_place
),
"Tensor::CopyFrom only support CPU now."
);
src
.
EnforceSufficientMemory
<
T
>
();
size_t
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
Resize
(
src
.
dims
());
const
void
*
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
void
*
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
memcpy
(
dst_ptr
,
src_ptr
,
size
);
}
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CUDADeviceContext
&
ctx
);
#endif
/**
* @brief Return the slice of the tensor.
*
* @param[in] begin_idx The begin index of the slice.
* @param[in] end_idx The end index of the slice.
*/
template
<
typename
T
>
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
EnforceSufficientMemory
<
T
>
();
PADDLE_ENFORCE
(
begin_idx
>=
0
,
"Slice begin index is less than zero."
);
PADDLE_ENFORCE
(
end_idx
<=
dims_
[
0
],
"Slice end index is out of bound."
);
PADDLE_ENFORCE
(
begin_idx
<
end_idx
,
"Begin index must be less than end index."
);
PADDLE_ENFORCE
(
dims_
[
0
]
!=
1
,
"Can not slice a tensor with dims_[0] = 1."
);
int
base
=
product
(
dims_
)
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
sizeof
(
T
);
return
dst
;
}
void
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
}
const
DDim
&
dims
()
const
{
return
dims_
;
}
inline
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
;
private:
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
template
<
typename
T
>
inline
void
check_memory_size
()
const
;
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of Variable.
*/
struct
Placeholder
{
virtual
~
Placeholder
()
{}
virtual
void
*
ptr
()
const
=
0
;
virtual
platform
::
Place
place
()
const
=
0
;
virtual
size_t
size
()
const
=
0
;
virtual
std
::
type_index
type
()
const
=
0
;
virtual
platform
::
Place
place
()
const
=
0
;
};
template
<
typename
T
,
typename
PlaceType
>
...
...
@@ -156,33 +137,38 @@ class Tensor {
place_
(
place
),
size_
(
size
)
{}
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
size_t
size
()
const
{
return
size_
;
}
virtual
paddle
::
platform
::
Place
place
()
const
{
return
place_
;
}
virtual
platform
::
Place
place
()
const
{
return
place_
;
}
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
std
::
type_index
type
()
const
{
return
std
::
type_index
(
typeid
(
T
));
}
/*! the pointer of memory block. */
std
::
unique_ptr
<
T
,
memory
::
PODDeleter
<
T
,
PlaceType
>>
ptr_
;
platform
::
Place
place_
;
// record the place of ptr_.
size_t
size_
;
// size of the memory block.
/*! the place of memory block. */
platform
::
Place
place_
;
/*! the size of memory block. */
size_t
size_
;
};
template
<
typename
T
>
inline
void
EnforceSufficientMemory
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"Tenosr holds no memory. Call Tensor::mutable_data first."
);
PADDLE_ENFORCE
(
holder_
->
size
()
>=
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."
);
}
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holds the memory block if allocated.
/*! holds the memory block if allocated. */
std
::
shared_ptr
<
Placeholder
>
holder_
;
/*! points to dimensions of memory block. */
DDim
dims_
;
// A PlaceHolder may be shared by more than one tensor. Some of them may be
// slices of the others. So the offset_ is introduced here to indicate the
// byte offset between PlaceHolder::ptr_ and where tensor's data really
// begins.
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
* @note Some of them may be slices of the others. So the offset_
* is introduced here to indicate the byte offset between
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t
offset_
;
};
}
// namespace framework
}
// namespace paddle
#include "paddle/framework/detail/tensor-inl.h"
paddle/framework/tensor_test.cc
浏览文件 @
46d766e2
...
...
@@ -72,7 +72,8 @@ TEST(Tensor, MutableData) {
p2
=
src_tensor
.
mutable_data
<
float
>
(
make_ddim
({
2
,
2
}),
CPUPlace
());
EXPECT_EQ
(
p1
,
p2
);
}
#ifdef __CUDACC__
#ifndef PADDLE_ONLY_CPU
{
Tensor
src_tensor
;
float
*
p1
=
nullptr
;
...
...
@@ -123,7 +124,7 @@ TEST(Tensor, ShareDataWith) {
ASSERT_EQ
(
src_tensor
.
data
<
int
>
(),
dst_tensor
.
data
<
int
>
());
}
#if
def __CUDACC__
#if
ndef PADDLE_ONLY_CPU
{
Tensor
src_tensor
;
Tensor
dst_tensor
;
...
...
@@ -160,7 +161,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ
(
src_data_address
+
3
*
4
*
1
*
sizeof
(
int
),
slice_data_address
);
}
#if
def __CUDACC__
#if
ndef PADDLE_ONLY_CPU
{
Tensor
src_tensor
;
src_tensor
.
mutable_data
<
double
>
(
make_ddim
({
6
,
9
}),
GPUPlace
());
...
...
@@ -188,25 +189,74 @@ TEST(Tensor, Slice) {
TEST
(
Tensor
,
CopyFrom
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
{
Tensor
src_tensor
;
Tensor
dst_tensor
;
int
*
src_ptr
=
src_tensor
.
mutable_data
<
int
>
(
make_ddim
({
3
,
3
}),
CPUPlace
());
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
Tensor
src_tensor
;
int
*
src_ptr
=
src_tensor
.
mutable_data
<
int
>
(
make_ddim
({
3
,
3
}),
CPUPlace
());
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
Tensor
dst_tensor
;
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
CPUPlace
());
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
auto
*
cpu_ctx
=
new
paddle
::
platform
::
CPUDeviceContext
();
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
cpu_ctx
);
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
cpu_ctx
);
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
dst_ptr
[
i
],
slice_ptr
[
i
]);
}
}
#ifndef PADDLE_ONLY_CPU
{
Tensor
src_tensor
;
Tensor
gpu_tensor
;
Tensor
dst_tensor
;
int
*
src_ptr
=
src_tensor
.
mutable_data
<
int
>
(
make_ddim
({
3
,
3
}),
CPUPlace
());
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
// CPU Tensor to GPU Tensor
auto
gpu_ctx
=
new
paddle
::
platform
::
CUDADeviceContext
(
0
);
gpu_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
gpu_ctx
);
// GPU Tensor to CPU Tensor
auto
cpu_ctx
=
new
paddle
::
platform
::
CPUDeviceContext
();
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_ctx
);
// Compare Tensors
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
// CPU Slice Tensor to GPU Tensor
gpu_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
gpu_ctx
);
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
CPUPlace
());
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
dst_ptr
[
i
],
slice_ptr
[
i
]);
// GPU Tensor to CPU Tensor
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_ctx
);
// Compare Slice Tensors
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
dst_ptr
[
i
],
slice_ptr
[
i
]);
}
}
#endif
}
paddle/gserver/activations/ActivationFunction.cpp
浏览文件 @
46d766e2
...
...
@@ -207,8 +207,8 @@ Error __must_check backward(Argument& act) {
argument_
.
value
->
setData
(
act
.
value
->
getData
()
+
offset
,
1UL
,
size
);
argument_
.
grad
->
setData
(
act
.
grad
->
getData
()
+
offset
,
1UL
,
size
);
Error
status
=
softmax_
.
backward
(
argument_
);
if
(
!
status
)
return
status
;
Error
err
=
softmax_
.
backward
(
argument_
);
if
(
!
err
.
isOK
())
return
err
;
}
return
Error
();
}
...
...
paddle/memory/detail/buddy_allocator.cc
浏览文件 @
46d766e2
...
...
@@ -27,12 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
system_allocator_
(
std
::
move
(
system_allocator
))
{}
BuddyAllocator
::~
BuddyAllocator
()
{
DLOG
(
INFO
)
<<
"BuddyAllocator Disconstructor makes sure that all of these "
"have actually been freed"
;
VLOG
(
3
)
<<
"BuddyAllocator Disconstructor makes sure that all of these "
"have actually been freed"
;
while
(
!
pool_
.
empty
())
{
auto
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
pool_
.
begin
()));
DLOG
(
INFO
)
<<
"Free from block ("
<<
block
<<
", "
<<
max_chunk_size_
<<
")"
;
VLOG
(
3
)
<<
"Free from block ("
<<
block
<<
", "
<<
max_chunk_size_
<<
")"
;
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
cache_
.
invalidate
(
block
);
...
...
@@ -52,12 +51,11 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
// acquire the allocator lock
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
DLOG
(
INFO
)
<<
"Allocate "
<<
unaligned_size
<<
" bytes from chunk size "
<<
size
;
VLOG
(
3
)
<<
"Allocate "
<<
unaligned_size
<<
" bytes from chunk size "
<<
size
;
// if the allocation is huge, send directly to the system allocator
if
(
size
>
max_chunk_size_
)
{
DLOG
(
INFO
)
<<
"Allocate from system allocator."
;
VLOG
(
3
)
<<
"Allocate from system allocator."
;
return
SystemAlloc
(
size
);
}
...
...
@@ -72,9 +70,9 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
return
nullptr
;
}
}
else
{
DLOG
(
INFO
)
<<
"Allocation from existing memory block "
<<
std
::
get
<
2
>
(
*
it
)
<<
" at address "
<<
reinterpret_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
))
->
data
();
VLOG
(
3
)
<<
"Allocation from existing memory block "
<<
std
::
get
<
2
>
(
*
it
)
<<
" at address "
<<
reinterpret_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
))
->
data
();
}
total_used_
+=
size
;
...
...
@@ -91,10 +89,10 @@ void BuddyAllocator::Free(void* p) {
// Acquire the allocator lock
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
DLOG
(
INFO
)
<<
"Free from address "
<<
block
;
VLOG
(
3
)
<<
"Free from address "
<<
block
;
if
(
block
->
type
(
cache_
)
==
MemoryBlock
::
HUGE_CHUNK
)
{
DLOG
(
INFO
)
<<
"Free directly from system allocator"
;
VLOG
(
3
)
<<
"Free directly from system allocator"
;
system_allocator_
->
Free
(
block
,
block
->
total_size
(
cache_
),
block
->
index
(
cache_
));
...
...
@@ -111,8 +109,8 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the right buddy
if
(
block
->
has_right_buddy
(
cache_
))
{
DLOG
(
INFO
)
<<
"Merging this block "
<<
block
<<
" with its right buddy "
<<
block
->
right_buddy
(
cache_
);
VLOG
(
3
)
<<
"Merging this block "
<<
block
<<
" with its right buddy "
<<
block
->
right_buddy
(
cache_
);
auto
right_buddy
=
block
->
right_buddy
(
cache_
);
...
...
@@ -129,8 +127,8 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the left buddy
if
(
block
->
has_left_buddy
(
cache_
))
{
DLOG
(
INFO
)
<<
"Merging this block "
<<
block
<<
" with its left buddy "
<<
block
->
left_buddy
(
cache_
);
VLOG
(
3
)
<<
"Merging this block "
<<
block
<<
" with its left buddy "
<<
block
->
left_buddy
(
cache_
);
auto
left_buddy
=
block
->
left_buddy
(
cache_
);
...
...
@@ -146,8 +144,8 @@ void BuddyAllocator::Free(void* p) {
}
// Dumping this block into pool
DLOG
(
INFO
)
<<
"Inserting free block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
")"
;
VLOG
(
3
)
<<
"Inserting free block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
")"
;
pool_
.
insert
(
IndexSizeAddress
(
block
->
index
(
cache_
),
block
->
total_size
(
cache_
),
block
));
...
...
@@ -166,7 +164,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
size_t
index
=
0
;
void
*
p
=
system_allocator_
->
Alloc
(
index
,
size
);
DLOG
(
INFO
)
<<
"Allocated "
<<
p
<<
" from system allocator."
;
VLOG
(
3
)
<<
"Allocated "
<<
p
<<
" from system allocator."
;
if
(
p
==
nullptr
)
return
nullptr
;
...
...
@@ -192,8 +190,8 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
if
(
p
==
nullptr
)
return
pool_
.
end
();
DLOG
(
INFO
)
<<
"Creating and inserting new block "
<<
p
<<
" from system allocator"
;
VLOG
(
3
)
<<
"Creating and inserting new block "
<<
p
<<
" from system allocator"
;
static_cast
<
MemoryBlock
*>
(
p
)
->
init
(
cache_
,
MemoryBlock
::
FREE_CHUNK
,
index
,
max_chunk_size_
,
nullptr
,
nullptr
);
...
...
@@ -237,19 +235,19 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
auto
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
));
pool_
.
erase
(
it
);
DLOG
(
INFO
)
<<
"Split block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
") into"
;
VLOG
(
3
)
<<
"Split block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
") into"
;
block
->
split
(
cache_
,
size
);
DLOG
(
INFO
)
<<
"Left block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
")"
;
VLOG
(
3
)
<<
"Left block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
")"
;
block
->
set_type
(
cache_
,
MemoryBlock
::
ARENA_CHUNK
);
// the rest of memory if exist
if
(
block
->
has_right_buddy
(
cache_
))
{
if
(
block
->
right_buddy
(
cache_
)
->
type
(
cache_
)
==
MemoryBlock
::
FREE_CHUNK
)
{
DLOG
(
INFO
)
<<
"Insert right block ("
<<
block
->
right_buddy
(
cache_
)
<<
", "
<<
block
->
right_buddy
(
cache_
)
->
total_size
(
cache_
)
<<
")"
;
VLOG
(
3
)
<<
"Insert right block ("
<<
block
->
right_buddy
(
cache_
)
<<
", "
<<
block
->
right_buddy
(
cache_
)
->
total_size
(
cache_
)
<<
")"
;
pool_
.
insert
(
IndexSizeAddress
(
block
->
right_buddy
(
cache_
)
->
index
(
cache_
),
...
...
@@ -276,7 +274,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
return
;
}
DLOG
(
INFO
)
<<
"Return block "
<<
block
<<
" to fallback allocator."
;
VLOG
(
3
)
<<
"Return block "
<<
block
<<
" to fallback allocator."
;
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
cache_
.
invalidate
(
block
);
...
...
@@ -312,7 +310,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
MemoryBlock
*
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
pool
));
DLOG
(
INFO
)
<<
"Return block "
<<
block
<<
" to base allocator."
;
VLOG
(
3
)
<<
"Return block "
<<
block
<<
" to base allocator."
;
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
cache_
.
invalidate
(
block
);
...
...
paddle/memory/memory.h
浏览文件 @
46d766e2
...
...
@@ -29,10 +29,10 @@ void Free(Place, void*);
template
<
typename
Place
>
size_t
Used
(
Place
);
template
<
typename
T
,
/* must be POD types */
typename
Place
/* platform::GPUPlace or platform::CPUPlace */
,
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
template
<
typename
T
,
typename
Place
>
class
PODDeleter
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
public:
PODDeleter
(
Place
place
)
:
place_
(
place
)
{}
void
operator
()(
T
*
ptr
)
{
Free
(
place_
,
static_cast
<
void
*>
(
ptr
));
}
...
...
paddle/platform/device_context.h
浏览文件 @
46d766e2
...
...
@@ -87,7 +87,7 @@ class CUDADeviceContext : public DeviceContext {
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
{
return
stream_
;
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
Eigen
::
GpuDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
...
...
paddle/trainer/NewRemoteParameterUpdater.cpp
浏览文件 @
46d766e2
...
...
@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init(
sgdConfigV2
->
set_decay
(
paramConfig
.
decay_rate
());
optimizeConfigV2
.
set_lr_policy
(
paddle
::
OptimizerConfig
::
Const
);
auto
constlr
=
optimizeConfigV2
.
mutable_const_lr
();
constlr
->
set_learning_rate
(
paramConfig
.
learning_rate
());
if
(
paramConfig
.
has_learning_rate
())
{
constlr
->
set_learning_rate
(
paramConfig
.
learning_rate
());
}
else
{
constlr
->
set_learning_rate
(
trainerConfig_
.
learning_rate
());
}
if
(
trainerConfig_
.
algorithm
()
==
"sgd"
)
{
optimizeConfigV2
.
set_optimizer
(
paddle
::
OptimizerConfig
::
SGD
);
// FIXME: config all algorithms
...
...
paddle/utils/Error.h
浏览文件 @
46d766e2
...
...
@@ -126,9 +126,11 @@ public:
}
/**
* @brief operator bool, return True if there is something error.
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
operator
bool
()
const
{
return
!
this
->
isOK
();
}
void
check
()
const
{
CHECK
(
this
->
isOK
())
<<
msg
();
}
/**
* @brief isOK return True if there is no error.
...
...
@@ -136,13 +138,6 @@ public:
*/
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
void
check
()
const
{
CHECK
(
this
->
isOK
())
<<
msg
();
}
private:
std
::
shared_ptr
<
std
::
string
>
msg_
;
};
...
...
paddle/utils/tests/test_Error.cpp
浏览文件 @
46d766e2
...
...
@@ -18,17 +18,17 @@ limitations under the License. */
TEST
(
Error
,
testAll
)
{
paddle
::
Error
error
;
ASSERT_
FALSE
(
error
);
ASSERT_
TRUE
(
error
.
isOK
()
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_
TRUE
(
error
);
ASSERT_
FALSE
(
error
.
isOK
()
);
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
error
=
paddle
::
Error
(
"error2"
);
ASSERT_
TRUE
(
error
);
ASSERT_
FALSE
(
error
.
isOK
()
);
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
int
i
=
3
;
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
ASSERT_
TRUE
(
error3
);
ASSERT_
FALSE
(
error3
.
isOK
()
);
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
}
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
46d766e2
...
...
@@ -272,7 +272,7 @@ class ExtraLayerAttribute(object):
for
key
in
self
.
attr
:
if
not
hasattr
(
self
,
'can_%s'
%
key
)
or
\
not
getattr
(
self
,
'can_%s'
%
key
):
raise
NotImplementedError
(
"Layer %s
can
not support %s"
%
raise
NotImplementedError
(
"Layer %s
does
not support %s"
%
(
layer_name
,
key
))
@
staticmethod
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
46d766e2
...
...
@@ -865,7 +865,7 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
@
wrap_name_default
(
"embedding"
)
@
wrap_param_attr_default
()
@
layer_support
(
ERROR_CLIPPING
)
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
def
embedding_layer
(
input
,
size
,
name
=
None
,
param_attr
=
None
,
layer_attr
=
None
):
"""
Define a embedding Layer.
...
...
@@ -1320,7 +1320,7 @@ def pooling_layer(input,
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
act
=
SigmoidActivation
())
@
wrap_act_default
(
param_names
=
[
"act"
,
'state_act'
],
act
=
TanhActivation
())
@
wrap_name_default
(
"lstmemory"
)
@
layer_support
(
DROPOUT
)
@
layer_support
()
def
lstmemory
(
input
,
name
=
None
,
size
=
None
,
...
...
@@ -1429,7 +1429,7 @@ def lstmemory(input,
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
act
=
SigmoidActivation
())
@
wrap_act_default
(
param_names
=
[
"act"
],
act
=
TanhActivation
())
@
wrap_name_default
(
"gru"
)
@
layer_support
(
DROPOUT
)
@
layer_support
()
def
grumemory
(
input
,
size
=
None
,
name
=
None
,
...
...
@@ -1793,7 +1793,7 @@ def repeat_layer(input,
@
wrap_name_default
(
"seqreshape"
)
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
layer_support
()
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
def
seq_reshape_layer
(
input
,
reshape_size
,
act
=
None
,
...
...
@@ -2703,7 +2703,7 @@ def img_cmrnorm_layer(input,
default_factory
=
lambda
_
:
ParamAttr
(
initial_mean
=
1.0
,
initial_std
=
0.
))
@
wrap_act_default
(
act
=
ReluActivation
())
@
wrap_name_default
(
"batch_norm"
)
@
layer_support
(
DROPOUT
)
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
batch_norm_layer
(
input
,
act
=
None
,
name
=
None
,
...
...
@@ -2783,15 +2783,6 @@ def batch_norm_layer(input,
:return: LayerOutput object.
:rtype: LayerOutput
"""
if
not
isinstance
(
act
,
ReluActivation
):
logger
.
log
(
logging
.
WARN
,
"%s is not recommend for batch normalization's activation, "
"maybe the relu is better"
%
act
.
name
)
if
not
isinstance
(
input
.
activation
,
LinearActivation
):
logger
.
log
(
logging
.
WARN
,
"The activation should be inside batch normalization, the "
"previous layer's activation may be Linear"
)
if
num_channels
is
None
:
if
input
.
num_filters
is
not
None
:
...
...
@@ -2861,7 +2852,7 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
@
wrap_name_default
(
"addto"
)
@
wrap_act_default
(
act
=
LinearActivation
())
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
layer_support
(
DROPOUT
)
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
addto_layer
(
input
,
act
=
None
,
name
=
None
,
bias_attr
=
None
,
layer_attr
=
None
):
"""
AddtoLayer.
...
...
@@ -2940,7 +2931,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None):
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_name_default
(
"concat"
)
@
layer_support
()
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
concat_layer
(
input
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
"""
Concat all input vector into one huge vector.
...
...
@@ -3024,7 +3015,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
@
wrap_name_default
(
"seqconcat"
)
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
layer_support
()
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
seq_concat_layer
(
a
,
b
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
"""
...
...
@@ -3177,7 +3168,7 @@ def memory(name,
@
wrap_act_default
(
param_names
=
[
'state_act'
],
act
=
TanhActivation
())
@
wrap_act_default
(
act
=
TanhActivation
())
@
wrap_name_default
(
'lstm_step'
)
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
@
layer_support
()
def
lstm_step_layer
(
input
,
state
,
size
=
None
,
...
...
@@ -4480,7 +4471,7 @@ def tensor_layer(a,
@
wrap_param_attr_default
()
@
wrap_bias_attr_default
()
@
wrap_act_default
()
@
layer_support
()
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
selective_fc_layer
(
input
,
size
,
select
=
None
,
...
...
@@ -5974,7 +5965,7 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
"""
The crop layer crops images by offset and shape. User can set crop shape by
args 'shape' explicitly or by reference input layer.
The example usage is:
.. code-block:: python
...
...
python/paddle/v2/dataset/common.py
浏览文件 @
46d766e2
...
...
@@ -166,55 +166,37 @@ def cluster_files_reader(files_pattern,
return
reader
def
convert
(
output_path
,
reader
,
num_shards
,
name_prefix
,
max_lines_to_shuffle
=
1000
):
def
convert
(
output_path
,
reader
,
line_count
,
name_prefix
):
import
recordio
"""
Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
"""
assert
num_shards
>=
1
assert
max_lines_to_shuffle
>=
1
def
open_writers
():
w
=
[]
for
i
in
range
(
0
,
num_shards
):
n
=
"%s/%s-%05d-of-%05d"
%
(
output_path
,
name_prefix
,
i
,
num_shards
-
1
)
w
.
append
(
recordio
.
writer
(
n
))
return
w
def
close_writers
(
w
):
for
i
in
range
(
0
,
num_shards
):
w
[
i
].
close
()
assert
line_count
>=
1
indx_f
=
0
def
write_data
(
w
,
lines
):
def
write_data
(
indx_f
,
lines
):
random
.
shuffle
(
lines
)
for
i
,
d
in
enumerate
(
lines
):
filename
=
"%s/%s-%05d"
%
(
output_path
,
name_prefix
,
indx_f
)
writer
=
recordio
.
writer
(
filename
)
for
l
in
lines
:
# FIXME(Yancey1989):
# dumps with protocol: pickle.HIGHEST_PROTOCOL
o
=
pickle
.
dumps
(
d
)
w
[
i
%
num_shards
].
write
(
o
)
writer
.
write
(
cPickle
.
dumps
(
l
)
)
writer
.
close
(
)
w
=
open_writers
()
lines
=
[]
for
i
,
d
in
enumerate
(
reader
()):
lines
.
append
(
d
)
if
i
%
max_lines_to_shuffle
==
0
and
i
>=
max_lines_to_shuffle
:
write_data
(
w
,
lines
)
if
i
%
line_count
==
0
and
i
>=
line_count
:
write_data
(
indx_f
,
lines
)
lines
=
[]
indx_f
+=
1
continue
write_data
(
w
,
lines
)
close_writers
(
w
)
write_data
(
indx_f
,
lines
)
python/paddle/v2/inference.py
浏览文件 @
46d766e2
...
...
@@ -35,6 +35,13 @@ class Inference(object):
name
=
param
.
getName
()
assert
isinstance
(
val
,
api
.
Vector
)
val
.
copyFromNumpyArray
(
parameters
.
get
(
name
).
flatten
())
# the setValueUpdated function is called in randomize, zeroMem,
# load function in paddle/parameter/Parameter.cpp. But in the
# inference mode, the setValueUpdated is never called, it will
# cause the parameter will not be dispatched
# in MultiGradientMachine for multi-GPU. So setValueUpdated is
# called here, but it's better to call this function in one place.
param
.
setValueUpdated
()
self
.
__gradient_machine__
=
gm
self
.
__data_types__
=
topo
.
data_type
()
...
...
python/paddle/v2/master/client.py
浏览文件 @
46d766e2
...
...
@@ -49,7 +49,6 @@ class client(object):
def
set_dataset
(
self
,
paths
):
holder_type
=
ctypes
.
c_char_p
*
len
(
paths
)
holder
=
holder_type
()
print
paths
for
idx
,
path
in
enumerate
(
paths
):
c_ptr
=
ctypes
.
c_char_p
(
path
)
holder
[
idx
]
=
c_ptr
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录