Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bc146e8f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
bc146e8f
编写于
8月 01, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:baidu/Paddle into feature/backward
上级
213fdad1
61ebacbc
变更
51
隐藏空白更改
内联
并排
Showing
51 changed file
with
2282 addition
and
442 deletion
+2282
-442
cmake/external/eigen.cmake
cmake/external/eigen.cmake
+1
-10
cmake/flags.cmake
cmake/flags.cmake
+1
-1
go/cmd/master/master.go
go/cmd/master/master.go
+24
-4
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+26
-5
go/glide.lock
go/glide.lock
+168
-8
go/glide.yaml
go/glide.yaml
+11
-0
go/master/etcd_client.go
go/master/etcd_client.go
+19
-6
go/master/inmem_store.go
go/master/inmem_store.go
+5
-0
go/master/service.go
go/master/service.go
+1
-0
go/master/service_test.go
go/master/service_test.go
+68
-0
go/pserver/client/c/cclient.go
go/pserver/client/c/cclient.go
+3
-3
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+56
-42
paddle/cuda/src/hl_cuda_sequence.cu
paddle/cuda/src/hl_cuda_sequence.cu
+1
-2
paddle/framework/detail/tensor-inl.h
paddle/framework/detail/tensor-inl.h
+12
-30
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+1
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+2
-2
paddle/framework/operator.h
paddle/framework/operator.h
+106
-75
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+32
-6
paddle/framework/tensor.h
paddle/framework/tensor.h
+9
-13
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+9
-9
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+60
-0
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+3
-3
paddle/memory/memcpy.h
paddle/memory/memcpy.h
+26
-0
paddle/memory/memory.cc
paddle/memory/memory.cc
+1
-0
paddle/memory/memory.h
paddle/memory/memory.h
+37
-3
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+5
-0
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+10
-11
paddle/operators/add_op.h
paddle/operators/add_op.h
+6
-5
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+9
-9
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+8
-8
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+7
-9
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+5
-7
paddle/operators/recurrent_network_op.cc
paddle/operators/recurrent_network_op.cc
+419
-0
paddle/operators/recurrent_network_op.h
paddle/operators/recurrent_network_op.h
+216
-0
paddle/operators/recurrent_network_op_test.cc
paddle/operators/recurrent_network_op_test.cc
+400
-0
paddle/operators/rnn_design.md
paddle/operators/rnn_design.md
+239
-0
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+7
-7
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+4
-6
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+8
-9
paddle/operators/sgd_op.h
paddle/operators/sgd_op.h
+5
-5
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+5
-7
paddle/operators/sigmoid_op.h
paddle/operators/sigmoid_op.h
+4
-5
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+7
-9
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+4
-4
paddle/operators/type_alias.h
paddle/operators/type_alias.h
+3
-1
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+90
-1
paddle/platform/device_context.h
paddle/platform/device_context.h
+35
-109
paddle/platform/enforce.h
paddle/platform/enforce.h
+5
-5
paddle/pybind/CMakeLists.txt
paddle/pybind/CMakeLists.txt
+1
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+6
-0
python/paddle/v2/framework/tests/test_recurrent_op.py
python/paddle/v2/framework/tests/test_recurrent_op.py
+92
-0
未找到文件。
cmake/external/eigen.cmake
浏览文件 @
bc146e8f
...
...
@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add
(
extern_eigen3
${
EXTERNAL_PROJECT_LOG_ARGS
}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY
"https://github.com/RLovelett/eigen.git"
GIT_TAG
"
a46d2e7337c4656f00abe54a8115f6d76153a048
"
GIT_TAG
"
master
"
PREFIX
${
EIGEN_SOURCE_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
cmake/flags.cmake
浏览文件 @
bc146e8f
...
...
@@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
LIST
(
APPEND CUDA_NVCC_FLAGS -std=c++11
)
LIST
(
APPEND CUDA_NVCC_FLAGS -std=c++11
--default-stream per-thread
)
LIST
(
APPEND CUDA_NVCC_FLAGS --use_fast_math
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
...
...
go/cmd/master/master.go
浏览文件 @
bc146e8f
...
...
@@ -19,6 +19,8 @@ import (
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"strings"
"time"
...
...
@@ -68,6 +70,20 @@ func main() {
store
=
&
master
.
InMemStore
{}
}
shutdown
:=
func
()
{
log
.
Infoln
(
"shutting down gracefully"
)
err
:=
store
.
Shutdown
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
}
// Guaranteed to run even panic happens.
defer
shutdown
()
c
:=
make
(
chan
os
.
Signal
,
1
)
signal
.
Notify
(
c
,
os
.
Interrupt
)
s
,
err
:=
master
.
NewService
(
store
,
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
if
err
!=
nil
{
log
.
Fatal
(
err
)
...
...
@@ -84,8 +100,12 @@ func main() {
log
.
Fatal
(
err
)
}
err
=
http
.
Serve
(
l
,
nil
)
if
err
!=
nil
{
log
.
Fatal
(
err
)
}
go
func
()
{
err
=
http
.
Serve
(
l
,
nil
)
if
err
!=
nil
{
log
.
Fatal
(
err
)
}
}()
<-
c
}
go/cmd/pserver/pserver.go
浏览文件 @
bc146e8f
...
...
@@ -18,6 +18,8 @@ import (
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"time"
...
...
@@ -33,7 +35,8 @@ func main() {
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of this pserver, should be larger or equal than 0"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Duration
(
"etcd-timeout"
,
5
*
time
.
Second
,
"timeout for etcd calls"
)
dialTimeout
:=
flag
.
Duration
(
"dial-timeout"
,
5
*
time
.
Second
,
"dial timeout"
)
etcdTTL
:=
flag
.
Int
(
"etcd-ttl"
,
5
,
"etcd time to live in seconds"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Duration
(
"checkpoint-interval"
,
600
*
time
.
Second
,
"save checkpoint per interval seconds"
)
...
...
@@ -53,7 +56,7 @@ func main() {
if
*
index
>=
0
{
idx
=
*
index
}
else
{
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
etcdTimeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
dialTimeout
,
*
etcdTTL
)
idx
,
err
=
e
.
Register
(
*
port
)
candy
.
Must
(
err
)
...
...
@@ -67,6 +70,20 @@ func main() {
}
}
shutdown
:=
func
()
{
log
.
Infoln
(
"shutting down gracefully"
)
sErr
:=
e
.
Shutdown
()
if
sErr
!=
nil
{
log
.
Errorln
(
sErr
)
}
}
// Guaranteed to run even panic happens.
defer
shutdown
()
c
:=
make
(
chan
os
.
Signal
,
1
)
signal
.
Notify
(
c
,
os
.
Interrupt
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
candy
.
Must
(
err
)
...
...
@@ -77,7 +94,11 @@ func main() {
l
,
err
:=
net
.
Listen
(
"tcp"
,
":"
+
strconv
.
Itoa
(
*
port
))
candy
.
Must
(
err
)
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
err
=
http
.
Serve
(
l
,
nil
)
candy
.
Must
(
err
)
go
func
()
{
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
err
=
http
.
Serve
(
l
,
nil
)
candy
.
Must
(
err
)
}()
<-
c
}
go/glide.lock
浏览文件 @
bc146e8f
hash:
a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-07-
11T10:04:40.786745417
+08:00
hash:
2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-07-
29T07:34:48.722757905
+08:00
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
version: c
b2a496c4ddd1c87a9f280e116649b599999ec79
version: c
31bec0f29facff13f7c3e3d948e55dd6689ed42
subpackages:
- alarm
- auth
- auth/authpb
- client
- clientv3
- clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages:
...
...
@@ -17,14 +107,61 @@ imports:
- proto
- name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/sirupsen/logrus
version:
7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
version:
a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
...
...
@@ -36,11 +173,15 @@ imports:
- lex/httplex
- trace
- name: golang.org/x/sys
version: abf9c25f54453410d0c6668e519582a9e1115027
version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages:
- unix
- name: golang.org/x/text
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- secure/bidirule
- transform
...
...
@@ -60,4 +201,23 @@ imports:
- stats
- tap
- transport
testImports: []
- name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
go/glide.yaml
浏览文件 @
bc146e8f
...
...
@@ -6,8 +6,19 @@ import:
subpackages
:
-
clientv3
-
clientv3/concurrency
-
embed
-
etcdserver
-
package
:
github.com/namsral/flag
version
:
^1.7.4-pre
-
package
:
github.com/sirupsen/logrus
version
:
^1.0.0
-
package
:
github.com/topicai/candy
-
package
:
golang.org/x/crypto
vcs
:
git
repo
:
https://github.com/golang/crypto.git
-
package
:
golang.org/x/sys
vcs
:
git
repo
:
https://github.com/golang/sys.git
-
package
:
golang.org/x/text
vcs
:
git
repo
:
https://github.com/golang/text.git
go/master/etcd_client.go
浏览文件 @
bc146e8f
...
...
@@ -39,15 +39,12 @@ type EtcdClient struct {
statePath
string
client
*
clientv3
.
Client
lock
*
concurrency
.
Mutex
sess
*
concurrency
.
Session
}
// NewEtcdClient creates a new EtcdClient.
func
NewEtcdClient
(
endpoints
[]
string
,
addr
string
,
lockPath
,
addrPath
,
statePath
string
,
ttlSec
int
)
(
*
EtcdClient
,
error
)
{
log
.
Debugf
(
"Connecting to etcd at %v"
,
endpoints
)
// TODO(helin): gracefully shutdown etcd store. Because etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
DialTimeout
:
dialTimeout
,
...
...
@@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log
.
Debug
f
(
"Trying to acquire lock at %s."
,
lockPath
)
log
.
Info
f
(
"Trying to acquire lock at %s."
,
lockPath
)
err
=
lock
.
Lock
(
context
.
TODO
())
if
err
!=
nil
{
return
nil
,
err
}
log
.
Debug
f
(
"Successfully acquired lock at %s."
,
lockPath
)
log
.
Info
f
(
"Successfully acquired lock at %s."
,
lockPath
)
put
:=
clientv3
.
OpPut
(
addrPath
,
addr
)
resp
,
err
:=
cli
.
Txn
(
context
.
Background
())
.
If
(
lock
.
IsOwner
())
.
Then
(
put
)
.
Commit
()
...
...
@@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath
:
statePath
,
client
:
cli
,
lock
:
lock
,
sess
:
sess
,
}
return
e
,
nil
...
...
@@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) {
return
state
,
nil
}
// Shutdown shuts down the etcd client gracefully.
func
(
e
*
EtcdClient
)
Shutdown
()
error
{
err
:=
e
.
sess
.
Close
()
newErr
:=
e
.
client
.
Close
()
if
newErr
!=
nil
{
if
err
==
nil
{
err
=
newErr
}
else
{
log
.
Errorln
(
newErr
)
}
}
return
err
}
// GetKey gets the value by the specify key.
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
time
.
Duration
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
...
...
go/master/inmem_store.go
浏览文件 @
bc146e8f
...
...
@@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return
m
.
buf
,
nil
}
// Shutdown shuts down the in mem store.
func
(
m
*
InMemStore
)
Shutdown
()
error
{
return
nil
}
go/master/service.go
浏览文件 @
bc146e8f
...
...
@@ -50,6 +50,7 @@ var ErrPassAfter = errors.New("pass number larger than master")
type
Store
interface
{
Save
([]
byte
)
error
Load
()
([]
byte
,
error
)
Shutdown
()
error
}
// Chunk is a chunk of data consisted of several data instances.
...
...
go/master/service_test.go
0 → 100644
浏览文件 @
bc146e8f
package
master_test
import
(
"os"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func
TestNewServiceWithEtcd
(
t
*
testing
.
T
)
{
// setup an embed etcd server
etcdDir
,
err
:=
ioutils
.
TempDir
(
""
,
""
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cfg
:=
embed
.
NewConfig
()
cfg
.
Dir
=
etcdDir
e
,
err
:=
embed
.
StartEtcd
(
cfg
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
defer
func
()
{
e
.
Close
()
if
err
:=
os
.
RemoveAll
(
etcdDir
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
}()
select
{
case
<-
e
.
Server
.
ReadyNotify
()
:
t
.
Log
(
"Server is ready!"
)
case
<-
time
.
After
(
60
*
time
.
Second
)
:
e
.
Server
.
Stop
()
// trigger a shutdown
t
.
Fatal
(
"Server took too long to start!"
)
}
ep
:=
[]
string
{
"127.0.0.1:2379"
}
masterAddr
:=
"127.0.0.1:3306"
store
,
err
:=
master
.
NewEtcdClient
(
ep
,
masterAddr
,
master
.
DefaultLockPath
,
master
.
DefaultAddrPath
,
master
.
DefaultStatePath
,
30
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
_
,
err
=
master
.
NewService
(
store
,
10
,
10
,
3
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
ep
,
DialTimeout
:
3
*
time
.
Second
,
})
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
v
,
err
:=
master
.
GetKey
(
cli
,
master
.
DefaultAddrPath
,
3
*
time
.
Second
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
err
:=
cli
.
Close
();
err
!=
nil
{
t
.
Fatal
(
err
)
}
// test master process registry itself into etcd server.
assert
.
Equal
(
t
,
masterAddr
,
v
,
"master process should registry itself into etcd server."
)
}
go/pserver/client/c/cclient.go
浏览文件 @
bc146e8f
...
...
@@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client
func
add
(
c
*
client
.
Client
)
C
.
paddle_pserver_client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
cli
ent
:=
curHandle
cli
:=
curHandle
curHandle
++
handleMap
[
cli
ent
]
=
c
return
cli
ent
handleMap
[
cli
]
=
c
return
cli
}
func
get
(
client
C
.
paddle_pserver_client
)
*
client
.
Client
{
...
...
go/pserver/etcd_client.go
浏览文件 @
bc146e8f
...
...
@@ -34,16 +34,19 @@ const (
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
retryTimeout
=
5
*
time
.
Second
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type
EtcdClient
struct
{
numPservers
int
etcdEndpoints
string
etcdClient
*
clientv3
.
Client
// etcdTimeout is also used as retry intervals.
etcdTimeout
time
.
Duration
numPservers
int
endpoints
string
client
*
clientv3
.
Client
sess
*
concurrency
.
Session
dialTimeout
time
.
Duration
ttlSec
int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP
string
// desired number of pservers in the job.
...
...
@@ -52,11 +55,12 @@ type EtcdClient struct {
}
// NewEtcdClient creates an EtcdClient
func
NewEtcdClient
(
endpoints
string
,
numPservers
int
,
timeout
time
.
Duration
)
*
EtcdClient
{
func
NewEtcdClient
(
endpoints
string
,
numPservers
int
,
dialtimeout
time
.
Duration
,
ttlSec
int
)
*
EtcdClient
{
return
&
EtcdClient
{
etcdTimeout
:
timeout
,
numPservers
:
numPservers
,
etcdEndpoints
:
endpoints
,
dialTimeout
:
dialtimeout
,
ttlSec
:
ttlSec
,
numPservers
:
numPservers
,
endpoints
:
endpoints
,
}
}
...
...
@@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
//
// Register returns the index of the current pserver.
func
(
e
*
EtcdClient
)
Register
(
port
int
)
(
int
,
error
)
{
var
err
error
e
.
externalIP
,
err
=
networkhelper
.
GetExternalIP
()
if
err
!=
nil
{
...
...
@@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) {
}
// initialize connection to etcd.
ep
:=
strings
.
Split
(
e
.
e
tcdE
ndpoints
,
","
)
ep
:=
strings
.
Split
(
e
.
endpoints
,
","
)
for
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
ep
,
DialTimeout
:
e
.
etcd
Timeout
,
DialTimeout
:
e
.
dial
Timeout
,
})
if
err
!=
nil
{
log
.
Errorf
(
"connect to etcd error: %v"
,
err
)
time
.
Sleep
(
e
.
etcdTimeout
)
time
.
Sleep
(
retryTimeout
)
continue
}
e
.
client
=
cli
sess
,
err
:=
concurrency
.
NewSession
(
cli
,
concurrency
.
WithTTL
(
e
.
ttlSec
))
if
err
!=
nil
{
log
.
Errorf
(
"create etcd session error: %v"
,
err
)
time
.
Sleep
(
retryTimeout
)
continue
}
e
.
etcdClient
=
cli
log
.
Debugf
(
"inited client to %s"
,
e
.
e
tcdE
ndpoints
)
e
.
sess
=
sess
log
.
Debugf
(
"inited client to %s"
,
e
.
endpoints
)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
...
...
@@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel
()
if
err
!=
nil
{
log
.
Warn
(
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
continue
}
break
...
...
@@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) {
// wait and set s.desired init value
for
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
resp
,
err
:=
e
.
etcdC
lient
.
Get
(
ctx
,
PsDesired
)
resp
,
err
:=
e
.
c
lient
.
Get
(
ctx
,
PsDesired
)
cancel
()
if
err
!=
nil
{
log
.
Errorf
(
"getting %s error: %v"
,
PsDesired
,
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
continue
}
if
len
(
resp
.
Kvs
)
!=
0
{
e
.
desired
,
err
=
strconv
.
Atoi
(
string
(
resp
.
Kvs
[
0
]
.
Value
))
if
err
!=
nil
{
log
.
Errorf
(
"value of %s invalid %v
\n
"
,
PsDesired
,
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
// NOTE: wait util ps_desired value change
continue
}
...
...
@@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel
()
if
err
!=
nil
{
log
.
Warn
(
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
continue
}
break
...
...
@@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) {
}
func
(
e
*
EtcdClient
)
initDesiredPservers
(
ctx
context
.
Context
,
numPservers
int
)
(
*
clientv3
.
TxnResponse
,
error
)
{
return
concurrency
.
NewSTM
(
e
.
etcdC
lient
,
func
(
c
concurrency
.
STM
)
error
{
return
concurrency
.
NewSTM
(
e
.
c
lient
,
func
(
c
concurrency
.
STM
)
error
{
dsStr
:=
c
.
Get
(
PsDesired
)
if
dsStr
==
""
{
c
.
Put
(
PsDesired
,
strconv
.
Itoa
(
numPservers
))
c
.
Put
(
PsDesired
,
strconv
.
Itoa
(
numPservers
)
,
clientv3
.
WithLease
(
e
.
sess
.
Lease
())
)
}
return
nil
},
concurrency
.
WithAbortContext
(
ctx
),
concurrency
.
WithIsolation
(
concurrency
.
RepeatableReads
))
...
...
@@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
// registerPserverEtcd registers pserver node on etcd using transaction.
func
(
e
*
EtcdClient
)
registerPserverEtcd
(
ctx
context
.
Context
,
port
int
)
(
int
,
error
)
{
var
idx
int
_
,
err
:=
concurrency
.
NewSTM
(
e
.
etcdC
lient
,
func
(
c
concurrency
.
STM
)
error
{
_
,
err
:=
concurrency
.
NewSTM
(
e
.
c
lient
,
func
(
c
concurrency
.
STM
)
error
{
registered
:=
false
for
i
:=
0
;
i
<
e
.
desired
;
i
++
{
psKey
:=
PsPath
+
strconv
.
Itoa
(
i
)
...
...
@@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
log
.
Debugf
(
"got value (%s) for key: %s"
,
ps
,
psKey
)
if
ps
==
""
{
resp
,
err
:=
e
.
etcdClient
.
Grant
(
context
.
TODO
(),
5
)
if
err
!=
nil
{
log
.
Fatal
(
err
)
}
// find the first id and write info
pserverAddr
:=
e
.
externalIP
+
":"
+
strconv
.
Itoa
(
port
)
c
.
Put
(
psKey
,
pserverAddr
,
clientv3
.
WithLease
(
resp
.
ID
))
c
.
Put
(
psKey
,
pserverAddr
,
clientv3
.
WithLease
(
e
.
sess
.
Lease
()
))
log
.
Debugf
(
"set pserver node %s with value %s"
,
psKey
,
pserverAddr
)
ch
,
kaerr
:=
e
.
etcdClient
.
KeepAlive
(
context
.
TODO
(),
resp
.
ID
)
if
kaerr
!=
nil
{
log
.
Errorf
(
"keepalive etcd node error: %v"
,
kaerr
)
return
kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go
func
(
ch
<-
chan
*
clientv3
.
LeaseKeepAliveResponse
)
{
ka
:=
<-
ch
log
.
Debugf
(
"keepalive: %d
\n
"
,
ka
.
TTL
)
}(
ch
)
log
.
Debug
(
"register finished"
)
idx
=
i
registered
=
true
...
...
@@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
// GetKey gets the value by the specified key
func
(
e
*
EtcdClient
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
resp
,
err
:=
e
.
etcdC
lient
.
Get
(
ctx
,
key
)
resp
,
err
:=
e
.
c
lient
.
Get
(
ctx
,
key
)
cancel
()
if
err
!=
nil
{
return
[]
byte
{},
err
...
...
@@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
_
,
err
:=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
),
clientv3
.
WithLease
(
e
.
sess
.
Lease
()
))
cancel
()
return
err
}
// Shutdown shuts down the etcd client gracefully.
func
(
e
*
EtcdClient
)
Shutdown
()
error
{
var
err
error
if
e
.
sess
!=
nil
{
err
=
e
.
sess
.
Close
()
}
if
e
.
client
!=
nil
{
newErr
:=
e
.
client
.
Close
()
if
newErr
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
newErr
)
}
else
{
err
=
newErr
}
}
}
return
err
}
paddle/cuda/src/hl_cuda_sequence.cu
浏览文件 @
bc146e8f
...
...
@@ -269,8 +269,7 @@ void hl_sequence2batch_copy_padding(real* batch,
int
blockDimY
=
CUDA_BLOCK_SIZE
/
blockDimX
;
dim3
threads
(
blockDimX
,
blockDimY
);
int
gridDimX
=
(
maxSequenceLength
*
blockDimX
+
CUDA_BLOCK_SIZE
-
1
)
/
CUDA_BLOCK_SIZE
;
int
gridDimX
=
(
maxSequenceLength
+
blockDimY
-
1
)
/
blockDimY
;
int
gridDimY
=
numSequences
;
dim3
grid
(
gridDimX
,
gridDimY
);
...
...
paddle/framework/detail/tensor-inl.h
浏览文件 @
bc146e8f
...
...
@@ -83,56 +83,38 @@ inline void Tensor::ShareDataWith(const Tensor& src) {
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CPUDeviceContext
&
ctx
)
{
const
platform
::
Place
&
dst_place
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_place
=
ctx
.
GetPlace
();
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
src_place
))
{
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_cpu_place
(
dst_place
)
)
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
);
}
#ifndef PADDLE_ONLY_CPU
else
if
(
platform
::
is_gpu_place
(
src_place
))
{
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
#endif
}
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_place
=
ctx
.
GetPlace
();
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
src_place
))
{
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
,
ctx
.
stream
());
}
else
if
(
platform
::
is_gpu_place
(
src
_place
))
{
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst
_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
ctx
.
stream
());
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
}
#endif
}
template
<
typename
T
>
inline
Tensor
Tensor
::
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
...
...
paddle/framework/net_op_test.cc
浏览文件 @
bc146e8f
...
...
@@ -11,8 +11,7 @@ static int run_cnt = 0;
class
TestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
Run
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
,
...
...
paddle/framework/operator.cc
浏览文件 @
bc146e8f
...
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
framework
{
template
<
>
Eigen
::
DefaultDevice
*
Kernel
Context
::
GetEigenDevice
<
Eigen
::
DefaultDevice
*
Execution
Context
::
GetEigenDevice
<
platform
::
CPUPlace
,
Eigen
::
DefaultDevice
>
()
const
{
return
device_context_
.
get_eigen_device
<
Eigen
::
DefaultDevice
>
();
}
...
...
@@ -28,7 +28,7 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice<
#ifndef PADDLE_ONLY_CPU
template
<
>
Eigen
::
GpuDevice
*
Kernel
Context
::
GetEigenDevice
<
platform
::
GPUPlace
,
Eigen
::
GpuDevice
>
()
const
{
Execution
Context
::
GetEigenDevice
<
platform
::
GPUPlace
,
Eigen
::
GpuDevice
>
()
const
{
return
device_context_
.
get_eigen_device
<
Eigen
::
GpuDevice
>
();
}
#endif
...
...
paddle/framework/operator.h
浏览文件 @
bc146e8f
...
...
@@ -32,22 +32,9 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
struct
EigenDeviceConverter
;
template
<>
struct
EigenDeviceConverter
<
platform
::
CPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
DefaultDevice
;
};
#ifndef PADDLE_ONLY_CPU
template
<>
struct
EigenDeviceConverter
<
platform
::
GPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
GpuDevice
;
};
#endif
class
OperatorBase
;
class
InferShapeContext
;
class
ExecutionContext
;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
...
...
@@ -126,46 +113,127 @@ class OperatorBase {
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
int
>>
in_out_idxs_
;
};
class
Kernel
Context
{
class
Operator
Context
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
OperatorContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
)
:
op_
(
*
op
),
scope_
(
scope
)
{}
size_t
InputSize
()
const
{
return
op_
.
inputs_
.
size
();
}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
size_t
OutputSize
()
const
{
return
op_
.
outputs_
.
size
();
}
const
Variable
*
InputVar
(
const
size_t
&
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
.
at
(
index
));
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]
);
Variable
*
Output
Var
(
const
size_t
&
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
.
at
(
index
)
);
}
const
Variable
*
Input
(
const
std
::
string
&
name
)
const
{
const
Variable
*
Input
Var
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Input
(
name
));
}
const
Variable
*
Output
(
const
std
::
string
&
name
)
const
{
Variable
*
OutputVar
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Output
(
name
));
}
const
std
::
vector
<
const
Variable
*>
Inputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
const
Variable
*>
MultiInputVar
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Inputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
res
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(
),
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
const
std
::
vector
<
const
Variable
*>
Outputs
(
const
std
::
string
&
name
)
const
{
std
::
vector
<
const
Variable
*>
MultiOutputVar
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Outputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
res
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(
),
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
template
<
typename
T
>
const
T
*
Input
(
const
size_t
&
index
)
const
{
return
&
(
InputVar
(
index
)
->
Get
<
T
>
());
}
template
<
typename
T
>
T
*
Output
(
const
size_t
&
index
)
const
{
return
OutputVar
(
index
)
->
GetMutable
<
T
>
();
}
template
<
typename
T
>
const
T
*
Input
(
const
std
::
string
&
name
)
const
{
return
&
(
InputVar
(
name
)
->
Get
<
T
>
());
}
template
<
typename
T
>
T
*
Output
(
const
std
::
string
&
name
)
const
{
return
OutputVar
(
name
)
->
GetMutable
<
T
>
();
}
template
<
typename
T
>
const
std
::
vector
<
const
T
*>
MultiInput
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Inputs
(
name
);
std
::
vector
<
const
T
*>
res
;
res
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
&
scope_
->
GetVariable
(
name
)
->
Get
<
T
>
();
});
return
res
;
}
template
<
typename
T
>
std
::
vector
<
const
T
*>
MultiOutput
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Outputs
(
name
);
std
::
vector
<
const
T
*>
res
;
res
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
)
->
GetMutable
<
T
>
();
});
return
res
;
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
};
class
InferShapeContext
:
public
OperatorContext
{
public:
InferShapeContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
)
:
OperatorContext
(
op
,
scope
)
{}
};
template
<
typename
T
>
struct
EigenDeviceConverter
;
template
<>
struct
EigenDeviceConverter
<
platform
::
CPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
DefaultDevice
;
};
#ifndef PADDLE_ONLY_CPU
template
<>
struct
EigenDeviceConverter
<
platform
::
GPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
GpuDevice
;
};
#endif
class
ExecutionContext
:
public
OperatorContext
{
public:
ExecutionContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
OperatorContext
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
EigenDeviceConverter
<
PlaceType
>::
EigenDeviceType
>
...
...
@@ -173,38 +241,23 @@ class KernelContext {
platform
::
Place
GetPlace
()
const
{
return
device_context_
.
GetPlace
();
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
class
OpKernel
{
public:
/**
*
Kernel
Context is the only parameter of Kernel Run function.
*
Execution
Context is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
*
Kernel
Context. User should construct it before run the Operator.
*
Execution
Context. User should construct it before run the Operator.
*/
virtual
void
Compute
(
const
Kernel
Context
&
context
)
const
=
0
;
virtual
void
Compute
(
const
Execution
Context
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
};
template
<
typename
T
>
struct
VarToTensor
{};
template
<>
struct
VarToTensor
<
Tensor
*>
{
Tensor
*
operator
()(
Variable
*
var
)
{
return
var
->
GetMutable
<
Tensor
>
();
}
};
template
<>
struct
VarToTensor
<
const
Tensor
*>
{
const
Tensor
*
operator
()(
Variable
*
var
)
{
return
&
var
->
Get
<
Tensor
>
();
}
};
class
OperatorWithKernel
:
public
OperatorBase
{
public:
struct
OpKernelKey
{
...
...
@@ -230,10 +283,14 @@ class OperatorWithKernel : public OperatorBase {
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{
InferShape
(
InferShapeContext
(
this
,
scope
));
}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
Kernel
Context
(
this
,
scope
,
dev_ctx
));
opKernel
->
Compute
(
Execution
Context
(
this
,
scope
,
dev_ctx
));
}
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
...
...
@@ -242,34 +299,8 @@ class OperatorWithKernel : public OperatorBase {
return
g_all_op_kernels
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
std
::
vector
<
const
Tensor
*>
ins
;
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
std
::
vector
<
Tensor
*>
outs
;
VarNamesToTensors
(
scope
,
outputs_
,
&
outs
);
InferShape
(
ins
,
outs
);
};
private:
template
<
typename
T
>
void
VarNamesToTensors
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
std
::
string
>&
var_names
,
std
::
vector
<
T
>*
container
)
const
{
container
->
reserve
(
var_names
.
size
());
VarToTensor
<
T
>
convert
;
for
(
auto
&
name
:
var_names
)
{
auto
var
=
scope
->
GetVariable
(
name
);
if
(
var
!=
nullptr
)
{
container
->
push_back
(
convert
(
var
));
}
else
{
container
->
push_back
(
nullptr
);
}
}
}
protected:
virtual
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>&
inputs
,
const
std
::
vector
<
Tensor
*>&
outputs
)
const
=
0
;
virtual
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
=
0
;
};
}
// namespace framework
...
...
paddle/framework/operator_test.cc
浏览文件 @
bc146e8f
...
...
@@ -24,7 +24,8 @@ static int op_run_num = 0;
class
OpWithoutKernelTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
InferShape
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
op_run_num
++
;
...
...
@@ -73,6 +74,7 @@ TEST(OperatorBase, all) {
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
scope
->
CreateVariable
(
"OUT1"
);
ASSERT_EQ
(
paddle
::
framework
::
op_run_num
,
0
);
op
->
InferShape
(
scope
);
op
->
Run
(
scope
,
device_context
);
ASSERT_EQ
(
paddle
::
framework
::
op_run_num
,
1
);
}
...
...
@@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0;
class
OpWithKernelTest
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>&
inputs
,
const
std
::
vector
<
Tensor
*>&
outputs
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
};
template
<
typename
T1
,
typename
T2
>
class
CPUKernelTest
:
public
OpKernel
{
public:
void
Compute
(
const
Kernel
Context
&
ctx
)
const
{
void
Compute
(
const
Execution
Context
&
ctx
)
const
{
std
::
cout
<<
"this is cpu kernel"
<<
std
::
endl
;
std
::
cout
<<
ctx
.
op_
.
DebugString
()
<<
std
::
endl
;
cpu_kernel_run_num
++
;
...
...
@@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel {
class
OperatorMultiInputsTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
InferShape
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
...
...
@@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class
CPUKernalMultiInputsTest
:
public
OpKernel
{
public:
void
Compute
(
const
Kernel
Context
&
ctx
)
const
{
void
Compute
(
const
Execution
Context
&
ctx
)
const
{
auto
xs
=
ctx
.
op_
.
Inputs
(
"xs"
);
ASSERT_EQ
(
xs
.
size
(),
3UL
);
ASSERT_EQ
(
xs
[
0
],
"x0"
);
ASSERT_EQ
(
xs
[
1
],
"x1"
);
ASSERT_EQ
(
xs
[
2
],
"x2"
);
auto
inVar0
=
ctx
.
MultiInputVar
(
"xs"
);
ASSERT_EQ
(
inVar0
.
size
(),
3
);
auto
intVar1
=
ctx
.
InputVar
(
"k"
);
ASSERT_NE
(
intVar1
,
nullptr
);
auto
outVar0
=
ctx
.
MultiOutputVar
(
"ys"
);
ASSERT_EQ
(
outVar0
.
size
(),
2
);
auto
inTensor0
=
ctx
.
MultiInput
<
Tensor
>
(
"xs"
);
ASSERT_EQ
(
inTensor0
.
size
(),
3
);
auto
intTensor1
=
ctx
.
Input
<
Tensor
>
(
"k"
);
ASSERT_NE
(
intTensor1
,
nullptr
);
auto
outTensor0
=
ctx
.
MultiOutput
<
Tensor
>
(
"ys"
);
ASSERT_EQ
(
outTensor0
.
size
(),
2
);
auto
k
=
ctx
.
op_
.
Input
(
"k"
);
ASSERT_EQ
(
k
,
"k0"
);
...
...
@@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) {
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
scope
->
CreateVariable
(
"x0"
)
->
GetMutable
<
Tensor
>
();
scope
->
CreateVariable
(
"x1"
)
->
GetMutable
<
Tensor
>
();
scope
->
CreateVariable
(
"x2"
)
->
GetMutable
<
Tensor
>
();
scope
->
CreateVariable
(
"k0"
)
->
GetMutable
<
Tensor
>
();
scope
->
CreateVariable
(
"y0"
)
->
GetMutable
<
Tensor
>
();
scope
->
CreateVariable
(
"y1"
)
->
GetMutable
<
Tensor
>
();
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_device_context
);
...
...
paddle/framework/tensor.h
浏览文件 @
bc146e8f
...
...
@@ -94,14 +94,7 @@ class Tensor {
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
template
<
typename
T
>
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CPUDeviceContext
&
ctx
);
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CUDADeviceContext
&
ctx
);
#endif
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
);
/**
* @brief Return the slice of the tensor.
...
...
@@ -129,13 +122,16 @@ class Tensor {
virtual
platform
::
Place
place
()
const
=
0
;
};
template
<
typename
T
,
typename
Place
Type
>
template
<
typename
T
,
typename
Place
>
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
Place
Type
place
,
size_t
size
)
PlaceholderImpl
(
Place
place
,
size_t
size
)
:
ptr_
(
static_cast
<
T
*>
(
memory
::
Alloc
(
place
,
size
)),
memory
::
PODDeleter
<
T
,
Place
Type
>
(
place
)),
memory
::
PODDeleter
<
T
,
Place
>
(
place
)),
place_
(
place
),
size_
(
size
)
{}
size_
(
size
)
{
PADDLE_ENFORCE
(
ptr_
!=
nullptr
,
"Insufficient %s memory to allocation."
,
is_cpu_place
(
place_
)
?
"CPU"
:
"GPU"
);
}
virtual
size_t
size
()
const
{
return
size_
;
}
virtual
platform
::
Place
place
()
const
{
return
place_
;
}
...
...
@@ -143,7 +139,7 @@ class Tensor {
virtual
std
::
type_index
type
()
const
{
return
std
::
type_index
(
typeid
(
T
));
}
/*! the pointer of memory block. */
std
::
unique_ptr
<
T
,
memory
::
PODDeleter
<
T
,
Place
Type
>>
ptr_
;
std
::
unique_ptr
<
T
,
memory
::
PODDeleter
<
T
,
Place
>>
ptr_
;
/*! the place of memory block. */
platform
::
Place
place_
;
...
...
paddle/framework/tensor_test.cc
浏览文件 @
bc146e8f
...
...
@@ -198,8 +198,8 @@ TEST(Tensor, CopyFrom) {
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
auto
*
cpu_ctx
=
new
paddle
::
platform
::
CPUDeviceContext
();
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
cpu_
ctx
);
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
cpu_
place
);
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
...
...
@@ -208,7 +208,7 @@ TEST(Tensor, CopyFrom) {
}
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
cpu_
ctx
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
cpu_
place
);
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
...
...
@@ -228,12 +228,12 @@ TEST(Tensor, CopyFrom) {
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
// CPU Tensor to GPU Tensor
auto
gpu_
ctx
=
new
paddle
::
platform
::
CUDADeviceContext
(
0
);
gpu_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
gpu_
ctx
);
auto
gpu_
place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
gpu_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
gpu_
place
);
// GPU Tensor to CPU Tensor
auto
cpu_
ctx
=
new
paddle
::
platform
::
CPUDeviceContext
();
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
ctx
);
auto
cpu_
place
=
new
paddle
::
platform
::
CPUPlace
();
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
place
);
// Compare Tensors
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
...
...
@@ -245,10 +245,10 @@ TEST(Tensor, CopyFrom) {
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
// CPU Slice Tensor to GPU Tensor
gpu_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
gpu_
ctx
);
gpu_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
gpu_
place
);
// GPU Tensor to CPU Tensor
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
ctx
);
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
place
);
// Compare Slice Tensors
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
bc146e8f
...
...
@@ -1141,4 +1141,64 @@ TEST(CpuMatrix, copyFrom) {
TensorCheckEqual
(
cpu
,
copy
);
}
void
testBatch2seqPadding
(
int
batchSize
,
int
inputDim
)
{
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
gpuInput
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
);
cpuInput
->
randomizeUniform
();
gpuInput
->
copyFrom
(
*
cpuInput
);
IVectorPtr
cpuSequence
;
generateSequenceStartPositions
(
batchSize
,
cpuSequence
);
IVectorPtr
gpuSequence
=
IVector
::
create
(
cpuSequence
->
getSize
(),
true
);
gpuSequence
->
copyFrom
(
*
cpuSequence
);
size_t
numSeq
=
cpuSequence
->
getSize
()
-
1
;
size_t
maxSeqLen
=
*
std
::
max_element
(
cpuSequence
->
getData
(),
cpuSequence
->
getData
()
+
numSeq
);
MatrixPtr
cBatch
=
std
::
make_shared
<
CpuMatrix
>
(
numSeq
*
maxSeqLen
,
inputDim
);
MatrixPtr
gBatch
=
std
::
make_shared
<
GpuMatrix
>
(
numSeq
*
maxSeqLen
,
inputDim
);
MatrixPtr
cCheck
=
std
::
make_shared
<
CpuMatrix
>
(
numSeq
*
maxSeqLen
,
inputDim
);
hl_sequence2batch_copy_padding
(
gBatch
->
getData
(),
gpuInput
->
getData
(),
cpuSequence
->
getData
(),
inputDim
,
maxSeqLen
,
numSeq
,
false
,
true
);
cCheck
->
copyFrom
(
*
gBatch
);
int
*
seqStart
=
cpuSequence
->
getData
();
float
*
batchData
=
cBatch
->
getData
();
float
*
seqData
=
cpuInput
->
getData
();
for
(
size_t
i
=
0
;
i
<
maxSeqLen
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
numSeq
;
j
++
)
{
size_t
sequenceStart
=
seqStart
[
j
];
size_t
sequenceLength
=
seqStart
[
j
+
1
]
-
seqStart
[
j
];
if
(
i
<
sequenceLength
)
{
memcpy
(
batchData
+
(
i
*
numSeq
+
j
)
*
inputDim
,
seqData
+
(
sequenceStart
+
i
)
*
inputDim
,
inputDim
*
sizeof
(
real
));
}
else
{
memset
(
batchData
+
(
i
*
numSeq
+
j
)
*
inputDim
,
0
,
inputDim
*
sizeof
(
real
));
}
}
}
TensorCheckErr
(
*
cBatch
,
*
cCheck
);
}
TEST
(
Matrix
,
warpCTC
)
{
for
(
auto
batchSize
:
{
51
,
526
,
2884
})
{
for
(
auto
inputDim
:
{
32
,
512
,
2026
})
{
VLOG
(
3
)
<<
" batchSize="
<<
batchSize
<<
" inputDim="
<<
inputDim
;
testBatch2seqPadding
(
batchSize
,
inputDim
);
}
}
}
#endif
paddle/memory/memcpy.cc
浏览文件 @
bc146e8f
...
...
@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
...
...
@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
dst_place
.
device
);
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
...
...
@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
if
(
dst_place
==
src_place
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
...
...
paddle/memory/memcpy.h
浏览文件 @
bc146e8f
...
...
@@ -20,13 +20,39 @@ limitations under the License. */
namespace
paddle
{
namespace
memory
{
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
*
*/
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
);
#ifndef PADDLE_ONLY_CPU
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU or GPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU or GPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
* \param[in] stream CUDA stream.
*
* \note For GPU memory copy, CUDA stream need to be specified
* for asynchronously memory copy.
*
*/
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
);
#endif // PADDLE_ONLY_CPU
}
// namespace memory
...
...
paddle/memory/memory.cc
浏览文件 @
bc146e8f
...
...
@@ -60,6 +60,7 @@ detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
platform
::
GpuMaxChunkSize
());
}
}
platform
::
SetDeviceId
(
gpu_id
);
return
as
[
gpu_id
];
}
...
...
paddle/memory/memory.h
浏览文件 @
bc146e8f
...
...
@@ -20,15 +20,49 @@ limitations under the License. */
namespace
paddle
{
namespace
memory
{
/**
* \brief Allocate memory block in one place.
*
* \param[in] place Allocation place (CPU or GPU).
* \param[in] size Allocation size.
*
* \return Allocated memory block address.
*
* \note If return nullptr, it indicates memory allocation failed
* because insufficient memory in current system. When Alloc
* function is invoked, you must check the returned memory
* address is valid or not.
*/
template
<
typename
Place
>
void
*
Alloc
(
Place
,
size_t
);
void
*
Alloc
(
Place
place
,
size_t
size
);
/**
* \brief Free memory block in one place.
*
* \param[in] place Allocation place (CPU or GPU).
* \param[in] ptr Memory block address to free.
*
*/
template
<
typename
Place
>
void
Free
(
Place
,
void
*
);
void
Free
(
Place
place
,
void
*
ptr
);
/**
* \brief Total size of used memory in one place.
*
* \param[in] place Allocation place (CPU or GPU).
*
*/
template
<
typename
Place
>
size_t
Used
(
Place
);
size_t
Used
(
Place
place
);
/**
* \brief Free memory block in one place.
*
* \note In some cases, custom deleter is used to
* deallocate the memory automatically for
* std::unique_ptr<T> in tensor.h.
*
*/
template
<
typename
T
,
typename
Place
>
class
PODDeleter
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
...
...
paddle/operators/CMakeLists.txt
浏览文件 @
bc146e8f
...
...
@@ -55,3 +55,8 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net
)
op_library
(
sgd_op SRCS sgd_op.cc sgd_op.cu
)
op_library
(
recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc
tensor op_registry operator net
)
cc_test
(
recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS
recurrent_network_op gtest mul_op add_op
)
paddle/operators/add_op.cc
浏览文件 @
bc146e8f
...
...
@@ -19,16 +19,16 @@ namespace operators {
class
AddOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of AddOp must be two
"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of AddOp must be one"
);
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
&&
inputs
[
1
]
!=
nullptr
&&
outputs
[
0
]
!=
nullptr
,
"Inputs/
Outputs of AddOp must all be set"
);
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
()
==
inputs
[
1
]
->
dims
(),
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of AddOp must be two"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of AddOp must be one
"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
&&
ctx
.
InputVar
(
1
)
!=
nullptr
,
"Inputs of AddOp must all be set"
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"
Outputs of AddOp must all be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
(),
"Two input of Add Op's dimension must be same."
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
};
...
...
@@ -49,8 +49,7 @@ The equation is: Out = X + Y
class
AddOpGrad
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"AddOpGrad"
;
return
""
;
...
...
paddle/operators/add_op.h
浏览文件 @
bc146e8f
...
...
@@ -21,16 +21,17 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
AddKernel
:
public
OpKernel
{
public:
void
Compute
(
const
Kernel
Context
&
context
)
const
override
{
auto
input0
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
(
);
auto
input1
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
(
);
auto
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
(
);
void
Compute
(
const
Execution
Context
&
context
)
const
override
{
auto
input0
=
context
.
Input
<
Tensor
>
(
0
);
auto
input1
=
context
.
Input
<
Tensor
>
(
1
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
EigenVector
<
T
>::
Flatten
(
*
output
).
device
(
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
EigenVector
<
T
>::
Flatten
(
input0
)
+
EigenVector
<
T
>::
Flatten
(
input1
);
framework
::
EigenVector
<
T
>::
Flatten
(
*
input0
)
+
framework
::
EigenVector
<
T
>::
Flatten
(
*
input1
);
}
};
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
bc146e8f
...
...
@@ -19,20 +19,20 @@ namespace operators {
class
OnehotCrossEntropyOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of OnehotCrossEntropyOp must be two"
);
PADDLE_ENFORCE
(
outputs
.
s
ize
()
==
1
,
PADDLE_ENFORCE
(
ctx
.
OutputS
ize
()
==
1
,
"Output size of OnehotCrossEntropyOp must be one"
);
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
&&
inputs
[
1
]
!=
nullptr
,
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
&&
ctx
.
InputVar
(
1
)
!=
nullptr
,
"Inputs of OnehotCrossEntropyOp must all be set"
);
PADDLE_ENFORCE
(
outputs
[
0
]
!=
nullptr
,
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"Outputs of OnehotCrossEntropyOp must all be set"
);
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
().
size
()
==
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE
(
outputs
[
0
]
->
dims
().
size
()
==
1
,
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
().
size
()
==
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE
(
ctx
.
Output
<
Tensor
>
(
0
)
->
dims
().
size
()
==
1
,
"label's dimension must be 1."
);
outputs
[
0
]
->
Resize
({
inputs
[
0
]
->
dims
()[
0
]});
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
({
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()[
0
]});
}
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
bc146e8f
...
...
@@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
public:
constexpr
T
LOG_THRESHOLD
()
const
{
return
static_cast
<
T
>
(
1e-20
);
}
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
X
=
c
ontext
.
Input
(
0
)
->
Get
<
Tensor
>
(
);
const
T
*
X_data
=
X
.
data
<
T
>
();
const
int
*
label_data
=
c
ontext
.
Input
(
1
)
->
Get
<
Tensor
>
().
data
<
int
>
();
auto
*
Y
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
(
);
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
auto
X
=
c
tx
.
Input
<
Tensor
>
(
0
);
const
T
*
X_data
=
X
->
data
<
T
>
();
const
int
*
label_data
=
c
tx
.
Input
<
Tensor
>
(
1
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
0
);
Y
->
mutable_data
<
T
>
(
c
ontext
.
GetPlace
());
Y
->
mutable_data
<
T
>
(
c
tx
.
GetPlace
());
T
*
Y_data
=
Y
->
data
<
T
>
();
int
batch_size
=
X
.
dims
()[
0
];
int
class_num
=
X
.
dims
()[
1
];
int
batch_size
=
X
->
dims
()[
0
];
int
class_num
=
X
->
dims
()[
1
];
// Y[i] = -log(X[i][j])
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
...
...
paddle/operators/mul_op.cc
浏览文件 @
bc146e8f
...
...
@@ -19,18 +19,17 @@ namespace operators {
class
MulOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"The mul op must take two inputs"
);
auto
dim0
=
inputs
[
0
]
->
dims
();
auto
dim1
=
inputs
[
1
]
->
dims
();
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"The mul op must take two inputs"
);
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
PADDLE_ENFORCE
(
dim0
.
size
()
==
2
&&
dim1
.
size
()
==
2
,
"The input of mul op must be matrix"
);
PADDLE_ENFORCE
(
dim0
[
1
]
==
dim1
[
0
],
"First matrix's width must be equal with second matrix's height."
);
PADDLE_ENFORCE
(
outputs
.
s
ize
()
==
1
,
"The mul op must take one output"
);
outputs
[
0
]
->
Resize
({
dim0
[
0
],
dim1
[
1
]});
PADDLE_ENFORCE
(
ctx
.
OutputS
ize
()
==
1
,
"The mul op must take one output"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
({
dim0
[
0
],
dim1
[
1
]});
}
};
...
...
@@ -51,8 +50,7 @@ The equation is: Out = X * Y
class
MulOpGrad
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
return
""
;
...
...
paddle/operators/mul_op.h
浏览文件 @
bc146e8f
...
...
@@ -22,19 +22,17 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
MulKernel
:
public
OpKernel
{
public:
void
Compute
(
const
Kernel
Context
&
context
)
const
override
{
void
Compute
(
const
Execution
Context
&
context
)
const
override
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
auto
input0
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
EigenMatrix
<
T
>::
From
(
*
output
).
device
(
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
EigenMatrix
<
T
>::
From
(
input0
).
contract
(
EigenMatrix
<
T
>::
From
(
input1
),
dim_pair
);
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
"X"
))
.
contract
(
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
"Y"
)),
dim_pair
);
}
};
}
// namespace operators
...
...
paddle/operators/recurrent_network_op.cc
0 → 100644
浏览文件 @
bc146e8f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/recurrent_network_op.h"
#include <glog/logging.h>
#include <cstring>
#include <sstream>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
rnn
{
void
SegmentInputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
)
{
PADDLE_ENFORCE
(
!
inlinks
.
empty
(),
"no in links are provided."
);
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
Tensor
*
input
=
step_scopes
[
0
]
->
GetVariable
(
inlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
DDim
dims
=
input
->
dims
();
PADDLE_ENFORCE
(
static_cast
<
size_t
>
(
dims
[
0
])
==
seq_len
,
"all the inlinks must have same length"
);
DDim
step_dims
=
slice_ddim
(
dims
,
1
,
dims
.
size
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
Tensor
*
step_input
=
step_scopes
[
j
]
->
CreateVariable
(
inlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
();
*
step_input
=
input
->
Slice
<
float
>
(
j
,
j
+
1
);
step_input
->
Resize
(
step_dims
);
}
}
}
void
ConcatOutputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
)
{
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
Tensor
*
output
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
// TODO(qingiqng) remove following code after adding
// InferShape in RecurrentGradientOp
DDim
step_dims
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len
);
output
->
mutable_data
<
float
>
(
make_ddim
(
dims_vec
),
platform
::
CPUPlace
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
Tensor
*
step_output
=
step_scopes
[
j
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(
output
->
Slice
<
float
>
(
j
,
j
+
1
))
.
CopyFrom
<
float
>
(
*
step_output
,
platform
::
CPUPlace
());
}
}
}
void
LinkMemories
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
scopes
,
const
std
::
vector
<
rnn
::
MemoryAttr
>&
memories
,
size_t
step_id
,
int
offset
)
{
PADDLE_ENFORCE
(
step_id
<
scopes
.
size
(),
"step [%d] is out of range of step scopes' size [%d]"
,
step_id
,
scopes
.
size
());
PADDLE_ENFORCE
(
static_cast
<
int
>
(
step_id
)
+
offset
>=
0
,
"offset [%d] must be large than -[%d]"
,
offset
,
step_id
);
PADDLE_ENFORCE
(
step_id
+
offset
<
scopes
.
size
(),
"offset [%d] is out of range, it must be less than (%d - %d)"
,
offset
,
scopes
.
size
(),
step_id
);
std
::
shared_ptr
<
Scope
>
scope
=
scopes
[
step_id
];
std
::
shared_ptr
<
Scope
>
linked_scope
=
scopes
[
step_id
+
offset
];
for
(
auto
&
attr
:
memories
)
{
auto
mem
=
scope
->
CreateVariable
(
attr
.
pre_var
)
->
GetMutable
<
Tensor
>
();
// maybe share variable is better?
auto
linked_mem
=
linked_scope
->
GetVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
mem
->
ShareDataWith
<
float
>
(
*
linked_mem
);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
auto
m
=
scope
->
CreateVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
// for unit test, as addOp and mulOp are null currently, if not
// mutable_data, mem.data() in output will be error. We will
// remove this line after merge the correct addOp and mulOp.
m
->
mutable_data
<
float
>
(
mem
->
dims
(),
platform
::
CPUPlace
());
}
}
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
,
const
OperatorBase
&
op
)
{
arg
->
step_net
=
op
.
Input
(
name
.
step_net
);
arg
->
step_scopes
=
op
.
Output
(
name
.
step_scopes
);
auto
inlinks
=
op
.
Inputs
(
name
.
inlinks
);
auto
inlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
inlink_alias
);
PADDLE_ENFORCE
(
inlinks
.
size
()
==
inlink_alias
.
size
(),
"the size of inlinks and inlink_alias don't match:%d,%d"
,
inlinks
.
size
(),
inlink_alias
.
size
());
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
rnn
::
Link
link
;
link
.
external
=
inlinks
[
i
];
link
.
internal
=
inlink_alias
[
i
];
(
arg
->
inlinks
).
push_back
(
link
);
}
auto
outlinks
=
op
.
Outputs
(
name
.
outlinks
);
auto
outlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
outlink_alias
);
PADDLE_ENFORCE
(
outlinks
.
size
()
==
outlink_alias
.
size
(),
"the size of outlinks and outlink_alias don't match:%d,%d"
,
outlinks
.
size
(),
outlink_alias
.
size
());
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
++
i
)
{
rnn
::
Link
link
;
link
.
external
=
outlinks
[
i
];
link
.
internal
=
outlink_alias
[
i
];
(
arg
->
outlinks
).
push_back
(
link
);
}
auto
boot_memories
=
op
.
Inputs
(
name
.
boot_memories
);
// attributes
auto
memories
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
);
auto
pre_memories
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
);
PADDLE_ENFORCE
(
memories
.
size
()
==
boot_memories
.
size
(),
"the size of memories, boot_memories don't match:%d,%d"
,
memories
.
size
(),
boot_memories
.
size
());
PADDLE_ENFORCE
(
pre_memories
.
size
()
==
boot_memories
.
size
(),
"the size of pre_memories, boot_memories don't match:%d,%d"
,
pre_memories
.
size
(),
boot_memories
.
size
());
PADDLE_ENFORCE
(
memories
.
size
()
>
0
,
"more than 1 memories should be set"
);
for
(
size_t
i
=
0
;
i
<
memories
.
size
();
++
i
)
{
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
var
=
memories
[
i
];
mem_attr
.
pre_var
=
pre_memories
[
i
];
mem_attr
.
boot_var
=
boot_memories
[
i
];
(
arg
->
memories
).
push_back
(
mem_attr
);
}
}
}
// namespace rnn
void
RecurrentAlgorithm
::
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{
seq_len_
=
scope
->
GetVariable
((
arg_
->
inlinks
[
0
]).
external
)
->
GetMutable
<
Tensor
>
()
->
dims
()[
0
];
CreateScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
// SegmentInputs is called in InferShape. The input must hold memory in
// SegmentInputs. But the other op only set dimension for the output in
// InferShape. That's a problem. Wether the RNN op needs InferShape or not?
// Wether the following functions (SegmentInputs, InitMemories, ...) need
// to rewrite for RNN op?
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
);
InitMemories
(
step_scopes
[
0
]);
PADDLE_ENFORCE
(
scope
->
HasVariable
(
arg_
->
step_net
),
"stepnet [%s] is not in scope."
,
arg_
->
step_net
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
// If the InferShape is called in OperatorBase's run function,
// the rnn op only needs to do InferShape for the first time step
for
(
size_t
i
=
0
;
i
<
seq_len_
;
i
++
)
{
if
(
i
>
0
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
i
,
-
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
step_scopes
[
i
]);
}
auto
outlinks
=
arg_
->
outlinks
;
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
DDim
step_dims
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
// now only support fixed length
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len_
);
Tensor
*
output
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
output
->
Resize
(
make_ddim
(
dims_vec
));
}
}
void
RecurrentAlgorithm
::
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
for
(
size_t
step_id
=
0
;
step_id
<
seq_len_
;
step_id
++
)
{
// the link memory is done in InferShape
// maybe remove following code after testing
if
(
step_id
>
0
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
-
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
Run
(
step_scopes
[
step_id
],
dev_ctx
);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
);
}
void
RecurrentAlgorithm
::
CreateScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
{
// TODO(xxx) Only two scopes are needed for inference, this case will be
// supported later.
auto
step_scopes
=
scope
->
GetVariable
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
if
(
seq_len_
>
step_scopes
->
size
())
{
for
(
size_t
i
=
step_scopes
->
size
();
i
<
seq_len_
;
++
i
)
{
std
::
shared_ptr
<
Scope
>
step_scope
=
std
::
make_shared
<
Scope
>
(
scope
);
// Now all variables in scope must be created outside of op.
auto
net_op
=
scope
->
GetVariable
(
arg_
->
step_net
)
->
GetMutable
<
NetOp
>
();
for
(
auto
&
input
:
net_op
->
inputs_
)
{
step_scope
->
CreateVariable
(
input
);
}
for
(
auto
&
output
:
net_op
->
outputs_
)
{
step_scope
->
CreateVariable
(
output
);
}
step_scopes
->
push_back
(
std
::
make_shared
<
Scope
>
(
step_scope
));
}
}
}
void
RecurrentAlgorithm
::
InitMemories
(
std
::
shared_ptr
<
Scope
>
step_scope
)
const
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
Tensor
*
pre_mem
=
step_scope
->
CreateVariable
(
attr
.
pre_var
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE
(
step_scope
->
HasVariable
(
attr
.
boot_var
),
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
attr
.
boot_var
);
Tensor
*
boot_mem
=
step_scope
->
GetVariable
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
pre_mem
->
ShareDataWith
<
float
>
(
*
boot_mem
);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
// here for unit test
auto
cur_step_mem
=
step_scope
->
CreateVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
cur_step_mem
->
mutable_data
<
float
>
(
boot_mem
->
dims
(),
platform
::
CPUPlace
());
}
}
const
rnn
::
ArgumentName
RecurrentOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"inlinks"
,
"outlinks"
,
"inlink_alias"
,
"outlink_alias"
,
"memories"
,
"pre_memories"
,
"boot_memories"
};
const
rnn
::
ArgumentName
RecurrentGradientOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"outlink@grad"
,
"inlink@grad"
,
"inlink_alias"
,
"outlink_alias"
,
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
void
RecurrentOp
::
Init
()
{
OperatorBase
::
Init
();
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
alg_
.
Init
(
std
::
move
(
arg
));
}
class
RecurrentAlgorithmProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
RecurrentAlgorithmProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
RecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
AddInputs
(
name
.
inlinks
,
"the input that need to be segmented for each step."
);
AddInputs
(
name
.
boot_memories
,
"variables to initialize memories."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutputs
(
name
.
outlinks
,
"the output that need to concated for all steps."
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
inlink_alias
,
"alias of inlinks"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
outlink_alias
,
"alias of outlinks"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
,
"names of pre-memories"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
,
"names of memories"
);
AddComment
(
"This is a recurrent group operator."
);
}
};
void
RecurrentGradientAlgorithm
::
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
);
PADDLE_ENFORCE
(
scope
->
HasVariable
(
arg_
->
step_net
),
"step net is not in scope."
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
Run
(
step_scopes
[
step_id
],
dev_ctx
);
}
LinkBootMemoryGradients
(
step_scopes
[
0
]);
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
);
}
void
RecurrentGradientAlgorithm
::
LinkBootMemoryGradients
(
std
::
shared_ptr
<
Scope
>
step_scope
)
const
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
Tensor
*
mem_grad
=
step_scope
->
CreateVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE
(
mem_grad
!=
nullptr
,
"boot_tensor should be retrieved before"
);
PADDLE_ENFORCE
(
step_scope
->
HasVariable
(
attr
.
boot_var
),
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
attr
.
boot_var
);
Tensor
*
boot_mem_grad
=
step_scope
->
CreateVariable
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
boot_mem_grad
->
ShareDataWith
<
float
>
(
*
mem_grad
);
}
}
void
RecurrentGradientAlgorithm
::
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{
seq_len_
=
scope
->
GetVariable
((
arg_
->
inlinks
[
0
]).
external
)
->
GetMutable
<
Tensor
>
()
->
dims
()[
0
];
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
);
PADDLE_ENFORCE
(
scope
->
HasVariable
(
arg_
->
step_net
),
"step net is not in scope."
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
step_scopes
[
step_id
]);
}
auto
outlinks
=
arg_
->
outlinks
;
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
DDim
step_dims
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
// now only support fixed length
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len_
);
Tensor
*
output
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
output
->
Resize
(
make_ddim
(
dims_vec
));
}
LinkBootMemoryGradients
(
step_scopes
[
0
]);
}
void
RecurrentGradientOp
::
Init
()
{
OperatorBase
::
Init
();
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
alg_
.
Init
(
std
::
move
(
arg
));
}
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
recurrent_op
,
paddle
::
operators
::
RecurrentOp
,
paddle
::
operators
::
RecurrentAlgorithmProtoAndCheckerMaker
);
paddle/operators/recurrent_network_op.h
0 → 100644
浏览文件 @
bc146e8f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
operators
{
using
namespace
paddle
::
framework
;
namespace
rnn
{
/**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
*
* Memory attributes cached by this op, dims will be infered from
* boot memories in father scope. Other attributes are copied from Op's proto
* attributes.
*/
struct
MemoryAttr
{
// name of current state variable
std
::
string
var
;
// name of previous step's state variable
std
::
string
pre_var
;
// name of the variables to init this memory (same role of `boot_layer` in
// PaddlePaddle), which is store in father's scope.
std
::
string
boot_var
;
};
struct
Link
{
// input or output links name.
std
::
string
internal
;
// alias to avoid duplicate keys in scopes.
std
::
string
external
;
};
struct
Argument
{
std
::
string
step_net
;
std
::
string
step_scopes
;
std
::
vector
<
Link
>
inlinks
;
std
::
vector
<
Link
>
outlinks
;
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
};
struct
ArgumentName
{
std
::
string
step_net
;
std
::
string
step_scopes
;
std
::
string
inlinks
;
std
::
string
outlinks
;
std
::
string
inlink_alias
;
// the alias of inlinks in step net.
std
::
string
outlink_alias
;
// the alias of outlinks in step net.
std
::
string
memories
;
// the memory name
std
::
string
pre_memories
;
// the previous memory name
std
::
string
boot_memories
;
// the boot memory name
};
/**
* Prepare inputs for each step net.
*/
void
SegmentInputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
);
/**
* Process outputs of step nets and merge to variables.
*/
void
ConcatOutputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
);
void
LinkMemories
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
MemoryAttr
>&
memories
,
size_t
step_id
,
int
offset
);
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
);
};
// namespace rnn
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
// TODO:
// 1. No-padding computing for sequences with indifinite length in one batch.
// 2. Hierarchical RNN for sequence with sub-sequence.
// 3. Internal Memory.
// 4. More Complex RNN architecture, such as Gated Feedback RNN.
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
class
RecurrentAlgorithm
{
public:
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
;
protected:
/*
* The step scopes will be stored in the father scope as a variable.
*
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void
CreateScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
;
inline
const
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
GetStepScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
{
return
*
(
scope
->
GetVariable
(
arg_
->
step_scopes
))
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
}
void
InitMemories
(
std
::
shared_ptr
<
Scope
>
step_scopes
)
const
;
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
mutable
size_t
seq_len_
;
};
class
RecurrentGradientAlgorithm
{
/**
* RNN's backward alogorithm.
*
* To accelerate the development of RecurrentGradientOp, we decouple RNN's
* algorithm and `OperatorBase`'s implementation, the former contains the core
* implementation of a RNN, and will keep stable even if the framework changes
* a
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* operator.
*/
public:
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
LinkBootMemoryGradients
(
std
::
shared_ptr
<
Scope
>
step_scopes
)
const
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
;
protected:
inline
const
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
GetStepScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
{
return
*
(
scope
->
GetVariable
(
arg_
->
step_scopes
))
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
}
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
mutable
size_t
seq_len_
;
};
class
RecurrentOp
final
:
public
OperatorBase
{
public:
void
Init
()
override
;
/**
* InferShape must be called before Run.
*/
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
}
static
const
rnn
::
ArgumentName
kArgName
;
private:
RecurrentAlgorithm
alg_
;
};
class
RecurrentGradientOp
final
:
public
OperatorBase
{
public:
void
Init
()
override
;
/**
* InferShape must be called before Run.
*/
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
}
static
const
rnn
::
ArgumentName
kArgName
;
private:
RecurrentGradientAlgorithm
alg_
;
};
}
// namespace operators
}
// namespace paddle
paddle/operators/recurrent_network_op_test.cc
0 → 100644
浏览文件 @
bc146e8f
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h"
namespace
paddle
{
namespace
operators
{
class
RecurrentOpTest
:
public
::
testing
::
Test
{
protected:
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateStepNet
();
CreateRNNOp
();
}
virtual
void
TearDown
()
override
{}
void
CreateGlobalVariables
()
{
scope_
=
std
::
make_shared
<
Scope
>
();
// create input, and init content
LOG
(
INFO
)
<<
"create global variable x"
;
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"x"
,
"x0"
,
"x1"
,
"h"
})
{
Variable
*
x
=
scope_
->
CreateVariable
(
inlink
);
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
}
// create output alias just for test
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
Variable
*
x
=
scope_
->
CreateVariable
(
inlink
);
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
}
LOG
(
INFO
)
<<
"create global variable w"
;
Variable
*
w
=
scope_
->
CreateVariable
(
"rnn/w"
);
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
(
std
::
vector
<
int
>
{
30
,
30
}),
platform
::
CPUPlace
());
for
(
auto
boot
:
std
::
vector
<
std
::
string
>
{
"x_boot"
,
"h_boot"
})
{
LOG
(
INFO
)
<<
"create global variable "
<<
boot
;
Variable
*
h_boot
=
scope_
->
CreateVariable
(
boot
);
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
}),
platform
::
CPUPlace
());
}
LOG
(
INFO
)
<<
"create variable step_scopes"
;
scope_
->
CreateVariable
(
"step_scopes"
);
LOG
(
INFO
)
<<
"create variable h"
;
scope_
->
CreateVariable
(
"h"
);
}
void
CreateRNNOp
()
{
OpDesc
op_desc
;
op_desc
.
set_type
(
"recurrent_op"
);
// inlinks 0
op_desc
.
add_inputs
(
"x"
);
op_desc
.
add_inputs
(
"x0"
);
op_desc
.
add_inputs
(
"x1"
);
// boot_memories 3
op_desc
.
add_inputs
(
"x_boot"
);
op_desc
.
add_inputs
(
"h_boot"
);
// step net 5
op_desc
.
add_inputs
(
"step_net"
);
// outlinks 6
op_desc
.
add_outputs
(
"h"
);
// step scopes 7
op_desc
.
add_outputs
(
"step_scopes"
);
auto
_input_format
=
std
::
vector
<
int
>
{
0
,
// in_link
3
,
// memories
5
// step_net
};
auto
input_format
=
op_desc
.
add_attrs
();
input_format
->
set_name
(
"input_format"
);
input_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
for
(
auto
i
:
_input_format
)
{
input_format
->
add_ints
(
i
);
}
auto
output_format
=
op_desc
.
add_attrs
();
output_format
->
set_name
(
"output_format"
);
output_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
for
(
auto
i
:
std
::
vector
<
int
>
{
0
,
1
,
2
})
{
output_format
->
add_ints
(
i
);
}
auto
inlink_alias
=
op_desc
.
add_attrs
();
inlink_alias
->
set_name
(
"inlink_alias"
);
inlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
outlink_alias
=
op_desc
.
add_attrs
();
outlink_alias
->
set_name
(
"outlink_alias"
);
outlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
pre_memories
=
op_desc
.
add_attrs
();
pre_memories
->
set_name
(
"pre_memories"
);
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
memories
=
op_desc
.
add_attrs
();
memories
->
set_name
(
"memories"
);
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// create inlink_alias
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"x@alias"
,
"x0@alias"
,
"x1@alias"
})
{
inlink_alias
->
add_strings
(
item
);
}
// pre memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/x@pre"
,
"rnn/h@pre"
})
{
pre_memories
->
add_strings
(
item
);
}
// memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/x"
,
"rnn/h"
})
{
memories
->
add_strings
(
item
);
}
// output alias
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
outlink_alias
->
add_strings
(
item
);
}
rnn_op_
=
OpRegistry
::
CreateOp
(
op_desc
);
LOG
(
INFO
)
<<
"rnn_op finish init"
;
}
void
CreateStepNet
()
{
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
->
CreateVariable
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
// rnn/s is net's input or output?
net
->
inputs_
=
{
"rnn/h@pre"
,
"rnn/w"
,
"rnn/x"
};
net
->
inputs_
=
{
"rnn/s"
,
"rnn/h"
};
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h@pre"
,
"rnn/w"
},
{
"rnn/s"
},
{}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/x"
,
"rnn/s"
},
{
"rnn/h"
},
{}));
net
->
CompleteAddOp
();
}
// father scope
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
shared_ptr
<
OperatorBase
>
rnn_op_
;
};
TEST_F
(
RecurrentOpTest
,
Run
)
{
platform
::
CPUDeviceContext
ctx
;
rnn_op_
->
InferShape
(
scope_
);
rnn_op_
->
Run
(
scope_
,
ctx
);
}
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
protected:
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateStepScopes
();
CreateStepNet
();
CreateRNNGradientAlgorithm
();
// segment inputs
SegmentInputs
();
// link forward memories
LinkeMemories
();
}
virtual
void
TearDown
()
override
{}
void
CreateGlobalVariables
()
{
scope_
=
std
::
make_shared
<
Scope
>
();
// inputs: x
LOG
(
INFO
)
<<
"create global variable x"
;
Variable
*
x
=
scope_
->
CreateVariable
(
"x"
);
DDim
dims
=
make_ddim
({
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
// inputs: h_boot
LOG
(
INFO
)
<<
"create global variable h_boot"
;
Variable
*
h_boot
=
scope_
->
CreateVariable
(
"h_boot"
);
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
20
/*batch size*/
,
30
/*input dim*/
}),
platform
::
CPUPlace
());
// inputs: w
LOG
(
INFO
)
<<
"create global variable w"
;
Variable
*
w
=
scope_
->
CreateVariable
(
"rnn/w"
);
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
30
,
30
}),
platform
::
CPUPlace
());
// inputs: h_grad
LOG
(
INFO
)
<<
"create variable h_grad"
;
Variable
*
dh
=
scope_
->
CreateVariable
(
"h_grad"
);
dh
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
10
,
20
,
30
}),
platform
::
CPUPlace
());
// inputs: step_scopes
LOG
(
INFO
)
<<
"create variable step_scopes"
;
scope_
->
CreateVariable
(
"step_scopes"
);
// inputs: step_net
LOG
(
INFO
)
<<
"create variable step_net"
;
scope_
->
CreateVariable
(
"step_net"
);
// outputs: w_grad
LOG
(
INFO
)
<<
"create global variable w_grad"
;
scope_
->
CreateVariable
(
"rnn/w_grad"
);
// outputs: x_grad
LOG
(
INFO
)
<<
"create global variable x_grad"
;
scope_
->
CreateVariable
(
"x_grad"
);
// outputs: h_boot_grad
LOG
(
INFO
)
<<
"create global variable h_boot_grad"
;
scope_
->
CreateVariable
(
"h_boot_grad"
);
}
void
CreateStepScopes
()
{
std
::
vector
<
std
::
shared_ptr
<
Scope
>>*
step_scopes
=
scope_
->
GetVariable
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
auto
scope
=
std
::
make_shared
<
Scope
>
(
scope_
);
auto
pre_t
=
scope
->
CreateVariable
(
"rnn/pre_h"
)
->
GetMutable
<
Tensor
>
();
pre_t
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
auto
tensor
=
scope
->
CreateVariable
(
"rnn/h"
)
->
GetMutable
<
Tensor
>
();
tensor
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
// for unit test of ConcatOutputs
auto
xg
=
scope
->
CreateVariable
(
"rnn/x_grad"
)
->
GetMutable
<
Tensor
>
();
xg
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
step_scopes
->
push_back
(
scope
);
}
// last time step
auto
g
=
(
*
step_scopes
)[
9
]
->
CreateVariable
(
"rnn/h_pre_grad"
)
->
GetMutable
<
Tensor
>
();
g
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
}
void
CreateRNNGradientAlgorithm
()
{
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
arg
->
step_net
=
"step_net"
;
arg
->
step_scopes
=
"step_scopes"
;
rnn
::
Link
inlink
;
inlink
.
external
=
"h_grad"
;
inlink
.
internal
=
"rnn/h_grad"
;
arg
->
inlinks
=
std
::
vector
<
rnn
::
Link
>
{
inlink
};
rnn
::
Link
outlink
;
outlink
.
external
=
"x_grad"
;
outlink
.
internal
=
"rnn/x_grad"
;
arg
->
outlinks
=
std
::
vector
<
rnn
::
Link
>
{
outlink
};
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
pre_var
=
"rnn/h_pre_grad"
;
mem_attr
.
var
=
"rnn/h_grad"
;
mem_attr
.
boot_var
=
"h_boot_grad"
;
arg
->
memories
=
std
::
vector
<
rnn
::
MemoryAttr
>
{
mem_attr
};
rnn_grad_algo_
.
Init
(
std
::
move
(
arg
));
}
void
CreateStepNet
()
{
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
->
CreateVariable
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h_pre"
,
"rnn/w"
,
"rnn/s_grad"
},
{
"rnn/h_pre_grad"
,
"rnn/w_grad"
},
{}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/h_grad"
},
{
"rnn/x_grad"
,
"rnn/s_grad"
},
{}));
net
->
CompleteAddOp
();
}
void
SegmentInputs
()
{
LOG
(
INFO
)
<<
"segment inputs"
;
std
::
vector
<
std
::
string
>
inlinks
=
{
"x"
};
std
::
vector
<
std
::
string
>
inlinks_alias
=
{
"rnn/x"
};
rnn
::
Link
inlink
;
inlink
.
external
=
"x"
;
inlink
.
internal
=
"rnn/x"
;
std
::
vector
<
std
::
shared_ptr
<
Scope
>>*
step_scopes
=
scope_
->
GetVariable
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
rnn
::
SegmentInputs
(
*
step_scopes
,
std
::
vector
<
rnn
::
Link
>
{
inlink
},
10
);
}
void
LinkeMemories
()
{
LOG
(
INFO
)
<<
"link memories"
;
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
pre_var
=
"rnn/h_pre"
;
mem_attr
.
var
=
"rnn/h"
;
mem_attr
.
boot_var
=
"boot_h"
;
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
memories
.
push_back
(
mem_attr
);
std
::
vector
<
std
::
shared_ptr
<
Scope
>>*
step_scopes
=
scope_
->
GetVariable
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
for
(
int
i
=
1
;
i
<
10
;
++
i
)
{
rnn
::
LinkMemories
(
*
step_scopes
,
memories
,
i
,
-
1
);
}
}
std
::
shared_ptr
<
Scope
>
scope_
;
RecurrentGradientAlgorithm
rnn_grad_algo_
;
};
// TEST_F(RecurrentGradientAlgorithmTest, Run) {
// platform::CPUDeviceContext ctx;
// rnn_grad_algo_.Run(scope_, ctx);
// }
}
// namespace operators
}
// namespace paddle
TEST
(
RecurrentOp
,
LinkMemories
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
operators
;
// create and init step scopes
int
len
=
10
;
std
::
vector
<
std
::
shared_ptr
<
Scope
>>
step_scopes
;
for
(
int
i
=
0
;
i
<
len
;
++
i
)
{
auto
scope
=
std
::
make_shared
<
Scope
>
();
scope
->
CreateVariable
(
"pre_h"
);
auto
tensor
=
scope
->
CreateVariable
(
"h"
)
->
GetMutable
<
Tensor
>
();
float
*
data
=
tensor
->
mutable_data
<
float
>
(
make_ddim
({
15
,
20
}),
CPUPlace
());
for
(
int
i
=
0
;
i
<
15
*
20
;
++
i
)
{
data
[
i
]
=
rand
()
*
(
1.
/
(
double
)
RAND_MAX
);
}
step_scopes
.
push_back
(
scope
);
}
// create MemoryAttr
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
pre_var
=
"pre_h"
;
mem_attr
.
var
=
"h"
;
mem_attr
.
boot_var
=
"boot_h"
;
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
memories
.
push_back
(
mem_attr
);
for
(
int
i
=
1
;
i
<
len
;
++
i
)
{
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
-
1
);
}
// check
for
(
int
i
=
0
;
i
<
len
-
1
;
++
i
)
{
const
float
*
a
=
step_scopes
[
i
]
->
GetVariable
(
"h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
const
float
*
b
=
step_scopes
[
i
+
1
]
->
GetVariable
(
"pre_h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
15
*
20
;
++
i
)
{
ASSERT_FLOAT_EQ
(
a
[
i
],
b
[
i
]);
}
}
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
1
);
}
// check
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
const
float
*
a
=
step_scopes
[
i
]
->
GetVariable
(
"pre_h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
const
float
*
b
=
step_scopes
[
i
+
1
]
->
GetVariable
(
"h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
15
*
20
;
++
i
)
{
ASSERT_FLOAT_EQ
(
a
[
i
],
b
[
i
]);
}
}
}
USE_OP
(
add_two
);
USE_OP
(
mul
);
paddle/operators/rnn_design.md
0 → 100644
浏览文件 @
bc146e8f
# RNN 变长输入设计
对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式,
即将一个mini-batch内不同长度的序列补0到固定长度参与计算。
现有Paddle包括
`RecurrentLayerGroup`
在内的RNN均实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。
## 背景介绍
由于tensor必须有明确的shape,因此基于tensor 的主流框架在存储变长序列时,
必须用zero-padding的方式将变长序列补全为固定shape的tensor。
由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在,
因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。
由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来进行优化
[
1
][
2
]
,
但不管是padding还是bucket,对于用户都是额外的使用负担。
因此,
**paddle原生支持变长序列的方式,能直接满足用户对变长序列的最直接的需求,在当前主流平台中可以算是一大优势**
。
但对变长序列的支持,需要对目前框架做一些修改,下面讨论如何在最小修改下支持变长序列。
## 多层序列数据格式 `LODTensor`
目前 Paddle 会将一个mini-batch内的数据存储在一维的内存上,
额外使用
`Argument.sequenceStartPositions`
来存储每个句子的信息。
Paddle里使用
`Argument.subSequenceStartPositions`
来存储2层的序列信息,更高维度的序列则无法直接支持;
为了支持
`N-level`
序列的存储,本文将序列信息定义成如下数据结构:
```
c++
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int
>>>
lod_start_pos_
;
```
或者更明确的定义
```
c++
typedef
std
::
vector
<
int
>
level_t
;
std
::
vector
<
level_t
>
lod_start_pos
;
```
这里的每一个
`level_t`
存储一个粒度(level)的偏移信息,和paddle目前做法一致。
为了更透明地传递序列信息,我们引入了一种新的tensor 称为
`LODTensor`
[4],
其关于tensor相关的接口都直接继承自
`Tensor`
,但另外添加了序列相关接口。
如此,在操作一个
`LODTensor`
时,普通
`Op`
直接当成
`Tensor`
使用,
而操作序列的
`Op`
会额外操作
`LODTensor`
的变长序列操作的相关接口。
`LODTensor`
具体定义如下:
```
c++
class
LODTensor
:
public
Tensor
{
public:
size_t
Levels
()
const
{
return
seq_start_positions_
.
size
();
}
size_t
Elements
(
int
level
=
0
)
const
{
return
seq_start_positions_
[
level
].
size
();
}
// slice of level[elem_begin: elem_end]
// NOTE low performance in slice seq_start_positions_.
// TODO should call Tensor's Slice.
LODTensor
LODSlice
(
int
level
,
int
elem_begin
,
int
elem_end
)
const
;
// slice with tensor's data shared with this.
LODTensor
LODSliceShared
(
int
level
,
int
elem_begin
,
int
elem_end
)
const
;
// copy other's lod_start_pos_, to share LOD info.
// NOTE the LOD info sould not be changed.
void
ShareConstLODFrom
(
const
LODTensor
&
other
)
{
lod_start_pos_
=
other
.
lod_start_pos_
;
}
// copy other's lod_start_pos_'s content, free to mutate.
void
ShareMutableLODFrom
(
const
LODTensor
&
other
)
{
lod_start_pos_
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
int
>>
(
other
.
lod_start_pos_
.
begin
(),
other
.
lod_start_pos_
.
end
());
}
private:
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int
>>>
lod_start_pos_
;
};
```
其中,
`lod_start_pos_`
使用了
`shared_ptr`
来减少存储和复制的代价,
可以认为
`LODTensor`
是
`Tensor`
的扩展,几乎完全兼容原始
`Tensor`
的使用。
## 框架支持
### 框架现有的 `Tensor` 调用替换为 `LODTensor`
为了实现
`LODTensor`
的传递,框架里很多
`Tensor`
都需要变成
`LODTensor`
,
简单实现,直接
**把之前所有的`Tensor` 全部替换成 `LODTensor`,这里可以直接修改 `pybind.cc` 里面创建`Tensor`的接口**
。
此外,用户有可能需要感知序列的存在(比如序列的可视化需要解析模型中输出的序列),因此一些序列操作的API也需要暴露到 python 层。
### `lod_start_pos` 随着Op调用链传递
框架需要支持下列特性,以实现
`lod_start_pos`
的传递:
1.
以
`shared_ptr`
的方式实现传递
-
不修改
`lod_start_pos`
内容的作为 consumer
-
修改
`lod_start_pos`
的作为 producer
-
约定 consumer 只需要复制传递过来的
`shared_ptr`
-
producer 需要创建自己的独立的内存,以存储自己独立的修改,并暴露
`shared_ptr`
给后续 consumer
-
由于传递过程是以复制
`shared_ptr`
的方式实现,因此框架只需要传递一次
`lod_start_pos`
2.
对于不感知
`lod_start_pos`
的Op足够透明
3.
需要修改
`lod_start_pos`
的producer Op可以在
`Run`
时更新自己的
`lod_start_pos`
数据
具体的设计分为以下3小节
#### `load_start_pos` 的传递
-
对于不需要修改
`lod_start_pos`
的情况,调用 LODTensor的
`ShareConstLODFrom`
接口实现复制
-
需要修改的,调用
`ShareMutableLODFrom`
接口自己分配内存以存储修改
#### 框架透明
传递这一步需要加入到网络跑之前的初始化操作中,并且只需要初始化一次,基于当前框架设计的初步方案如下
-
在 Op 的
`attrs`
中添加一项
`do_mutate_lod_info`
的属性,默认为
`false`
-
有需要修改
`lod_start_pos`
的Op需要在定义
`OpProto`
时设置为
`true`
-
`OperatorBase`
的
`InferShape`
中会读取
`do_mutate_lod_info`
,并且调用
`LODTensor`
相关的方法实现
`lod_start_pos`
的复制。
-
`OperatorBase`
中添加一个 member
`is_lod_inited{false}`
来保证传递只进行一次
一些逻辑如下
```
c++
class
OperatorBase
{
public:
// ...
void
InferShape
()
{
if
(
!
is_load_inited
)
{
bool
do_mutate_lod_info
=
GetAttr
<
bool
>
(
"do_mutate_load_info"
);
// find a input having LOD to copy
auto
lod_input
=
ValidLODInput
();
for
(
auto
&
output
:
outputs
)
{
if
(
do_mutate_load_info
)
{
output
.
ShareMutableLODFrom
(
lod_input
);
}
else
{
output
.
ShareConstLODFrom
(
load_input
);
}
}
is_pod_inited
=
true
;
}
// call op's InferShape
// ...
}
private:
// ...
bool
is_lod_inited
{
false
};
};
```
如此,
`lod_start_pos`
的信息的传递对非OLD的Op的实现是完全透明的。
#### `lod_start_pos` 的更新
上一小节介绍到,对于需要修改
`load_start_pos`
的Op,
`OperatorBase`
会分配一块自己的内存以存储修改,
Op在
`Run`
的实现中,操作更新自己的
`load_start_pos`
,
而所有依赖其 outputs 的 op 会通过共享的指针自动获取到其更新。
## 根据长度排序
按照长度排序后,从前往后的时间步的batch size会自然地递减,可以直接塞入 Net 做batch计算
比如原始的输入:
```
origin:
xxxx
xx
xxx
-> sorted:
xxxx
xxx
xx
```
经过
`SegmentInputs`
之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)
```
0 1 2 3
x x x x
x x x
x x
```
为了追踪排序前后序列的变化,这里用
```
c++
struct
SortedSeqItem
{
void
*
start
{
nullptr
};
void
*
end
{
nullptr
};
};
std
::
vector
<
SortedSeqItem
>
sorted_seqs
;
```
来追踪序列排序后的位置,并添加一个新的接口
```
c++
std
::
vector
<
SortedSeqItem
>
SortBySeqLen
(
const
LODTensor
&
tensor
);
```
由于输入序列的顺序变化,以下现有的接口需要针对性地修改:
-
InitMemories, memory需要根据
`sorted_seqs`
重新排列
-
SetmentInputs
-
ConcatOutputs
此外,由于
`sorted_seqs`
需要被
`RecurrentGradientOp`
复用,因此会变成
`RecurrentOp`
一个新的output输出,
之后作为
`RecurrentGradientOp`
的一个输入传入。
## InitMemories
由于序列顺序的变化,
`boot_memories`
的batch上的element的顺序也需要对应重新排列。
## SegmentInputs
`SegmentInputs`
会依赖
`sorted_seqs`
的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。
即下面的转变:
```
origin:
xxxx
xx
xxx
|
|
\ /
!
0 1 2 3
x x x x
x x x
x x
```
## ConcatOutputs
`ConcatOutputs`
需要
-
将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
-
将每个序列concat 为规则的mini-batch表示
## 参考文献
1.
[
Tensorflow Bucketing
](
https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing
)
2.
[
mxnet Bucketing
](
http://mxnet.io/how_to/bucketing.html
)
3.
[
variable length input in RNN scenario
](
https://discuss.pytorch.org/t/about-the-variable-length-input-in-rnn-scenario/345/5
)
4.
[
Level of details
](
https://en.wikipedia.org/wiki/Level_of_detail
)
paddle/operators/rowwise_add_op.cc
浏览文件 @
bc146e8f
...
...
@@ -18,17 +18,17 @@ namespace operators {
class
RowWiseAddOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2UL
,
"Two inputs is needed by rowwise add"
);
auto
dim0
=
inputs
[
0
]
->
dims
();
auto
dim1
=
inputs
[
1
]
->
dims
();
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
"Two inputs is needed by rowwise add"
);
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
PADDLE_ENFORCE
(
dim0
.
size
()
==
2
,
"Input 0 must be matrix"
);
PADDLE_ENFORCE
(
dim1
.
size
()
==
1
,
"The second input must be vector"
);
PADDLE_ENFORCE
(
dim0
[
1
]
==
dim1
[
0
],
"The width of two input must be same"
);
PADDLE_ENFORCE
(
outputs
.
s
ize
()
==
1
,
"The output size must be 1"
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputS
ize
()
==
1
,
"The output size must be 1"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
};
...
...
paddle/operators/rowwise_add_op.h
浏览文件 @
bc146e8f
...
...
@@ -21,14 +21,12 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
RowWiseAddKernel
:
public
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
in0
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
in1
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
();
auto
*
out
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
input
=
EigenMatrix
<
T
>::
From
(
in0
);
auto
bias
=
EigenVector
<
T
>::
From
(
in1
);
auto
input
=
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
0
)
);
auto
bias
=
EigenVector
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
1
)
);
auto
output
=
EigenMatrix
<
T
>::
From
(
*
out
);
const
int
bias_size
=
bias
.
dimension
(
0
);
...
...
paddle/operators/sgd_op.cc
浏览文件 @
bc146e8f
...
...
@@ -19,16 +19,15 @@ namespace operators {
class
SGDOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of SGDOp must be one"
);
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
,
"inputs[0] mast be set"
);
PADDLE_ENFORCE
(
inputs
[
1
]
!=
nullptr
,
"inputs[1] mast be set"
);
PADDLE_ENFORCE
(
outputs
[
0
]
!=
nullptr
,
"outputs[0] mast be set"
);
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
()
==
inputs
[
1
]
->
dims
(),
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of SGDOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
,
"inputs[0] mast be set"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
1
)
!=
nullptr
,
"inputs[1] mast be set"
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"outputs[0] mast be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
(),
"Two input of SGD Op's dimension must be same."
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
};
...
...
paddle/operators/sgd_op.h
浏览文件 @
bc146e8f
...
...
@@ -21,16 +21,16 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
SGDOpKernel
:
public
OpKernel
{
public:
void
Compute
(
const
Kernel
Context
&
ctx
)
const
override
{
auto
param
=
ctx
.
Input
(
"param"
)
->
Get
<
Tensor
>
(
);
auto
grad
=
ctx
.
Input
(
"grad"
)
->
Get
<
Tensor
>
(
);
auto
*
param_out
=
ctx
.
Output
(
0
)
->
GetMutable
<
Tensor
>
(
);
void
Compute
(
const
Execution
Context
&
ctx
)
const
override
{
auto
param
=
ctx
.
Input
<
Tensor
>
(
"param"
);
auto
grad
=
ctx
.
Input
<
Tensor
>
(
"grad"
);
auto
param_out
=
ctx
.
Output
<
Tensor
>
(
0
);
float
lr
=
ctx
.
op_
.
GetAttr
<
float
>
(
"learning_rate"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenVector
<
T
>::
Flatten
(
*
param_out
).
device
(
*
(
ctx
.
GetEigenDevice
<
Place
>
()))
=
EigenVector
<
T
>::
Flatten
(
param
)
-
lr
*
EigenVector
<
T
>::
Flatten
(
grad
);
EigenVector
<
T
>::
Flatten
(
*
param
)
-
lr
*
EigenVector
<
T
>::
Flatten
(
*
grad
);
}
};
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
bc146e8f
...
...
@@ -18,11 +18,10 @@ namespace operators {
class
SigmoidOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Sigmoid Op only have one output"
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
};
...
...
@@ -38,8 +37,7 @@ public:
class
SigmoidOpGrad
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
return
""
;
...
...
paddle/operators/sigmoid_op.h
浏览文件 @
bc146e8f
...
...
@@ -22,15 +22,14 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
SigmoidKernel
:
public
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
input
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
EigenVector
<
T
>::
Flatten
(
*
output
).
device
(
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
1.0
/
(
1.0
+
(
-
1.0
*
EigenVector
<
T
>::
Flatten
(
input
)).
exp
());
1.0
/
(
1.0
+
(
-
1.0
*
EigenVector
<
T
>::
Flatten
(
*
input
)).
exp
());
}
};
}
// namespace operators
...
...
paddle/operators/softmax_op.cc
浏览文件 @
bc146e8f
...
...
@@ -18,14 +18,13 @@ namespace operators {
class
SoftmaxOp
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Only one input is need for softmax"
);
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
().
size
()
==
2
,
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Only one input is need for softmax"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
().
size
()
==
2
,
"The input of softmax op must be matrix"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Only one output is need for softmax"
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Only one output is need for softmax"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
};
...
...
@@ -41,8 +40,7 @@ public:
class
SoftmaxOpGrad
:
public
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SoftmaxOpGrad"
;
return
""
;
...
...
paddle/operators/softmax_op.h
浏览文件 @
bc146e8f
...
...
@@ -22,12 +22,12 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
SoftmaxKernel
:
public
OpKernel
{
public:
void
Compute
(
const
Kernel
Context
&
context
)
const
override
{
auto
input
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
(
);
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
(
);
void
Compute
(
const
Execution
Context
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
logits
=
EigenMatrix
<
T
>::
From
(
input
);
auto
logits
=
EigenMatrix
<
T
>::
From
(
*
input
);
auto
softmax
=
EigenMatrix
<
T
>::
From
(
*
output
);
const
int
kBatchDim
=
0
;
...
...
paddle/operators/type_alias.h
浏览文件 @
bc146e8f
...
...
@@ -22,7 +22,9 @@ namespace paddle {
namespace
operators
{
using
OpKernel
=
framework
::
OpKernel
;
using
KernelContext
=
framework
::
KernelContext
;
using
InferShapeContext
=
framework
::
InferShapeContext
;
using
ExecutionContext
=
framework
::
ExecutionContext
;
using
Variable
=
framework
::
Variable
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
...
...
paddle/platform/device_context.cc
浏览文件 @
bc146e8f
...
...
@@ -20,12 +20,101 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return
reinterpret_cast
<
const
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
}
CPUDeviceContext
::
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
)
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
Eigen
::
DefaultDevice
*
CPUDeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Place
CPUDeviceContext
::
GetPlace
()
const
{
return
CPUPlace
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
const
{
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
}
#endif
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
// later. Please refer to the implementation of class EigenCudaStreamDevice
// in TensorFlow.
//
// We find that CUDA 7 introduces a new option, the per-thread default stream,
// that has two effects. Please refer to https://devblogs.nvidia.com/
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
//
// So, we decide to use default stream and add –default-stream per-thread nvcc
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
());
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
CUDADeviceContext
::~
CUDADeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
if
(
cublas_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
}
if
(
cudnn_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
if
(
curand_generator_
)
{
PADDLE_ENFORCE
(
dynload
::
curandDestroyGenerator
(
curand_generator_
));
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
));
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
CUDADeviceContext
::
cublas_handle
()
{
if
(
!
cublas_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
}
return
cublas_handle_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
{
if
(
!
cudnn_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
}
return
cudnn_handle_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
}
return
curand_generator_
;
}
#endif // PADDLE_ONLY_CPU
}
// namespace platform
}
// namespace paddle
paddle/platform/device_context.h
浏览文件 @
bc146e8f
...
...
@@ -39,14 +39,13 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
public:
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
();
CPUDeviceContext
(
CPUPlace
);
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
Place
GetPlace
()
const
override
{
Place
retv
=
CPUPlace
();
return
retv
;
}
Place
GetPlace
()
const
override
;
private:
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
...
...
@@ -54,119 +53,46 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU
class
GPUPlaceGuard
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
GPUPlaceGuard
(
GPUPlace
new_place
)
:
previous_
(
GetCurrentDeviceId
())
{
if
(
previous_
!=
new_place
)
{
paddle
::
platform
::
SetDeviceId
(
new_place
.
device
);
}
}
explicit
CUDADeviceContext
(
GPUPlace
);
virtual
~
CUDADeviceContext
();
~
GPUPlaceGuard
()
{
paddle
::
platform
::
SetDeviceId
(
previous_
.
device
);
}
/*! \brief Wait for all operations completion in the stream. */
void
Wait
()
const
;
private:
GPUPlace
previous_
;
};
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
override
;
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
(
&
stream_
));
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
Place
GetPlace
()
const
override
{
Place
retv
=
GPUPlace
();
return
retv
;
}
void
Wait
()
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
Eigen
::
GpuDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasCreate
(
&
blas_handle_
),
"cublasCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
),
"cublasSetStream failed"
);
}
return
blas_handle_
;
}
cudnnHandle_t
cudnn_handle
()
{
if
(
!
dnn_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnCreate
(
&
dnn_handle_
),
"cudnnCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
),
"cudnnSetStream failed"
);
}
return
dnn_handle_
;
}
curandGenerator_t
curand_generator
()
{
if
(
!
rand_generator_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
),
"curandCreateGenerator failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
rand_generator_
,
random_seed_
),
"curandSetPseudoRandomGeneratorSeed failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetStream
(
rand_generator_
,
stream_
),
"curandSetStream failed"
);
}
return
rand_generator_
;
}
~
CUDADeviceContext
()
{
Wait
();
if
(
blas_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasDestroy
(
blas_handle_
),
"cublasDestroy failed"
);
}
if
(
dnn_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnDestroy
(
dnn_handle_
),
"cudnnDestroy failed"
);
}
if
(
rand_generator_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandDestroyGenerator
(
rand_generator_
),
"curandDestroyGenerator failed"
);
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
// clang-format off
/*! \brief Return cublas handle in the device context. */
cublasHandle_t
cublas_handle
();
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
// clang-format on
private:
GPUPlace
gpu_place_
;
cudaStream_t
stream_
;
GPUPlace
place_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
cublasHandle_t
blas_handle_
{
nullptr
};
cudnnHandle_t
dnn_handle_
{
nullptr
};
private:
uint64_t
seed_
;
int
random_seed_
;
curandGenerator_t
rand_generator_
{
nullptr
};
// clang-format off
cudnnHandle_t
cudnn_handle_
=
nullptr
;
cublasHandle_t
cublas_handle_
=
nullptr
;
curandGenerator_t
curand_generator_
=
nullptr
;
// clang-format on
};
#endif
...
...
paddle/platform/enforce.h
浏览文件 @
bc146e8f
...
...
@@ -58,11 +58,6 @@ struct EnforceNotMet : public std::exception {
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
template
<
typename
T
>
inline
void
throw_on_error
(
T
e
)
{
throw_on_error
(
e
,
""
);
}
template
<
typename
...
Args
>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
int
stat
,
const
Args
&
...
args
)
{
...
...
@@ -132,6 +127,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif // PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
throw_on_error
(
T
e
)
{
throw_on_error
(
e
,
""
);
}
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
...
...
paddle/pybind/CMakeLists.txt
浏览文件 @
bc146e8f
cc_library
(
paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op fc_op sgd_op cross_entropy_op
)
add_op fc_op sgd_op cross_entropy_op
recurrent_network_op
)
paddle/pybind/pybind.cc
浏览文件 @
bc146e8f
...
...
@@ -36,6 +36,7 @@ USE_OP(mul);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
USE_OP
(
rowwise_add
);
USE_OP_WITHOUT_KERNEL
(
recurrent_op
);
template
<
typename
ClassType
>
void
ExposeOperator
(
ClassType
&
m
)
{
...
...
@@ -94,6 +95,11 @@ All parameter, weight, gradient are variables in Paddle.
[](
pd
::
Variable
&
self
)
->
pd
::
Tensor
*
{
return
self
.
GetMutable
<
pd
::
Tensor
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_net"
,
[](
pd
::
Variable
&
self
)
->
pd
::
NetOp
*
{
return
self
.
GetMutable
<
pd
::
NetOp
>
();
},
py
::
return_value_policy
::
reference
);
py
::
class_
<
pd
::
Scope
,
std
::
shared_ptr
<
pd
::
Scope
>>
(
m
,
"Scope"
)
...
...
python/paddle/v2/framework/tests/test_recurrent_op.py
0 → 100644
浏览文件 @
bc146e8f
import
paddle.v2.framework.core
as
core
import
unittest
import
numpy
as
np
import
paddle.v2.framework.create_op_creation_methods
as
creation
ops
=
creation
.
op_creations
def
create_tensor
(
scope
,
name
,
shape
):
tensor
=
scope
.
create_var
(
name
).
get_tensor
()
tensor
.
set_dims
(
shape
)
tensor
.
alloc_float
()
tensor
.
set
(
np
.
random
.
random
(
shape
))
return
tensor
class
TestRNN
(
unittest
.
TestCase
):
'''
Test RNNOp
equation:
h_t = \sigma (W x_t + U h_{t-1})
weights:
- W
- U
vars:
- x
memories:
- h
outputs:
- h
'''
def
init
(
self
):
input_dim
=
30
batch_size
=
50
weight_dim
=
15
self
.
scope
=
core
.
Scope
(
None
)
# create vars
create_tensor
(
self
.
scope
,
"x"
,
[
batch_size
,
input_dim
])
create_tensor
(
self
.
scope
,
"W"
,
[
input_dim
,
weight_dim
])
create_tensor
(
self
.
scope
,
"U"
,
[
weight_dim
,
weight_dim
])
create_tensor
(
self
.
scope
,
"h_boot"
,
[
batch_size
,
weight_dim
])
x_alias
=
"x@alias"
y_alias
=
"y@alias"
memory
=
"h@alias"
prememory
=
"h@pre"
output
=
"rnn_out"
output_alias
=
"rnn_out@alias"
# create step net
stepnet_var
=
self
.
scope
.
create_var
(
"stepnet"
)
stepnet
=
stepnet_var
.
get_net
()
# stepnet = core.Net.create()
x_fc_op
=
ops
.
fc
(
X
=
x_alias
,
W
=
"W"
,
Y
=
"Wx"
)
h_fc_op
=
ops
.
fc
(
X
=
prememory
,
W
=
"U"
,
Y
=
"Uh"
)
sum_op
=
ops
.
add_two
(
X
=
"Wx"
,
Y
=
"Uh"
,
Out
=
"sum"
)
sig_op
=
ops
.
sigmoid
(
X
=
"sum"
,
Y
=
memory
)
stepnet
.
add_op
(
x_fc_op
)
stepnet
.
add_op
(
h_fc_op
)
stepnet
.
add_op
(
sum_op
)
stepnet
.
add_op
(
sig_op
)
stepnet
.
complete_add_op
(
True
)
# create RNNOp
rnnop
=
ops
.
recurrent_op
(
# inputs
inlinks
=
[
"x"
],
boot_memories
=
[
"h_boot"
],
step_net
=
"stepnet"
,
# outputs
outlinks
=
[
output
],
step_scopes
=
"step_scopes"
,
# attributes
inlink_alias
=
[
"x@alias"
],
outlink_alias
=
[
output_alias
],
pre_memories
=
[
prememory
],
memories
=
[
memory
])
ctx
=
core
.
DeviceContext
.
cpu_context
()
rnnop
.
infer_shape
(
self
.
scope
)
rnnop
.
run
(
self
.
scope
,
ctx
)
def
test_recurrent
(
self
):
self
.
init
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录