Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7702ab1a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7702ab1a
编写于
8月 09, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:baidu/Paddle into feature/refactorize_framework_proto
上级
4a788854
b008360b
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
900 addition
and
177 deletion
+900
-177
Dockerfile
Dockerfile
+1
-1
go/glide.lock
go/glide.lock
+3
-7
go/master/service_test.go
go/master/service_test.go
+14
-10
go/pserver/client/c/cclient.go
go/pserver/client/c/cclient.go
+14
-6
go/pserver/client/client.go
go/pserver/client/client.go
+7
-3
go/pserver/client/client_test.go
go/pserver/client/client_test.go
+11
-3
go/pserver/client/etcd_client.go
go/pserver/client/etcd_client.go
+134
-19
go/pserver/client/etcd_client_test.go
go/pserver/client/etcd_client_test.go
+106
-0
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+3
-0
paddle/framework/details/lod_tensor.cc
paddle/framework/details/lod_tensor.cc
+62
-0
paddle/framework/details/lod_tensor.h
paddle/framework/details/lod_tensor.h
+46
-0
paddle/framework/lod_tensor.cc
paddle/framework/lod_tensor.cc
+51
-0
paddle/framework/lod_tensor.h
paddle/framework/lod_tensor.h
+145
-0
paddle/framework/lod_tensor_impl.h
paddle/framework/lod_tensor_impl.h
+60
-0
paddle/framework/lod_tensor_test.cc
paddle/framework/lod_tensor_test.cc
+165
-0
paddle/framework/operator.h
paddle/framework/operator.h
+8
-14
paddle/framework/tensor.h
paddle/framework/tensor.h
+2
-0
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+1
-1
paddle/function/FunctionTest.cpp
paddle/function/FunctionTest.cpp
+11
-11
paddle/function/TensorShapeTest.cpp
paddle/function/TensorShapeTest.cpp
+1
-1
paddle/gserver/layers/SubNestedSequenceLayer.cpp
paddle/gserver/layers/SubNestedSequenceLayer.cpp
+2
-2
paddle/gserver/tests/test_KmaxSeqScore.cpp
paddle/gserver/tests/test_KmaxSeqScore.cpp
+1
-1
paddle/platform/CMakeLists.txt
paddle/platform/CMakeLists.txt
+1
-1
paddle/platform/enforce_test.cc
paddle/platform/enforce_test.cc
+49
-95
paddle/trainer/NewRemoteParameterUpdater.cpp
paddle/trainer/NewRemoteParameterUpdater.cpp
+2
-2
未找到文件。
Dockerfile
浏览文件 @
7702ab1a
...
...
@@ -28,7 +28,7 @@ RUN apt-get update && \
wget unzip unrar
tar
xz-utils bzip2
gzip
coreutils ntp
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
python-matplotlib gcc-4.8 g++-4.8
\
automake locales clang-format
-3.8
swig doxygen cmake
\
automake locales clang-format swig doxygen cmake
\
liblapack-dev liblapacke-dev libboost-dev
\
clang-3.8 llvm-3.8 libclang-3.8-dev
\
net-tools
&&
\
...
...
go/glide.lock
浏览文件 @
7702ab1a
hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-08-0
3T21:46:51.744995189
Z
updated: 2017-08-0
7T23:37:48.867469328
Z
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...
...
@@ -10,7 +10,7 @@ imports:
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
version:
c31bec0f29facff13f7c3e3d948e55dd6689ed42
version:
d0d1a87aa96ae14914751d42264262cb69eda170
subpackages:
- alarm
- auth
...
...
@@ -24,6 +24,7 @@ imports:
- error
- etcdserver
- etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
...
...
@@ -210,11 +211,6 @@ testImports:
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
...
...
go/master/service_test.go
浏览文件 @
7702ab1a
package
master_test
import
(
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func
TestNewServiceWithEtcd
(
t
*
testing
.
T
)
{
// setup an embed etcd server
etcdDir
,
err
:=
ioutil
s
.
TempDir
(
""
,
""
)
etcdDir
,
err
:=
ioutil
.
TempDir
(
""
,
""
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cfg
:=
embed
.
NewConfig
()
lpurl
,
_
:=
url
.
Parse
(
"http://localhost:0"
)
lcurl
,
_
:=
url
.
Parse
(
"http://localhost:0"
)
cfg
.
LPUrls
=
[]
url
.
URL
{
*
lpurl
}
cfg
.
LCUrls
=
[]
url
.
URL
{
*
lcurl
}
cfg
.
Dir
=
etcdDir
e
,
err
:=
embed
.
StartEtcd
(
cfg
)
if
err
!=
nil
{
...
...
@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) {
t
.
Fatal
(
err
)
}
}()
select
{
case
<-
e
.
Server
.
ReadyNotify
()
:
t
.
Log
(
"Server is ready!"
)
case
<-
time
.
After
(
60
*
time
.
Second
)
:
e
.
Server
.
Stop
()
// trigger a shutdown
t
.
Fatal
(
"Server took too long to start!"
)
}
ep
:=
[]
string
{
"127.0.0.1:2379"
}
<-
e
.
Server
.
ReadyNotify
()
port
:=
strings
.
Split
(
e
.
Clients
[
0
]
.
Addr
()
.
String
(),
":"
)[
1
]
endpoint
:=
"127.0.0.1:"
+
port
ep
:=
[]
string
{
endpoint
}
masterAddr
:=
"127.0.0.1:3306"
store
,
err
:=
master
.
NewEtcdClient
(
ep
,
masterAddr
,
master
.
DefaultLockPath
,
master
.
DefaultAddrPath
,
master
.
DefaultStatePath
,
30
)
if
err
!=
nil
{
...
...
go/pserver/client/c/cclient.go
浏览文件 @
7702ab1a
...
...
@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type
selector
bool
func
(
s
selector
)
Select
()
bool
{
return
bool
(
s
)
func
(
s
selector
)
Select
()
(
bool
,
error
)
{
return
bool
(
s
),
nil
}
func
(
s
selector
)
Done
()
error
{
return
nil
}
type
lister
[]
client
.
Server
...
...
@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
}
//export paddle_new_etcd_pserver_client
func
paddle_new_etcd_pserver_client
(
etcdEndpoints
*
C
.
char
,
selected
int
)
C
.
paddle_pserver_client
{
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
func
paddle_new_etcd_pserver_client
(
etcdEndpoints
*
C
.
char
)
C
.
paddle_pserver_client
{
addr
:=
C
.
GoString
(
etcdEndpoints
)
etcdClient
:=
client
.
NewEtcd
(
addr
)
c
:=
client
.
NewClient
(
etcdClient
,
etcdClient
.
Desired
(),
selector
(
selected
!=
0
)
)
c
:=
client
.
NewClient
(
etcdClient
,
etcdClient
.
Desired
(),
etcdClient
)
return
add
(
c
)
}
...
...
@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
//export paddle_begin_init_params
func
paddle_begin_init_params
(
client
C
.
paddle_pserver_client
)
C
.
int
{
c
:=
get
(
client
)
if
selected
:=
c
.
BeginInitParams
();
selected
{
selected
,
err
:=
c
.
BeginInitParams
()
if
err
!=
nil
{
panic
(
err
)
}
if
selected
{
return
1
}
return
0
...
...
go/pserver/client/client.go
浏览文件 @
7702ab1a
...
...
@@ -27,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
// Selector selects if the client should initialize parameters and
// reports the initialization process done.
type
Selector
interface
{
Select
()
bool
// Select selects if the client should initialize parameter servers.
Select
()
(
bool
,
error
)
// Done indicates the initialization process is done.
Done
()
error
}
// Server is the identification of a parameter Server.
...
...
@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func
(
c
*
Client
)
BeginInitParams
()
bool
{
func
(
c
*
Client
)
BeginInitParams
()
(
bool
,
error
)
{
return
c
.
sel
.
Select
()
}
...
...
go/pserver/client/client_test.go
浏览文件 @
7702ab1a
...
...
@@ -124,8 +124,12 @@ func initEtcdClient() {
type
selector
bool
func
(
s
selector
)
Select
()
bool
{
return
bool
(
s
)
func
(
s
selector
)
Select
()
(
bool
,
error
)
{
return
bool
(
s
),
nil
}
func
(
s
selector
)
Done
()
error
{
return
nil
}
type
lister
[]
client
.
Server
...
...
@@ -135,7 +139,11 @@ func (l lister) List() []client.Server {
}
func
testClient
(
t
*
testing
.
T
,
c
*
client
.
Client
)
{
selected
:=
c
.
BeginInitParams
()
selected
,
err
:=
c
.
BeginInitParams
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
!
selected
{
t
.
Fatal
(
"should be selected."
)
}
...
...
go/pserver/client/etcd_client.go
浏览文件 @
7702ab1a
...
...
@@ -16,53 +16,60 @@ package client
import
(
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log
"github.com/sirupsen/logrus"
)
const
(
defaultEtcdTimeout
time
.
Duration
=
5
*
time
.
Second
initLockPath
=
"/init_ps/lock"
initDonePath
=
"/init_ps/done"
initDoneVal
=
"1"
)
// Etcd
Client
is used by pserver client that is a part of trainer process.
// Etcd is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type
EtcdClient
struct
{
// 1. add watcher to watch the change state of pservers.
type
Etcd
struct
{
client
*
clientv3
.
Client
timeout
time
.
Duration
endpoints
[]
string
lock
*
concurrency
.
Mutex
}
// Desired read ps desired number from etcd.
func
(
p
*
EtcdClient
)
Desired
()
int
{
func
(
e
*
Etcd
)
Desired
()
int
{
var
psDesired
int
for
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
p
.
timeout
)
resp
,
err
:=
p
.
client
.
Get
(
ctx
,
pserver
.
PsDesired
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
resp
,
err
:=
e
.
client
.
Get
(
ctx
,
pserver
.
PsDesired
)
cancel
()
if
err
!=
nil
{
log
.
Errorf
(
"Get ps dresire number failed! recnnectiong..., %v"
,
err
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
e
.
timeout
)
continue
}
kvs
:=
resp
.
Kvs
if
len
(
kvs
)
==
0
{
log
.
Infoln
(
"Waiting for ps desired registered ..."
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
e
.
timeout
)
continue
}
psDesired
,
err
=
strconv
.
Atoi
(
string
(
resp
.
Kvs
[
0
]
.
Value
))
if
err
!=
nil
{
log
.
Errorf
(
"psDesired %d invalid %v"
,
psDesired
,
err
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
e
.
timeout
)
continue
}
...
...
@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int {
}
// List return the pserver list read from etcd.
func
(
p
*
EtcdClient
)
List
()
[]
Server
{
psDesired
:=
p
.
Desired
()
func
(
e
*
Etcd
)
List
()
[]
Server
{
psDesired
:=
e
.
Desired
()
servers
:=
make
([]
Server
,
psDesired
)
for
{
for
i
:=
0
;
i
<
psDesired
;
i
++
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
p
.
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
psKey
:=
pserver
.
PsPath
+
strconv
.
Itoa
(
i
)
log
.
Debugf
(
"checking %s"
,
psKey
)
resp
,
err
:=
p
.
client
.
Get
(
ctx
,
psKey
)
resp
,
err
:=
e
.
client
.
Get
(
ctx
,
psKey
)
cancel
()
if
err
!=
nil
{
log
.
Infof
(
"Get psKey= %s error, %v"
,
psKey
,
err
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
e
.
timeout
)
continue
}
kvs
:=
resp
.
Kvs
if
len
(
kvs
)
==
0
{
log
.
Infof
(
"Waiting for ps addr registered ..."
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
e
.
timeout
)
continue
}
...
...
@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address
if
psAddr
==
""
{
log
.
Infof
(
"Get psKey = %s, psAddr is empty"
,
psKey
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
e
.
timeout
)
continue
}
log
.
Debugf
(
"got value (%s) for key: %s"
,
psAddr
,
psKey
)
...
...
@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server {
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func
NewEtcd
(
endpoints
string
)
*
Etcd
Client
{
func
NewEtcd
(
endpoints
string
)
*
Etcd
{
ep
:=
strings
.
Split
(
endpoints
,
","
)
var
cli
*
clientv3
.
Client
var
err
error
...
...
@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient {
break
}
log
.
Infof
(
"Connected to etcd: %s
\n
"
,
endpoints
)
client
:=
&
Etcd
Client
{
client
:=
&
Etcd
{
client
:
cli
,
timeout
:
defaultEtcdTimeout
,
endpoints
:
ep
,
}
return
client
}
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func
(
e
*
Etcd
)
Select
()
(
bool
,
error
)
{
sess
,
err
:=
concurrency
.
NewSession
(
e
.
client
,
concurrency
.
WithTTL
(
5
))
if
err
!=
nil
{
return
false
,
err
}
lock
:=
concurrency
.
NewMutex
(
sess
,
initLockPath
)
log
.
Infof
(
"Trying to acquire lock at %s."
,
initLockPath
)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err
=
lock
.
Lock
(
context
.
Background
())
if
err
!=
nil
{
return
false
,
err
}
log
.
Infof
(
"Successfully acquired lock at %s."
,
initLockPath
)
get
:=
clientv3
.
OpGet
(
initDonePath
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
tresp
,
err
:=
e
.
client
.
Txn
(
ctx
)
.
If
(
lock
.
IsOwner
())
.
Then
(
get
)
.
Commit
()
cancel
()
if
err
!=
nil
{
return
false
,
err
}
if
!
tresp
.
Succeeded
{
return
false
,
errors
.
New
(
"no longer the owner of the lock"
)
}
resp
:=
tresp
.
Responses
[
0
]
.
GetResponseRange
()
if
len
(
resp
.
Kvs
)
==
0
{
// Key value not set, select current trainer.
e
.
lock
=
lock
log
.
Infoln
(
"Trainer selected."
)
return
true
,
nil
}
if
string
(
resp
.
Kvs
[
0
]
.
Value
)
==
initDoneVal
{
log
.
Infoln
(
"Initialization is already done."
)
ctx
,
cancel
=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
err
=
lock
.
Unlock
(
ctx
)
cancel
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
return
false
,
nil
}
return
false
,
fmt
.
Errorf
(
"key %s have unexpected value: %v"
,
initDonePath
,
resp
.
Kvs
[
0
]
.
Value
)
}
// Done indicates the parameter initialization process is done.
func
(
e
*
Etcd
)
Done
()
error
{
if
e
.
lock
==
nil
{
return
errors
.
New
(
"lock is nil, Done called unexpectedly"
)
}
put
:=
clientv3
.
OpPut
(
initDonePath
,
initDoneVal
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
tresp
,
err
:=
e
.
client
.
Txn
(
ctx
)
.
If
(
e
.
lock
.
IsOwner
())
.
Then
(
put
)
.
Commit
()
cancel
()
if
err
!=
nil
{
return
err
}
if
!
tresp
.
Succeeded
{
return
errors
.
New
(
"no longer the owner of the lock"
)
}
ctx
,
cancel
=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
err
=
e
.
lock
.
Unlock
(
ctx
)
cancel
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
else
{
e
.
lock
=
nil
}
return
nil
}
// Close closes the etcd client.
func
(
e
*
Etcd
)
Close
()
error
{
var
err
error
if
e
.
lock
!=
nil
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
e
.
timeout
)
err
=
e
.
lock
.
Unlock
(
ctx
)
cancel
()
if
err
==
nil
{
e
.
lock
=
nil
}
}
cErr
:=
e
.
client
.
Close
()
if
cErr
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
cErr
)
return
err
}
return
cErr
}
return
err
}
go/pserver/client/etcd_client_test.go
0 → 100644
浏览文件 @
7702ab1a
package
client_test
import
(
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func
TestSelector
(
t
*
testing
.
T
)
{
etcdDir
,
err
:=
ioutil
.
TempDir
(
""
,
""
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cfg
:=
embed
.
NewConfig
()
lpurl
,
_
:=
url
.
Parse
(
"http://localhost:0"
)
lcurl
,
_
:=
url
.
Parse
(
"http://localhost:0"
)
cfg
.
LPUrls
=
[]
url
.
URL
{
*
lpurl
}
cfg
.
LCUrls
=
[]
url
.
URL
{
*
lcurl
}
cfg
.
Dir
=
etcdDir
e
,
err
:=
embed
.
StartEtcd
(
cfg
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
defer
func
()
{
e
.
Close
()
if
err
:=
os
.
RemoveAll
(
etcdDir
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
}()
<-
e
.
Server
.
ReadyNotify
()
port
:=
strings
.
Split
(
e
.
Clients
[
0
]
.
Addr
()
.
String
(),
":"
)[
1
]
endpoint
:=
"127.0.0.1:"
+
port
var
mu
sync
.
Mutex
selectedCount
:=
0
var
wg
sync
.
WaitGroup
selectAndDone
:=
func
(
c
*
client
.
Etcd
)
{
defer
wg
.
Done
()
selected
,
err
:=
c
.
Select
()
if
err
!=
nil
{
panic
(
err
)
}
if
selected
{
mu
.
Lock
()
selectedCount
++
mu
.
Unlock
()
err
=
c
.
Done
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
}
}
c0
:=
client
.
NewEtcd
(
endpoint
)
c1
:=
client
.
NewEtcd
(
endpoint
)
c2
:=
client
.
NewEtcd
(
endpoint
)
c3
:=
client
.
NewEtcd
(
endpoint
)
wg
.
Add
(
3
)
go
selectAndDone
(
c0
)
go
selectAndDone
(
c1
)
go
selectAndDone
(
c2
)
wg
.
Wait
()
// simulate trainer crashed and restarted after the
// initialization process.
wg
.
Add
(
1
)
go
selectAndDone
(
c3
)
wg
.
Wait
()
mu
.
Lock
()
if
selectedCount
!=
1
{
t
.
Fatal
(
"selected count wrong:"
,
selectedCount
)
}
mu
.
Unlock
()
err
=
c0
.
Close
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
err
=
c1
.
Close
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
err
=
c2
.
Close
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
err
=
c3
.
Close
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
}
paddle/framework/CMakeLists.txt
浏览文件 @
7702ab1a
...
...
@@ -7,6 +7,9 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
cc_library
(
lod_tensor SRCS lod_tensor.cc details/lod_tensor.cc DEPS ddim place tensor
)
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor
)
cc_test
(
variable_test SRCS variable_test.cc
)
cc_library
(
scope SRCS scope.cc
)
...
...
paddle/framework/details/lod_tensor.cc
0 → 100644
浏览文件 @
7702ab1a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <memory>
namespace
paddle
{
namespace
framework
{
namespace
details
{
using
LOD
=
LODTensor
::
LOD
;
std
::
shared_ptr
<
LOD
>
SliceLOD
(
const
LOD
&
lod
,
size_t
level_begin
,
size_t
level_end
)
{
auto
new_lod
=
std
::
make_shared
<
LOD
>
();
new_lod
->
reserve
(
level_end
-
level_begin
);
for
(
size_t
i
=
level_begin
;
i
<
level_end
;
i
++
)
{
new_lod
->
emplace_back
(
lod
[
i
]);
}
return
new_lod
;
}
std
::
shared_ptr
<
LOD
>
SliceLOD
(
const
LOD
&
lod
,
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
,
bool
tensor_shared
)
{
// slice the lod.
auto
new_lod
=
std
::
make_shared
<
LOD
>
();
new_lod
->
reserve
(
lod
.
size
()
-
level
);
auto
start
=
lod
.
at
(
level
)[
elem_begin
];
auto
end
=
lod
.
at
(
level
)[
elem_end
];
for
(
auto
it
=
lod
.
begin
()
+
level
;
it
!=
lod
.
end
();
it
++
)
{
auto
it_begin
=
std
::
find
(
it
->
begin
(),
it
->
end
(),
start
);
auto
it_end
=
std
::
find
(
it_begin
,
it
->
end
(),
end
);
PADDLE_ENFORCE
(
it_begin
!=
it
->
end
(),
"error in parsing lod info"
);
PADDLE_ENFORCE
(
it_end
!=
it
->
end
(),
"error in parsing lod info"
);
new_lod
->
emplace_back
(
it_begin
,
it_end
+
1
);
if
(
!
tensor_shared
)
{
// reset offset if tensor is copyed and sliced.
std
::
transform
(
new_lod
->
back
().
begin
(),
new_lod
->
back
().
end
(),
new_lod
->
back
().
begin
(),
[
start
](
int
v
)
{
return
v
-
start
;
});
PADDLE_ENFORCE
(
new_lod
->
back
().
front
()
==
0
,
"error in slice LOD"
);
}
}
return
new_lod
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/framework/details/lod_tensor.h
0 → 100644
浏览文件 @
7702ab1a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
namespace
paddle
{
namespace
framework
{
namespace
details
{
/*
* Slice levels from LOD.
*
* @lod: LOD to slice.
* @level_begin: level to begin slice.
* @level_end: level to end slice.
*/
std
::
shared_ptr
<
LODTensor
::
LOD
>
SliceLOD
(
const
LODTensor
::
LOD
&
lod
,
size_t
level_begin
,
size_t
level_end
);
/*
* Slice elements from a level of LOD.
*
* @lod: LOD to slice.
* @level: which level to slice.
* @elem_begin: element's index to begin slice.
* @elem_end: element's index to end slice.
*/
std
::
shared_ptr
<
LODTensor
::
LOD
>
SliceLOD
(
const
LODTensor
::
LOD
&
lod
,
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
,
bool
tensor_shared
);
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/framework/lod_tensor.cc
0 → 100644
浏览文件 @
7702ab1a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
namespace
paddle
{
namespace
framework
{
LODTensor
LODTensor
::
SliceShared
(
size_t
level_begin
,
size_t
level_end
)
const
{
PADDLE_ENFORCE
(
HasLOD
(),
"has no LOD info, can't be sliced."
);
auto
new_lod
=
details
::
SliceLOD
(
*
lod_start_pos_
,
level_begin
,
level_end
);
// slice levels just need to update LOD info, each level will contains the
// whole tensor_, so no need to modify tensor_.
return
LODTensor
(
tensor_
,
new_lod
);
}
LODTensor
LODTensor
::
SliceShared
(
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
)
const
{
PADDLE_ENFORCE
(
HasLOD
(),
"has no LOD info, can't be sliced."
);
PADDLE_ENFORCE
(
level
<
NumLevels
(),
"level [%d] out of range [%d]"
,
level
,
NumLevels
());
PADDLE_ENFORCE
(
elem_begin
<
NumElements
(
level
),
"element begin [%d] out of range [%d]"
,
elem_begin
,
NumElements
(
level
));
PADDLE_ENFORCE
(
elem_end
<
NumElements
(
level
)
+
1
,
"element end [%d] out of range [%d]"
,
elem_end
,
NumElements
(
level
));
auto
new_lod
=
details
::
SliceLOD
(
*
lod_start_pos_
,
level
,
elem_begin
,
elem_end
,
true
/*tensor_shared*/
);
// slice elements just need to update LOD info, because offsets are not
// changed, so the original tensor_ can be reused.
return
LODTensor
(
tensor_
,
new_lod
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/lod_tensor.h
0 → 100644
浏览文件 @
7702ab1a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#if (!PADDLE_ONLY_CPU)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
/*
* LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class
LODTensor
{
public:
// Level save offsets of each unit.
#ifdef PADDLE_ONLY_CPU
using
Level
=
std
::
vector
<
size_t
>
;
#else
using
Level
=
thrust
::
device_vector
<
size_t
>
;
#endif
// LOD stores offsets of each level of units, the largest units level first,
// then the smaller units level. Each Level stores the offsets of units in
// Tesor.
typedef
std
::
vector
<
Level
>
LOD
;
LODTensor
()
{}
LODTensor
(
const
std
::
shared_ptr
<
Tensor
>
&
tensor
,
const
std
::
shared_ptr
<
LOD
>
&
lod
)
{
Reset
(
tensor
,
lod
);
}
void
Reset
(
const
std
::
shared_ptr
<
Tensor
>
&
tensor
,
const
std
::
shared_ptr
<
LOD
>
&
lod
)
{
tensor_
=
tensor
;
lod_start_pos_
=
lod
;
}
/*
* Get a element from LOD.
*/
size_t
lod_element
(
size_t
level
,
size_t
elem
)
const
{
PADDLE_ENFORCE
(
level
<
NumLevels
(),
"level [%d] out of range [%d]"
,
level
,
NumLevels
());
PADDLE_ENFORCE
(
elem
<
NumElements
(
level
),
"element begin [%d] out of range [%d]"
,
elem
,
NumElements
(
level
));
return
(
*
lod_start_pos_
)[
level
][
elem
];
}
/*
* Number of LODTensor's levels, each level has units of data, for example,
* in the sentence's view, article, paragraph, sentence are 3 levels.
*/
size_t
NumLevels
()
const
{
return
lod_start_pos_
?
lod_start_pos_
->
size
()
:
0UL
;
}
/*
* Number of elements in a level.
*/
size_t
NumElements
(
size_t
level
=
0
)
const
{
PADDLE_ENFORCE
(
level
<
NumLevels
(),
"level [%d] out of range [%d]"
,
level
,
NumLevels
());
// the last offset is the end of last element
return
lod_start_pos_
->
at
(
level
).
size
()
-
1
;
}
/*
* Slice of levels[level_begin:level_end], with tensor copied.
*/
template
<
typename
T
>
LODTensor
SliceCopied
(
size_t
level_begin
,
size_t
level_end
,
const
platform
::
Place
&
dst_place
)
const
;
/*
* Slice of levels[level_begin:level_end], with tensor shared.
*/
LODTensor
SliceShared
(
size_t
level_begin
,
size_t
level_end
)
const
;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor copied.
* @note: low performance in slice lod_start_pos_.
*/
template
<
typename
T
>
LODTensor
SliceCopied
(
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
,
const
platform
::
Place
&
dst_place
)
const
;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor shared.
* @note: low performance in slice lod_start_pos_.
*/
LODTensor
SliceShared
(
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
)
const
;
/*
* Copy other's lod_start_pos_, to share LOD info.
* @note: the LOD info should not be changed.
*/
void
ShareLOD
(
const
LODTensor
&
other
)
{
lod_start_pos_
=
other
.
lod_start_pos_
;
}
/*
* Copy other's lod_start_pos_'s content, free to mutate.
*/
void
CopyLOD
(
const
LODTensor
&
other
)
{
lod_start_pos_
=
std
::
make_shared
<
LOD
>
(
*
other
.
lod_start_pos_
);
}
/*
* Determine whether LODTensor has a valid LOD info.
*/
bool
HasLOD
()
const
{
return
bool
(
lod_start_pos_
);
}
LOD
*
lod
()
const
{
return
lod_start_pos_
.
get
();
}
std
::
shared_ptr
<
Tensor
>
&
tensor
()
{
return
tensor_
;
}
Tensor
*
raw_tensor
()
{
return
tensor_
.
get
();
}
private:
std
::
shared_ptr
<
LOD
>
lod_start_pos_
;
std
::
shared_ptr
<
Tensor
>
tensor_
;
};
}
// namespace framework
}
// namespace paddle
#include "paddle/framework/lod_tensor_impl.h"
paddle/framework/lod_tensor_impl.h
0 → 100644
浏览文件 @
7702ab1a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/details/lod_tensor.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
LODTensor
LODTensor
::
SliceCopied
(
size_t
level_begin
,
size_t
level_end
,
const
platform
::
Place
&
dst_place
)
const
{
PADDLE_ENFORCE
(
HasLOD
(),
"has no LOD info, can't be sliced."
);
auto
new_lod
=
details
::
SliceLOD
(
*
lod_start_pos_
,
level_begin
,
level_end
);
auto
new_tensor
=
std
::
make_shared
<
Tensor
>
();
new_tensor
->
CopyFrom
<
T
>
(
*
tensor_
,
dst_place
);
return
LODTensor
(
new_tensor
,
new_lod
);
}
template
<
typename
T
>
LODTensor
LODTensor
::
SliceCopied
(
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
,
const
platform
::
Place
&
dst_place
)
const
{
PADDLE_ENFORCE
(
HasLOD
(),
"has no LOD info, can't be sliced."
);
PADDLE_ENFORCE
(
level
<
NumLevels
(),
"level [%d] out of range [%d]"
,
level
,
NumLevels
());
PADDLE_ENFORCE
(
elem_begin
<
NumElements
(
level
),
"element begin [%d] out of range [%d]"
,
elem_begin
,
NumElements
(
level
));
PADDLE_ENFORCE
(
elem_end
<
NumElements
(
level
)
+
1
,
"element end [%d] out of range [%d]"
,
elem_end
,
NumElements
(
level
));
auto
new_lod
=
details
::
SliceLOD
(
*
lod_start_pos_
,
level
,
elem_begin
,
elem_end
,
false
/*tensor_shared*/
);
auto
start_idx
=
new_lod
->
front
().
front
();
auto
end_idx
=
new_lod
->
front
().
back
()
-
1
/*the next element's start*/
;
auto
sliced_tensor
=
tensor_
->
Slice
<
T
>
(
start_idx
,
end_idx
);
auto
new_tensor
=
std
::
make_shared
<
Tensor
>
();
new_tensor
->
CopyFrom
<
T
>
(
sliced_tensor
,
dst_place
);
return
LODTensor
(
new_tensor
,
new_lod
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/lod_tensor_test.cc
0 → 100644
浏览文件 @
7702ab1a
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
namespace
paddle
{
namespace
framework
{
class
LODTensorTester
:
public
::
testing
::
Test
{
public:
virtual
void
SetUp
()
override
{
lod_tensor
.
reset
(
new
LODTensor
);
// tensor's batch_size: 30
// 3 levels
// 0 10 20
// 0 5 10 15 20
// 0 2 5 7 10 12 15 20
auto
lod
=
std
::
make_shared
<
LODTensor
::
LOD
>
();
lod
->
push_back
(
std
::
vector
<
size_t
>
{
0
,
10
,
20
});
lod
->
push_back
(
std
::
vector
<
size_t
>
{
0
,
5
,
10
,
15
,
20
});
lod
->
push_back
(
std
::
vector
<
size_t
>
{
0
,
2
,
5
,
7
,
10
,
12
,
15
,
17
,
20
});
auto
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
Resize
({
20
/*batch size*/
,
128
/*dim*/
});
// malloc memory
tensor
->
mutable_data
<
float
>
(
place
);
lod_tensor
->
Reset
(
tensor
,
lod
);
}
protected:
std
::
unique_ptr
<
LODTensor
>
lod_tensor
;
platform
::
CPUPlace
place
;
};
TEST_F
(
LODTensorTester
,
NumLevels
)
{
ASSERT_EQ
(
lod_tensor
->
NumLevels
(),
3UL
);
}
TEST_F
(
LODTensorTester
,
NumElements
)
{
ASSERT_EQ
(
lod_tensor
->
NumElements
(
0
),
2UL
);
ASSERT_EQ
(
lod_tensor
->
NumElements
(
1
),
4UL
);
ASSERT_EQ
(
lod_tensor
->
NumElements
(
2
),
8UL
);
}
TEST_F
(
LODTensorTester
,
SliceShared_Level
)
{
// slice 1 level
for
(
size_t
level
=
0
;
level
<
3UL
;
++
level
)
{
auto
new_lod_tensor
=
lod_tensor
->
SliceShared
(
level
,
level
+
1
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
1UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0UL
),
lod_tensor
->
NumElements
(
level
));
ASSERT_EQ
(
new_lod_tensor
.
tensor
(),
lod_tensor
->
tensor
());
}
// slice 2 level
for
(
size_t
level
=
0
;
level
<
2UL
;
++
level
)
{
auto
new_lod_tensor
=
lod_tensor
->
SliceShared
(
level
,
level
+
2
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
lod_tensor
->
NumElements
(
level
));
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
lod_tensor
->
NumElements
(
level
+
1
));
ASSERT_EQ
(
new_lod_tensor
.
tensor
(),
lod_tensor
->
tensor
());
}
}
TEST_F
(
LODTensorTester
,
SliceCopied_Level
)
{
// slice 1 level
for
(
size_t
level
=
0
;
level
<
3UL
;
++
level
)
{
auto
new_lod_tensor
=
lod_tensor
->
SliceCopied
<
float
>
(
level
,
level
+
1
,
place
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
1UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0UL
),
lod_tensor
->
NumElements
(
level
));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
// slice 2 level
for
(
size_t
level
=
0
;
level
<
2UL
;
++
level
)
{
auto
new_lod_tensor
=
lod_tensor
->
SliceCopied
<
float
>
(
level
,
level
+
2
,
place
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
lod_tensor
->
NumElements
(
level
));
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
lod_tensor
->
NumElements
(
level
+
1
));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
}
TEST_F
(
LODTensorTester
,
SliceShared_Element
)
{
size_t
level
=
0
;
auto
new_lod_tensor
=
lod_tensor
->
SliceShared
(
level
,
0
,
2
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
3UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
4UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
2
),
8UL
);
ASSERT_EQ
(
new_lod_tensor
.
raw_tensor
(),
lod_tensor
->
raw_tensor
());
level
=
1
;
new_lod_tensor
=
lod_tensor
->
SliceShared
(
level
,
0
,
2
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
4UL
);
ASSERT_EQ
(
new_lod_tensor
.
raw_tensor
(),
lod_tensor
->
raw_tensor
());
}
TEST_F
(
LODTensorTester
,
SliceCopied_Element
)
{
size_t
level
=
0
;
auto
new_lod_tensor
=
lod_tensor
->
SliceCopied
<
float
>
(
level
,
0
,
2
,
place
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
3UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
4UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
2
),
8UL
);
ASSERT_NE
(
new_lod_tensor
.
raw_tensor
(),
lod_tensor
->
raw_tensor
());
level
=
1
;
new_lod_tensor
=
lod_tensor
->
SliceCopied
<
float
>
(
level
,
0
,
2
,
place
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
4UL
);
ASSERT_NE
(
new_lod_tensor
.
raw_tensor
(),
lod_tensor
->
raw_tensor
());
level
=
1
;
// LOD is
// 0 5 10
// 0 2 5 7 10
new_lod_tensor
=
lod_tensor
->
SliceCopied
<
float
>
(
level
,
1
,
3
,
place
);
ASSERT_EQ
(
new_lod_tensor
.
NumLevels
(),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
0
),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
NumElements
(
1
),
4UL
);
ASSERT_EQ
(
new_lod_tensor
.
lod_element
(
0
,
0
),
0UL
);
ASSERT_EQ
(
new_lod_tensor
.
lod_element
(
0
,
1
),
5UL
);
ASSERT_EQ
(
new_lod_tensor
.
lod_element
(
1
,
0
),
0UL
);
ASSERT_EQ
(
new_lod_tensor
.
lod_element
(
1
,
1
),
2UL
);
ASSERT_EQ
(
new_lod_tensor
.
lod_element
(
1
,
2
),
5UL
);
ASSERT_EQ
(
new_lod_tensor
.
lod_element
(
1
,
3
),
7UL
);
// TODO(superjom) compare the content of these tensors
}
TEST_F
(
LODTensorTester
,
ShareLOD
)
{
LODTensor
new_lod_tensor
;
new_lod_tensor
.
ShareLOD
(
*
lod_tensor
);
ASSERT_EQ
(
new_lod_tensor
.
lod
(),
lod_tensor
->
lod
());
}
TEST_F
(
LODTensorTester
,
CopyLOD
)
{
LODTensor
new_lod_tensor
;
new_lod_tensor
.
CopyLOD
(
*
lod_tensor
);
ASSERT_NE
(
new_lod_tensor
.
lod
(),
lod_tensor
->
lod
());
}
}
// namespace framework
}
// namespace paddle
paddle/framework/operator.h
浏览文件 @
7702ab1a
...
...
@@ -117,10 +117,10 @@ class OperatorBase {
AttributeMap
attrs_
;
};
class
Operator
Context
{
class
InferShape
Context
{
public:
OperatorContext
(
const
OperatorBase
*
op
,
const
Scope
&
scope
)
:
op_
(
*
op
),
scope_
(
scope
)
{}
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
inputs_
.
at
(
name
).
size
();
...
...
@@ -209,12 +209,6 @@ class OperatorContext {
const
Scope
&
scope_
;
};
class
InferShapeContext
:
public
OperatorContext
{
public:
InferShapeContext
(
const
OperatorBase
*
op
,
const
Scope
&
scope
)
:
OperatorContext
(
op
,
scope
)
{}
};
template
<
typename
T
>
struct
EigenDeviceConverter
;
...
...
@@ -230,11 +224,11 @@ struct EigenDeviceConverter<platform::GPUPlace> {
};
#endif
class
ExecutionContext
:
public
Operator
Context
{
class
ExecutionContext
:
public
InferShape
Context
{
public:
ExecutionContext
(
const
OperatorBase
*
op
,
const
Scope
&
scope
,
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
*
device_context
)
:
Operator
Context
(
op
,
scope
),
device_context_
(
device_context
)
{}
:
InferShape
Context
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
typename
DeviceType
=
...
...
@@ -286,13 +280,13 @@ class OperatorWithKernel : public OperatorBase {
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{
InferShape
(
InferShapeContext
(
this
,
scope
));
InferShape
(
InferShapeContext
(
*
this
,
scope
));
}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
ExecutionContext
(
this
,
scope
,
&
dev_ctx
));
opKernel
->
Compute
(
ExecutionContext
(
*
this
,
scope
,
&
dev_ctx
));
}
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
...
...
paddle/framework/tensor.h
浏览文件 @
7702ab1a
...
...
@@ -18,6 +18,8 @@ limitations under the License. */
#include <cstring>
#include <memory>
#include <typeindex>
#include <vector>
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
...
...
paddle/framework/tensor_test.cc
浏览文件 @
7702ab1a
...
...
@@ -19,7 +19,7 @@ TEST(Tensor, Dims) {
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
tt
;
tt
.
Resize
(
make_ddim
({
2
,
3
,
4
})
);
tt
.
Resize
(
{
2
,
3
,
4
}
);
DDim
dims
=
tt
.
dims
();
ASSERT_EQ
(
arity
(
dims
),
3
);
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
...
...
paddle/function/FunctionTest.cpp
浏览文件 @
7702ab1a
...
...
@@ -93,8 +93,8 @@ TEST(Arguments, Matrix) {
MatrixPtr
matrix
=
Matrix
::
create
(
100
,
200
);
CheckBufferArg
check
=
[
=
](
const
BufferArg
&
arg
)
{
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
2U
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
100
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
200
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
100
U
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
200
U
);
EXPECT_EQ
(
arg
.
data
(),
matrix
->
getData
());
EXPECT_EQ
(
arg
.
matrix
<
DEVICE_TYPE_CPU
>
().
getHeight
(),
matrix
->
getHeight
());
...
...
@@ -112,8 +112,8 @@ TEST(Arguments, Matrix) {
TEST
(
Arguments
,
Vector
)
{
VectorPtr
vector
=
Vector
::
create
(
100
,
false
);
CheckBufferArg
check
=
[
=
](
const
BufferArg
&
arg
)
{
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
1
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
100
);
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
1
U
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
100
U
);
EXPECT_EQ
(
arg
.
data
(),
vector
->
getData
());
CpuVector
inVector
=
arg
.
vector
<
real
,
DEVICE_TYPE_CPU
>
();
...
...
@@ -131,9 +131,9 @@ TEST(Arguments, Vector) {
TEST
(
Arguments
,
CpuSparseMatrix
)
{
CpuSparseMatrix
sparse
(
200
,
300
,
50
);
CheckBufferArg
check
=
[
=
](
const
BufferArg
&
arg
)
{
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
2
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
200
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
300
);
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
2
U
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
200
U
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
300
U
);
EXPECT_EQ
(
arg
.
data
(),
sparse
.
getData
());
// CHECK_EQ(arg.sparse().nnz(), 50);
// CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT);
...
...
@@ -152,10 +152,10 @@ TEST(Arguments, CpuSparseMatrix) {
TEST
(
Arguments
,
BufferArg
)
{
BufferArg
arg
(
nullptr
,
VALUE_TYPE_FLOAT
,
{
1
,
2
,
3
});
CheckBufferArg
check
=
[
=
](
const
BufferArg
&
arg
)
{
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
3
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
1
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
2
);
EXPECT_EQ
(
arg
.
shape
()[
2
],
3
);
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
3
U
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
1
U
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
2
U
);
EXPECT_EQ
(
arg
.
shape
()[
2
],
3
U
);
};
BufferArgs
argments
;
...
...
paddle/function/TensorShapeTest.cpp
浏览文件 @
7702ab1a
...
...
@@ -44,7 +44,7 @@ TEST(TensorShape, GetAndSet) {
EXPECT_EQ
(
t
.
ndims
(),
3U
);
EXPECT_EQ
(
t
.
getElements
(),
6U
);
EXPECT_EQ
(
t
[
1
],
2
);
EXPECT_EQ
(
t
[
1
],
2
U
);
t
.
setDim
(
1
,
100
);
EXPECT_EQ
(
t
.
getElements
(),
300U
);
EXPECT_EQ
(
t
[
1
],
100U
);
...
...
paddle/gserver/layers/SubNestedSequenceLayer.cpp
浏览文件 @
7702ab1a
...
...
@@ -96,7 +96,7 @@ void SubNestedSequenceLayer::calSelectedCols(
for
(
size_t
i
=
0
;
i
<
seqNum
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
beamSize
;
++
j
)
{
if
(
selectedIndices
->
getElement
(
i
,
j
)
==
-
1.
)
break
;
in
t
selSubSeqIdx
=
selectedIndices
->
getElement
(
i
,
j
);
size_
t
selSubSeqIdx
=
selectedIndices
->
getElement
(
i
,
j
);
CHECK_GT
(
inputSeqInfoVec_
[
i
].
size
()
-
1
,
selSubSeqIdx
);
size_t
subSeqLen
=
inputSeqInfoVec_
[
i
][
selSubSeqIdx
+
1
]
-
...
...
@@ -135,7 +135,7 @@ void SubNestedSequenceLayer::forward(PassType passType) {
CHECK
(
inputSeq
.
hasSubseq
())
<<
"The first input of SubNestSequence layer "
<<
"must be a nested sequence."
;
const
MatrixPtr
selectedIndices
=
getInputValue
(
1
);
CHECK_EQ
(
inputSeq
.
getNumSequences
(
),
selectedIndices
->
getHeight
());
CHECK_EQ
(
size_t
(
inputSeq
.
getNumSequences
()
),
selectedIndices
->
getHeight
());
if
(
dynamic_cast
<
GpuMatrix
*>
(
selectedIndices
.
get
()))
{
/*
...
...
paddle/gserver/tests/test_KmaxSeqScore.cpp
浏览文件 @
7702ab1a
...
...
@@ -88,7 +88,7 @@ void checkLayerOut(vector<vector<int>> groundTruth,
TEST
(
Layer
,
kmaxSeqScoreLayer
)
{
const
size_t
maxBeamSize
=
100
;
in
t
beamSize
=
1
+
(
rand
()
%
maxBeamSize
);
size_
t
beamSize
=
1
+
(
rand
()
%
maxBeamSize
);
vector
<
int
>
seqStartPosition
;
vector
<
int
>
subSeqStartPosition
;
...
...
paddle/platform/CMakeLists.txt
浏览文件 @
7702ab1a
...
...
@@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory
(
dynload
)
cc_test
(
enforce_test SRCS enforce_test.cc
)
cc_test
(
enforce_test SRCS enforce_test.cc
DEPS stringpiece
)
IF
(
WITH_GPU
)
set
(
GPU_CTX_DEPS dynload_cuda dynamic_loader
)
...
...
paddle/platform/enforce_test.cc
浏览文件 @
7702ab1a
...
...
@@ -13,6 +13,10 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"
using
StringPiece
=
paddle
::
string
::
Piece
;
using
paddle
::
string
::
HasPrefix
;
TEST
(
ENFORCE
,
OK
)
{
PADDLE_ENFORCE
(
true
,
"Enforce is ok %d now %f"
,
123
,
0.345
);
...
...
@@ -22,19 +26,15 @@ TEST(ENFORCE, OK) {
}
TEST
(
ENFORCE
,
FAILED
)
{
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
PADDLE_ENFORCE
(
false
,
"Enforce is not ok %d at all"
,
123
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
// your error handling code here
in_catch
=
true
;
std
::
string
msg
=
"Enforce is not ok 123 at all"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"Enforce is not ok 123 at all"
));
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE
,
NO_ARG_OK
)
{
...
...
@@ -47,41 +47,27 @@ TEST(ENFORCE, NO_ARG_OK) {
TEST
(
ENFORCE_EQ
,
NO_EXTRA_MSG_FAIL
)
{
int
a
=
2
;
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
PADDLE_ENFORCE_EQ
(
a
,
1
+
3
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce a == 1 + 3 failed, 2 != 4"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce a == 1 + 3 failed, 2 != 4"
);
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_EQ
,
EXTRA_MSG_FAIL
)
{
int
a
=
2
;
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
PADDLE_ENFORCE_EQ
(
a
,
1
+
3
,
"%s size not match"
,
"their"
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce a == 1 + 3 failed, 2 != 4
\n
their size not match"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce a == 1 + 3 failed, 2 != 4
\n
their size not match"
);
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_NE
,
OK
)
{
...
...
@@ -89,42 +75,32 @@ TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE
(
1.0
,
2UL
);
}
TEST
(
ENFORCE_NE
,
FAIL
)
{
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
// 2UL here to check data type compatible
PADDLE_ENFORCE_NE
(
1.0
,
1UL
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce 1.0 != 1UL failed, 1.000000 == 1"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce 1.0 != 1UL failed, 1.000000 == 1"
))
<<
error
.
what
()
<<
" does not have expected prefix"
;
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_GT
,
OK
)
{
PADDLE_ENFORCE_GT
(
2
,
1
);
}
TEST
(
ENFORCE_GT
,
FAIL
)
{
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
// 2UL here to check data type compatible
PADDLE_ENFORCE_GT
(
1
,
2UL
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce 1 > 2UL failed, 1 <= 2"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce 1 > 2UL failed, 1 <= 2"
));
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_GE
,
OK
)
{
...
...
@@ -134,21 +110,16 @@ TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE
(
3.21
,
2UL
);
}
TEST
(
ENFORCE_GE
,
FAIL
)
{
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
PADDLE_ENFORCE_GE
(
1
,
2UL
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce 1 >= 2UL failed, 1 < 2"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce 1 >= 2UL failed, 1 < 2"
));
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_LE
,
OK
)
{
...
...
@@ -159,21 +130,16 @@ TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE
(
2UL
,
3.2
);
}
TEST
(
ENFORCE_LE
,
FAIL
)
{
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
PADDLE_ENFORCE_GT
(
1
,
2UL
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce 1 > 2UL failed, 1 <= 2"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce 1 > 2UL failed, 1 <= 2"
));
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_LT
,
OK
)
{
...
...
@@ -182,21 +148,15 @@ TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT
(
2UL
,
3
);
}
TEST
(
ENFORCE_LT
,
FAIL
)
{
bool
in_catch
=
false
;
bool
caught_exception
=
false
;
try
{
PADDLE_ENFORCE_LT
(
1UL
,
0.12
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"enforce 1UL < 0.12 failed, 1 >= 0.12"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"enforce 1UL < 0.12 failed, 1 >= 0.12"
));
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
TEST
(
ENFORCE_NOT_NULL
,
OK
)
{
...
...
@@ -205,20 +165,14 @@ TEST(ENFORCE_NOT_NULL, OK) {
delete
a
;
}
TEST
(
ENFORCE_NOT_NULL
,
FAIL
)
{
bool
in_catch
=
false
;
int
*
a
{
nullptr
};
bool
caught_exception
=
false
;
try
{
int
*
a
=
nullptr
;
PADDLE_ENFORCE_NOT_NULL
(
a
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
in_catch
=
true
;
const
std
::
string
msg
=
"a should not be null"
;
const
char
*
what
=
error
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
}
caught_exception
=
true
;
EXPECT_TRUE
(
HasPrefix
(
StringPiece
(
error
.
what
()),
"a should not be null"
));
}
ASSERT_TRUE
(
in_catch
);
EXPECT_TRUE
(
caught_exception
);
}
paddle/trainer/NewRemoteParameterUpdater.cpp
浏览文件 @
7702ab1a
...
...
@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init(
// create parameter server client.
if
(
useEtcd_
)
{
parameterClient_
=
paddle_new_etcd_pserver_client
(
(
char
*
)
pserverSpec_
.
c_str
(),
FLAGS_trainer_id
==
0
);
parameterClient_
=
paddle_new_etcd_pserver_client
((
char
*
)
pserverSpec_
.
c_str
()
);
}
else
{
parameterClient_
=
paddle_new_pserver_client
((
char
*
)
pserverSpec_
.
c_str
(),
FLAGS_trainer_id
==
0
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录