Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4cc42171
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4cc42171
编写于
7月 31, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
merge baidu/develop
上级
4ecf68e0
0973c2c9
变更
92
显示空白变更内容
内联
并排
Showing
92 changed file
with
3079 addition
and
970 deletion
+3079
-970
.pre-commit-config.yaml
.pre-commit-config.yaml
+7
-5
cmake/external/eigen.cmake
cmake/external/eigen.cmake
+1
-10
cmake/flags.cmake
cmake/flags.cmake
+1
-1
go/cmd/master/master.go
go/cmd/master/master.go
+24
-4
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+26
-5
go/glide.lock
go/glide.lock
+168
-8
go/glide.yaml
go/glide.yaml
+11
-0
go/master/c/client.go
go/master/c/client.go
+13
-4
go/master/client.go
go/master/client.go
+39
-31
go/master/client_internal_test.go
go/master/client_internal_test.go
+32
-28
go/master/client_test.go
go/master/client_test.go
+57
-26
go/master/etcd_client.go
go/master/etcd_client.go
+19
-6
go/master/inmem_store.go
go/master/inmem_store.go
+5
-0
go/master/service.go
go/master/service.go
+61
-38
go/master/service_internal_test.go
go/master/service_internal_test.go
+2
-1
go/master/service_test.go
go/master/service_test.go
+68
-0
go/pserver/client/c/cclient.go
go/pserver/client/c/cclient.go
+3
-3
go/pserver/client/c/test/test_train.py
go/pserver/client/c/test/test_train.py
+13
-7
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+56
-42
paddle/api/Evaluator.cpp
paddle/api/Evaluator.cpp
+1
-1
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+2
-4
paddle/framework/detail/tensor-inl.h
paddle/framework/detail/tensor-inl.h
+142
-0
paddle/framework/net.cc
paddle/framework/net.cc
+4
-12
paddle/framework/net.h
paddle/framework/net.h
+5
-19
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+16
-23
paddle/framework/net_proto.proto
paddle/framework/net_proto.proto
+0
-15
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+5
-5
paddle/framework/operator.cc
paddle/framework/operator.cc
+13
-7
paddle/framework/operator.h
paddle/framework/operator.h
+8
-8
paddle/framework/scope.h
paddle/framework/scope.h
+16
-5
paddle/framework/scope_test.cc
paddle/framework/scope_test.cc
+5
-0
paddle/framework/tensor.cc
paddle/framework/tensor.cc
+1
-1
paddle/framework/tensor.h
paddle/framework/tensor.h
+88
-113
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+72
-22
paddle/gserver/activations/ActivationFunction.cpp
paddle/gserver/activations/ActivationFunction.cpp
+2
-2
paddle/memory/detail/buddy_allocator.cc
paddle/memory/detail/buddy_allocator.cc
+27
-29
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+3
-3
paddle/memory/memcpy.h
paddle/memory/memcpy.h
+26
-0
paddle/memory/memory.cc
paddle/memory/memory.cc
+1
-0
paddle/memory/memory.h
paddle/memory/memory.h
+40
-6
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+5
-0
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+12
-17
paddle/operators/add_op.cu
paddle/operators/add_op.cu
+2
-3
paddle/operators/add_op.h
paddle/operators/add_op.h
+8
-11
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+11
-17
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+1
-3
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+6
-8
paddle/operators/fc_op.cc
paddle/operators/fc_op.cc
+17
-22
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+12
-17
paddle/operators/mul_op.cu
paddle/operators/mul_op.cu
+1
-4
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+9
-12
paddle/operators/recurrent_network_op.cc
paddle/operators/recurrent_network_op.cc
+419
-0
paddle/operators/recurrent_network_op.h
paddle/operators/recurrent_network_op.h
+216
-0
paddle/operators/recurrent_network_op_test.cc
paddle/operators/recurrent_network_op_test.cc
+400
-0
paddle/operators/rnn_design.md
paddle/operators/rnn_design.md
+239
-0
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+9
-15
paddle/operators/rowwise_add_op.cu
paddle/operators/rowwise_add_op.cu
+2
-4
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+9
-11
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+8
-13
paddle/operators/sgd_op.cu
paddle/operators/sgd_op.cu
+1
-3
paddle/operators/sgd_op.h
paddle/operators/sgd_op.h
+8
-12
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+12
-20
paddle/operators/sigmoid_op.cu
paddle/operators/sigmoid_op.cu
+1
-3
paddle/operators/sigmoid_op.h
paddle/operators/sigmoid_op.h
+7
-9
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+10
-17
paddle/operators/softmax_op.cu
paddle/operators/softmax_op.cu
+1
-2
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+7
-9
paddle/operators/type_alias.h
paddle/operators/type_alias.h
+51
-0
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+90
-1
paddle/platform/device_context.h
paddle/platform/device_context.h
+35
-109
paddle/platform/enforce.h
paddle/platform/enforce.h
+42
-36
paddle/platform/enforce_test.cc
paddle/platform/enforce_test.cc
+1
-1
paddle/pybind/CMakeLists.txt
paddle/pybind/CMakeLists.txt
+1
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+31
-18
paddle/trainer/NewRemoteParameterUpdater.cpp
paddle/trainer/NewRemoteParameterUpdater.cpp
+5
-1
paddle/utils/Error.h
paddle/utils/Error.h
+4
-9
paddle/utils/tests/test_Error.cpp
paddle/utils/tests/test_Error.cpp
+4
-4
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+1
-2
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+1
-1
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+11
-20
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+8
-0
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+13
-31
python/paddle/v2/dataset/mq2007.py
python/paddle/v2/dataset/mq2007.py
+2
-2
python/paddle/v2/framework/create_op_creation_methods.py
python/paddle/v2/framework/create_op_creation_methods.py
+3
-0
python/paddle/v2/framework/network.py
python/paddle/v2/framework/network.py
+124
-0
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+3
-2
python/paddle/v2/framework/tests/test_net.py
python/paddle/v2/framework/tests/test_net.py
+2
-2
python/paddle/v2/framework/tests/test_network.py
python/paddle/v2/framework/tests/test_network.py
+32
-0
python/paddle/v2/framework/tests/test_recurrent_op.py
python/paddle/v2/framework/tests/test_recurrent_op.py
+92
-0
python/paddle/v2/inference.py
python/paddle/v2/inference.py
+7
-0
python/paddle/v2/layer.py
python/paddle/v2/layer.py
+0
-3
python/paddle/v2/master/client.py
python/paddle/v2/master/client.py
+0
-1
未找到文件。
.pre-commit-config.yaml
浏览文件 @
4cc42171
...
@@ -22,9 +22,11 @@
...
@@ -22,9 +22,11 @@
hooks
:
hooks
:
-
id
:
clang-formater
-
id
:
clang-formater
-
repo
:
https://github.com/PaddlePaddle/pre-commit-golang
-
repo
:
https://github.com/PaddlePaddle/pre-commit-golang
sha
:
16398aeccf263adaf53b2495eed0406347d76281
sha
:
8337620115c25ff8333f1b1a493bd031049bd7c0
hooks
:
hooks
:
-
id
:
go-fmt
-
id
:
go-fmt
types
:
[
go
]
types
:
-
go
-
id
:
gometalinter
-
id
:
gometalinter
types
:
[
go
]
types
:
-
go
cmake/external/eigen.cmake
浏览文件 @
4cc42171
...
@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
...
@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add
(
ExternalProject_Add
(
extern_eigen3
extern_eigen3
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY
"https://github.com/RLovelett/eigen.git"
GIT_REPOSITORY
"https://github.com/RLovelett/eigen.git"
GIT_TAG
"
a46d2e7337c4656f00abe54a8115f6d76153a048
"
GIT_TAG
"
master
"
PREFIX
${
EIGEN_SOURCE_DIR
}
PREFIX
${
EIGEN_SOURCE_DIR
}
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
cmake/flags.cmake
浏览文件 @
4cc42171
...
@@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
...
@@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
# So, don't set these flags here.
LIST
(
APPEND CUDA_NVCC_FLAGS -std=c++11
)
LIST
(
APPEND CUDA_NVCC_FLAGS -std=c++11
--default-stream per-thread
)
LIST
(
APPEND CUDA_NVCC_FLAGS --use_fast_math
)
LIST
(
APPEND CUDA_NVCC_FLAGS --use_fast_math
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
...
...
go/cmd/master/master.go
浏览文件 @
4cc42171
...
@@ -19,6 +19,8 @@ import (
...
@@ -19,6 +19,8 @@ import (
"net"
"net"
"net/http"
"net/http"
"net/rpc"
"net/rpc"
"os"
"os/signal"
"strconv"
"strconv"
"strings"
"strings"
"time"
"time"
...
@@ -68,6 +70,20 @@ func main() {
...
@@ -68,6 +70,20 @@ func main() {
store
=
&
master
.
InMemStore
{}
store
=
&
master
.
InMemStore
{}
}
}
shutdown
:=
func
()
{
log
.
Infoln
(
"shutting down gracefully"
)
err
:=
store
.
Shutdown
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
}
// Guaranteed to run even panic happens.
defer
shutdown
()
c
:=
make
(
chan
os
.
Signal
,
1
)
signal
.
Notify
(
c
,
os
.
Interrupt
)
s
,
err
:=
master
.
NewService
(
store
,
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
s
,
err
:=
master
.
NewService
(
store
,
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatal
(
err
)
log
.
Fatal
(
err
)
...
@@ -84,8 +100,12 @@ func main() {
...
@@ -84,8 +100,12 @@ func main() {
log
.
Fatal
(
err
)
log
.
Fatal
(
err
)
}
}
go
func
()
{
err
=
http
.
Serve
(
l
,
nil
)
err
=
http
.
Serve
(
l
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Fatal
(
err
)
log
.
Fatal
(
err
)
}
}
}()
<-
c
}
}
go/cmd/pserver/pserver.go
浏览文件 @
4cc42171
...
@@ -18,6 +18,8 @@ import (
...
@@ -18,6 +18,8 @@ import (
"net"
"net"
"net/http"
"net/http"
"net/rpc"
"net/rpc"
"os"
"os/signal"
"strconv"
"strconv"
"time"
"time"
...
@@ -33,7 +35,8 @@ func main() {
...
@@ -33,7 +35,8 @@ func main() {
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of this pserver, should be larger or equal than 0"
)
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of this pserver, should be larger or equal than 0"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Duration
(
"etcd-timeout"
,
5
*
time
.
Second
,
"timeout for etcd calls"
)
dialTimeout
:=
flag
.
Duration
(
"dial-timeout"
,
5
*
time
.
Second
,
"dial timeout"
)
etcdTTL
:=
flag
.
Int
(
"etcd-ttl"
,
5
,
"etcd time to live in seconds"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Duration
(
"checkpoint-interval"
,
600
*
time
.
Second
,
"save checkpoint per interval seconds"
)
checkpointInterval
:=
flag
.
Duration
(
"checkpoint-interval"
,
600
*
time
.
Second
,
"save checkpoint per interval seconds"
)
...
@@ -53,7 +56,7 @@ func main() {
...
@@ -53,7 +56,7 @@ func main() {
if
*
index
>=
0
{
if
*
index
>=
0
{
idx
=
*
index
idx
=
*
index
}
else
{
}
else
{
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
etcdTimeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
dialTimeout
,
*
etcdTTL
)
idx
,
err
=
e
.
Register
(
*
port
)
idx
,
err
=
e
.
Register
(
*
port
)
candy
.
Must
(
err
)
candy
.
Must
(
err
)
...
@@ -67,6 +70,20 @@ func main() {
...
@@ -67,6 +70,20 @@ func main() {
}
}
}
}
shutdown
:=
func
()
{
log
.
Infoln
(
"shutting down gracefully"
)
sErr
:=
e
.
Shutdown
()
if
sErr
!=
nil
{
log
.
Errorln
(
sErr
)
}
}
// Guaranteed to run even panic happens.
defer
shutdown
()
c
:=
make
(
chan
os
.
Signal
,
1
)
signal
.
Notify
(
c
,
os
.
Interrupt
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
candy
.
Must
(
err
)
candy
.
Must
(
err
)
...
@@ -77,7 +94,11 @@ func main() {
...
@@ -77,7 +94,11 @@ func main() {
l
,
err
:=
net
.
Listen
(
"tcp"
,
":"
+
strconv
.
Itoa
(
*
port
))
l
,
err
:=
net
.
Listen
(
"tcp"
,
":"
+
strconv
.
Itoa
(
*
port
))
candy
.
Must
(
err
)
candy
.
Must
(
err
)
go
func
()
{
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
err
=
http
.
Serve
(
l
,
nil
)
err
=
http
.
Serve
(
l
,
nil
)
candy
.
Must
(
err
)
candy
.
Must
(
err
)
}()
<-
c
}
}
go/glide.lock
浏览文件 @
4cc42171
hash:
a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
hash:
2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-07-
11T10:04:40.786745417
+08:00
updated: 2017-07-
29T07:34:48.722757905
+08:00
imports:
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
- name: github.com/coreos/etcd
version: c
b2a496c4ddd1c87a9f280e116649b599999ec79
version: c
31bec0f29facff13f7c3e3d948e55dd6689ed42
subpackages:
subpackages:
- alarm
- auth
- auth/authpb
- auth/authpb
- client
- clientv3
- clientv3
- clientv3/concurrency
- clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes
- etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb
- etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb
- mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf
- name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7
version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages:
subpackages:
...
@@ -17,14 +107,61 @@ imports:
...
@@ -17,14 +107,61 @@ imports:
- proto
- proto
- name: github.com/golang/snappy
- name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9
version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag
- name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/sirupsen/logrus
- name: github.com/sirupsen/logrus
version:
7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
version:
a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
subpackages:
...
@@ -36,11 +173,15 @@ imports:
...
@@ -36,11 +173,15 @@ imports:
- lex/httplex
- lex/httplex
- trace
- trace
- name: golang.org/x/sys
- name: golang.org/x/sys
version: abf9c25f54453410d0c6668e519582a9e1115027
version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages:
subpackages:
- unix
- unix
- name: golang.org/x/text
- name: golang.org/x/text
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages:
subpackages:
- secure/bidirule
- secure/bidirule
- transform
- transform
...
@@ -60,4 +201,23 @@ imports:
...
@@ -60,4 +201,23 @@ imports:
- stats
- stats
- tap
- tap
- transport
- transport
testImports: []
- name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
go/glide.yaml
浏览文件 @
4cc42171
...
@@ -6,8 +6,19 @@ import:
...
@@ -6,8 +6,19 @@ import:
subpackages
:
subpackages
:
-
clientv3
-
clientv3
-
clientv3/concurrency
-
clientv3/concurrency
-
embed
-
etcdserver
-
package
:
github.com/namsral/flag
-
package
:
github.com/namsral/flag
version
:
^1.7.4-pre
version
:
^1.7.4-pre
-
package
:
github.com/sirupsen/logrus
-
package
:
github.com/sirupsen/logrus
version
:
^1.0.0
version
:
^1.0.0
-
package
:
github.com/topicai/candy
-
package
:
github.com/topicai/candy
-
package
:
golang.org/x/crypto
vcs
:
git
repo
:
https://github.com/golang/crypto.git
-
package
:
golang.org/x/sys
vcs
:
git
repo
:
https://github.com/golang/sys.git
-
package
:
golang.org/x/text
vcs
:
git
repo
:
https://github.com/golang/text.git
go/master/c/client.go
浏览文件 @
4cc42171
...
@@ -18,7 +18,6 @@ package main
...
@@ -18,7 +18,6 @@ package main
#include <stdlib.h>
#include <stdlib.h>
#include <string.h>
#include <string.h>
#include <stdio.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_MASTER_ERROR -1
...
@@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
...
@@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove
(
client
)
remove
(
client
)
}
}
//export paddle_start_get_records
func
paddle_start_get_records
(
client
C
.
paddle_master_client
,
pass
C
.
int
)
{
c
:=
get
(
client
)
c
.
StartGetRecords
(
int
(
pass
))
}
//export paddle_set_dataset
//export paddle_set_dataset
func
paddle_set_dataset
(
client
C
.
paddle_master_client
,
path
**
C
.
char
,
size
C
.
int
)
C
.
int
{
func
paddle_set_dataset
(
client
C
.
paddle_master_client
,
path
**
C
.
char
,
size
C
.
int
)
C
.
int
{
c
:=
get
(
client
)
c
:=
get
(
client
)
...
@@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
...
@@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// paddle_next_record gets the nexts training record.
// paddle_next_record gets the nexts training record.
//
//
// returns number of bytes of the records if success, -1 if failed.
// returns number of bytes of the records if success, -1 if failed
, -2 if pass end
.
//
//
//export paddle_next_record
//export paddle_next_record
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
c
:=
get
(
client
)
c
:=
get
(
client
)
r
,
err
:=
c
.
NextRecord
()
r
,
err
:=
c
.
NextRecord
()
if
err
!=
nil
{
if
err
!=
nil
{
// Error
// NOTE: use errors to indicate pass ends
// TODO: return the type of error?
if
err
.
Error
()
==
master
.
ErrAllTaskFailed
.
Error
()
||
err
.
Error
()
==
master
.
ErrNoMoreAvailable
.
Error
()
||
err
.
Error
()
==
master
.
ErrPassBefore
.
Error
()
{
return
-
2
}
*
record
=
(
*
C
.
uchar
)(
nil
)
*
record
=
(
*
C
.
uchar
)(
nil
)
return
-
1
return
-
1
}
}
...
...
go/master/client.go
浏览文件 @
4cc42171
...
@@ -16,7 +16,6 @@ package master
...
@@ -16,7 +16,6 @@ package master
import
(
import
(
"os"
"os"
"sync"
"time"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/connection"
...
@@ -29,7 +28,7 @@ import (
...
@@ -29,7 +28,7 @@ import (
type
Client
struct
{
type
Client
struct
{
conn
*
connection
.
Conn
conn
*
connection
.
Conn
ch
chan
record
ch
chan
record
initChOnce
sync
.
Once
bufSize
int
}
}
type
record
struct
{
type
record
struct
{
...
@@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
...
@@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if
bufSize
<=
0
{
if
bufSize
<=
0
{
return
nil
return
nil
}
}
c
.
bufSize
=
bufSize
c
.
initChOnce
.
Do
(
func
()
{
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
getRecords
()
})
return
nil
return
nil
}
}
}
}
...
@@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
...
@@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
}
}
c
.
ch
=
make
(
chan
record
,
c
.
bufSize
)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time
.
Sleep
(
time
.
Second
)
return
c
,
nil
return
c
,
nil
}
}
func
(
c
*
Client
)
getRecords
()
{
// StartGetRecords must be called at beginning of each pass
func
(
c
*
Client
)
StartGetRecords
(
passID
int
)
{
go
c
.
getRecords
(
passID
)
}
func
(
c
*
Client
)
getRecords
(
passID
int
)
{
for
{
for
{
t
,
err
:=
c
.
getTask
()
t
,
err
:=
c
.
getTask
(
passID
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
if
err
.
Error
()
==
ErrPassBefore
.
Error
()
||
time
.
Sleep
(
3
*
time
.
Second
)
err
.
Error
()
==
ErrNoMoreAvailable
.
Error
()
||
err
.
Error
()
==
ErrAllTaskFailed
.
Error
()
{
c
.
ch
<-
record
{
nil
,
err
}
break
}
if
err
.
Error
()
==
ErrPassAfter
.
Error
()
{
// wait util last pass finishes
time
.
Sleep
(
time
.
Second
*
3
)
continue
continue
}
}
log
.
Errorf
(
"getTask error: %s"
,
err
)
}
for
_
,
chunk
:=
range
t
.
Chunks
{
for
_
,
chunk
:=
range
t
.
Chunks
{
f
,
e
rr
:=
os
.
Open
(
chunk
.
Path
)
f
,
e
:=
os
.
Open
(
chunk
.
Path
)
if
e
rr
!=
nil
{
if
e
!=
nil
{
log
.
Errorln
(
e
rr
)
log
.
Errorln
(
e
)
continue
continue
}
}
...
@@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
...
@@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}
}
}
// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
//
// SetDataset can be call multiple times from different nodes. But
// After all tasks are done, another call of SetDataset will start another pass.
// only the first call will be honored.
func
(
c
*
Client
)
SetDataset
(
globPaths
[]
string
)
error
{
func
(
c
*
Client
)
SetDataset
(
globPaths
[]
string
)
error
{
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
err
:=
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
return
err
}
}
// getTask gets a new task from the master server.
// getTask gets a new task from the master server.
func
(
c
*
Client
)
getTask
()
(
Task
,
error
)
{
func
(
c
*
Client
)
getTask
(
passID
int
)
(
Task
,
error
)
{
var
t
Task
var
t
Task
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
passID
,
&
t
)
return
t
,
err
return
t
,
err
}
}
...
@@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
...
@@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// NextRecord will block until the next record is available. It is
// thread-safe.
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
c
.
initChOnce
.
Do
(
func
()
{
// initialize with in case WithBuffer is not used.
c
.
ch
=
make
(
chan
record
,
0
)
go
c
.
getRecords
()
})
r
:=
<-
c
.
ch
r
:=
<-
c
.
ch
return
r
.
r
,
r
.
err
return
r
.
r
,
r
.
err
}
}
...
...
go/master/client_internal_test.go
浏览文件 @
4cc42171
...
@@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
...
@@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic
(
err
)
panic
(
err
)
}
}
go
func
(
l
net
.
Listener
)
{
go
func
(
l
net
.
Listener
)
{
s
,
e
rr
:=
NewService
(
&
InMemStore
{},
chunkPerTask
,
time
.
Second
,
1
)
s
,
sE
rr
:=
NewService
(
&
InMemStore
{},
chunkPerTask
,
time
.
Second
,
1
)
if
e
rr
!=
nil
{
if
sE
rr
!=
nil
{
panic
(
e
rr
)
panic
(
sE
rr
)
}
}
server
:=
rpc
.
NewServer
()
server
:=
rpc
.
NewServer
()
e
rr
=
server
.
Register
(
s
)
sE
rr
=
server
.
Register
(
s
)
if
e
rr
!=
nil
{
if
sE
rr
!=
nil
{
panic
(
e
rr
)
panic
(
sE
rr
)
}
}
mux
:=
http
.
NewServeMux
()
mux
:=
http
.
NewServeMux
()
mux
.
Handle
(
rpc
.
DefaultRPCPath
,
server
)
mux
.
Handle
(
rpc
.
DefaultRPCPath
,
server
)
e
rr
=
http
.
Serve
(
l
,
mux
)
sE
rr
=
http
.
Serve
(
l
,
mux
)
if
e
rr
!=
nil
{
if
sE
rr
!=
nil
{
panic
(
e
rr
)
panic
(
sE
rr
)
}
}
}(
l
)
}(
l
)
...
@@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
...
@@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch
:=
make
(
chan
string
,
1
)
ch
:=
make
(
chan
string
,
1
)
ch
<-
addr
ch
<-
addr
go
c
.
monitorMaster
(
ch
)
go
c
.
monitorMaster
(
ch
)
err
=
c
.
SetDataset
([]
string
{
path
})
err
=
c
.
SetDataset
([]
string
{
path
})
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
...
@@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
...
@@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass
:=
func
(
i
int
)
{
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
Task
var
tasks
[]
Task
for
idx
:=
0
;
idx
<
totalTask
;
idx
++
{
for
idx
:=
0
;
idx
<
totalTask
;
idx
++
{
task
,
err
:=
c
.
getTask
(
)
task
,
cErr
:=
c
.
getTask
(
i
)
if
err
!=
nil
{
if
cErr
!=
nil
&&
cErr
.
Error
()
!=
ErrNoMoreAvailable
.
Error
()
&&
cErr
.
Error
()
!=
ErrPassAfter
.
Error
()
{
t
.
Fatalf
(
"
Error: %v, pass: %d
\n
"
,
e
rr
,
i
)
t
.
Fatalf
(
"
error: %v, pass: %d
\n
"
,
cE
rr
,
i
)
}
}
tasks
=
append
(
tasks
,
task
)
tasks
=
append
(
tasks
,
task
)
}
}
_
,
err
=
c
.
getTask
()
// getting task before task finishes should return error
if
err
==
nil
{
_
,
cErr
:=
c
.
getTask
(
i
)
if
cErr
==
nil
{
t
.
Fatalf
(
"Should get error, pass: %d
\n
"
,
i
)
t
.
Fatalf
(
"Should get error, pass: %d
\n
"
,
i
)
}
}
e
rr
=
c
.
taskFinished
(
tasks
[
0
]
.
Meta
.
ID
)
cE
rr
=
c
.
taskFinished
(
tasks
[
0
]
.
Meta
.
ID
)
if
e
rr
!=
nil
{
if
cE
rr
!=
nil
{
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
e
rr
,
i
)
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
cE
rr
,
i
)
}
}
// call taskFailed once won't put the task to failed queue, just ensure
err
=
c
.
taskFailed
(
tasks
[
0
]
.
Meta
)
// the call
if
err
!=
nil
{
cErr
=
c
.
taskFailed
(
tasks
[
0
]
.
Meta
)
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
err
,
i
)
if
cErr
!=
nil
{
t
.
Fatalf
(
"Error: %v, pass: %d
\n
"
,
cErr
,
i
)
}
}
tasks
=
tasks
[
1
:
]
tasks
=
tasks
[
1
:
]
task
,
err
:=
c
.
getTask
(
)
_
,
cErr
=
c
.
getTask
(
i
)
if
err
!=
nil
{
if
cErr
!=
nil
&&
cErr
.
Error
()
!=
ErrNoMoreAvailable
.
Error
()
&&
cErr
.
Error
()
!=
ErrPassAfter
.
Error
()
{
t
.
Fatal
(
e
rr
)
t
.
Fatal
f
(
"Should be ErrNoMoreAvailable or ErrPassAfter: %s"
,
cE
rr
)
}
}
tasks
=
append
(
tasks
,
task
)
for
_
,
task
:=
range
tasks
{
for
_
,
task
:=
range
tasks
{
e
rr
=
c
.
taskFinished
(
task
.
Meta
.
ID
)
cE
rr
=
c
.
taskFinished
(
task
.
Meta
.
ID
)
if
e
rr
!=
nil
{
if
cE
rr
!=
nil
{
t
.
Fatal
f
(
"Error: %v, pass: %d
\n
"
,
err
,
i
)
t
.
Fatal
(
cErr
)
}
}
}
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
for
i
:=
0
;
i
<
10
;
i
++
{
// init pass data
c
.
StartGetRecords
(
i
)
checkOnePass
(
i
)
checkOnePass
(
i
)
}
}
}
}
go/master/client_test.go
浏览文件 @
4cc42171
...
@@ -20,8 +20,10 @@ import (
...
@@ -20,8 +20,10 @@ import (
"net/http"
"net/http"
"net/rpc"
"net/rpc"
"os"
"os"
"runtime"
"strconv"
"strconv"
"strings"
"strings"
"sync"
"testing"
"testing"
"time"
"time"
...
@@ -29,6 +31,18 @@ import (
...
@@ -29,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio"
"github.com/PaddlePaddle/recordio"
)
)
// tool function for testing output goroutine ids
func
goid
()
int
{
var
buf
[
64
]
byte
n
:=
runtime
.
Stack
(
buf
[
:
],
false
)
idField
:=
strings
.
Fields
(
strings
.
TrimPrefix
(
string
(
buf
[
:
n
]),
"goroutine "
))[
0
]
id
,
err
:=
strconv
.
Atoi
(
idField
)
if
err
!=
nil
{
panic
(
fmt
.
Sprintf
(
"cannot get goroutine id: %v"
,
err
))
}
return
id
}
func
TestNextRecord
(
t
*
testing
.
T
)
{
func
TestNextRecord
(
t
*
testing
.
T
)
{
const
(
const
(
path
=
"/tmp/master_client_TestFull"
path
=
"/tmp/master_client_TestFull"
...
@@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
...
@@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
panic
(
err
)
}
}
go
func
(
l
net
.
Listener
)
{
go
func
(
l
net
.
Listener
)
{
s
,
err
:=
master
.
NewService
(
&
master
.
InMemStore
{},
1
0
,
time
.
Second
,
1
)
s
,
err
:=
master
.
NewService
(
&
master
.
InMemStore
{},
1
,
time
.
Second
*
60
,
1
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
@@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
...
@@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
panic
(
err
)
}
}
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
:=
recordio
.
NewWriter
(
f
,
1
,
-
1
)
for
i
:=
0
;
i
<
total
;
i
++
{
for
i
:=
0
;
i
<
total
;
i
++
{
_
,
err
=
w
.
Write
([]
byte
{
byte
(
i
)})
_
,
err
=
w
.
Write
([]
byte
{
byte
(
i
)})
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
...
@@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
panic
(
err
)
panic
(
err
)
}
}
c
,
err
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
10
))
// start several client to test task fetching
if
err
!=
nil
{
var
wg
sync
.
WaitGroup
panic
(
err
)
for
i
:=
0
;
i
<
4
;
i
++
{
wg
.
Add
(
1
)
// test for multiple concurrent clients
go
func
()
{
defer
wg
.
Done
()
// each go-routine needs a single client connection instance
c
,
e
:=
master
.
NewClient
(
master
.
WithAddr
(
fmt
.
Sprintf
(
":%d"
,
p
)),
master
.
WithBuffer
(
1
))
if
e
!=
nil
{
t
.
Fatal
(
e
)
}
}
e
=
c
.
SetDataset
([]
string
{
path
})
err
=
c
.
SetDataset
([]
string
{
path
})
if
e
!=
nil
{
if
err
!=
nil
{
panic
(
e
)
panic
(
err
)
}
}
// test for n passes
for
pass
:=
0
;
pass
<
10
;
pass
++
{
c
.
StartGetRecords
(
pass
)
for
pass
:=
0
;
pass
<
50
;
pass
++
{
received
:=
make
(
map
[
byte
]
bool
)
received
:=
make
(
map
[
byte
]
bool
)
for
i
:=
0
;
i
<
total
;
i
++
{
taskid
:=
0
r
,
err
:=
c
.
NextRecord
()
for
{
if
err
!=
nil
{
r
,
e
:=
c
.
NextRecord
()
t
.
Fatal
(
pass
,
i
,
"Read error:"
,
err
)
if
e
!=
nil
{
// ErrorPassAfter will wait, else break for next pass
if
e
.
Error
()
==
master
.
ErrPassBefore
.
Error
()
||
e
.
Error
()
==
master
.
ErrNoMoreAvailable
.
Error
()
{
break
}
t
.
Fatal
(
pass
,
taskid
,
"Read error:"
,
e
)
}
}
if
len
(
r
)
!=
1
{
if
len
(
r
)
!=
1
{
t
.
Fatal
(
pass
,
i
,
"Length should be 1."
,
r
)
t
.
Fatal
(
pass
,
taskid
,
"Length should be 1."
,
r
)
}
}
if
received
[
r
[
0
]]
{
if
received
[
r
[
0
]]
{
t
.
Fatal
(
pass
,
i
,
"Received duplicate."
,
received
,
r
)
t
.
Fatal
(
pass
,
taskid
,
"Received duplicate."
,
received
,
r
)
}
}
taskid
++
received
[
r
[
0
]]
=
true
received
[
r
[
0
]]
=
true
}
}
}
}
}()
}
wg
.
Wait
()
}
}
go/master/etcd_client.go
浏览文件 @
4cc42171
...
@@ -39,15 +39,12 @@ type EtcdClient struct {
...
@@ -39,15 +39,12 @@ type EtcdClient struct {
statePath
string
statePath
string
client
*
clientv3
.
Client
client
*
clientv3
.
Client
lock
*
concurrency
.
Mutex
lock
*
concurrency
.
Mutex
sess
*
concurrency
.
Session
}
}
// NewEtcdClient creates a new EtcdClient.
// NewEtcdClient creates a new EtcdClient.
func
NewEtcdClient
(
endpoints
[]
string
,
addr
string
,
lockPath
,
addrPath
,
statePath
string
,
ttlSec
int
)
(
*
EtcdClient
,
error
)
{
func
NewEtcdClient
(
endpoints
[]
string
,
addr
string
,
lockPath
,
addrPath
,
statePath
string
,
ttlSec
int
)
(
*
EtcdClient
,
error
)
{
log
.
Debugf
(
"Connecting to etcd at %v"
,
endpoints
)
log
.
Debugf
(
"Connecting to etcd at %v"
,
endpoints
)
// TODO(helin): gracefully shutdown etcd store. Because etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
endpoints
,
Endpoints
:
endpoints
,
DialTimeout
:
dialTimeout
,
DialTimeout
:
dialTimeout
,
...
@@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
...
@@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// multiple master servers running), and the cluster management
// software will kill one of them.
// software will kill one of them.
log
.
Debug
f
(
"Trying to acquire lock at %s."
,
lockPath
)
log
.
Info
f
(
"Trying to acquire lock at %s."
,
lockPath
)
err
=
lock
.
Lock
(
context
.
TODO
())
err
=
lock
.
Lock
(
context
.
TODO
())
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
log
.
Debug
f
(
"Successfully acquired lock at %s."
,
lockPath
)
log
.
Info
f
(
"Successfully acquired lock at %s."
,
lockPath
)
put
:=
clientv3
.
OpPut
(
addrPath
,
addr
)
put
:=
clientv3
.
OpPut
(
addrPath
,
addr
)
resp
,
err
:=
cli
.
Txn
(
context
.
Background
())
.
If
(
lock
.
IsOwner
())
.
Then
(
put
)
.
Commit
()
resp
,
err
:=
cli
.
Txn
(
context
.
Background
())
.
If
(
lock
.
IsOwner
())
.
Then
(
put
)
.
Commit
()
...
@@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
...
@@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath
:
statePath
,
statePath
:
statePath
,
client
:
cli
,
client
:
cli
,
lock
:
lock
,
lock
:
lock
,
sess
:
sess
,
}
}
return
e
,
nil
return
e
,
nil
...
@@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) {
...
@@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) {
return
state
,
nil
return
state
,
nil
}
}
// Shutdown shuts down the etcd client gracefully.
func
(
e
*
EtcdClient
)
Shutdown
()
error
{
err
:=
e
.
sess
.
Close
()
newErr
:=
e
.
client
.
Close
()
if
newErr
!=
nil
{
if
err
==
nil
{
err
=
newErr
}
else
{
log
.
Errorln
(
newErr
)
}
}
return
err
}
// GetKey gets the value by the specify key.
// GetKey gets the value by the specify key.
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
time
.
Duration
)
(
string
,
error
)
{
func
GetKey
(
c
*
clientv3
.
Client
,
key
string
,
timeout
time
.
Duration
)
(
string
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
...
...
go/master/inmem_store.go
浏览文件 @
4cc42171
...
@@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
...
@@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return
m
.
buf
,
nil
return
m
.
buf
,
nil
}
}
// Shutdown shuts down the in mem store.
func
(
m
*
InMemStore
)
Shutdown
()
error
{
return
nil
}
go/master/service.go
浏览文件 @
4cc42171
...
@@ -19,6 +19,7 @@ import (
...
@@ -19,6 +19,7 @@ import (
"compress/gzip"
"compress/gzip"
"encoding/gob"
"encoding/gob"
"errors"
"errors"
"math/rand"
"os"
"os"
"path/filepath"
"path/filepath"
"sync"
"sync"
...
@@ -33,10 +34,23 @@ const (
...
@@ -33,10 +34,23 @@ const (
dialTimeout
=
5
*
time
.
Second
dialTimeout
=
5
*
time
.
Second
)
)
// ErrAllTaskFailed occur when tasks are in done or failed state.
var
ErrAllTaskFailed
=
errors
.
New
(
"all task finished"
)
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var
ErrNoMoreAvailable
=
errors
.
New
(
"no more available task"
)
// ErrPassBefore client side pass number does not match with master counter.
var
ErrPassBefore
=
errors
.
New
(
"pass number smaller than master"
)
// ErrPassAfter client side pass number does not match with master counter.
var
ErrPassAfter
=
errors
.
New
(
"pass number larger than master"
)
// Store is the interface for save and load the master state.
// Store is the interface for save and load the master state.
type
Store
interface
{
type
Store
interface
{
Save
([]
byte
)
error
Save
([]
byte
)
error
Load
()
([]
byte
,
error
)
Load
()
([]
byte
,
error
)
Shutdown
()
error
}
}
// Chunk is a chunk of data consisted of several data instances.
// Chunk is a chunk of data consisted of several data instances.
...
@@ -75,17 +89,26 @@ type Service struct {
...
@@ -75,17 +89,26 @@ type Service struct {
chunksPerTask
int
chunksPerTask
int
timeoutDur
time
.
Duration
timeoutDur
time
.
Duration
failureMax
int
failureMax
int
ready
chan
struct
{}
store
Store
store
Store
mu
sync
.
Mutex
ready
chan
struct
{}
initDone
bool
initDone
bool
mu
sync
.
Mutex
taskQueues
taskQueues
taskQueues
taskQueues
currPass
int
jobTasks
[]
taskEntry
savingTrainer
string
savingTrainer
string
}
}
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
func
partition
(
chunks
[]
Chunk
,
chunksPerTask
int
)
[]
taskEntry
{
id
:=
0
// generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart
:=
rand
.
Int
()
counter
:=
0
timestamp
:=
time
.
Now
()
.
Nanosecond
()
id
:=
timestamp
+
randStart
+
counter
if
chunksPerTask
<=
0
{
if
chunksPerTask
<=
0
{
chunksPerTask
=
1
chunksPerTask
=
1
}
}
...
@@ -95,7 +118,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
...
@@ -95,7 +118,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for
i
,
c
:=
range
chunks
{
for
i
,
c
:=
range
chunks
{
if
i
%
chunksPerTask
==
0
&&
len
(
cur
.
Task
.
Chunks
)
>
0
{
if
i
%
chunksPerTask
==
0
&&
len
(
cur
.
Task
.
Chunks
)
>
0
{
cur
.
Task
.
Meta
.
ID
=
id
cur
.
Task
.
Meta
.
ID
=
id
id
++
counter
++
id
=
timestamp
+
randStart
+
counter
result
=
append
(
result
,
cur
)
result
=
append
(
result
,
cur
)
cur
.
Task
.
Chunks
=
nil
cur
.
Task
.
Chunks
=
nil
}
}
...
@@ -266,19 +290,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
...
@@ -266,19 +290,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return
err
return
err
}
}
s
.
taskQueues
.
Todo
=
partition
(
chunks
,
s
.
chunksPerTask
)
s
.
jobTasks
=
partition
(
chunks
,
s
.
chunksPerTask
)
s
.
taskQueues
.
Todo
=
s
.
jobTasks
err
=
s
.
snapshot
()
err
=
s
.
snapshot
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
err
)
log
.
Errorln
(
err
)
return
err
return
err
}
}
close
(
s
.
ready
)
close
(
s
.
ready
)
s
.
initDone
=
true
s
.
initDone
=
true
return
nil
return
nil
}
}
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func
(
s
*
Service
)
processFailedTask
(
t
taskEntry
,
epoch
int
)
{
func
(
s
*
Service
)
processFailedTask
(
t
taskEntry
,
epoch
int
)
{
if
t
.
Task
.
Meta
.
Epoch
!=
epoch
{
if
t
.
Task
.
Meta
.
Epoch
!=
epoch
{
// new epoch, task launched after the
// new epoch, task launched after the
...
@@ -302,8 +328,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
...
@@ -302,8 +328,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
return
return
}
}
log
.
Warningf
(
"Task %v failed %d times,
discard
."
,
t
.
Task
,
t
.
NumFailure
)
log
.
Warningf
(
"Task %v failed %d times,
re-dispatch
."
,
t
.
Task
,
t
.
NumFailure
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
t
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
t
)
return
}
}
func
(
s
*
Service
)
checkTimeoutFunc
(
taskID
int
,
epoch
int
)
func
()
{
func
(
s
*
Service
)
checkTimeoutFunc
(
taskID
int
,
epoch
int
)
func
()
{
...
@@ -331,37 +358,30 @@ func (s *Service) logFields() log.Fields {
...
@@ -331,37 +358,30 @@ func (s *Service) logFields() log.Fields {
}
}
// GetTask gets a new task from the service.
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
_
int
,
task
*
Task
)
error
{
// passID is the client side pass count
func
(
s
*
Service
)
GetTask
(
passID
int
,
task
*
Task
)
error
{
select
{
select
{
case
<-
s
.
ready
:
case
<-
s
.
ready
:
}
}
s
.
mu
.
Lock
()
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
defer
s
.
mu
.
Unlock
()
if
passID
<
s
.
currPass
{
return
ErrPassBefore
}
if
passID
>
s
.
currPass
{
// Client may get run to pass after master when one client faster than the
// other
return
ErrPassAfter
}
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Done
)
==
0
&&
len
(
s
.
taskQueues
.
Pending
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"All tasks failed, may start next pass"
)
err
:=
errors
.
New
(
"all task failed"
)
return
ErrAllTaskFailed
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"All tasks failed."
)
return
err
}
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err
:=
errors
.
New
(
"no more available task"
)
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"No more available task."
)
log
.
WithFields
(
s
.
logFields
())
.
Warningln
(
"No more available task."
)
return
err
return
ErrNoMoreAvailable
}
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Done
=
nil
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo task, but trainer is requesting task to do. Move all done task to todo."
)
}
}
t
:=
s
.
taskQueues
.
Todo
[
0
]
t
:=
s
.
taskQueues
.
Todo
[
0
]
...
@@ -381,7 +401,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
...
@@ -381,7 +401,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
}
}
// TaskFinished tell the service that a task is finished.
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
_
*
int
)
error
{
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
select
{
select
{
case
<-
s
.
ready
:
case
<-
s
.
ready
:
}
}
...
@@ -401,11 +421,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
...
@@ -401,11 +421,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d finished."
,
taskID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d finished."
,
taskID
)
if
len
(
s
.
taskQueues
.
Todo
)
==
0
&&
len
(
s
.
taskQueues
.
Pending
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
&&
len
(
s
.
taskQueues
.
Todo
)
==
0
{
// increase master side pass count if all tasks finished
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo and pending task, start a new pass."
)
s
.
currPass
++
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Todo
=
s
.
jobTasks
s
.
taskQueues
.
Done
=
nil
s
.
taskQueues
.
Done
=
[]
taskEntry
{}
// TODO(typhoonzero): deal with failed tasks
s
.
taskQueues
.
Failed
=
[]
taskEntry
{}
log
.
WithFields
(
s
.
logFields
())
.
Warningf
(
"all task finished, add new pass data, newpass: %d."
,
s
.
currPass
)
}
}
err
:=
s
.
snapshot
()
err
:=
s
.
snapshot
()
...
@@ -416,7 +439,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
...
@@ -416,7 +439,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
}
}
// TaskFailed tells the service that a task is failed.
// TaskFailed tells the service that a task is failed.
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
_
*
int
)
error
{
func
(
s
*
Service
)
TaskFailed
(
meta
TaskMeta
,
dummy
*
int
)
error
{
select
{
select
{
case
<-
s
.
ready
:
case
<-
s
.
ready
:
}
}
...
...
go/master/service_internal_test.go
浏览文件 @
4cc42171
...
@@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
...
@@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs
:=
make
([]
Chunk
,
100
)
cs
:=
make
([]
Chunk
,
100
)
ts
:=
partition
(
cs
,
20
)
ts
:=
partition
(
cs
,
20
)
for
i
:=
range
ts
{
for
i
:=
range
ts
{
if
ts
[
i
]
.
Task
.
Meta
.
ID
!=
i
{
// test auto increament ids
if
i
>
0
&&
ts
[
i
]
.
Task
.
Meta
.
ID
!=
ts
[
i
-
1
]
.
Task
.
Meta
.
ID
+
1
{
t
.
Error
(
ts
[
i
],
i
)
t
.
Error
(
ts
[
i
],
i
)
}
}
}
}
...
...
go/master/service_test.go
0 → 100644
浏览文件 @
4cc42171
package
master_test
import
(
"os"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func
TestNewServiceWithEtcd
(
t
*
testing
.
T
)
{
// setup an embed etcd server
etcdDir
,
err
:=
ioutils
.
TempDir
(
""
,
""
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cfg
:=
embed
.
NewConfig
()
cfg
.
Dir
=
etcdDir
e
,
err
:=
embed
.
StartEtcd
(
cfg
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
defer
func
()
{
e
.
Close
()
if
err
:=
os
.
RemoveAll
(
etcdDir
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
}()
select
{
case
<-
e
.
Server
.
ReadyNotify
()
:
t
.
Log
(
"Server is ready!"
)
case
<-
time
.
After
(
60
*
time
.
Second
)
:
e
.
Server
.
Stop
()
// trigger a shutdown
t
.
Fatal
(
"Server took too long to start!"
)
}
ep
:=
[]
string
{
"127.0.0.1:2379"
}
masterAddr
:=
"127.0.0.1:3306"
store
,
err
:=
master
.
NewEtcdClient
(
ep
,
masterAddr
,
master
.
DefaultLockPath
,
master
.
DefaultAddrPath
,
master
.
DefaultStatePath
,
30
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
_
,
err
=
master
.
NewService
(
store
,
10
,
10
,
3
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
ep
,
DialTimeout
:
3
*
time
.
Second
,
})
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
v
,
err
:=
master
.
GetKey
(
cli
,
master
.
DefaultAddrPath
,
3
*
time
.
Second
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
err
:=
cli
.
Close
();
err
!=
nil
{
t
.
Fatal
(
err
)
}
// test master process registry itself into etcd server.
assert
.
Equal
(
t
,
masterAddr
,
v
,
"master process should registry itself into etcd server."
)
}
go/pserver/client/c/cclient.go
浏览文件 @
4cc42171
...
@@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client
...
@@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client
func
add
(
c
*
client
.
Client
)
C
.
paddle_pserver_client
{
func
add
(
c
*
client
.
Client
)
C
.
paddle_pserver_client
{
mu
.
Lock
()
mu
.
Lock
()
defer
mu
.
Unlock
()
defer
mu
.
Unlock
()
cli
ent
:=
curHandle
cli
:=
curHandle
curHandle
++
curHandle
++
handleMap
[
cli
ent
]
=
c
handleMap
[
cli
]
=
c
return
cli
ent
return
cli
}
}
func
get
(
client
C
.
paddle_pserver_client
)
*
client
.
Client
{
func
get
(
client
C
.
paddle_pserver_client
)
*
client
.
Client
{
...
...
go/pserver/client/c/test/test_train.py
浏览文件 @
4cc42171
...
@@ -6,16 +6,19 @@ import cPickle as pickle
...
@@ -6,16 +6,19 @@ import cPickle as pickle
etcd_ip
=
os
.
getenv
(
"MASTER_IP"
,
"127.0.0.1"
)
etcd_ip
=
os
.
getenv
(
"MASTER_IP"
,
"127.0.0.1"
)
etcd_endpoint
=
"http://"
+
etcd_ip
+
":2379"
etcd_endpoint
=
"http://"
+
etcd_ip
+
":2379"
print
"connecting to master, etcd endpoints: "
,
etcd_endpoint
master_client
=
master
.
client
(
etcd_endpoint
,
5
,
64
)
def
cloud_reader
():
def
cloud_reader
():
print
"connecting to master, etcd endpoints: "
,
etcd_endpoint
global
master_client
master_client
=
master
.
client
(
etcd_endpoint
,
5
,
64
)
master_client
.
set_dataset
(
master_client
.
set_dataset
(
[
"/pfs/dlnel/public/dataset/uci_housing/uci_housing-*
-of-*"
]
)
[
"/pfs/dlnel/public/dataset/uci_housing/uci_housing-*
"
],
passes
=
30
)
while
1
:
while
1
:
r
,
e
=
master_client
.
next_record
()
r
,
e
=
master_client
.
next_record
()
if
not
r
:
if
not
r
:
if
e
!=
-
2
:
# other errors
print
"get record error:"
,
e
break
break
yield
pickle
.
loads
(
r
)
yield
pickle
.
loads
(
r
)
...
@@ -27,10 +30,12 @@ def main():
...
@@ -27,10 +30,12 @@ def main():
# network config
# network config
x
=
paddle
.
layer
.
data
(
name
=
'x'
,
type
=
paddle
.
data_type
.
dense_vector
(
13
))
x
=
paddle
.
layer
.
data
(
name
=
'x'
,
type
=
paddle
.
data_type
.
dense_vector
(
13
))
y_predict
=
paddle
.
layer
.
fc
(
input
=
x
,
y_predict
=
paddle
.
layer
.
fc
(
input
=
x
,
param_attr
=
paddle
.
attr
.
Param
(
name
=
'w'
),
param_attr
=
paddle
.
attr
.
Param
(
name
=
'w'
,
learning_rate
=
1e-3
),
size
=
1
,
size
=
1
,
act
=
paddle
.
activation
.
Linear
(),
act
=
paddle
.
activation
.
Linear
(),
bias_attr
=
paddle
.
attr
.
Param
(
name
=
'b'
))
bias_attr
=
paddle
.
attr
.
Param
(
name
=
'b'
,
learning_rate
=
1e-3
))
y
=
paddle
.
layer
.
data
(
name
=
'y'
,
type
=
paddle
.
data_type
.
dense_vector
(
1
))
y
=
paddle
.
layer
.
data
(
name
=
'y'
,
type
=
paddle
.
data_type
.
dense_vector
(
1
))
cost
=
paddle
.
layer
.
mse_cost
(
input
=
y_predict
,
label
=
y
)
cost
=
paddle
.
layer
.
mse_cost
(
input
=
y_predict
,
label
=
y
)
...
@@ -38,9 +43,8 @@ def main():
...
@@ -38,9 +43,8 @@ def main():
parameters
=
paddle
.
parameters
.
create
(
cost
)
parameters
=
paddle
.
parameters
.
create
(
cost
)
# create optimizer of new remote updater to pserver
# create optimizer of new remote updater to pserver
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
,
learning_rate
=
1e-3
)
print
"etcd endoint: "
,
etcd_endpoint
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
parameters
=
parameters
,
update_equation
=
optimizer
,
update_equation
=
optimizer
,
...
@@ -51,6 +55,8 @@ def main():
...
@@ -51,6 +55,8 @@ def main():
# event_handler to print training and testing info
# event_handler to print training and testing info
def
event_handler
(
event
):
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if
event
.
batch_id
%
100
==
0
:
if
event
.
batch_id
%
100
==
0
:
print
"Pass %d, Batch %d, Cost %f"
%
(
print
"Pass %d, Batch %d, Cost %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
...
...
go/pserver/etcd_client.go
浏览文件 @
4cc42171
...
@@ -34,16 +34,19 @@ const (
...
@@ -34,16 +34,19 @@ const (
PsPath
=
"/ps/"
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
PsCheckpoint
=
"/checkpoints/"
retryTimeout
=
5
*
time
.
Second
)
)
// EtcdClient is the etcd client that the pserver uses for fault
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
// tolerance, service registry and coordination.
type
EtcdClient
struct
{
type
EtcdClient
struct
{
numPservers
int
numPservers
int
etcdEndpoints
string
endpoints
string
etcdClient
*
clientv3
.
Client
client
*
clientv3
.
Client
// etcdTimeout is also used as retry intervals.
sess
*
concurrency
.
Session
etcdTimeout
time
.
Duration
dialTimeout
time
.
Duration
ttlSec
int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP
string
externalIP
string
// desired number of pservers in the job.
// desired number of pservers in the job.
...
@@ -52,11 +55,12 @@ type EtcdClient struct {
...
@@ -52,11 +55,12 @@ type EtcdClient struct {
}
}
// NewEtcdClient creates an EtcdClient
// NewEtcdClient creates an EtcdClient
func
NewEtcdClient
(
endpoints
string
,
numPservers
int
,
timeout
time
.
Duration
)
*
EtcdClient
{
func
NewEtcdClient
(
endpoints
string
,
numPservers
int
,
dialtimeout
time
.
Duration
,
ttlSec
int
)
*
EtcdClient
{
return
&
EtcdClient
{
return
&
EtcdClient
{
etcdTimeout
:
timeout
,
dialTimeout
:
dialtimeout
,
ttlSec
:
ttlSec
,
numPservers
:
numPservers
,
numPservers
:
numPservers
,
e
tcdEndpoints
:
endpoints
,
e
ndpoints
:
endpoints
,
}
}
}
}
...
@@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
...
@@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
//
//
// Register returns the index of the current pserver.
// Register returns the index of the current pserver.
func
(
e
*
EtcdClient
)
Register
(
port
int
)
(
int
,
error
)
{
func
(
e
*
EtcdClient
)
Register
(
port
int
)
(
int
,
error
)
{
var
err
error
var
err
error
e
.
externalIP
,
err
=
networkhelper
.
GetExternalIP
()
e
.
externalIP
,
err
=
networkhelper
.
GetExternalIP
()
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) {
...
@@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) {
}
}
// initialize connection to etcd.
// initialize connection to etcd.
ep
:=
strings
.
Split
(
e
.
e
tcdE
ndpoints
,
","
)
ep
:=
strings
.
Split
(
e
.
endpoints
,
","
)
for
{
for
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
ep
,
Endpoints
:
ep
,
DialTimeout
:
e
.
etcd
Timeout
,
DialTimeout
:
e
.
dial
Timeout
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorf
(
"connect to etcd error: %v"
,
err
)
log
.
Errorf
(
"connect to etcd error: %v"
,
err
)
time
.
Sleep
(
e
.
etcdTimeout
)
time
.
Sleep
(
retryTimeout
)
continue
}
e
.
client
=
cli
sess
,
err
:=
concurrency
.
NewSession
(
cli
,
concurrency
.
WithTTL
(
e
.
ttlSec
))
if
err
!=
nil
{
log
.
Errorf
(
"create etcd session error: %v"
,
err
)
time
.
Sleep
(
retryTimeout
)
continue
continue
}
}
e
.
etcdClient
=
cli
e
.
sess
=
sess
log
.
Debugf
(
"inited client to %s"
,
e
.
e
tcdE
ndpoints
)
log
.
Debugf
(
"inited client to %s"
,
e
.
endpoints
)
break
break
}
}
// init /ps_desired using transaction, for multiple pservers may want to write
// init /ps_desired using transaction, for multiple pservers may want to write
...
@@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
...
@@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Warn
(
err
)
log
.
Warn
(
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
continue
continue
}
}
break
break
...
@@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) {
...
@@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) {
// wait and set s.desired init value
// wait and set s.desired init value
for
{
for
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
resp
,
err
:=
e
.
etcdC
lient
.
Get
(
ctx
,
PsDesired
)
resp
,
err
:=
e
.
c
lient
.
Get
(
ctx
,
PsDesired
)
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorf
(
"getting %s error: %v"
,
PsDesired
,
err
)
log
.
Errorf
(
"getting %s error: %v"
,
PsDesired
,
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
continue
continue
}
}
if
len
(
resp
.
Kvs
)
!=
0
{
if
len
(
resp
.
Kvs
)
!=
0
{
e
.
desired
,
err
=
strconv
.
Atoi
(
string
(
resp
.
Kvs
[
0
]
.
Value
))
e
.
desired
,
err
=
strconv
.
Atoi
(
string
(
resp
.
Kvs
[
0
]
.
Value
))
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorf
(
"value of %s invalid %v
\n
"
,
PsDesired
,
err
)
log
.
Errorf
(
"value of %s invalid %v
\n
"
,
PsDesired
,
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
// NOTE: wait util ps_desired value change
// NOTE: wait util ps_desired value change
continue
continue
}
}
...
@@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
...
@@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Warn
(
err
)
log
.
Warn
(
err
)
time
.
Sleep
(
e
.
etcd
Timeout
)
time
.
Sleep
(
retry
Timeout
)
continue
continue
}
}
break
break
...
@@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) {
...
@@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) {
}
}
func
(
e
*
EtcdClient
)
initDesiredPservers
(
ctx
context
.
Context
,
numPservers
int
)
(
*
clientv3
.
TxnResponse
,
error
)
{
func
(
e
*
EtcdClient
)
initDesiredPservers
(
ctx
context
.
Context
,
numPservers
int
)
(
*
clientv3
.
TxnResponse
,
error
)
{
return
concurrency
.
NewSTM
(
e
.
etcdC
lient
,
func
(
c
concurrency
.
STM
)
error
{
return
concurrency
.
NewSTM
(
e
.
c
lient
,
func
(
c
concurrency
.
STM
)
error
{
dsStr
:=
c
.
Get
(
PsDesired
)
dsStr
:=
c
.
Get
(
PsDesired
)
if
dsStr
==
""
{
if
dsStr
==
""
{
c
.
Put
(
PsDesired
,
strconv
.
Itoa
(
numPservers
))
c
.
Put
(
PsDesired
,
strconv
.
Itoa
(
numPservers
)
,
clientv3
.
WithLease
(
e
.
sess
.
Lease
())
)
}
}
return
nil
return
nil
},
concurrency
.
WithAbortContext
(
ctx
),
concurrency
.
WithIsolation
(
concurrency
.
RepeatableReads
))
},
concurrency
.
WithAbortContext
(
ctx
),
concurrency
.
WithIsolation
(
concurrency
.
RepeatableReads
))
...
@@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
...
@@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
// registerPserverEtcd registers pserver node on etcd using transaction.
// registerPserverEtcd registers pserver node on etcd using transaction.
func
(
e
*
EtcdClient
)
registerPserverEtcd
(
ctx
context
.
Context
,
port
int
)
(
int
,
error
)
{
func
(
e
*
EtcdClient
)
registerPserverEtcd
(
ctx
context
.
Context
,
port
int
)
(
int
,
error
)
{
var
idx
int
var
idx
int
_
,
err
:=
concurrency
.
NewSTM
(
e
.
etcdC
lient
,
func
(
c
concurrency
.
STM
)
error
{
_
,
err
:=
concurrency
.
NewSTM
(
e
.
c
lient
,
func
(
c
concurrency
.
STM
)
error
{
registered
:=
false
registered
:=
false
for
i
:=
0
;
i
<
e
.
desired
;
i
++
{
for
i
:=
0
;
i
<
e
.
desired
;
i
++
{
psKey
:=
PsPath
+
strconv
.
Itoa
(
i
)
psKey
:=
PsPath
+
strconv
.
Itoa
(
i
)
...
@@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
...
@@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
log
.
Debugf
(
"got value (%s) for key: %s"
,
ps
,
psKey
)
log
.
Debugf
(
"got value (%s) for key: %s"
,
ps
,
psKey
)
if
ps
==
""
{
if
ps
==
""
{
resp
,
err
:=
e
.
etcdClient
.
Grant
(
context
.
TODO
(),
5
)
if
err
!=
nil
{
log
.
Fatal
(
err
)
}
// find the first id and write info
// find the first id and write info
pserverAddr
:=
e
.
externalIP
+
":"
+
strconv
.
Itoa
(
port
)
pserverAddr
:=
e
.
externalIP
+
":"
+
strconv
.
Itoa
(
port
)
c
.
Put
(
psKey
,
pserverAddr
,
clientv3
.
WithLease
(
resp
.
ID
))
c
.
Put
(
psKey
,
pserverAddr
,
clientv3
.
WithLease
(
e
.
sess
.
Lease
()
))
log
.
Debugf
(
"set pserver node %s with value %s"
,
psKey
,
pserverAddr
)
log
.
Debugf
(
"set pserver node %s with value %s"
,
psKey
,
pserverAddr
)
ch
,
kaerr
:=
e
.
etcdClient
.
KeepAlive
(
context
.
TODO
(),
resp
.
ID
)
if
kaerr
!=
nil
{
log
.
Errorf
(
"keepalive etcd node error: %v"
,
kaerr
)
return
kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go
func
(
ch
<-
chan
*
clientv3
.
LeaseKeepAliveResponse
)
{
ka
:=
<-
ch
log
.
Debugf
(
"keepalive: %d
\n
"
,
ka
.
TTL
)
}(
ch
)
log
.
Debug
(
"register finished"
)
log
.
Debug
(
"register finished"
)
idx
=
i
idx
=
i
registered
=
true
registered
=
true
...
@@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
...
@@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
// GetKey gets the value by the specified key
// GetKey gets the value by the specified key
func
(
e
*
EtcdClient
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
func
(
e
*
EtcdClient
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
resp
,
err
:=
e
.
etcdC
lient
.
Get
(
ctx
,
key
)
resp
,
err
:=
e
.
c
lient
.
Get
(
ctx
,
key
)
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
return
[]
byte
{},
err
return
[]
byte
{},
err
...
@@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
...
@@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
// PutKey put into etcd with value by key specified
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
_
,
err
:=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
),
clientv3
.
WithLease
(
e
.
sess
.
Lease
()
))
cancel
()
cancel
()
return
err
return
err
}
}
// Shutdown shuts down the etcd client gracefully.
func
(
e
*
EtcdClient
)
Shutdown
()
error
{
var
err
error
if
e
.
sess
!=
nil
{
err
=
e
.
sess
.
Close
()
}
if
e
.
client
!=
nil
{
newErr
:=
e
.
client
.
Close
()
if
newErr
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
newErr
)
}
else
{
err
=
newErr
}
}
}
return
err
}
paddle/api/Evaluator.cpp
浏览文件 @
4cc42171
...
@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
...
@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double
Evaluator
::
getValue
(
const
std
::
string
name
)
const
{
double
Evaluator
::
getValue
(
const
std
::
string
name
)
const
{
paddle
::
Error
err
;
paddle
::
Error
err
;
double
v
=
m
->
rawPtr
->
getValue
(
name
,
&
err
);
double
v
=
m
->
rawPtr
->
getValue
(
name
,
&
err
);
if
(
err
)
{
if
(
!
err
.
isOK
()
)
{
throw
std
::
runtime_error
(
err
.
msg
());
throw
std
::
runtime_error
(
err
.
msg
());
}
}
return
v
;
return
v
;
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
4cc42171
...
@@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
...
@@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place paddle_memory
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place paddle_memory
device_context
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
...
@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
...
@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
cc_library
(
net SRCS net.cc DEPS op_registry
)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op
)
paddle/framework/detail/tensor-inl.h
0 → 100644
浏览文件 @
4cc42171
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
inline
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"Tenosr holds no memory. Call Tensor::mutable_data first."
);
PADDLE_ENFORCE
(
holder_
->
size
()
>=
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."
);
}
template
<
typename
T
>
inline
const
T
*
Tensor
::
data
()
const
{
check_memory_size
<
T
>
();
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
data
()
{
check_memory_size
<
T
>
();
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
mutable_data
(
DDim
dims
,
platform
::
Place
place
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
Resize
(
dims
);
return
mutable_data
<
T
>
(
place
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
mutable_data
(
platform
::
Place
place
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
PADDLE_ENFORCE
(
product
(
dims_
)
>
0
,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."
);
/* some versions of boost::variant don't have operator!= */
size_t
size
=
product
(
dims_
)
*
sizeof
(
T
);
if
(
holder_
==
nullptr
||
!
(
holder_
->
place
()
==
place
)
||
holder_
->
size
()
<
size
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
));
}
#ifndef PADDLE_ONLY_CPU
else
if
(
platform
::
is_gpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
GPUPlace
>
(
boost
::
get
<
platform
::
GPUPlace
>
(
place
),
size
));
}
#endif
offset_
=
0
;
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
inline
void
Tensor
::
ShareDataWith
(
const
Tensor
&
src
)
{
src
.
check_memory_size
<
T
>
();
*
this
=
src
;
}
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
);
}
#ifndef PADDLE_ONLY_CPU
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
#endif
}
template
<
typename
T
>
inline
Tensor
Tensor
::
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
check_memory_size
<
T
>
();
PADDLE_ENFORCE
(
begin_idx
>=
0
,
"Slice begin index is less than zero."
);
PADDLE_ENFORCE
(
end_idx
<=
dims_
[
0
],
"Slice end index is out of bound."
);
PADDLE_ENFORCE
(
begin_idx
<
end_idx
,
"Begin index must be less than end index."
);
PADDLE_ENFORCE
(
dims_
[
0
]
!=
1
,
"Can not slice a tensor with dims_[0] = 1."
);
int
base
=
product
(
dims_
)
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
sizeof
(
T
);
return
dst
;
}
inline
void
Tensor
::
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
}
inline
const
DDim
&
Tensor
::
dims
()
const
{
return
dims_
;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/net.cc
浏览文件 @
4cc42171
...
@@ -20,17 +20,7 @@
...
@@ -20,17 +20,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
)
{
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
auto
grad_ops
=
std
::
make_shared
<
PlainNet
>
();
for
(
auto
&
op
:
ForwardOps
->
ops_
)
{
auto
op_grad
=
OpRegistry
::
CreateGradOp
(
op
);
grad_ops
->
AddOp
(
op_grad
);
}
grad_ops
->
CompleteAddOp
();
return
grad_ops
;
}
void
PlainNet
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
add_op_done_
=
true
;
if
(
!
calc
)
return
;
if
(
!
calc
)
return
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
input_set
;
...
@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
...
@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_
[
"temporary_index"
]
=
tmp_index
;
attrs_
[
"temporary_index"
]
=
tmp_index
;
}
}
std
::
string
PlainNet
::
DebugString
()
const
{
std
::
string
NetOp
::
DebugString
()
const
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
OperatorBase
::
DebugString
()
<<
std
::
endl
;
os
<<
OperatorBase
::
DebugString
()
<<
std
::
endl
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
...
@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
...
@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return
os
.
str
();
return
os
.
str
();
}
}
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net.h
浏览文件 @
4cc42171
...
@@ -37,21 +37,7 @@ namespace framework {
...
@@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* This is the base class of network, all the networks should implement the APIs
* it defines.
* it defines.
*/
*/
class
Net
:
public
OperatorBase
{
class
NetOp
:
public
OperatorBase
{
public:
virtual
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
=
0
;
virtual
void
CompleteAddOp
(
bool
calc
)
=
0
;
};
using
NetPtr
=
std
::
shared_ptr
<
Net
>
;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class
PlainNet
:
public
Net
{
public:
public:
/**
/**
* Infer all the operators' input and output variables' shapes, will be called
* Infer all the operators' input and output variables' shapes, will be called
...
@@ -80,15 +66,17 @@ class PlainNet : public Net {
...
@@ -80,15 +66,17 @@ class PlainNet : public Net {
/**
/**
* @brief Add an operator by ptr
* @brief Add an operator by ptr
*/
*/
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
override
{
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
ops_
.
push_back
(
op
);
ops_
.
push_back
(
op
);
}
}
void
CompleteAddOp
(
bool
calculate
=
true
)
override
;
void
CompleteAddOp
(
bool
calculate
=
true
);
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
bool
IsNetOp
()
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
private:
...
@@ -100,7 +88,5 @@ class PlainNet : public Net {
...
@@ -100,7 +88,5 @@ class PlainNet : public Net {
}
}
};
};
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net_op_test.cc
浏览文件 @
4cc42171
...
@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
...
@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
}
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
make_shared
<
TestOp
>
();
...
@@ -69,30 +69,23 @@ TEST(OpKernel, all) {
...
@@ -69,30 +69,23 @@ TEST(OpKernel, all) {
net
->
Run
(
scope
,
dev_ctx
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
std
::
runtime_error
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
AddBackwardOp
,
TestGradOp
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
""
},
{}));
auto
grad_ops
=
AddBackwardOp
(
net
);
for
(
auto
&
op
:
grad_ops
->
ops_
)
{
op
->
DebugString
();
}
}
}
//
TODO(zhihong): add fc grad without registering
.
//
! TODO(yuyang18): Refine Backward Op
.
// TEST(AddBackwardOp, Test
No
GradOp) {
// TEST(AddBackwardOp, TestGradOp) {
//
auto net = std::make_shared<PlainNet
>();
//
auto net = std::make_shared<NetOp
>();
// ASSERT_NE(net, nullptr);
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// op->DebugString();
// }
// }
//
}
//}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net_proto.proto
已删除
100644 → 0
浏览文件 @
4ecf68e0
syntax
=
"proto2"
;
package
paddle
.
framework
;
import
"op_proto.proto"
;
message
NetDesc
{
// network identification
optional
string
name
=
1
;
// operator contains in network
repeated
OpProto
operators
=
2
;
// network type to run with. e.g "plainNet", "DAG"
optional
string
net_type
=
3
;
// num worker always
optional
int32
num_workers
=
4
;
}
paddle/framework/op_registry_test.cc
浏览文件 @
4cc42171
...
@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
...
@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool
caught
=
false
;
bool
caught
=
false
;
try
{
try
{
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
"larger_than check fail"
;
std
::
string
msg
=
"larger_than check fail"
;
const
char
*
err_msg
=
err
.
what
();
const
char
*
err_msg
=
err
.
what
();
...
@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool
caught
=
false
;
bool
caught
=
false
;
try
{
try
{
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
"Attribute 'test_attr' is required!"
;
std
::
string
msg
=
"Attribute 'test_attr' is required!"
;
const
char
*
err_msg
=
err
.
what
();
const
char
*
err_msg
=
err
.
what
();
...
@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught
=
false
;
caught
=
false
;
try
{
try
{
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
"'test_attr' must be even!"
;
std
::
string
msg
=
"'test_attr' must be even!"
;
const
char
*
err_msg
=
err
.
what
();
const
char
*
err_msg
=
err
.
what
();
...
@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
...
@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd
::
OpProto
op_proto
;
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestAttrProtoMaker
(
&
op_proto
,
&
op_checker
);
auto
proto_maker
=
TestAttrProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
std
::
runtime_error
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
platform
::
EnforceNotMet
);
}
}
class
TestInOutProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
class
TestInOutProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
...
@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
...
@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd
::
OpProto
op_proto
;
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestInOutProtoMaker
(
&
op_proto
,
&
op_checker
);
auto
proto_maker
=
TestInOutProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
std
::
runtime_error
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
platform
::
EnforceNotMet
);
}
}
paddle/framework/operator.cc
浏览文件 @
4cc42171
...
@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
...
@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif
#endif
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"Input Output Indices could not be nullptr"
);
auto
it
=
in_out_idxs_
->
find
(
name
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
[
it
->
second
]
;
return
inputs_
.
at
((
size_t
)
it
->
second
)
;
}
else
{
}
else
{
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
int
idx
=
input_format
[
it
->
second
];
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
(
idx
);
return
inputs_
.
at
(
(
size_t
)
idx
);
}
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"IO Idx could not be nullptr"
);
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
input_format
.
at
((
size_t
)
offset
+
1
)
<=
inputs_
.
size
(),
"Input Out Of Range"
);
return
std
::
vector
<
std
::
string
>
{
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
...
@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
...
@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
}
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
auto
it
=
in_out_idxs_
->
find
(
name
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
outputs_
[
it
->
second
]
;
return
outputs_
.
at
((
size_t
)
it
->
second
)
;
}
else
{
}
else
{
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
int
idx
=
output_format
[
it
->
second
];
int
idx
=
output_format
[
it
->
second
];
return
outputs_
.
at
(
idx
);
return
outputs_
.
at
(
(
size_t
)
idx
);
}
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
output_format
.
at
((
size_t
)
offset
+
1
)
<=
outputs_
.
size
(),
"Output Out of Range"
);
return
std
::
vector
<
std
::
string
>
{
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
...
...
paddle/framework/operator.h
浏览文件 @
4cc42171
...
@@ -90,15 +90,17 @@ class OperatorBase {
...
@@ -90,15 +90,17 @@ class OperatorBase {
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
virtual
bool
IsNetOp
()
const
{
return
false
;
}
//! Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
// Get a input which has multiple variables.
//
!
Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//
!
TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
// Get a output with argument's name described in `op_proto`
//
!
Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
// Get an output which has multiple variables.
//
!
Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//
!
TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
public:
public:
...
@@ -199,8 +201,6 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -199,8 +201,6 @@ class OperatorWithKernel : public OperatorBase {
place_
=
dev_ctx
.
GetPlace
();
place_
=
dev_ctx
.
GetPlace
();
}
}
// bool operator==(const OpKernelKey& o) const { return place_ == o.place_;
// }
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
);
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
);
}
}
...
...
paddle/framework/scope.h
浏览文件 @
4cc42171
...
@@ -56,7 +56,9 @@ class Scope {
...
@@ -56,7 +56,9 @@ class Scope {
if
(
var
)
{
if
(
var
)
{
return
var
;
return
var
;
}
else
{
}
else
{
vars_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
new
Variable
());
auto
ptr
=
new
Variable
();
name_to_var_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
ptr
);
var_to_name_
[
ptr
]
=
name
;
return
GetVariable
(
name
);
return
GetVariable
(
name
);
}
}
}
}
...
@@ -68,8 +70,8 @@ class Scope {
...
@@ -68,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found.
* from it's parent scope. Return nullptr if not found.
*/
*/
Variable
*
GetVariable
(
const
std
::
string
&
name
)
const
{
Variable
*
GetVariable
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars
_
.
find
(
name
);
auto
it
=
name_to_var
_
.
find
(
name
);
if
(
it
!=
vars
_
.
end
())
{
if
(
it
!=
name_to_var
_
.
end
())
{
return
it
->
second
.
get
();
return
it
->
second
.
get
();
}
else
if
(
parent_
!=
nullptr
)
{
}
else
if
(
parent_
!=
nullptr
)
{
return
parent_
->
GetVariable
(
name
);
return
parent_
->
GetVariable
(
name
);
...
@@ -84,12 +86,21 @@ class Scope {
...
@@ -84,12 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope
* Find if there is a Variable in this scope and it's parent scope
*/
*/
bool
HasVariable
(
const
std
::
string
&
name
)
const
{
bool
HasVariable
(
const
std
::
string
&
name
)
const
{
return
(
vars_
.
find
(
name
)
!=
vars
_
.
end
()
||
return
(
name_to_var_
.
find
(
name
)
!=
name_to_var
_
.
end
()
||
(
parent_
&&
parent_
->
HasVariable
(
name
)));
(
parent_
&&
parent_
->
HasVariable
(
name
)));
}
}
std
::
string
GetVariableName
(
Variable
*
const
var
)
const
{
try
{
return
var_to_name_
.
at
(
var
);
}
catch
(...)
{
return
""
;
}
}
private:
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars_
;
std
::
unordered_map
<
Variable
*
,
std
::
string
>
var_to_name_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
name_to_var_
;
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
};
};
...
...
paddle/framework/scope_test.cc
浏览文件 @
4cc42171
...
@@ -40,6 +40,11 @@ TEST(Scope, Create) {
...
@@ -40,6 +40,11 @@ TEST(Scope, Create) {
/// already exist.
/// already exist.
Variable
*
var4
=
scope
->
CreateVariable
(
"a"
);
Variable
*
var4
=
scope
->
CreateVariable
(
"a"
);
EXPECT_EQ
(
var4
,
var2
);
EXPECT_EQ
(
var4
,
var2
);
EXPECT_EQ
(
"a"
,
scope
->
GetVariableName
(
var4
));
Scope
scope2
;
auto
var
=
scope2
.
CreateVariable
(
"tmp"
);
EXPECT_EQ
(
""
,
scope
->
GetVariableName
(
var
));
}
}
TEST
(
Scope
,
Parent
)
{
TEST
(
Scope
,
Parent
)
{
...
...
paddle/framework/tensor.cc
浏览文件 @
4cc42171
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include
<paddle/framework/tensor.h>
#include
"paddle/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{}
namespace
framework
{}
...
...
paddle/framework/tensor.h
浏览文件 @
4cc42171
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/framework/ddim.h"
#include "paddle/framework/ddim.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "unsupported/Eigen/CXX11/Tensor"
...
@@ -32,9 +33,11 @@ template <bool less, size_t i, typename... args>
...
@@ -32,9 +33,11 @@ template <bool less, size_t i, typename... args>
struct
CastToPyBufferImpl
;
struct
CastToPyBufferImpl
;
}
// namespace details
}
// namespace details
}
// namespace pybind
}
// namespace pybind
namespace
framework
{
namespace
framework
{
class
Tensor
{
class
Tensor
{
public:
template
<
bool
less
,
size_t
i
,
typename
...
args
>
template
<
bool
less
,
size_t
i
,
typename
...
args
>
friend
struct
paddle
::
pybind
::
details
::
CastToPyBufferImpl
;
friend
struct
paddle
::
pybind
::
details
::
CastToPyBufferImpl
;
...
@@ -47,151 +50,123 @@ class Tensor {
...
@@ -47,151 +50,123 @@ class Tensor {
public:
public:
Tensor
()
:
offset_
(
0
)
{}
Tensor
()
:
offset_
(
0
)
{}
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
template
<
typename
T
>
const
T
*
data
()
const
{
inline
T
*
data
();
EnforceSufficientMemory
<
T
>
();
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
/*! Return a pointer to constant memory block. */
template
<
typename
T
>
template
<
typename
T
>
T
*
data
()
{
inline
const
T
*
data
()
const
;
EnforceSufficientMemory
<
T
>
();
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
,
// must be POD types
/**
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
* @brief Return a pointer to mutable memory block.
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
)
{
* @note If not exist, then allocation.
Resize
(
dims
);
*/
return
mutable_data
<
T
>
(
place
);
template
<
typename
T
>
}
inline
T
*
mutable_data
(
platform
::
Place
place
);
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
*
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
);
template
<
typename
T
,
// must be POD types
/*! Return the dimensions of the memory block. */
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
inline
const
DDim
&
dims
()
const
;
T
*
mutable_data
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
product
(
dims_
)
>
0
,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."
);
if
(
holder_
==
nullptr
||
!
(
holder_
->
place
()
==
place
)
/* some versions of boost::variant don't have operator!= */
||
holder_
->
size
()
<
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
product
(
dims_
)
*
sizeof
(
T
)));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW
(
"'GPUPlace' is not supported in CPU only device."
);
#else
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
GPUPlace
>
(
boost
::
get
<
platform
::
GPUPlace
>
(
place
),
product
(
dims_
)
*
sizeof
(
T
)));
#endif
}
else
{
PADDLE_THROW
(
"Unknown 'place'."
);
}
offset_
=
0
;
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
template
<
typename
T
>
/*! Resize the dimensions of the memory block. */
void
ShareDataWith
(
const
Tensor
&
src
)
{
inline
void
Resize
(
const
DDim
&
dims
);
src
.
EnforceSufficientMemory
<
T
>
();
*
this
=
src
;
}
/*! The internal of two tensors share the same memory block. */
template
<
typename
T
>
template
<
typename
T
>
void
CopyFrom
(
const
Tensor
&
src
,
platform
::
Place
dst_place
)
{
inline
void
ShareDataWith
(
const
Tensor
&
src
);
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
dst_place
),
"Tensor::CopyFrom only support dst CPU now."
);
/**
size_t
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
* @brief Copy the content of external tensor to a new place.
Resize
(
src
.
dims
());
*
const
void
*
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
* @param[in] src The external tensor.
void
*
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
* @param[in] ctx The device context contains place where to store.
if
(
paddle
::
platform
::
is_cpu_place
(
holder_
->
place
()))
{
*
std
::
memcpy
(
dst_ptr
,
src_ptr
,
size
);
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
}
else
if
(
paddle
::
platform
::
is_gpu_place
(
holder_
->
place
()))
{
*/
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW
(
"'GPUPlace' is not supported in CPU only device."
);
#else
platform
::
GpuMemcpySync
(
dst_ptr
,
src_ptr
,
size
,
cudaMemcpyDeviceToHost
);
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
);
EnforceSufficientMemory
<
T
>
();
PADDLE_ENFORCE
(
begin_idx
>=
0
,
"Slice begin index is less than zero."
);
PADDLE_ENFORCE
(
end_idx
<=
dims_
[
0
],
"Slice end index is out of bound."
);
PADDLE_ENFORCE
(
begin_idx
<
end_idx
,
"Begin index must be less than end index."
);
PADDLE_ENFORCE
(
dims_
[
0
]
!=
1
,
"Can not slice a tensor with dims_[0] = 1."
);
int
base
=
product
(
dims_
)
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
sizeof
(
T
);
return
dst
;
}
void
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
}
/**
* @brief Return the slice of the tensor.
*
* @param[in] begin_idx The begin index of the slice.
* @param[in] end_idx The end index of the slice.
*/
inline
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
;
const
DDim
&
dims
()
const
{
return
dims_
;
}
private:
template
<
typename
T
>
inline
void
check_memory_size
()
const
;
paddle
::
platform
::
Place
place
()
const
{
return
holder_
->
place
();
}
paddle
::
platform
::
Place
place
()
const
{
return
holder_
->
place
();
}
private:
private:
// Placeholder hides type T, so it doesn't appear as a template
/**
// parameter of Variable.
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of Variable.
*/
struct
Placeholder
{
struct
Placeholder
{
virtual
~
Placeholder
()
{}
virtual
~
Placeholder
()
{}
virtual
void
*
ptr
()
const
=
0
;
virtual
void
*
ptr
()
const
=
0
;
virtual
platform
::
Place
place
()
const
=
0
;
virtual
size_t
size
()
const
=
0
;
virtual
size_t
size
()
const
=
0
;
virtual
std
::
type_index
type
()
const
=
0
;
virtual
std
::
type_index
type
()
const
=
0
;
virtual
platform
::
Place
place
()
const
=
0
;
};
};
template
<
typename
T
,
typename
Place
Type
>
template
<
typename
T
,
typename
Place
>
struct
PlaceholderImpl
:
public
Placeholder
{
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
Place
Type
place
,
size_t
size
)
PlaceholderImpl
(
Place
place
,
size_t
size
)
:
ptr_
(
static_cast
<
T
*>
(
memory
::
Alloc
(
place
,
size
)),
:
ptr_
(
static_cast
<
T
*>
(
memory
::
Alloc
(
place
,
size
)),
memory
::
PODDeleter
<
T
,
Place
Type
>
(
place
)),
memory
::
PODDeleter
<
T
,
Place
>
(
place
)),
place_
(
place
),
place_
(
place
),
size_
(
size
)
{}
size_
(
size
)
{
PADDLE_ENFORCE
(
ptr_
!=
nullptr
,
"Insufficient %s memory to allocation."
,
is_cpu_place
(
place_
)
?
"CPU"
:
"GPU"
);
}
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
size_t
size
()
const
{
return
size_
;
}
virtual
size_t
size
()
const
{
return
size_
;
}
virtual
paddle
::
platform
::
Place
place
()
const
{
return
place_
;
}
virtual
platform
::
Place
place
()
const
{
return
place_
;
}
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
std
::
type_index
type
()
const
{
return
std
::
type_index
(
typeid
(
T
));
}
virtual
std
::
type_index
type
()
const
{
return
std
::
type_index
(
typeid
(
T
));
}
std
::
unique_ptr
<
T
,
memory
::
PODDeleter
<
T
,
PlaceType
>>
ptr_
;
/*! the pointer of memory block. */
platform
::
Place
place_
;
// record the place of ptr_.
std
::
unique_ptr
<
T
,
memory
::
PODDeleter
<
T
,
Place
>>
ptr_
;
size_t
size_
;
// size of the memory block.
/*! the place of memory block. */
platform
::
Place
place_
;
/*! the size of memory block. */
size_t
size_
;
};
};
template
<
typename
T
>
/*! holds the memory block if allocated. */
inline
void
EnforceSufficientMemory
()
const
{
std
::
shared_ptr
<
Placeholder
>
holder_
;
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"Tenosr holds no memory. Call Tensor::mutable_data first."
);
PADDLE_ENFORCE
(
holder_
->
size
()
>=
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."
);
}
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holds the memory block if allocated.
/*! points to dimensions of memory block. */
DDim
dims_
;
DDim
dims_
;
// A PlaceHolder may be shared by more than one tensor. Some of them may be
// slices of the others. So the offset_ is introduced here to indicate the
/**
// byte offset between PlaceHolder::ptr_ and where tensor's data really
* @brief A PlaceHolder may be shared by more than one tensor.
// begins.
*
* @note Some of them may be slices of the others. So the offset_
* is introduced here to indicate the byte offset between
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t
offset_
;
size_t
offset_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#include "paddle/framework/detail/tensor-inl.h"
paddle/framework/tensor_test.cc
浏览文件 @
4cc42171
...
@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
...
@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool
caught
=
false
;
bool
caught
=
false
;
try
{
try
{
src_tensor
.
data
<
double
>
();
src_tensor
.
data
<
double
>
();
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
std
::
string
msg
=
"Tenosr holds no memory. Call Tensor::mutable_data first."
;
"Tenosr holds no memory. Call Tensor::mutable_data first."
;
...
@@ -72,7 +72,8 @@ TEST(Tensor, MutableData) {
...
@@ -72,7 +72,8 @@ TEST(Tensor, MutableData) {
p2
=
src_tensor
.
mutable_data
<
float
>
(
make_ddim
({
2
,
2
}),
CPUPlace
());
p2
=
src_tensor
.
mutable_data
<
float
>
(
make_ddim
({
2
,
2
}),
CPUPlace
());
EXPECT_EQ
(
p1
,
p2
);
EXPECT_EQ
(
p1
,
p2
);
}
}
#ifdef __CUDACC__
#ifndef PADDLE_ONLY_CPU
{
{
Tensor
src_tensor
;
Tensor
src_tensor
;
float
*
p1
=
nullptr
;
float
*
p1
=
nullptr
;
...
@@ -107,7 +108,7 @@ TEST(Tensor, ShareDataWith) {
...
@@ -107,7 +108,7 @@ TEST(Tensor, ShareDataWith) {
bool
caught
=
false
;
bool
caught
=
false
;
try
{
try
{
dst_tensor
.
ShareDataWith
<
float
>
(
src_tensor
);
dst_tensor
.
ShareDataWith
<
float
>
(
src_tensor
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
std
::
string
msg
=
"Tenosr holds no memory. Call Tensor::mutable_data first."
;
"Tenosr holds no memory. Call Tensor::mutable_data first."
;
...
@@ -123,7 +124,7 @@ TEST(Tensor, ShareDataWith) {
...
@@ -123,7 +124,7 @@ TEST(Tensor, ShareDataWith) {
ASSERT_EQ
(
src_tensor
.
data
<
int
>
(),
dst_tensor
.
data
<
int
>
());
ASSERT_EQ
(
src_tensor
.
data
<
int
>
(),
dst_tensor
.
data
<
int
>
());
}
}
#if
def __CUDACC__
#if
ndef PADDLE_ONLY_CPU
{
{
Tensor
src_tensor
;
Tensor
src_tensor
;
Tensor
dst_tensor
;
Tensor
dst_tensor
;
...
@@ -160,7 +161,7 @@ TEST(Tensor, Slice) {
...
@@ -160,7 +161,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ
(
src_data_address
+
3
*
4
*
1
*
sizeof
(
int
),
slice_data_address
);
EXPECT_EQ
(
src_data_address
+
3
*
4
*
1
*
sizeof
(
int
),
slice_data_address
);
}
}
#if
def __CUDACC__
#if
ndef PADDLE_ONLY_CPU
{
{
Tensor
src_tensor
;
Tensor
src_tensor
;
src_tensor
.
mutable_data
<
double
>
(
make_ddim
({
6
,
9
}),
GPUPlace
());
src_tensor
.
mutable_data
<
double
>
(
make_ddim
({
6
,
9
}),
GPUPlace
());
...
@@ -188,13 +189,53 @@ TEST(Tensor, Slice) {
...
@@ -188,13 +189,53 @@ TEST(Tensor, Slice) {
TEST
(
Tensor
,
CopyFrom
)
{
TEST
(
Tensor
,
CopyFrom
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
platform
;
{
Tensor
src_tensor
;
Tensor
src_tensor
;
Tensor
dst_tensor
;
int
*
src_ptr
=
src_tensor
.
mutable_data
<
int
>
(
make_ddim
({
3
,
3
}),
CPUPlace
());
int
*
src_ptr
=
src_tensor
.
mutable_data
<
int
>
(
make_ddim
({
3
,
3
}),
CPUPlace
());
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
cpu_place
);
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
cpu_place
);
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
dst_ptr
[
i
],
slice_ptr
[
i
]);
}
}
#ifndef PADDLE_ONLY_CPU
{
Tensor
src_tensor
;
Tensor
gpu_tensor
;
Tensor
dst_tensor
;
Tensor
dst_tensor
;
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
CPUPlace
());
int
*
src_ptr
=
src_tensor
.
mutable_data
<
int
>
(
make_ddim
({
3
,
3
}),
CPUPlace
());
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
// CPU Tensor to GPU Tensor
auto
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
gpu_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
gpu_place
);
// GPU Tensor to CPU Tensor
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_place
);
// Compare Tensors
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
...
@@ -202,11 +243,20 @@ TEST(Tensor, CopyFrom) {
...
@@ -202,11 +243,20 @@ TEST(Tensor, CopyFrom) {
}
}
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
CPUPlace
());
// CPU Slice Tensor to GPU Tensor
gpu_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
gpu_place
);
// GPU Tensor to CPU Tensor
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_place
);
// Compare Slice Tensors
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
dst_ptr
[
i
],
slice_ptr
[
i
]);
EXPECT_EQ
(
dst_ptr
[
i
],
slice_ptr
[
i
]);
}
}
}
#endif
}
}
paddle/gserver/activations/ActivationFunction.cpp
浏览文件 @
4cc42171
...
@@ -207,8 +207,8 @@ Error __must_check backward(Argument& act) {
...
@@ -207,8 +207,8 @@ Error __must_check backward(Argument& act) {
argument_
.
value
->
setData
(
act
.
value
->
getData
()
+
offset
,
1UL
,
size
);
argument_
.
value
->
setData
(
act
.
value
->
getData
()
+
offset
,
1UL
,
size
);
argument_
.
grad
->
setData
(
act
.
grad
->
getData
()
+
offset
,
1UL
,
size
);
argument_
.
grad
->
setData
(
act
.
grad
->
getData
()
+
offset
,
1UL
,
size
);
Error
status
=
softmax_
.
backward
(
argument_
);
Error
err
=
softmax_
.
backward
(
argument_
);
if
(
!
status
)
return
status
;
if
(
!
err
.
isOK
())
return
err
;
}
}
return
Error
();
return
Error
();
}
}
...
...
paddle/memory/detail/buddy_allocator.cc
浏览文件 @
4cc42171
...
@@ -27,12 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
...
@@ -27,12 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
system_allocator_
(
std
::
move
(
system_allocator
))
{}
system_allocator_
(
std
::
move
(
system_allocator
))
{}
BuddyAllocator
::~
BuddyAllocator
()
{
BuddyAllocator
::~
BuddyAllocator
()
{
DLOG
(
INFO
)
<<
"BuddyAllocator Disconstructor makes sure that all of these "
VLOG
(
3
)
<<
"BuddyAllocator Disconstructor makes sure that all of these "
"have actually been freed"
;
"have actually been freed"
;
while
(
!
pool_
.
empty
())
{
while
(
!
pool_
.
empty
())
{
auto
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
pool_
.
begin
()));
auto
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
pool_
.
begin
()));
DLOG
(
INFO
)
<<
"Free from block ("
<<
block
<<
", "
<<
max_chunk_size_
VLOG
(
3
)
<<
"Free from block ("
<<
block
<<
", "
<<
max_chunk_size_
<<
")"
;
<<
")"
;
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
cache_
.
invalidate
(
block
);
cache_
.
invalidate
(
block
);
...
@@ -52,12 +51,11 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
...
@@ -52,12 +51,11 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
// acquire the allocator lock
// acquire the allocator lock
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
DLOG
(
INFO
)
<<
"Allocate "
<<
unaligned_size
<<
" bytes from chunk size "
VLOG
(
3
)
<<
"Allocate "
<<
unaligned_size
<<
" bytes from chunk size "
<<
size
;
<<
size
;
// if the allocation is huge, send directly to the system allocator
// if the allocation is huge, send directly to the system allocator
if
(
size
>
max_chunk_size_
)
{
if
(
size
>
max_chunk_size_
)
{
DLOG
(
INFO
)
<<
"Allocate from system allocator."
;
VLOG
(
3
)
<<
"Allocate from system allocator."
;
return
SystemAlloc
(
size
);
return
SystemAlloc
(
size
);
}
}
...
@@ -72,7 +70,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
...
@@ -72,7 +70,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
return
nullptr
;
return
nullptr
;
}
}
}
else
{
}
else
{
DLOG
(
INFO
)
<<
"Allocation from existing memory block "
<<
std
::
get
<
2
>
(
*
it
)
VLOG
(
3
)
<<
"Allocation from existing memory block "
<<
std
::
get
<
2
>
(
*
it
)
<<
" at address "
<<
" at address "
<<
reinterpret_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
))
->
data
();
<<
reinterpret_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
))
->
data
();
}
}
...
@@ -91,10 +89,10 @@ void BuddyAllocator::Free(void* p) {
...
@@ -91,10 +89,10 @@ void BuddyAllocator::Free(void* p) {
// Acquire the allocator lock
// Acquire the allocator lock
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
DLOG
(
INFO
)
<<
"Free from address "
<<
block
;
VLOG
(
3
)
<<
"Free from address "
<<
block
;
if
(
block
->
type
(
cache_
)
==
MemoryBlock
::
HUGE_CHUNK
)
{
if
(
block
->
type
(
cache_
)
==
MemoryBlock
::
HUGE_CHUNK
)
{
DLOG
(
INFO
)
<<
"Free directly from system allocator"
;
VLOG
(
3
)
<<
"Free directly from system allocator"
;
system_allocator_
->
Free
(
block
,
block
->
total_size
(
cache_
),
system_allocator_
->
Free
(
block
,
block
->
total_size
(
cache_
),
block
->
index
(
cache_
));
block
->
index
(
cache_
));
...
@@ -111,7 +109,7 @@ void BuddyAllocator::Free(void* p) {
...
@@ -111,7 +109,7 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the right buddy
// Trying to merge the right buddy
if
(
block
->
has_right_buddy
(
cache_
))
{
if
(
block
->
has_right_buddy
(
cache_
))
{
DLOG
(
INFO
)
<<
"Merging this block "
<<
block
<<
" with its right buddy "
VLOG
(
3
)
<<
"Merging this block "
<<
block
<<
" with its right buddy "
<<
block
->
right_buddy
(
cache_
);
<<
block
->
right_buddy
(
cache_
);
auto
right_buddy
=
block
->
right_buddy
(
cache_
);
auto
right_buddy
=
block
->
right_buddy
(
cache_
);
...
@@ -129,7 +127,7 @@ void BuddyAllocator::Free(void* p) {
...
@@ -129,7 +127,7 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the left buddy
// Trying to merge the left buddy
if
(
block
->
has_left_buddy
(
cache_
))
{
if
(
block
->
has_left_buddy
(
cache_
))
{
DLOG
(
INFO
)
<<
"Merging this block "
<<
block
<<
" with its left buddy "
VLOG
(
3
)
<<
"Merging this block "
<<
block
<<
" with its left buddy "
<<
block
->
left_buddy
(
cache_
);
<<
block
->
left_buddy
(
cache_
);
auto
left_buddy
=
block
->
left_buddy
(
cache_
);
auto
left_buddy
=
block
->
left_buddy
(
cache_
);
...
@@ -146,7 +144,7 @@ void BuddyAllocator::Free(void* p) {
...
@@ -146,7 +144,7 @@ void BuddyAllocator::Free(void* p) {
}
}
// Dumping this block into pool
// Dumping this block into pool
DLOG
(
INFO
)
<<
"Inserting free block ("
<<
block
<<
", "
VLOG
(
3
)
<<
"Inserting free block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
")"
;
<<
block
->
total_size
(
cache_
)
<<
")"
;
pool_
.
insert
(
pool_
.
insert
(
IndexSizeAddress
(
block
->
index
(
cache_
),
block
->
total_size
(
cache_
),
block
));
IndexSizeAddress
(
block
->
index
(
cache_
),
block
->
total_size
(
cache_
),
block
));
...
@@ -166,7 +164,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
...
@@ -166,7 +164,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
size_t
index
=
0
;
size_t
index
=
0
;
void
*
p
=
system_allocator_
->
Alloc
(
index
,
size
);
void
*
p
=
system_allocator_
->
Alloc
(
index
,
size
);
DLOG
(
INFO
)
<<
"Allocated "
<<
p
<<
" from system allocator."
;
VLOG
(
3
)
<<
"Allocated "
<<
p
<<
" from system allocator."
;
if
(
p
==
nullptr
)
return
nullptr
;
if
(
p
==
nullptr
)
return
nullptr
;
...
@@ -192,7 +190,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
...
@@ -192,7 +190,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
if
(
p
==
nullptr
)
return
pool_
.
end
();
if
(
p
==
nullptr
)
return
pool_
.
end
();
DLOG
(
INFO
)
<<
"Creating and inserting new block "
<<
p
VLOG
(
3
)
<<
"Creating and inserting new block "
<<
p
<<
" from system allocator"
;
<<
" from system allocator"
;
static_cast
<
MemoryBlock
*>
(
p
)
->
init
(
cache_
,
MemoryBlock
::
FREE_CHUNK
,
index
,
static_cast
<
MemoryBlock
*>
(
p
)
->
init
(
cache_
,
MemoryBlock
::
FREE_CHUNK
,
index
,
...
@@ -237,18 +235,18 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
...
@@ -237,18 +235,18 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
auto
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
));
auto
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
it
));
pool_
.
erase
(
it
);
pool_
.
erase
(
it
);
DLOG
(
INFO
)
<<
"Split block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
VLOG
(
3
)
<<
"Split block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
") into"
;
<<
") into"
;
block
->
split
(
cache_
,
size
);
block
->
split
(
cache_
,
size
);
DLOG
(
INFO
)
<<
"Left block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
VLOG
(
3
)
<<
"Left block ("
<<
block
<<
", "
<<
block
->
total_size
(
cache_
)
<<
")"
;
<<
")"
;
block
->
set_type
(
cache_
,
MemoryBlock
::
ARENA_CHUNK
);
block
->
set_type
(
cache_
,
MemoryBlock
::
ARENA_CHUNK
);
// the rest of memory if exist
// the rest of memory if exist
if
(
block
->
has_right_buddy
(
cache_
))
{
if
(
block
->
has_right_buddy
(
cache_
))
{
if
(
block
->
right_buddy
(
cache_
)
->
type
(
cache_
)
==
MemoryBlock
::
FREE_CHUNK
)
{
if
(
block
->
right_buddy
(
cache_
)
->
type
(
cache_
)
==
MemoryBlock
::
FREE_CHUNK
)
{
DLOG
(
INFO
)
<<
"Insert right block ("
<<
block
->
right_buddy
(
cache_
)
<<
", "
VLOG
(
3
)
<<
"Insert right block ("
<<
block
->
right_buddy
(
cache_
)
<<
", "
<<
block
->
right_buddy
(
cache_
)
->
total_size
(
cache_
)
<<
")"
;
<<
block
->
right_buddy
(
cache_
)
->
total_size
(
cache_
)
<<
")"
;
pool_
.
insert
(
pool_
.
insert
(
...
@@ -276,7 +274,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
...
@@ -276,7 +274,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
return
;
return
;
}
}
DLOG
(
INFO
)
<<
"Return block "
<<
block
<<
" to fallback allocator."
;
VLOG
(
3
)
<<
"Return block "
<<
block
<<
" to fallback allocator."
;
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
cache_
.
invalidate
(
block
);
cache_
.
invalidate
(
block
);
...
@@ -312,7 +310,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
...
@@ -312,7 +310,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
MemoryBlock
*
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
pool
));
MemoryBlock
*
block
=
static_cast
<
MemoryBlock
*>
(
std
::
get
<
2
>
(
*
pool
));
DLOG
(
INFO
)
<<
"Return block "
<<
block
<<
" to base allocator."
;
VLOG
(
3
)
<<
"Return block "
<<
block
<<
" to base allocator."
;
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
system_allocator_
->
Free
(
block
,
max_chunk_size_
,
block
->
index
(
cache_
));
cache_
.
invalidate
(
block
);
cache_
.
invalidate
(
block
);
...
...
paddle/memory/memcpy.cc
浏览文件 @
4cc42171
...
@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
...
@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform
::
GPUPlace
src_place
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
}
...
@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
...
@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform
::
CPUPlace
src_place
,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
dst_place
.
device
);
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
}
...
@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
...
@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const
void
*
src
,
size_t
num
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
if
(
dst_place
==
src_place
)
{
if
(
dst_place
==
src_place
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
...
...
paddle/memory/memcpy.h
浏览文件 @
4cc42171
...
@@ -20,13 +20,39 @@ limitations under the License. */
...
@@ -20,13 +20,39 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
memory
{
namespace
memory
{
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
*
*/
template
<
typename
DstPlace
,
typename
SrcPlace
>
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
);
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
);
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU or GPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU or GPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
* \param[in] stream CUDA stream.
*
* \note For GPU memory copy, CUDA stream need to be specified
* for asynchronously memory copy.
*
*/
template
<
typename
DstPlace
,
typename
SrcPlace
>
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
);
cudaStream_t
stream
);
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_ONLY_CPU
}
// namespace memory
}
// namespace memory
...
...
paddle/memory/memory.cc
浏览文件 @
4cc42171
...
@@ -60,6 +60,7 @@ detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
...
@@ -60,6 +60,7 @@ detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
platform
::
GpuMaxChunkSize
());
platform
::
GpuMaxChunkSize
());
}
}
}
}
platform
::
SetDeviceId
(
gpu_id
);
return
as
[
gpu_id
];
return
as
[
gpu_id
];
}
}
...
...
paddle/memory/memory.h
浏览文件 @
4cc42171
...
@@ -20,19 +20,53 @@ limitations under the License. */
...
@@ -20,19 +20,53 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
memory
{
namespace
memory
{
/**
* \brief Allocate memory block in one place.
*
* \param[in] place Allocation place (CPU or GPU).
* \param[in] size Allocation size.
*
* \return Allocated memory block address.
*
* \note If return nullptr, it indicates memory allocation failed
* because insufficient memory in current system. When Alloc
* function is invoked, you must check the returned memory
* address is valid or not.
*/
template
<
typename
Place
>
template
<
typename
Place
>
void
*
Alloc
(
Place
,
size_t
);
void
*
Alloc
(
Place
place
,
size_t
size
);
/**
* \brief Free memory block in one place.
*
* \param[in] place Allocation place (CPU or GPU).
* \param[in] ptr Memory block address to free.
*
*/
template
<
typename
Place
>
template
<
typename
Place
>
void
Free
(
Place
,
void
*
);
void
Free
(
Place
place
,
void
*
ptr
);
/**
* \brief Total size of used memory in one place.
*
* \param[in] place Allocation place (CPU or GPU).
*
*/
template
<
typename
Place
>
template
<
typename
Place
>
size_t
Used
(
Place
);
size_t
Used
(
Place
place
);
template
<
typename
T
,
/* must be POD types */
/**
typename
Place
/* platform::GPUPlace or platform::CPUPlace */
,
* \brief Free memory block in one place.
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
*
* \note In some cases, custom deleter is used to
* deallocate the memory automatically for
* std::unique_ptr<T> in tensor.h.
*
*/
template
<
typename
T
,
typename
Place
>
class
PODDeleter
{
class
PODDeleter
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
public:
public:
PODDeleter
(
Place
place
)
:
place_
(
place
)
{}
PODDeleter
(
Place
place
)
:
place_
(
place
)
{}
void
operator
()(
T
*
ptr
)
{
Free
(
place_
,
static_cast
<
void
*>
(
ptr
));
}
void
operator
()(
T
*
ptr
)
{
Free
(
place_
,
static_cast
<
void
*>
(
ptr
));
}
...
...
paddle/operators/CMakeLists.txt
浏览文件 @
4cc42171
...
@@ -54,3 +54,8 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
...
@@ -54,3 +54,8 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net
)
softmax_op net
)
op_library
(
sgd_op SRCS sgd_op.cc sgd_op.cu
)
op_library
(
sgd_op SRCS sgd_op.cc sgd_op.cu
)
op_library
(
recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc
tensor op_registry operator net
)
cc_test
(
recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS
recurrent_network_op gtest mul_op add_op
)
paddle/operators/add_op.cc
浏览文件 @
4cc42171
...
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
...
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/add_op.h"
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
AddOp
:
public
framework
::
OperatorWithKernel
{
class
AddOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of AddOp must be two"
);
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of AddOp must be two"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of AddOp must be one"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of AddOp must be one"
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
@@ -35,10 +32,10 @@ protected:
...
@@ -35,10 +32,10 @@ protected:
}
}
};
};
class
AddOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
AddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
AddOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of add op"
);
AddInput
(
"X"
,
"The first input of add op"
);
AddInput
(
"Y"
,
"The second input of add op"
);
AddInput
(
"Y"
,
"The second input of add op"
);
AddOutput
(
"Out"
,
"The output of add op"
);
AddOutput
(
"Out"
,
"The output of add op"
);
...
@@ -50,11 +47,10 @@ The equation is: Out = X + Y
...
@@ -50,11 +47,10 @@ The equation is: Out = X + Y
}
}
};
};
class
AddOpGrad
:
public
framework
::
OperatorWithKernel
{
class
AddOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"AddOpGrad"
;
LOG
(
INFO
)
<<
"AddOpGrad"
;
return
""
;
return
""
;
...
@@ -64,7 +60,6 @@ protected:
...
@@ -64,7 +60,6 @@ protected:
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_OP
(
add_two
,
ops
::
AddOp
,
ops
::
AddOpMaker
);
REGISTER_GRADIENT_OP
(
add_two
,
add_two_grad
,
paddle
::
operators
::
AddOpGrad
);
REGISTER_GRADIENT_OP
(
add_two
,
add_two_grad
,
ops
::
AddOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
add_two
,
ops
::
AddKernel
<
ops
::
CPUPlace
,
float
>
);
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/add_op.cu
浏览文件 @
4cc42171
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL
(
add_two
,
REGISTER_OP_GPU_KERNEL
(
add_two
,
ops
::
AddKernel
<
ops
::
GPUPlace
,
float
>
);
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
\ No newline at end of file
paddle/operators/add_op.h
浏览文件 @
4cc42171
...
@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
...
@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
AddKernel
:
public
framework
::
OpKernel
{
class
AddKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
input0
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input0
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
device
(
EigenVector
<
T
>::
Flatten
(
*
output
).
device
(
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
framework
::
EigenVector
<
T
>::
Flatten
(
input0
)
+
EigenVector
<
T
>::
Flatten
(
input0
)
+
EigenVector
<
T
>::
Flatten
(
input1
);
framework
::
EigenVector
<
T
>::
Flatten
(
input1
);
}
}
};
};
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
4cc42171
...
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
...
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
OnehotCrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
class
OnehotCrossEntropyOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of OnehotCrossEntropyOp must be two"
);
"Input size of OnehotCrossEntropyOp must be two"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
...
@@ -35,15 +32,14 @@ protected:
...
@@ -35,15 +32,14 @@ protected:
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
().
size
()
==
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
().
size
()
==
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE
(
outputs
[
0
]
->
dims
().
size
()
==
1
,
PADDLE_ENFORCE
(
outputs
[
0
]
->
dims
().
size
()
==
1
,
"label's dimension must be 1."
);
"label's dimension must be 1."
);
outputs
[
0
]
->
Resize
(
framework
::
make_ddim
({
inputs
[
0
]
->
dims
()[
0
]})
);
outputs
[
0
]
->
Resize
(
{
inputs
[
0
]
->
dims
()[
0
]}
);
}
}
};
};
class
OnehotCrossEntropyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
OnehotCrossEntropyOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
OnehotCrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
OnehotCrossEntropyOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of OnehotCrossEntropyOp"
);
AddInput
(
"X"
,
"The first input of OnehotCrossEntropyOp"
);
AddInput
(
"label"
,
"The second input of OnehotCrossEntropyOp"
);
AddInput
(
"label"
,
"The second input of OnehotCrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of OnehotCrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of OnehotCrossEntropyOp"
);
...
@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
...
@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
onehot_cross_entropy
,
REGISTER_OP
(
onehot_cross_entropy
,
paddle
::
operators
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOp
,
paddle
::
operators
::
OnehotCrossEntropyOpMaker
);
ops
::
OnehotCrossEntropyOpMaker
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
ops
::
CPUPlace
,
float
>
);
paddle
::
operators
::
OnehotCrossEntropyOpKernel
<::
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/cross_entropy_op.cu
浏览文件 @
4cc42171
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
paddle
::
operators
::
OnehotCrossEntropyOpKernel
<
ops
::
OnehotCrossEntropyOpKernel
<
ops
::
GPUPlace
,
float
>
);
::
paddle
::
platform
::
GPUPlace
,
float
>
);
\ No newline at end of file
\ No newline at end of file
paddle/operators/cross_entropy_op.h
浏览文件 @
4cc42171
...
@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
...
@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
OnehotCrossEntropyOpKernel
:
public
framework
::
OpKernel
{
class
OnehotCrossEntropyOpKernel
:
public
OpKernel
{
public:
public:
constexpr
T
LOG_THRESHOLD
()
const
{
return
static_cast
<
T
>
(
1e-20
);
}
constexpr
T
LOG_THRESHOLD
()
const
{
return
static_cast
<
T
>
(
1e-20
);
}
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
X
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
X
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
const
T
*
X_data
=
X
.
data
<
T
>
();
const
T
*
X_data
=
X
.
data
<
T
>
();
const
int
*
label_data
=
const
int
*
label_data
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
().
data
<
int
>
();
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
().
data
<
int
>
();
auto
*
Y
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
auto
*
Y
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/fc_op.cc
浏览文件 @
4cc42171
...
@@ -12,41 +12,38 @@
...
@@ -12,41 +12,38 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/net.h"
#include "type_alias.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
FullyConnectedOp
:
public
framework
::
PlainNet
{
class
FullyConnectedOp
:
public
NetOp
{
public:
public:
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
{
Input
(
"X"
),
Input
(
"W"
),
Input
(
"X"
),
Input
(
"W"
),
},
},
{
Output
(
"before_act"
)},
{
Output
(
"before_act"
)},
{}));
{}));
auto
b
=
Input
(
"b"
);
auto
b
=
Input
(
"b"
);
if
(
b
!=
framework
::
OperatorBase
::
EMPTY_VAR_NAME
())
{
if
(
b
!=
EMPTY_VAR_NAME
())
{
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"before_act"
),
Input
(
"b"
)},
{
Output
(
"before_act"
),
Input
(
"b"
)},
{
Output
(
"before_act"
)},
{
Output
(
"before_act"
)},
{}));
{}));
}
}
auto
activation
=
GetAttr
<
std
::
string
>
(
"activation"
);
auto
activation
=
GetAttr
<
std
::
string
>
(
"activation"
);
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
AddOp
(
OpRegistry
::
CreateOp
(
activation
,
{
Output
(
"before_act"
)},
{
Output
(
"Y"
)},
{}));
activation
,
{
Output
(
"before_act"
)},
{
Output
(
"Y"
)},
{}));
CompleteAddOp
(
false
);
CompleteAddOp
(
false
);
}
}
};
};
class
FullyConnectedOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FullyConnectedOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
FullyConnectedOpMaker
(
framework
::
OpProto
*
proto
,
FullyConnectedOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input of fc operator"
);
AddInput
(
"X"
,
"the input of fc operator"
);
AddInput
(
"W"
,
"the weight of fc operator"
);
AddInput
(
"W"
,
"the weight of fc operator"
);
...
@@ -71,6 +68,4 @@ USE_OP(rowwise_add);
...
@@ -71,6 +68,4 @@ USE_OP(rowwise_add);
USE_OP
(
sigmoid
);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
USE_OP
(
softmax
);
REGISTER_OP
(
fc
,
REGISTER_OP
(
fc
,
ops
::
FullyConnectedOp
,
ops
::
FullyConnectedOpMaker
);
paddle
::
operators
::
FullyConnectedOp
,
paddle
::
operators
::
FullyConnectedOpMaker
);
paddle/operators/mul_op.cc
浏览文件 @
4cc42171
...
@@ -13,17 +13,14 @@
...
@@ -13,17 +13,14 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
MulOp
:
public
framework
::
OperatorWithKernel
{
class
MulOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"The mul op must take two inputs"
);
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"The mul op must take two inputs"
);
auto
dim0
=
inputs
[
0
]
->
dims
();
auto
dim0
=
inputs
[
0
]
->
dims
();
auto
dim1
=
inputs
[
1
]
->
dims
();
auto
dim1
=
inputs
[
1
]
->
dims
();
...
@@ -37,10 +34,10 @@ protected:
...
@@ -37,10 +34,10 @@ protected:
}
}
};
};
class
MulOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
MulOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of mul op"
);
AddInput
(
"X"
,
"The first input of mul op"
);
AddInput
(
"Y"
,
"The second input of mul op"
);
AddInput
(
"Y"
,
"The second input of mul op"
);
AddOutput
(
"Out"
,
"The output of mul op"
);
AddOutput
(
"Out"
,
"The output of mul op"
);
...
@@ -52,11 +49,10 @@ The equation is: Out = X * Y
...
@@ -52,11 +49,10 @@ The equation is: Out = X * Y
}
}
};
};
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
class
MulOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
LOG
(
INFO
)
<<
"MulGrad"
;
return
""
;
return
""
;
...
@@ -66,8 +62,7 @@ protected:
...
@@ -66,8 +62,7 @@ protected:
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
mul
,
paddle
::
operators
::
MulOp
,
paddle
::
operator
s
::
MulOpMaker
);
REGISTER_OP
(
mul
,
ops
::
MulOp
,
op
s
::
MulOpMaker
);
REGISTER_GRADIENT_OP
(
mul
,
mul_grad
,
paddle
::
operator
s
::
MulOpGrad
);
REGISTER_GRADIENT_OP
(
mul
,
mul_grad
,
op
s
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mul
,
ops
::
MulKernel
<
ops
::
CPUPlace
,
float
>
);
mul
,
paddle
::
operators
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/mul_op.cu
浏览文件 @
4cc42171
...
@@ -13,8 +13,5 @@
...
@@ -13,8 +13,5 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL
(
mul
,
REGISTER_OP_GPU_KERNEL
(
mul
,
ops
::
MulKernel
<
ops
::
GPUPlace
,
float
>
);
paddle
::
operators
::
MulKernel
<
paddle
::
platform
\ No newline at end of file
::
GPUPlace
,
float
>
);
\ No newline at end of file
paddle/operators/mul_op.h
浏览文件 @
4cc42171
...
@@ -14,30 +14,27 @@
...
@@ -14,30 +14,27 @@
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MulKernel
:
public
framework
::
OpKernel
{
class
MulKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
void
Compute
(
const
KernelContext
&
context
)
const
override
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
auto
input0
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input0
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
EigenMatrix
<
T
>::
From
(
*
output
).
device
(
EigenMatrix
<
T
>::
From
(
*
output
).
device
(
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
EigenMatrix
<
T
>::
From
(
input0
).
contract
(
EigenMatrix
<
T
>::
From
(
input1
),
framework
::
EigenMatrix
<
T
>::
From
(
input0
).
contract
(
dim_pair
);
framework
::
EigenMatrix
<
T
>::
From
(
input1
),
dim_pair
);
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/operators/recurrent_network_op.cc
0 → 100644
浏览文件 @
4cc42171
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/recurrent_network_op.h"
#include <glog/logging.h>
#include <cstring>
#include <sstream>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
rnn
{
void
SegmentInputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
)
{
PADDLE_ENFORCE
(
!
inlinks
.
empty
(),
"no in links are provided."
);
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
Tensor
*
input
=
step_scopes
[
0
]
->
GetVariable
(
inlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
DDim
dims
=
input
->
dims
();
PADDLE_ENFORCE
(
static_cast
<
size_t
>
(
dims
[
0
])
==
seq_len
,
"all the inlinks must have same length"
);
DDim
step_dims
=
slice_ddim
(
dims
,
1
,
dims
.
size
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
Tensor
*
step_input
=
step_scopes
[
j
]
->
CreateVariable
(
inlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
();
*
step_input
=
input
->
Slice
<
float
>
(
j
,
j
+
1
);
step_input
->
Resize
(
step_dims
);
}
}
}
void
ConcatOutputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
)
{
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
Tensor
*
output
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
// TODO(qingiqng) remove following code after adding
// InferShape in RecurrentGradientOp
DDim
step_dims
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len
);
output
->
mutable_data
<
float
>
(
make_ddim
(
dims_vec
),
platform
::
CPUPlace
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
Tensor
*
step_output
=
step_scopes
[
j
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(
output
->
Slice
<
float
>
(
j
,
j
+
1
))
.
CopyFrom
<
float
>
(
*
step_output
,
platform
::
CPUPlace
());
}
}
}
void
LinkMemories
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
scopes
,
const
std
::
vector
<
rnn
::
MemoryAttr
>&
memories
,
size_t
step_id
,
int
offset
)
{
PADDLE_ENFORCE
(
step_id
<
scopes
.
size
(),
"step [%d] is out of range of step scopes' size [%d]"
,
step_id
,
scopes
.
size
());
PADDLE_ENFORCE
(
static_cast
<
int
>
(
step_id
)
+
offset
>=
0
,
"offset [%d] must be large than -[%d]"
,
offset
,
step_id
);
PADDLE_ENFORCE
(
step_id
+
offset
<
scopes
.
size
(),
"offset [%d] is out of range, it must be less than (%d - %d)"
,
offset
,
scopes
.
size
(),
step_id
);
std
::
shared_ptr
<
Scope
>
scope
=
scopes
[
step_id
];
std
::
shared_ptr
<
Scope
>
linked_scope
=
scopes
[
step_id
+
offset
];
for
(
auto
&
attr
:
memories
)
{
auto
mem
=
scope
->
CreateVariable
(
attr
.
pre_var
)
->
GetMutable
<
Tensor
>
();
// maybe share variable is better?
auto
linked_mem
=
linked_scope
->
GetVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
mem
->
ShareDataWith
<
float
>
(
*
linked_mem
);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
auto
m
=
scope
->
CreateVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
// for unit test, as addOp and mulOp are null currently, if not
// mutable_data, mem.data() in output will be error. We will
// remove this line after merge the correct addOp and mulOp.
m
->
mutable_data
<
float
>
(
mem
->
dims
(),
platform
::
CPUPlace
());
}
}
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
,
const
OperatorBase
&
op
)
{
arg
->
step_net
=
op
.
Input
(
name
.
step_net
);
arg
->
step_scopes
=
op
.
Output
(
name
.
step_scopes
);
auto
inlinks
=
op
.
Inputs
(
name
.
inlinks
);
auto
inlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
inlink_alias
);
PADDLE_ENFORCE
(
inlinks
.
size
()
==
inlink_alias
.
size
(),
"the size of inlinks and inlink_alias don't match:%d,%d"
,
inlinks
.
size
(),
inlink_alias
.
size
());
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
rnn
::
Link
link
;
link
.
external
=
inlinks
[
i
];
link
.
internal
=
inlink_alias
[
i
];
(
arg
->
inlinks
).
push_back
(
link
);
}
auto
outlinks
=
op
.
Outputs
(
name
.
outlinks
);
auto
outlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
outlink_alias
);
PADDLE_ENFORCE
(
outlinks
.
size
()
==
outlink_alias
.
size
(),
"the size of outlinks and outlink_alias don't match:%d,%d"
,
outlinks
.
size
(),
outlink_alias
.
size
());
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
++
i
)
{
rnn
::
Link
link
;
link
.
external
=
outlinks
[
i
];
link
.
internal
=
outlink_alias
[
i
];
(
arg
->
outlinks
).
push_back
(
link
);
}
auto
boot_memories
=
op
.
Inputs
(
name
.
boot_memories
);
// attributes
auto
memories
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
);
auto
pre_memories
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
);
PADDLE_ENFORCE
(
memories
.
size
()
==
boot_memories
.
size
(),
"the size of memories, boot_memories don't match:%d,%d"
,
memories
.
size
(),
boot_memories
.
size
());
PADDLE_ENFORCE
(
pre_memories
.
size
()
==
boot_memories
.
size
(),
"the size of pre_memories, boot_memories don't match:%d,%d"
,
pre_memories
.
size
(),
boot_memories
.
size
());
PADDLE_ENFORCE
(
memories
.
size
()
>
0
,
"more than 1 memories should be set"
);
for
(
size_t
i
=
0
;
i
<
memories
.
size
();
++
i
)
{
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
var
=
memories
[
i
];
mem_attr
.
pre_var
=
pre_memories
[
i
];
mem_attr
.
boot_var
=
boot_memories
[
i
];
(
arg
->
memories
).
push_back
(
mem_attr
);
}
}
}
// namespace rnn
void
RecurrentAlgorithm
::
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{
seq_len_
=
scope
->
GetVariable
((
arg_
->
inlinks
[
0
]).
external
)
->
GetMutable
<
Tensor
>
()
->
dims
()[
0
];
CreateScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
// SegmentInputs is called in InferShape. The input must hold memory in
// SegmentInputs. But the other op only set dimension for the output in
// InferShape. That's a problem. Wether the RNN op needs InferShape or not?
// Wether the following functions (SegmentInputs, InitMemories, ...) need
// to rewrite for RNN op?
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
);
InitMemories
(
step_scopes
[
0
]);
PADDLE_ENFORCE
(
scope
->
HasVariable
(
arg_
->
step_net
),
"stepnet [%s] is not in scope."
,
arg_
->
step_net
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
// If the InferShape is called in OperatorBase's run function,
// the rnn op only needs to do InferShape for the first time step
for
(
size_t
i
=
0
;
i
<
seq_len_
;
i
++
)
{
if
(
i
>
0
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
i
,
-
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
step_scopes
[
i
]);
}
auto
outlinks
=
arg_
->
outlinks
;
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
DDim
step_dims
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
// now only support fixed length
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len_
);
Tensor
*
output
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
output
->
Resize
(
make_ddim
(
dims_vec
));
}
}
void
RecurrentAlgorithm
::
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
for
(
size_t
step_id
=
0
;
step_id
<
seq_len_
;
step_id
++
)
{
// the link memory is done in InferShape
// maybe remove following code after testing
if
(
step_id
>
0
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
-
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
Run
(
step_scopes
[
step_id
],
dev_ctx
);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
);
}
void
RecurrentAlgorithm
::
CreateScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
{
// TODO(xxx) Only two scopes are needed for inference, this case will be
// supported later.
auto
step_scopes
=
scope
->
GetVariable
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
if
(
seq_len_
>
step_scopes
->
size
())
{
for
(
size_t
i
=
step_scopes
->
size
();
i
<
seq_len_
;
++
i
)
{
std
::
shared_ptr
<
Scope
>
step_scope
=
std
::
make_shared
<
Scope
>
(
scope
);
// Now all variables in scope must be created outside of op.
auto
net_op
=
scope
->
GetVariable
(
arg_
->
step_net
)
->
GetMutable
<
NetOp
>
();
for
(
auto
&
input
:
net_op
->
inputs_
)
{
step_scope
->
CreateVariable
(
input
);
}
for
(
auto
&
output
:
net_op
->
outputs_
)
{
step_scope
->
CreateVariable
(
output
);
}
step_scopes
->
push_back
(
std
::
make_shared
<
Scope
>
(
step_scope
));
}
}
}
void
RecurrentAlgorithm
::
InitMemories
(
std
::
shared_ptr
<
Scope
>
step_scope
)
const
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
Tensor
*
pre_mem
=
step_scope
->
CreateVariable
(
attr
.
pre_var
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE
(
step_scope
->
HasVariable
(
attr
.
boot_var
),
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
attr
.
boot_var
);
Tensor
*
boot_mem
=
step_scope
->
GetVariable
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
pre_mem
->
ShareDataWith
<
float
>
(
*
boot_mem
);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
// here for unit test
auto
cur_step_mem
=
step_scope
->
CreateVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
cur_step_mem
->
mutable_data
<
float
>
(
boot_mem
->
dims
(),
platform
::
CPUPlace
());
}
}
const
rnn
::
ArgumentName
RecurrentOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"inlinks"
,
"outlinks"
,
"inlink_alias"
,
"outlink_alias"
,
"memories"
,
"pre_memories"
,
"boot_memories"
};
const
rnn
::
ArgumentName
RecurrentGradientOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"outlink@grad"
,
"inlink@grad"
,
"inlink_alias"
,
"outlink_alias"
,
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
void
RecurrentOp
::
Init
()
{
OperatorBase
::
Init
();
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
alg_
.
Init
(
std
::
move
(
arg
));
}
class
RecurrentAlgorithmProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
RecurrentAlgorithmProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
RecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
AddInputs
(
name
.
inlinks
,
"the input that need to be segmented for each step."
);
AddInputs
(
name
.
boot_memories
,
"variables to initialize memories."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutputs
(
name
.
outlinks
,
"the output that need to concated for all steps."
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
inlink_alias
,
"alias of inlinks"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
outlink_alias
,
"alias of outlinks"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
,
"names of pre-memories"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
,
"names of memories"
);
AddComment
(
"This is a recurrent group operator."
);
}
};
void
RecurrentGradientAlgorithm
::
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
);
PADDLE_ENFORCE
(
scope
->
HasVariable
(
arg_
->
step_net
),
"step net is not in scope."
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
Run
(
step_scopes
[
step_id
],
dev_ctx
);
}
LinkBootMemoryGradients
(
step_scopes
[
0
]);
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
);
}
void
RecurrentGradientAlgorithm
::
LinkBootMemoryGradients
(
std
::
shared_ptr
<
Scope
>
step_scope
)
const
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
Tensor
*
mem_grad
=
step_scope
->
CreateVariable
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE
(
mem_grad
!=
nullptr
,
"boot_tensor should be retrieved before"
);
PADDLE_ENFORCE
(
step_scope
->
HasVariable
(
attr
.
boot_var
),
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
attr
.
boot_var
);
Tensor
*
boot_mem_grad
=
step_scope
->
CreateVariable
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
boot_mem_grad
->
ShareDataWith
<
float
>
(
*
mem_grad
);
}
}
void
RecurrentGradientAlgorithm
::
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{
seq_len_
=
scope
->
GetVariable
((
arg_
->
inlinks
[
0
]).
external
)
->
GetMutable
<
Tensor
>
()
->
dims
()[
0
];
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
);
PADDLE_ENFORCE
(
scope
->
HasVariable
(
arg_
->
step_net
),
"step net is not in scope."
);
Variable
*
net
=
scope
->
GetVariable
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
);
}
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
step_scopes
[
step_id
]);
}
auto
outlinks
=
arg_
->
outlinks
;
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
DDim
step_dims
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
// now only support fixed length
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len_
);
Tensor
*
output
=
step_scopes
[
0
]
->
GetVariable
(
outlinks
[
i
].
external
)
->
GetMutable
<
Tensor
>
();
output
->
Resize
(
make_ddim
(
dims_vec
));
}
LinkBootMemoryGradients
(
step_scopes
[
0
]);
}
void
RecurrentGradientOp
::
Init
()
{
OperatorBase
::
Init
();
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
alg_
.
Init
(
std
::
move
(
arg
));
}
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
recurrent_op
,
paddle
::
operators
::
RecurrentOp
,
paddle
::
operators
::
RecurrentAlgorithmProtoAndCheckerMaker
);
paddle/operators/recurrent_network_op.h
0 → 100644
浏览文件 @
4cc42171
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
operators
{
using
namespace
paddle
::
framework
;
namespace
rnn
{
/**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
*
* Memory attributes cached by this op, dims will be infered from
* boot memories in father scope. Other attributes are copied from Op's proto
* attributes.
*/
struct
MemoryAttr
{
// name of current state variable
std
::
string
var
;
// name of previous step's state variable
std
::
string
pre_var
;
// name of the variables to init this memory (same role of `boot_layer` in
// PaddlePaddle), which is store in father's scope.
std
::
string
boot_var
;
};
struct
Link
{
// input or output links name.
std
::
string
internal
;
// alias to avoid duplicate keys in scopes.
std
::
string
external
;
};
struct
Argument
{
std
::
string
step_net
;
std
::
string
step_scopes
;
std
::
vector
<
Link
>
inlinks
;
std
::
vector
<
Link
>
outlinks
;
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
};
struct
ArgumentName
{
std
::
string
step_net
;
std
::
string
step_scopes
;
std
::
string
inlinks
;
std
::
string
outlinks
;
std
::
string
inlink_alias
;
// the alias of inlinks in step net.
std
::
string
outlink_alias
;
// the alias of outlinks in step net.
std
::
string
memories
;
// the memory name
std
::
string
pre_memories
;
// the previous memory name
std
::
string
boot_memories
;
// the boot memory name
};
/**
* Prepare inputs for each step net.
*/
void
SegmentInputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
);
/**
* Process outputs of step nets and merge to variables.
*/
void
ConcatOutputs
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
);
void
LinkMemories
(
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
step_scopes
,
const
std
::
vector
<
MemoryAttr
>&
memories
,
size_t
step_id
,
int
offset
);
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
);
};
// namespace rnn
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
// TODO:
// 1. No-padding computing for sequences with indifinite length in one batch.
// 2. Hierarchical RNN for sequence with sub-sequence.
// 3. Internal Memory.
// 4. More Complex RNN architecture, such as Gated Feedback RNN.
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
class
RecurrentAlgorithm
{
public:
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
;
protected:
/*
* The step scopes will be stored in the father scope as a variable.
*
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void
CreateScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
;
inline
const
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
GetStepScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
{
return
*
(
scope
->
GetVariable
(
arg_
->
step_scopes
))
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
}
void
InitMemories
(
std
::
shared_ptr
<
Scope
>
step_scopes
)
const
;
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
mutable
size_t
seq_len_
;
};
class
RecurrentGradientAlgorithm
{
/**
* RNN's backward alogorithm.
*
* To accelerate the development of RecurrentGradientOp, we decouple RNN's
* algorithm and `OperatorBase`'s implementation, the former contains the core
* implementation of a RNN, and will keep stable even if the framework changes
* a
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* operator.
*/
public:
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
LinkBootMemoryGradients
(
std
::
shared_ptr
<
Scope
>
step_scopes
)
const
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
;
protected:
inline
const
std
::
vector
<
std
::
shared_ptr
<
Scope
>>&
GetStepScopes
(
std
::
shared_ptr
<
Scope
>
scope
)
const
{
return
*
(
scope
->
GetVariable
(
arg_
->
step_scopes
))
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
}
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
mutable
size_t
seq_len_
;
};
class
RecurrentOp
final
:
public
OperatorBase
{
public:
void
Init
()
override
;
/**
* InferShape must be called before Run.
*/
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
}
static
const
rnn
::
ArgumentName
kArgName
;
private:
RecurrentAlgorithm
alg_
;
};
class
RecurrentGradientOp
final
:
public
OperatorBase
{
public:
void
Init
()
override
;
/**
* InferShape must be called before Run.
*/
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
}
static
const
rnn
::
ArgumentName
kArgName
;
private:
RecurrentGradientAlgorithm
alg_
;
};
}
// namespace operators
}
// namespace paddle
paddle/operators/recurrent_network_op_test.cc
0 → 100644
浏览文件 @
4cc42171
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h"
namespace
paddle
{
namespace
operators
{
class
RecurrentOpTest
:
public
::
testing
::
Test
{
protected:
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateStepNet
();
CreateRNNOp
();
}
virtual
void
TearDown
()
override
{}
void
CreateGlobalVariables
()
{
scope_
=
std
::
make_shared
<
Scope
>
();
// create input, and init content
LOG
(
INFO
)
<<
"create global variable x"
;
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"x"
,
"x0"
,
"x1"
,
"h"
})
{
Variable
*
x
=
scope_
->
CreateVariable
(
inlink
);
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
}
// create output alias just for test
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
Variable
*
x
=
scope_
->
CreateVariable
(
inlink
);
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
}
LOG
(
INFO
)
<<
"create global variable w"
;
Variable
*
w
=
scope_
->
CreateVariable
(
"rnn/w"
);
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
(
std
::
vector
<
int
>
{
30
,
30
}),
platform
::
CPUPlace
());
for
(
auto
boot
:
std
::
vector
<
std
::
string
>
{
"x_boot"
,
"h_boot"
})
{
LOG
(
INFO
)
<<
"create global variable "
<<
boot
;
Variable
*
h_boot
=
scope_
->
CreateVariable
(
boot
);
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
}),
platform
::
CPUPlace
());
}
LOG
(
INFO
)
<<
"create variable step_scopes"
;
scope_
->
CreateVariable
(
"step_scopes"
);
LOG
(
INFO
)
<<
"create variable h"
;
scope_
->
CreateVariable
(
"h"
);
}
void
CreateRNNOp
()
{
OpDesc
op_desc
;
op_desc
.
set_type
(
"recurrent_op"
);
// inlinks 0
op_desc
.
add_inputs
(
"x"
);
op_desc
.
add_inputs
(
"x0"
);
op_desc
.
add_inputs
(
"x1"
);
// boot_memories 3
op_desc
.
add_inputs
(
"x_boot"
);
op_desc
.
add_inputs
(
"h_boot"
);
// step net 5
op_desc
.
add_inputs
(
"step_net"
);
// outlinks 6
op_desc
.
add_outputs
(
"h"
);
// step scopes 7
op_desc
.
add_outputs
(
"step_scopes"
);
auto
_input_format
=
std
::
vector
<
int
>
{
0
,
// in_link
3
,
// memories
5
// step_net
};
auto
input_format
=
op_desc
.
add_attrs
();
input_format
->
set_name
(
"input_format"
);
input_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
for
(
auto
i
:
_input_format
)
{
input_format
->
add_ints
(
i
);
}
auto
output_format
=
op_desc
.
add_attrs
();
output_format
->
set_name
(
"output_format"
);
output_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
for
(
auto
i
:
std
::
vector
<
int
>
{
0
,
1
,
2
})
{
output_format
->
add_ints
(
i
);
}
auto
inlink_alias
=
op_desc
.
add_attrs
();
inlink_alias
->
set_name
(
"inlink_alias"
);
inlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
outlink_alias
=
op_desc
.
add_attrs
();
outlink_alias
->
set_name
(
"outlink_alias"
);
outlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
pre_memories
=
op_desc
.
add_attrs
();
pre_memories
->
set_name
(
"pre_memories"
);
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
memories
=
op_desc
.
add_attrs
();
memories
->
set_name
(
"memories"
);
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// create inlink_alias
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"x@alias"
,
"x0@alias"
,
"x1@alias"
})
{
inlink_alias
->
add_strings
(
item
);
}
// pre memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/x@pre"
,
"rnn/h@pre"
})
{
pre_memories
->
add_strings
(
item
);
}
// memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/x"
,
"rnn/h"
})
{
memories
->
add_strings
(
item
);
}
// output alias
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
outlink_alias
->
add_strings
(
item
);
}
rnn_op_
=
OpRegistry
::
CreateOp
(
op_desc
);
LOG
(
INFO
)
<<
"rnn_op finish init"
;
}
void
CreateStepNet
()
{
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
->
CreateVariable
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
// rnn/s is net's input or output?
net
->
inputs_
=
{
"rnn/h@pre"
,
"rnn/w"
,
"rnn/x"
};
net
->
inputs_
=
{
"rnn/s"
,
"rnn/h"
};
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h@pre"
,
"rnn/w"
},
{
"rnn/s"
},
{}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/x"
,
"rnn/s"
},
{
"rnn/h"
},
{}));
net
->
CompleteAddOp
();
}
// father scope
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
shared_ptr
<
OperatorBase
>
rnn_op_
;
};
TEST_F
(
RecurrentOpTest
,
Run
)
{
platform
::
CPUDeviceContext
ctx
;
rnn_op_
->
InferShape
(
scope_
);
rnn_op_
->
Run
(
scope_
,
ctx
);
}
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
protected:
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateStepScopes
();
CreateStepNet
();
CreateRNNGradientAlgorithm
();
// segment inputs
SegmentInputs
();
// link forward memories
LinkeMemories
();
}
virtual
void
TearDown
()
override
{}
void
CreateGlobalVariables
()
{
scope_
=
std
::
make_shared
<
Scope
>
();
// inputs: x
LOG
(
INFO
)
<<
"create global variable x"
;
Variable
*
x
=
scope_
->
CreateVariable
(
"x"
);
DDim
dims
=
make_ddim
({
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
// inputs: h_boot
LOG
(
INFO
)
<<
"create global variable h_boot"
;
Variable
*
h_boot
=
scope_
->
CreateVariable
(
"h_boot"
);
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
20
/*batch size*/
,
30
/*input dim*/
}),
platform
::
CPUPlace
());
// inputs: w
LOG
(
INFO
)
<<
"create global variable w"
;
Variable
*
w
=
scope_
->
CreateVariable
(
"rnn/w"
);
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
30
,
30
}),
platform
::
CPUPlace
());
// inputs: h_grad
LOG
(
INFO
)
<<
"create variable h_grad"
;
Variable
*
dh
=
scope_
->
CreateVariable
(
"h_grad"
);
dh
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
10
,
20
,
30
}),
platform
::
CPUPlace
());
// inputs: step_scopes
LOG
(
INFO
)
<<
"create variable step_scopes"
;
scope_
->
CreateVariable
(
"step_scopes"
);
// inputs: step_net
LOG
(
INFO
)
<<
"create variable step_net"
;
scope_
->
CreateVariable
(
"step_net"
);
// outputs: w_grad
LOG
(
INFO
)
<<
"create global variable w_grad"
;
scope_
->
CreateVariable
(
"rnn/w_grad"
);
// outputs: x_grad
LOG
(
INFO
)
<<
"create global variable x_grad"
;
scope_
->
CreateVariable
(
"x_grad"
);
// outputs: h_boot_grad
LOG
(
INFO
)
<<
"create global variable h_boot_grad"
;
scope_
->
CreateVariable
(
"h_boot_grad"
);
}
void
CreateStepScopes
()
{
std
::
vector
<
std
::
shared_ptr
<
Scope
>>*
step_scopes
=
scope_
->
GetVariable
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
auto
scope
=
std
::
make_shared
<
Scope
>
(
scope_
);
auto
pre_t
=
scope
->
CreateVariable
(
"rnn/pre_h"
)
->
GetMutable
<
Tensor
>
();
pre_t
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
auto
tensor
=
scope
->
CreateVariable
(
"rnn/h"
)
->
GetMutable
<
Tensor
>
();
tensor
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
// for unit test of ConcatOutputs
auto
xg
=
scope
->
CreateVariable
(
"rnn/x_grad"
)
->
GetMutable
<
Tensor
>
();
xg
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
step_scopes
->
push_back
(
scope
);
}
// last time step
auto
g
=
(
*
step_scopes
)[
9
]
->
CreateVariable
(
"rnn/h_pre_grad"
)
->
GetMutable
<
Tensor
>
();
g
->
mutable_data
<
float
>
(
make_ddim
({
20
,
30
}),
platform
::
CPUPlace
());
}
void
CreateRNNGradientAlgorithm
()
{
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
arg
->
step_net
=
"step_net"
;
arg
->
step_scopes
=
"step_scopes"
;
rnn
::
Link
inlink
;
inlink
.
external
=
"h_grad"
;
inlink
.
internal
=
"rnn/h_grad"
;
arg
->
inlinks
=
std
::
vector
<
rnn
::
Link
>
{
inlink
};
rnn
::
Link
outlink
;
outlink
.
external
=
"x_grad"
;
outlink
.
internal
=
"rnn/x_grad"
;
arg
->
outlinks
=
std
::
vector
<
rnn
::
Link
>
{
outlink
};
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
pre_var
=
"rnn/h_pre_grad"
;
mem_attr
.
var
=
"rnn/h_grad"
;
mem_attr
.
boot_var
=
"h_boot_grad"
;
arg
->
memories
=
std
::
vector
<
rnn
::
MemoryAttr
>
{
mem_attr
};
rnn_grad_algo_
.
Init
(
std
::
move
(
arg
));
}
void
CreateStepNet
()
{
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
->
CreateVariable
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h_pre"
,
"rnn/w"
,
"rnn/s_grad"
},
{
"rnn/h_pre_grad"
,
"rnn/w_grad"
},
{}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/h_grad"
},
{
"rnn/x_grad"
,
"rnn/s_grad"
},
{}));
net
->
CompleteAddOp
();
}
void
SegmentInputs
()
{
LOG
(
INFO
)
<<
"segment inputs"
;
std
::
vector
<
std
::
string
>
inlinks
=
{
"x"
};
std
::
vector
<
std
::
string
>
inlinks_alias
=
{
"rnn/x"
};
rnn
::
Link
inlink
;
inlink
.
external
=
"x"
;
inlink
.
internal
=
"rnn/x"
;
std
::
vector
<
std
::
shared_ptr
<
Scope
>>*
step_scopes
=
scope_
->
GetVariable
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
rnn
::
SegmentInputs
(
*
step_scopes
,
std
::
vector
<
rnn
::
Link
>
{
inlink
},
10
);
}
void
LinkeMemories
()
{
LOG
(
INFO
)
<<
"link memories"
;
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
pre_var
=
"rnn/h_pre"
;
mem_attr
.
var
=
"rnn/h"
;
mem_attr
.
boot_var
=
"boot_h"
;
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
memories
.
push_back
(
mem_attr
);
std
::
vector
<
std
::
shared_ptr
<
Scope
>>*
step_scopes
=
scope_
->
GetVariable
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
std
::
shared_ptr
<
Scope
>>>
();
for
(
int
i
=
1
;
i
<
10
;
++
i
)
{
rnn
::
LinkMemories
(
*
step_scopes
,
memories
,
i
,
-
1
);
}
}
std
::
shared_ptr
<
Scope
>
scope_
;
RecurrentGradientAlgorithm
rnn_grad_algo_
;
};
// TEST_F(RecurrentGradientAlgorithmTest, Run) {
// platform::CPUDeviceContext ctx;
// rnn_grad_algo_.Run(scope_, ctx);
// }
}
// namespace operators
}
// namespace paddle
TEST
(
RecurrentOp
,
LinkMemories
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
operators
;
// create and init step scopes
int
len
=
10
;
std
::
vector
<
std
::
shared_ptr
<
Scope
>>
step_scopes
;
for
(
int
i
=
0
;
i
<
len
;
++
i
)
{
auto
scope
=
std
::
make_shared
<
Scope
>
();
scope
->
CreateVariable
(
"pre_h"
);
auto
tensor
=
scope
->
CreateVariable
(
"h"
)
->
GetMutable
<
Tensor
>
();
float
*
data
=
tensor
->
mutable_data
<
float
>
(
make_ddim
({
15
,
20
}),
CPUPlace
());
for
(
int
i
=
0
;
i
<
15
*
20
;
++
i
)
{
data
[
i
]
=
rand
()
*
(
1.
/
(
double
)
RAND_MAX
);
}
step_scopes
.
push_back
(
scope
);
}
// create MemoryAttr
rnn
::
MemoryAttr
mem_attr
;
mem_attr
.
pre_var
=
"pre_h"
;
mem_attr
.
var
=
"h"
;
mem_attr
.
boot_var
=
"boot_h"
;
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
memories
.
push_back
(
mem_attr
);
for
(
int
i
=
1
;
i
<
len
;
++
i
)
{
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
-
1
);
}
// check
for
(
int
i
=
0
;
i
<
len
-
1
;
++
i
)
{
const
float
*
a
=
step_scopes
[
i
]
->
GetVariable
(
"h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
const
float
*
b
=
step_scopes
[
i
+
1
]
->
GetVariable
(
"pre_h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
15
*
20
;
++
i
)
{
ASSERT_FLOAT_EQ
(
a
[
i
],
b
[
i
]);
}
}
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
1
);
}
// check
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
const
float
*
a
=
step_scopes
[
i
]
->
GetVariable
(
"pre_h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
const
float
*
b
=
step_scopes
[
i
+
1
]
->
GetVariable
(
"h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
15
*
20
;
++
i
)
{
ASSERT_FLOAT_EQ
(
a
[
i
],
b
[
i
]);
}
}
}
USE_OP
(
add_two
);
USE_OP
(
mul
);
paddle/operators/rnn_design.md
0 → 100644
浏览文件 @
4cc42171
# RNN 变长输入设计
对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式,
即将一个mini-batch内不同长度的序列补0到固定长度参与计算。
现有Paddle包括
`RecurrentLayerGroup`
在内的RNN均实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。
## 背景介绍
由于tensor必须有明确的shape,因此基于tensor 的主流框架在存储变长序列时,
必须用zero-padding的方式将变长序列补全为固定shape的tensor。
由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在,
因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。
由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来进行优化
[
1
][
2
]
,
但不管是padding还是bucket,对于用户都是额外的使用负担。
因此,
**paddle原生支持变长序列的方式,能直接满足用户对变长序列的最直接的需求,在当前主流平台中可以算是一大优势**
。
但对变长序列的支持,需要对目前框架做一些修改,下面讨论如何在最小修改下支持变长序列。
## 多层序列数据格式 `LODTensor`
目前 Paddle 会将一个mini-batch内的数据存储在一维的内存上,
额外使用
`Argument.sequenceStartPositions`
来存储每个句子的信息。
Paddle里使用
`Argument.subSequenceStartPositions`
来存储2层的序列信息,更高维度的序列则无法直接支持;
为了支持
`N-level`
序列的存储,本文将序列信息定义成如下数据结构:
```
c++
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int
>>>
lod_start_pos_
;
```
或者更明确的定义
```
c++
typedef
std
::
vector
<
int
>
level_t
;
std
::
vector
<
level_t
>
lod_start_pos
;
```
这里的每一个
`level_t`
存储一个粒度(level)的偏移信息,和paddle目前做法一致。
为了更透明地传递序列信息,我们引入了一种新的tensor 称为
`LODTensor`
[4],
其关于tensor相关的接口都直接继承自
`Tensor`
,但另外添加了序列相关接口。
如此,在操作一个
`LODTensor`
时,普通
`Op`
直接当成
`Tensor`
使用,
而操作序列的
`Op`
会额外操作
`LODTensor`
的变长序列操作的相关接口。
`LODTensor`
具体定义如下:
```
c++
class
LODTensor
:
public
Tensor
{
public:
size_t
Levels
()
const
{
return
seq_start_positions_
.
size
();
}
size_t
Elements
(
int
level
=
0
)
const
{
return
seq_start_positions_
[
level
].
size
();
}
// slice of level[elem_begin: elem_end]
// NOTE low performance in slice seq_start_positions_.
// TODO should call Tensor's Slice.
LODTensor
LODSlice
(
int
level
,
int
elem_begin
,
int
elem_end
)
const
;
// slice with tensor's data shared with this.
LODTensor
LODSliceShared
(
int
level
,
int
elem_begin
,
int
elem_end
)
const
;
// copy other's lod_start_pos_, to share LOD info.
// NOTE the LOD info sould not be changed.
void
ShareConstLODFrom
(
const
LODTensor
&
other
)
{
lod_start_pos_
=
other
.
lod_start_pos_
;
}
// copy other's lod_start_pos_'s content, free to mutate.
void
ShareMutableLODFrom
(
const
LODTensor
&
other
)
{
lod_start_pos_
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
int
>>
(
other
.
lod_start_pos_
.
begin
(),
other
.
lod_start_pos_
.
end
());
}
private:
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int
>>>
lod_start_pos_
;
};
```
其中,
`lod_start_pos_`
使用了
`shared_ptr`
来减少存储和复制的代价,
可以认为
`LODTensor`
是
`Tensor`
的扩展,几乎完全兼容原始
`Tensor`
的使用。
## 框架支持
### 框架现有的 `Tensor` 调用替换为 `LODTensor`
为了实现
`LODTensor`
的传递,框架里很多
`Tensor`
都需要变成
`LODTensor`
,
简单实现,直接
**把之前所有的`Tensor` 全部替换成 `LODTensor`,这里可以直接修改 `pybind.cc` 里面创建`Tensor`的接口**
。
此外,用户有可能需要感知序列的存在(比如序列的可视化需要解析模型中输出的序列),因此一些序列操作的API也需要暴露到 python 层。
### `lod_start_pos` 随着Op调用链传递
框架需要支持下列特性,以实现
`lod_start_pos`
的传递:
1.
以
`shared_ptr`
的方式实现传递
-
不修改
`lod_start_pos`
内容的作为 consumer
-
修改
`lod_start_pos`
的作为 producer
-
约定 consumer 只需要复制传递过来的
`shared_ptr`
-
producer 需要创建自己的独立的内存,以存储自己独立的修改,并暴露
`shared_ptr`
给后续 consumer
-
由于传递过程是以复制
`shared_ptr`
的方式实现,因此框架只需要传递一次
`lod_start_pos`
2.
对于不感知
`lod_start_pos`
的Op足够透明
3.
需要修改
`lod_start_pos`
的producer Op可以在
`Run`
时更新自己的
`lod_start_pos`
数据
具体的设计分为以下3小节
#### `load_start_pos` 的传递
-
对于不需要修改
`lod_start_pos`
的情况,调用 LODTensor的
`ShareConstLODFrom`
接口实现复制
-
需要修改的,调用
`ShareMutableLODFrom`
接口自己分配内存以存储修改
#### 框架透明
传递这一步需要加入到网络跑之前的初始化操作中,并且只需要初始化一次,基于当前框架设计的初步方案如下
-
在 Op 的
`attrs`
中添加一项
`do_mutate_lod_info`
的属性,默认为
`false`
-
有需要修改
`lod_start_pos`
的Op需要在定义
`OpProto`
时设置为
`true`
-
`OperatorBase`
的
`InferShape`
中会读取
`do_mutate_lod_info`
,并且调用
`LODTensor`
相关的方法实现
`lod_start_pos`
的复制。
-
`OperatorBase`
中添加一个 member
`is_lod_inited{false}`
来保证传递只进行一次
一些逻辑如下
```
c++
class
OperatorBase
{
public:
// ...
void
InferShape
()
{
if
(
!
is_load_inited
)
{
bool
do_mutate_lod_info
=
GetAttr
<
bool
>
(
"do_mutate_load_info"
);
// find a input having LOD to copy
auto
lod_input
=
ValidLODInput
();
for
(
auto
&
output
:
outputs
)
{
if
(
do_mutate_load_info
)
{
output
.
ShareMutableLODFrom
(
lod_input
);
}
else
{
output
.
ShareConstLODFrom
(
load_input
);
}
}
is_pod_inited
=
true
;
}
// call op's InferShape
// ...
}
private:
// ...
bool
is_lod_inited
{
false
};
};
```
如此,
`lod_start_pos`
的信息的传递对非OLD的Op的实现是完全透明的。
#### `lod_start_pos` 的更新
上一小节介绍到,对于需要修改
`load_start_pos`
的Op,
`OperatorBase`
会分配一块自己的内存以存储修改,
Op在
`Run`
的实现中,操作更新自己的
`load_start_pos`
,
而所有依赖其 outputs 的 op 会通过共享的指针自动获取到其更新。
## 根据长度排序
按照长度排序后,从前往后的时间步的batch size会自然地递减,可以直接塞入 Net 做batch计算
比如原始的输入:
```
origin:
xxxx
xx
xxx
-> sorted:
xxxx
xxx
xx
```
经过
`SegmentInputs`
之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)
```
0 1 2 3
x x x x
x x x
x x
```
为了追踪排序前后序列的变化,这里用
```
c++
struct
SortedSeqItem
{
void
*
start
{
nullptr
};
void
*
end
{
nullptr
};
};
std
::
vector
<
SortedSeqItem
>
sorted_seqs
;
```
来追踪序列排序后的位置,并添加一个新的接口
```
c++
std
::
vector
<
SortedSeqItem
>
SortBySeqLen
(
const
LODTensor
&
tensor
);
```
由于输入序列的顺序变化,以下现有的接口需要针对性地修改:
-
InitMemories, memory需要根据
`sorted_seqs`
重新排列
-
SetmentInputs
-
ConcatOutputs
此外,由于
`sorted_seqs`
需要被
`RecurrentGradientOp`
复用,因此会变成
`RecurrentOp`
一个新的output输出,
之后作为
`RecurrentGradientOp`
的一个输入传入。
## InitMemories
由于序列顺序的变化,
`boot_memories`
的batch上的element的顺序也需要对应重新排列。
## SegmentInputs
`SegmentInputs`
会依赖
`sorted_seqs`
的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。
即下面的转变:
```
origin:
xxxx
xx
xxx
|
|
\ /
!
0 1 2 3
x x x x
x x x
x x
```
## ConcatOutputs
`ConcatOutputs`
需要
-
将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
-
将每个序列concat 为规则的mini-batch表示
## 参考文献
1.
[
Tensorflow Bucketing
](
https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing
)
2.
[
mxnet Bucketing
](
http://mxnet.io/how_to/bucketing.html
)
3.
[
variable length input in RNN scenario
](
https://discuss.pytorch.org/t/about-the-variable-length-input-in-rnn-scenario/345/5
)
4.
[
Level of details
](
https://en.wikipedia.org/wiki/Level_of_detail
)
paddle/operators/rowwise_add_op.cc
浏览文件 @
4cc42171
...
@@ -13,15 +13,13 @@
...
@@ -13,15 +13,13 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
RowWiseAddOp
:
public
framework
::
OperatorWithKernel
{
class
RowWiseAddOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2UL
,
"Two inputs is needed by rowwise add"
);
PADDLE_ENFORCE
(
inputs
.
size
()
==
2UL
,
"Two inputs is needed by rowwise add"
);
auto
dim0
=
inputs
[
0
]
->
dims
();
auto
dim0
=
inputs
[
0
]
->
dims
();
auto
dim1
=
inputs
[
1
]
->
dims
();
auto
dim1
=
inputs
[
1
]
->
dims
();
...
@@ -34,11 +32,10 @@ protected:
...
@@ -34,11 +32,10 @@ protected:
}
}
};
};
class
RowWiseAddOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
RowWiseAddOpMaker
(
framework
::
OpProto
*
proto
,
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The left input of row-wise add op, must be matrix"
);
AddInput
(
"X"
,
"The left input of row-wise add op, must be matrix"
);
AddInput
(
"b"
,
"The right input of row-wise add op, must be vector"
);
AddInput
(
"b"
,
"The right input of row-wise add op, must be vector"
);
AddOutput
(
"Out"
,
"The output of row-wise add op"
);
AddOutput
(
"Out"
,
"The output of row-wise add op"
);
...
@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]):
...
@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]):
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
rowwise_add
,
REGISTER_OP
(
rowwise_add
,
ops
::
RowWiseAddOp
,
ops
::
RowWiseAddOpMaker
);
paddle
::
operators
::
RowWiseAddOp
,
REGISTER_OP_CPU_KERNEL
(
rowwise_add
,
paddle
::
operators
::
RowWiseAddOpMaker
);
ops
::
RowWiseAddKernel
<
ops
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
rowwise_add
,
paddle
::
operators
::
RowWiseAddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/rowwise_add_op.cu
浏览文件 @
4cc42171
#include "paddle/framework/op_registry.h"
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL
(
REGISTER_OP_GPU_KERNEL
(
rowwise_add
,
rowwise_add
,
ops
::
RowWiseAddKernel
<
ops
::
GPUPlace
,
float
>
);
paddle
::
operators
::
RowWiseAddKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/rowwise_add_op.h
浏览文件 @
4cc42171
...
@@ -13,25 +13,23 @@
...
@@ -13,25 +13,23 @@
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
RowWiseAddKernel
:
public
framework
::
OpKernel
{
class
RowWiseAddKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
in0
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
in0
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
in1
=
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
();
auto
in1
=
context
.
Input
(
1
)
->
Get
<
Tensor
>
();
auto
*
out
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
out
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
input
=
framework
::
EigenMatrix
<
T
>::
From
(
in0
);
auto
input
=
EigenMatrix
<
T
>::
From
(
in0
);
auto
bias
=
framework
::
EigenVector
<
T
>::
From
(
in1
);
auto
bias
=
EigenVector
<
T
>::
From
(
in1
);
auto
output
=
framework
::
EigenMatrix
<
T
>::
From
(
*
out
);
auto
output
=
EigenMatrix
<
T
>::
From
(
*
out
);
const
int
bias_size
=
bias
.
dimension
(
0
);
const
int
bias_size
=
bias
.
dimension
(
0
);
const
int
rest_size
=
input
.
size
()
/
bias_size
;
const
int
rest_size
=
input
.
size
()
/
bias_size
;
...
...
paddle/operators/sgd_op.cc
浏览文件 @
4cc42171
...
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
...
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/sgd_op.h"
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
SGDOp
:
public
framework
::
OperatorWithKernel
{
class
SGDOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
inputs
.
size
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of SGDOp must be one"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of SGDOp must be one"
);
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
,
"inputs[0] mast be set"
);
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
,
"inputs[0] mast be set"
);
...
@@ -35,10 +32,10 @@ protected:
...
@@ -35,10 +32,10 @@ protected:
}
}
};
};
class
SGDOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SGDOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
SGDOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
SGDOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"param"
,
"input parameter"
);
AddInput
(
"param"
,
"input parameter"
);
AddInput
(
"grad"
,
"input gradient"
);
AddInput
(
"grad"
,
"input gradient"
);
AddOutput
(
"param_out"
,
"output parameter"
);
AddOutput
(
"param_out"
,
"output parameter"
);
...
@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad;
...
@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
sgd
,
paddle
::
operators
::
SGDOp
,
paddle
::
operators
::
SGDOpMaker
);
REGISTER_OP
(
sgd
,
ops
::
SGDOp
,
ops
::
SGDOpMaker
);
typedef
paddle
::
operators
::
SGDOpKernel
<::
paddle
::
platform
::
CPUPlace
,
float
>
REGISTER_OP_CPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
ops
::
CPUPlace
,
float
>
);
SGDOpKernel_CPU_float
;
REGISTER_OP_CPU_KERNEL
(
sgd
,
SGDOpKernel_CPU_float
);
paddle/operators/sgd_op.cu
浏览文件 @
4cc42171
#include "paddle/operators/sgd_op.h"
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
typedef
paddle
::
operators
::
SGDOpKernel
<::
paddle
::
platform
::
GPUPlace
,
float
>
SGDOpKernel_GPU_float
;
REGISTER_OP_GPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
ops
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
sgd
,
SGDOpKernel_GPU_float
);
\ No newline at end of file
\ No newline at end of file
paddle/operators/sgd_op.h
浏览文件 @
4cc42171
...
@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and
...
@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SGDOpKernel
:
public
framework
::
OpKernel
{
class
SGDOpKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
KernelContext
&
ctx
)
const
override
{
void
Compute
(
const
KernelContext
&
ctx
)
const
override
{
auto
param
=
ctx
.
Input
(
"param"
)
->
Get
<
framework
::
Tensor
>
();
auto
param
=
ctx
.
Input
(
"param"
)
->
Get
<
Tensor
>
();
auto
grad
=
ctx
.
Input
(
"grad"
)
->
Get
<
framework
::
Tensor
>
();
auto
grad
=
ctx
.
Input
(
"grad"
)
->
Get
<
Tensor
>
();
auto
*
param_out
=
ctx
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
param_out
=
ctx
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
float
lr
=
ctx
.
op_
.
GetAttr
<
float
>
(
"learning_rate"
);
float
lr
=
ctx
.
op_
.
GetAttr
<
float
>
(
"learning_rate"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
)
EigenVector
<
T
>::
Flatten
(
*
param_out
).
device
(
*
(
ctx
.
GetEigenDevice
<
Place
>
()))
=
.
device
(
*
(
ctx
.
GetEigenDevice
<
Place
>
()))
=
EigenVector
<
T
>::
Flatten
(
param
)
-
lr
*
EigenVector
<
T
>::
Flatten
(
grad
);
framework
::
EigenVector
<
T
>::
Flatten
(
param
)
-
lr
*
framework
::
EigenVector
<
T
>::
Flatten
(
grad
);
}
}
};
};
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
4cc42171
...
@@ -13,37 +13,33 @@
...
@@ -13,37 +13,33 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
SigmoidOp
:
public
framework
::
OperatorWithKernel
{
class
SigmoidOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Sigmoid Op only have one output"
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Sigmoid Op only have one output"
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
}
}
};
};
class
SigmoidOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SigmoidOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
SigmoidOpMaker
(
framework
::
OpProto
*
proto
,
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"sigmoid input"
);
AddInput
(
"X"
,
"sigmoid input"
);
AddOutput
(
"Y"
,
"sigmoid output"
);
AddOutput
(
"Y"
,
"sigmoid output"
);
AddComment
(
"Sigmoid function"
);
AddComment
(
"Sigmoid function"
);
}
}
};
};
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
class
SigmoidOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
LOG
(
INFO
)
<<
"SigmoidGrad"
;
return
""
;
return
""
;
...
@@ -53,11 +49,7 @@ protected:
...
@@ -53,11 +49,7 @@ protected:
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
sigmoid
,
REGISTER_OP
(
sigmoid
,
ops
::
SigmoidOp
,
ops
::
SigmoidOpMaker
);
paddle
::
operators
::
SigmoidOp
,
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
ops
::
SigmoidOpGrad
);
paddle
::
operators
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
paddle
::
operators
::
SigmoidOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
ops
::
CPUPlace
,
float
>
);
sigmoid
,
paddle
::
operators
::
SigmoidKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/sigmoid_op.cu
浏览文件 @
4cc42171
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL
(
REGISTER_OP_GPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
ops
::
GPUPlace
,
float
>
);
sigmoid
,
paddle
::
operators
::
SigmoidKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/sigmoid_op.h
浏览文件 @
4cc42171
...
@@ -14,25 +14,23 @@
...
@@ -14,25 +14,23 @@
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SigmoidKernel
:
public
framework
::
OpKernel
{
class
SigmoidKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
input
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
device
(
EigenVector
<
T
>::
Flatten
(
*
output
).
device
(
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
*
(
context
.
GetEigenDevice
<
Place
>
()))
=
1.0
/
(
1.0
+
(
-
1.0
*
framework
::
EigenVector
<
T
>::
Flatten
(
input
)).
exp
());
1.0
/
(
1.0
+
(
-
1.0
*
EigenVector
<
T
>::
Flatten
(
input
)).
exp
());
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/operators/softmax_op.cc
浏览文件 @
4cc42171
...
@@ -12,16 +12,14 @@
...
@@ -12,16 +12,14 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/softmax_op.h"
#include "paddle/operators/softmax_op.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
SoftmaxOp
:
public
framework
::
OperatorWithKernel
{
class
SoftmaxOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Only one input is need for softmax"
);
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Only one input is need for softmax"
);
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
().
size
()
==
2
,
PADDLE_ENFORCE
(
inputs
[
0
]
->
dims
().
size
()
==
2
,
"The input of softmax op must be matrix"
);
"The input of softmax op must be matrix"
);
...
@@ -31,10 +29,9 @@ protected:
...
@@ -31,10 +29,9 @@ protected:
}
}
};
};
class
SoftmaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SoftmaxOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
SoftmaxOpMaker
(
framework
::
OpProto
*
proto
,
SoftmaxOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"input of softmax"
);
AddInput
(
"X"
,
"input of softmax"
);
AddOutput
(
"Y"
,
"output of softmax"
);
AddOutput
(
"Y"
,
"output of softmax"
);
...
@@ -42,11 +39,10 @@ public:
...
@@ -42,11 +39,10 @@ public:
}
}
};
};
class
SoftmaxOpGrad
:
public
framework
::
OperatorWithKernel
{
class
SoftmaxOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
)
const
override
{}
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SoftmaxOpGrad"
;
LOG
(
INFO
)
<<
"SoftmaxOpGrad"
;
return
""
;
return
""
;
...
@@ -56,9 +52,6 @@ protected:
...
@@ -56,9 +52,6 @@ protected:
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_GRADIENT_OP
(
softmax
,
softmax_grad
,
paddle
::
operators
::
SoftmaxOpGrad
);
REGISTER_GRADIENT_OP
(
softmax
,
softmax_grad
,
ops
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
ops
::
CPUPlace
,
float
>
);
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_op.cu
浏览文件 @
4cc42171
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"
#include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL
(
REGISTER_OP_GPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
ops
::
GPUPlace
,
float
>
);
softmax
,
paddle
::
operators
::
SoftmaxKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/softmax_op.h
浏览文件 @
4cc42171
...
@@ -14,23 +14,21 @@
...
@@ -14,23 +14,21 @@
#pragma once
#pragma once
#include "glog/logging.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SoftmaxKernel
:
public
framework
::
OpKernel
{
class
SoftmaxKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
input
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input
=
context
.
Input
(
0
)
->
Get
<
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
Tensor
>
();
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
logits
=
framework
::
EigenMatrix
<
T
>::
From
(
input
);
auto
logits
=
EigenMatrix
<
T
>::
From
(
input
);
auto
softmax
=
framework
::
EigenMatrix
<
T
>::
From
(
*
output
);
auto
softmax
=
EigenMatrix
<
T
>::
From
(
*
output
);
const
int
kBatchDim
=
0
;
const
int
kBatchDim
=
0
;
const
int
kClassDim
=
1
;
const
int
kClassDim
=
1
;
...
...
paddle/operators/type_alias.h
0 → 100644
浏览文件 @
4cc42171
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
OpKernel
=
framework
::
OpKernel
;
using
KernelContext
=
framework
::
KernelContext
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Tensor
=
framework
::
Tensor
;
using
OperatorWithKernel
=
framework
::
OperatorWithKernel
;
using
OpProtoAndCheckerMaker
=
framework
::
OpProtoAndCheckerMaker
;
using
OpProto
=
framework
::
OpProto
;
using
OpAttrChecker
=
framework
::
OpAttrChecker
;
using
CPUPlace
=
platform
::
CPUPlace
;
using
GPUPlace
=
platform
::
GPUPlace
;
using
NetOp
=
framework
::
NetOp
;
using
OpRegistry
=
framework
::
OpRegistry
;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
paddle/platform/device_context.cc
浏览文件 @
4cc42171
...
@@ -20,12 +20,101 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
...
@@ -20,12 +20,101 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return
reinterpret_cast
<
const
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
return
reinterpret_cast
<
const
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
}
}
CPUDeviceContext
::
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
)
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
Eigen
::
DefaultDevice
*
CPUDeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Place
CPUDeviceContext
::
GetPlace
()
const
{
return
CPUPlace
();
}
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
template
<
>
template
<
>
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
const
{
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
const
{
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
}
}
#endif
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
// later. Please refer to the implementation of class EigenCudaStreamDevice
// in TensorFlow.
//
// We find that CUDA 7 introduces a new option, the per-thread default stream,
// that has two effects. Please refer to https://devblogs.nvidia.com/
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
//
// So, we decide to use default stream and add –default-stream per-thread nvcc
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
());
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
CUDADeviceContext
::~
CUDADeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
if
(
cublas_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
}
if
(
cudnn_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
if
(
curand_generator_
)
{
PADDLE_ENFORCE
(
dynload
::
curandDestroyGenerator
(
curand_generator_
));
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
));
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
CUDADeviceContext
::
cublas_handle
()
{
if
(
!
cublas_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
}
return
cublas_handle_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
{
if
(
!
cudnn_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
}
return
cudnn_handle_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
}
return
curand_generator_
;
}
#endif // PADDLE_ONLY_CPU
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/platform/device_context.h
浏览文件 @
4cc42171
...
@@ -39,14 +39,13 @@ class DeviceContext {
...
@@ -39,14 +39,13 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
class
CPUDeviceContext
:
public
DeviceContext
{
public:
public:
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
();
CPUDeviceContext
(
CPUPlace
);
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
Place
GetPlace
()
const
override
{
Place
GetPlace
()
const
override
;
Place
retv
=
CPUPlace
();
return
retv
;
}
private:
private:
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
...
@@ -54,119 +53,46 @@ class CPUDeviceContext : public DeviceContext {
...
@@ -54,119 +53,46 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
class
GPUPlaceGuard
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
GPUPlaceGuard
(
GPUPlace
new_place
)
:
previous_
(
GetCurrentDeviceId
())
{
explicit
CUDADeviceContext
(
GPUPlace
);
if
(
previous_
!=
new_place
)
{
virtual
~
CUDADeviceContext
();
paddle
::
platform
::
SetDeviceId
(
new_place
.
device
);
}
}
~
GPUPlaceGuard
()
{
paddle
::
platform
::
SetDeviceId
(
previous_
.
device
);
}
/*! \brief Wait for all operations completion in the stream. */
void
Wait
()
const
;
private:
/*! \brief Return place in the device context. */
GPUPlace
previous_
;
Place
GetPlace
()
const
override
;
};
class
CUDADeviceContext
:
public
DeviceContext
{
/*! \brief Return eigen device in the device context. */
public:
Eigen
::
GpuDevice
*
eigen_device
()
const
;
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
// clang-format off
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
/*! \brief Return cublas handle in the device context. */
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
(
&
stream_
));
cublasHandle_t
cublas_handle
();
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
Place
GetPlace
()
const
override
{
Place
retv
=
GPUPlace
();
/*! \brief Return curand handle in the device context. */
return
retv
;
curandGenerator_t
curand_generator
();
}
// clang-format on
void
Wait
()
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
{
return
stream_
;
}
Eigen
::
GpuDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasCreate
(
&
blas_handle_
),
"cublasCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
),
"cublasSetStream failed"
);
}
return
blas_handle_
;
}
cudnnHandle_t
cudnn_handle
()
{
if
(
!
dnn_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnCreate
(
&
dnn_handle_
),
"cudnnCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
),
"cudnnSetStream failed"
);
}
return
dnn_handle_
;
}
curandGenerator_t
curand_generator
()
{
if
(
!
rand_generator_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
),
"curandCreateGenerator failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
rand_generator_
,
random_seed_
),
"curandSetPseudoRandomGeneratorSeed failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetStream
(
rand_generator_
,
stream_
),
"curandSetStream failed"
);
}
return
rand_generator_
;
}
~
CUDADeviceContext
()
{
Wait
();
if
(
blas_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasDestroy
(
blas_handle_
),
"cublasDestroy failed"
);
}
if
(
dnn_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnDestroy
(
dnn_handle_
),
"cudnnDestroy failed"
);
}
if
(
rand_generator_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandDestroyGenerator
(
rand_generator_
),
"curandDestroyGenerator failed"
);
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
private:
private:
GPUPlace
gpu_place_
;
GPUPlace
place_
;
cudaStream_t
stream_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
cublasHandle_t
blas_handle_
{
nullptr
};
private:
uint64_t
seed_
;
cudnnHandle_t
dnn_handle_
{
nullptr
};
int
random_seed_
;
// clang-format off
curandGenerator_t
rand_generator_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
=
nullptr
;
cublasHandle_t
cublas_handle_
=
nullptr
;
curandGenerator_t
curand_generator_
=
nullptr
;
// clang-format on
};
};
#endif
#endif
...
...
paddle/platform/enforce.h
浏览文件 @
4cc42171
...
@@ -36,6 +36,21 @@ limitations under the License. */
...
@@ -36,6 +36,21 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
struct
EnforceNotMet
:
public
std
::
exception
{
std
::
exception_ptr
exp_
;
std
::
string
err_str_
;
EnforceNotMet
(
std
::
exception_ptr
e
,
const
char
*
f
,
int
l
)
:
exp_
(
e
)
{
try
{
std
::
rethrow_exception
(
exp_
);
}
catch
(
const
std
::
exception
&
exp
)
{
err_str_
=
string
::
Sprintf
(
"%s at [%s:%d]"
,
exp
.
what
(),
f
,
l
);
}
}
const
char
*
what
()
const
noexcept
{
return
err_str_
.
c_str
();
}
};
// Because most enforce conditions would evaluate to true, we can use
// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
// always forces branch prediction of true.
...
@@ -43,18 +58,11 @@ namespace platform {
...
@@ -43,18 +58,11 @@ namespace platform {
// For more details, please check https://stackoverflow.com/a/43870188/724872.
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
template
<
typename
T
>
inline
void
throw_on_error
(
T
e
)
{
throw_on_error
(
e
,
""
);
}
template
<
typename
...
Args
>
template
<
typename
...
Args
>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
int
stat
,
const
Args
&
...
args
)
{
int
stat
,
const
Args
&
...
args
)
{
if
(
UNLIKELY
(
!
(
stat
)))
{
if
(
UNLIKELY
(
!
(
stat
)))
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
string
::
Sprintf
(
args
...));
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
}
}
}
}
...
@@ -64,12 +72,8 @@ template <typename... Args>
...
@@ -64,12 +72,8 @@ template <typename... Args>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
cudaError_t
e
,
const
Args
&
...
args
)
{
cudaError_t
e
,
const
Args
&
...
args
)
{
if
(
UNLIKELY
(
e
))
{
if
(
UNLIKELY
(
e
))
{
// clang-format off
throw
thrust
::
system_error
(
e
,
thrust
::
cuda_category
(),
throw
thrust
::
system_error
(
string
::
Sprintf
(
args
...));
e
,
thrust
::
cuda_category
(),
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
// clang-format on
}
}
}
}
...
@@ -77,12 +81,8 @@ template <typename... Args>
...
@@ -77,12 +81,8 @@ template <typename... Args>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
curandStatus_t
stat
,
const
Args
&
...
args
)
{
curandStatus_t
stat
,
const
Args
&
...
args
)
{
if
(
stat
!=
CURAND_STATUS_SUCCESS
)
{
if
(
stat
!=
CURAND_STATUS_SUCCESS
)
{
// clang-format off
throw
thrust
::
system_error
(
cudaErrorLaunchFailure
,
thrust
::
cuda_category
(),
throw
thrust
::
system_error
(
string
::
Sprintf
(
args
...));
cudaErrorLaunchFailure
,
thrust
::
cuda_category
(),
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
// clang-format on
}
}
}
}
...
@@ -92,12 +92,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
...
@@ -92,12 +92,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if
(
stat
==
CUDNN_STATUS_SUCCESS
)
{
if
(
stat
==
CUDNN_STATUS_SUCCESS
)
{
return
;
return
;
}
else
{
}
else
{
// clang-format off
throw
std
::
runtime_error
(
platform
::
dynload
::
cudnnGetErrorString
(
stat
)
+
throw
std
::
runtime_error
(
string
::
Sprintf
(
args
...));
platform
::
dynload
::
cudnnGetErrorString
(
stat
)
+
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
// clang-format on
}
}
}
}
...
@@ -126,22 +122,32 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
...
@@ -126,22 +122,32 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
}
else
if
(
stat
==
CUBLAS_STATUS_LICENSE_ERROR
)
{
}
else
if
(
stat
==
CUBLAS_STATUS_LICENSE_ERROR
)
{
err
=
"CUBLAS: license error, "
;
err
=
"CUBLAS: license error, "
;
}
}
throw
std
::
runtime_error
(
err
+
string
::
Sprintf
(
args
...)
+
throw
std
::
runtime_error
(
err
+
string
::
Sprintf
(
args
...));
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
}
}
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
throw_on_error
(
T
e
)
{
throw_on_error
(
e
,
""
);
}
#define PADDLE_THROW(...) \
#define PADDLE_THROW(...) \
do { \
do { \
throw std::runtime_error( \
throw ::paddle::platform::EnforceNotMet( \
string::Sprintf(__VA_ARGS__) + \
std::make_exception_ptr( \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)
} while (0)
#define PADDLE_ENFORCE(...) \
#define PADDLE_ENFORCE(...) \
do { \
do { \
try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0)
} while (0)
}
// namespace platform
}
// namespace platform
...
...
paddle/platform/enforce_test.cc
浏览文件 @
4cc42171
...
@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
...
@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool
in_catch
=
false
;
bool
in_catch
=
false
;
try
{
try
{
PADDLE_ENFORCE
(
false
,
"Enforce is not ok %d at all"
,
123
);
PADDLE_ENFORCE
(
false
,
"Enforce is not ok %d at all"
,
123
);
}
catch
(
const
std
::
runtime_error
&
error
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
// your error handling code here
// your error handling code here
in_catch
=
true
;
in_catch
=
true
;
std
::
string
msg
=
"Enforce is not ok 123 at all"
;
std
::
string
msg
=
"Enforce is not ok 123 at all"
;
...
...
paddle/pybind/CMakeLists.txt
浏览文件 @
4cc42171
cc_library
(
paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
cc_library
(
paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op fc_op sgd_op cross_entropy_op
)
add_op fc_op sgd_op cross_entropy_op
recurrent_network_op
)
paddle/pybind/pybind.cc
浏览文件 @
4cc42171
...
@@ -38,6 +38,7 @@ USE_OP(mul);
...
@@ -38,6 +38,7 @@ USE_OP(mul);
USE_OP
(
sigmoid
);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
USE_OP
(
softmax
);
USE_OP
(
rowwise_add
);
USE_OP
(
rowwise_add
);
USE_OP_WITHOUT_KERNEL
(
recurrent_op
);
template
<
typename
ClassType
>
template
<
typename
ClassType
>
void
ExposeOperator
(
ClassType
&
m
)
{
void
ExposeOperator
(
ClassType
&
m
)
{
...
@@ -50,6 +51,11 @@ void ExposeOperator(ClassType& m) {
...
@@ -50,6 +51,11 @@ void ExposeOperator(ClassType& m) {
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
);
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
);
}
}
static
size_t
UniqueIntegerGenerator
()
{
static
std
::
atomic
<
size_t
>
generator
;
return
generator
.
fetch_add
(
1
);
}
PYBIND11_PLUGIN
(
core
)
{
PYBIND11_PLUGIN
(
core
)
{
py
::
module
m
(
"core"
,
"C++ core of PaddlePaddle"
);
py
::
module
m
(
"core"
,
"C++ core of PaddlePaddle"
);
...
@@ -103,6 +109,11 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -103,6 +109,11 @@ All parameter, weight, gradient are variables in Paddle.
[](
pd
::
Variable
&
self
)
->
pd
::
Tensor
*
{
[](
pd
::
Variable
&
self
)
->
pd
::
Tensor
*
{
return
self
.
GetMutable
<
pd
::
Tensor
>
();
return
self
.
GetMutable
<
pd
::
Tensor
>
();
},
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_net"
,
[](
pd
::
Variable
&
self
)
->
pd
::
NetOp
*
{
return
self
.
GetMutable
<
pd
::
NetOp
>
();
},
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
);
py
::
class_
<
pd
::
Scope
,
std
::
shared_ptr
<
pd
::
Scope
>>
(
m
,
"Scope"
)
py
::
class_
<
pd
::
Scope
,
std
::
shared_ptr
<
pd
::
Scope
>>
(
m
,
"Scope"
)
...
@@ -112,7 +123,8 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -112,7 +123,8 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"create_var"
,
.
def
(
"create_var"
,
&
pd
::
Scope
::
CreateVariable
,
&
pd
::
Scope
::
CreateVariable
,
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
)
.
def
(
"get_var_name"
,
&
pd
::
Scope
::
GetVariableName
);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
//! Python str. If you want a str object, you should cast them in Python.
...
@@ -166,24 +178,25 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -166,24 +178,25 @@ All parameter, weight, gradient are variables in Paddle.
});
});
ExposeOperator
(
operator_base
);
ExposeOperator
(
operator_base
);
using
PlainNetPtr
=
std
::
shared_ptr
<
pd
::
PlainNet
>
;
py
::
class_
<
pd
::
NetOp
,
std
::
shared_ptr
<
pd
::
NetOp
>>
net
(
m
,
"Net"
);
py
::
class_
<
pd
::
PlainNet
,
PlainNetPtr
>
plain_net
(
m
,
"PlainNet"
);
plain_net
net
.
def_static
(
"create"
,
.
def_static
(
"create"
,
[]()
->
std
::
shared_ptr
<
pd
::
NetOp
>
{
[]()
->
std
::
shared_ptr
<
pd
::
PlainNet
>
{
auto
retv
=
std
::
make_shared
<
pd
::
NetOp
>
();
auto
retv
=
std
::
make_shared
<
pd
::
PlainNet
>
();
retv
->
type_
=
"plain_net"
;
retv
->
type_
=
"plain_net"
;
return
retv
;
return
retv
;
})
})
.
def
(
"add_op"
,
&
pd
::
PlainNet
::
AddOp
)
.
def
(
"add_op"
,
&
pd
::
NetOp
::
AddOp
)
.
def
(
"add_op"
,
.
def
(
"add_op"
,
[](
PlainNetPtr
&
self
,
const
PlainNetPtr
&
plain_
net
)
->
void
{
[](
pd
::
NetOp
&
self
,
const
std
::
shared_ptr
<
pd
::
NetOp
>&
net
)
->
void
{
self
->
AddOp
(
std
::
static_pointer_cast
<
pd
::
OperatorBase
>
(
plain_
net
));
self
.
AddOp
(
std
::
static_pointer_cast
<
pd
::
OperatorBase
>
(
net
));
})
})
.
def
(
"complete_add_op"
,
&
pd
::
PlainNet
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
&
pd
::
NetOp
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
[](
PlainNetPtr
&
self
)
{
self
->
CompleteAddOp
();
});
.
def
(
"complete_add_op"
,
ExposeOperator
(
plain_net
);
[](
std
::
shared_ptr
<
pd
::
NetOp
>&
self
)
{
self
->
CompleteAddOp
();
});
ExposeOperator
(
net
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
return
m
.
ptr
();
return
m
.
ptr
();
}
}
paddle/trainer/NewRemoteParameterUpdater.cpp
浏览文件 @
4cc42171
...
@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init(
...
@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init(
sgdConfigV2
->
set_decay
(
paramConfig
.
decay_rate
());
sgdConfigV2
->
set_decay
(
paramConfig
.
decay_rate
());
optimizeConfigV2
.
set_lr_policy
(
paddle
::
OptimizerConfig
::
Const
);
optimizeConfigV2
.
set_lr_policy
(
paddle
::
OptimizerConfig
::
Const
);
auto
constlr
=
optimizeConfigV2
.
mutable_const_lr
();
auto
constlr
=
optimizeConfigV2
.
mutable_const_lr
();
if
(
paramConfig
.
has_learning_rate
())
{
constlr
->
set_learning_rate
(
paramConfig
.
learning_rate
());
constlr
->
set_learning_rate
(
paramConfig
.
learning_rate
());
}
else
{
constlr
->
set_learning_rate
(
trainerConfig_
.
learning_rate
());
}
if
(
trainerConfig_
.
algorithm
()
==
"sgd"
)
{
if
(
trainerConfig_
.
algorithm
()
==
"sgd"
)
{
optimizeConfigV2
.
set_optimizer
(
paddle
::
OptimizerConfig
::
SGD
);
optimizeConfigV2
.
set_optimizer
(
paddle
::
OptimizerConfig
::
SGD
);
// FIXME: config all algorithms
// FIXME: config all algorithms
...
...
paddle/utils/Error.h
浏览文件 @
4cc42171
...
@@ -126,9 +126,11 @@ public:
...
@@ -126,9 +126,11 @@ public:
}
}
/**
/**
* @brief operator bool, return True if there is something error.
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
*/
operator
bool
()
const
{
return
!
this
->
isOK
();
}
void
check
()
const
{
CHECK
(
this
->
isOK
())
<<
msg
();
}
/**
/**
* @brief isOK return True if there is no error.
* @brief isOK return True if there is no error.
...
@@ -136,13 +138,6 @@ public:
...
@@ -136,13 +138,6 @@ public:
*/
*/
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
void
check
()
const
{
CHECK
(
this
->
isOK
())
<<
msg
();
}
private:
private:
std
::
shared_ptr
<
std
::
string
>
msg_
;
std
::
shared_ptr
<
std
::
string
>
msg_
;
};
};
...
...
paddle/utils/tests/test_Error.cpp
浏览文件 @
4cc42171
...
@@ -18,17 +18,17 @@ limitations under the License. */
...
@@ -18,17 +18,17 @@ limitations under the License. */
TEST
(
Error
,
testAll
)
{
TEST
(
Error
,
testAll
)
{
paddle
::
Error
error
;
paddle
::
Error
error
;
ASSERT_
FALSE
(
error
);
ASSERT_
TRUE
(
error
.
isOK
()
);
error
=
paddle
::
Error
(
"I'm the error"
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_
TRUE
(
error
);
ASSERT_
FALSE
(
error
.
isOK
()
);
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
error
=
paddle
::
Error
(
"error2"
);
error
=
paddle
::
Error
(
"error2"
);
ASSERT_
TRUE
(
error
);
ASSERT_
FALSE
(
error
.
isOK
()
);
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
int
i
=
3
;
int
i
=
3
;
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
ASSERT_
TRUE
(
error3
);
ASSERT_
FALSE
(
error3
.
isOK
()
);
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
}
}
python/paddle/trainer/config_parser.py
浏览文件 @
4cc42171
...
@@ -2055,8 +2055,7 @@ class BatchNormLayer(LayerBase):
...
@@ -2055,8 +2055,7 @@ class BatchNormLayer(LayerBase):
# Automatically select cudnn_batch_norm for GPU and batch_norm for CPU.
# Automatically select cudnn_batch_norm for GPU and batch_norm for CPU.
# Also based on cudnn version.
# Also based on cudnn version.
use_cudnn
=
use_gpu
and
batch_norm_type
!=
"batch_norm"
and
\
use_cudnn
=
use_gpu
and
batch_norm_type
!=
"batch_norm"
and
\
((
not
parallel_nn
)
or
self
.
config
.
device
>
-
1
)
and
\
((
not
parallel_nn
)
or
self
.
config
.
device
>
-
1
)
cudnn_version
>=
4007
self
.
layer_type
=
"cudnn_batch_norm"
if
use_cudnn
else
"batch_norm"
self
.
layer_type
=
"cudnn_batch_norm"
if
use_cudnn
else
"batch_norm"
super
(
BatchNormLayer
,
self
).
__init__
(
super
(
BatchNormLayer
,
self
).
__init__
(
name
,
self
.
layer_type
,
0
,
inputs
=
inputs
,
**
xargs
)
name
,
self
.
layer_type
,
0
,
inputs
=
inputs
,
**
xargs
)
...
...
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
4cc42171
...
@@ -272,7 +272,7 @@ class ExtraLayerAttribute(object):
...
@@ -272,7 +272,7 @@ class ExtraLayerAttribute(object):
for
key
in
self
.
attr
:
for
key
in
self
.
attr
:
if
not
hasattr
(
self
,
'can_%s'
%
key
)
or
\
if
not
hasattr
(
self
,
'can_%s'
%
key
)
or
\
not
getattr
(
self
,
'can_%s'
%
key
):
not
getattr
(
self
,
'can_%s'
%
key
):
raise
NotImplementedError
(
"Layer %s
can
not support %s"
%
raise
NotImplementedError
(
"Layer %s
does
not support %s"
%
(
layer_name
,
key
))
(
layer_name
,
key
))
@
staticmethod
@
staticmethod
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
4cc42171
...
@@ -865,7 +865,7 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
...
@@ -865,7 +865,7 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
@
wrap_name_default
(
"embedding"
)
@
wrap_name_default
(
"embedding"
)
@
wrap_param_attr_default
()
@
wrap_param_attr_default
()
@
layer_support
(
ERROR_CLIPPING
)
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
def
embedding_layer
(
input
,
size
,
name
=
None
,
param_attr
=
None
,
layer_attr
=
None
):
def
embedding_layer
(
input
,
size
,
name
=
None
,
param_attr
=
None
,
layer_attr
=
None
):
"""
"""
Define a embedding Layer.
Define a embedding Layer.
...
@@ -1320,7 +1320,7 @@ def pooling_layer(input,
...
@@ -1320,7 +1320,7 @@ def pooling_layer(input,
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
act
=
SigmoidActivation
())
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
act
=
SigmoidActivation
())
@
wrap_act_default
(
param_names
=
[
"act"
,
'state_act'
],
act
=
TanhActivation
())
@
wrap_act_default
(
param_names
=
[
"act"
,
'state_act'
],
act
=
TanhActivation
())
@
wrap_name_default
(
"lstmemory"
)
@
wrap_name_default
(
"lstmemory"
)
@
layer_support
(
DROPOUT
)
@
layer_support
()
def
lstmemory
(
input
,
def
lstmemory
(
input
,
name
=
None
,
name
=
None
,
size
=
None
,
size
=
None
,
...
@@ -1429,7 +1429,7 @@ def lstmemory(input,
...
@@ -1429,7 +1429,7 @@ def lstmemory(input,
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
act
=
SigmoidActivation
())
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
act
=
SigmoidActivation
())
@
wrap_act_default
(
param_names
=
[
"act"
],
act
=
TanhActivation
())
@
wrap_act_default
(
param_names
=
[
"act"
],
act
=
TanhActivation
())
@
wrap_name_default
(
"gru"
)
@
wrap_name_default
(
"gru"
)
@
layer_support
(
DROPOUT
)
@
layer_support
()
def
grumemory
(
input
,
def
grumemory
(
input
,
size
=
None
,
size
=
None
,
name
=
None
,
name
=
None
,
...
@@ -1793,7 +1793,7 @@ def repeat_layer(input,
...
@@ -1793,7 +1793,7 @@ def repeat_layer(input,
@
wrap_name_default
(
"seqreshape"
)
@
wrap_name_default
(
"seqreshape"
)
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
layer_support
()
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
def
seq_reshape_layer
(
input
,
def
seq_reshape_layer
(
input
,
reshape_size
,
reshape_size
,
act
=
None
,
act
=
None
,
...
@@ -2703,7 +2703,7 @@ def img_cmrnorm_layer(input,
...
@@ -2703,7 +2703,7 @@ def img_cmrnorm_layer(input,
default_factory
=
lambda
_
:
ParamAttr
(
initial_mean
=
1.0
,
initial_std
=
0.
))
default_factory
=
lambda
_
:
ParamAttr
(
initial_mean
=
1.0
,
initial_std
=
0.
))
@
wrap_act_default
(
act
=
ReluActivation
())
@
wrap_act_default
(
act
=
ReluActivation
())
@
wrap_name_default
(
"batch_norm"
)
@
wrap_name_default
(
"batch_norm"
)
@
layer_support
(
DROPOUT
)
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
batch_norm_layer
(
input
,
def
batch_norm_layer
(
input
,
act
=
None
,
act
=
None
,
name
=
None
,
name
=
None
,
...
@@ -2783,15 +2783,6 @@ def batch_norm_layer(input,
...
@@ -2783,15 +2783,6 @@ def batch_norm_layer(input,
:return: LayerOutput object.
:return: LayerOutput object.
:rtype: LayerOutput
:rtype: LayerOutput
"""
"""
if
not
isinstance
(
act
,
ReluActivation
):
logger
.
log
(
logging
.
WARN
,
"%s is not recommend for batch normalization's activation, "
"maybe the relu is better"
%
act
.
name
)
if
not
isinstance
(
input
.
activation
,
LinearActivation
):
logger
.
log
(
logging
.
WARN
,
"The activation should be inside batch normalization, the "
"previous layer's activation may be Linear"
)
if
num_channels
is
None
:
if
num_channels
is
None
:
if
input
.
num_filters
is
not
None
:
if
input
.
num_filters
is
not
None
:
...
@@ -2861,7 +2852,7 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
...
@@ -2861,7 +2852,7 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
@
wrap_name_default
(
"addto"
)
@
wrap_name_default
(
"addto"
)
@
wrap_act_default
(
act
=
LinearActivation
())
@
wrap_act_default
(
act
=
LinearActivation
())
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
layer_support
(
DROPOUT
)
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
addto_layer
(
input
,
act
=
None
,
name
=
None
,
bias_attr
=
None
,
layer_attr
=
None
):
def
addto_layer
(
input
,
act
=
None
,
name
=
None
,
bias_attr
=
None
,
layer_attr
=
None
):
"""
"""
AddtoLayer.
AddtoLayer.
...
@@ -2940,7 +2931,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None):
...
@@ -2940,7 +2931,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None):
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_name_default
(
"concat"
)
@
wrap_name_default
(
"concat"
)
@
layer_support
()
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
concat_layer
(
input
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
def
concat_layer
(
input
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
"""
"""
Concat all input vector into one huge vector.
Concat all input vector into one huge vector.
...
@@ -3024,7 +3015,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
...
@@ -3024,7 +3015,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
@
wrap_name_default
(
"seqconcat"
)
@
wrap_name_default
(
"seqconcat"
)
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_act_default
(
act
=
IdentityActivation
())
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
wrap_bias_attr_default
(
has_bias
=
False
)
@
layer_support
()
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
seq_concat_layer
(
a
,
b
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
def
seq_concat_layer
(
a
,
b
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
bias_attr
=
None
):
"""
"""
...
@@ -3177,7 +3168,7 @@ def memory(name,
...
@@ -3177,7 +3168,7 @@ def memory(name,
@
wrap_act_default
(
param_names
=
[
'state_act'
],
act
=
TanhActivation
())
@
wrap_act_default
(
param_names
=
[
'state_act'
],
act
=
TanhActivation
())
@
wrap_act_default
(
act
=
TanhActivation
())
@
wrap_act_default
(
act
=
TanhActivation
())
@
wrap_name_default
(
'lstm_step'
)
@
wrap_name_default
(
'lstm_step'
)
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
@
layer_support
()
def
lstm_step_layer
(
input
,
def
lstm_step_layer
(
input
,
state
,
state
,
size
=
None
,
size
=
None
,
...
@@ -4480,7 +4471,7 @@ def tensor_layer(a,
...
@@ -4480,7 +4471,7 @@ def tensor_layer(a,
@
wrap_param_attr_default
()
@
wrap_param_attr_default
()
@
wrap_bias_attr_default
()
@
wrap_bias_attr_default
()
@
wrap_act_default
()
@
wrap_act_default
()
@
layer_support
()
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
selective_fc_layer
(
input
,
def
selective_fc_layer
(
input
,
size
,
size
,
select
=
None
,
select
=
None
,
...
...
python/paddle/v2/__init__.py
浏览文件 @
4cc42171
...
@@ -34,6 +34,7 @@ import minibatch
...
@@ -34,6 +34,7 @@ import minibatch
import
plot
import
plot
import
image
import
image
import
model
import
model
import
paddle.trainer.config_parser
as
cp
__all__
=
[
__all__
=
[
'optimizer'
,
'optimizer'
,
...
@@ -58,6 +59,8 @@ __all__ = [
...
@@ -58,6 +59,8 @@ __all__ = [
'model'
,
'model'
,
]
]
cp
.
begin_parse
()
def
init
(
**
kwargs
):
def
init
(
**
kwargs
):
import
py_paddle.swig_paddle
as
api
import
py_paddle.swig_paddle
as
api
...
@@ -73,6 +76,11 @@ def init(**kwargs):
...
@@ -73,6 +76,11 @@ def init(**kwargs):
for
key
in
args_dict
.
keys
():
for
key
in
args_dict
.
keys
():
args
.
append
(
'--%s=%s'
%
(
key
,
str
(
args_dict
[
key
])))
args
.
append
(
'--%s=%s'
%
(
key
,
str
(
args_dict
[
key
])))
if
'use_gpu'
in
kwargs
:
cp
.
g_command_config_args
[
'use_gpu'
]
=
kwargs
[
'use_gpu'
]
assert
'parallel_nn'
not
in
kwargs
,
(
"currently 'parallel_nn' is not "
"supported in v2 APIs."
)
api
.
initPaddle
(
*
args
)
api
.
initPaddle
(
*
args
)
...
...
python/paddle/v2/dataset/common.py
浏览文件 @
4cc42171
...
@@ -166,55 +166,37 @@ def cluster_files_reader(files_pattern,
...
@@ -166,55 +166,37 @@ def cluster_files_reader(files_pattern,
return
reader
return
reader
def
convert
(
output_path
,
def
convert
(
output_path
,
reader
,
line_count
,
name_prefix
):
reader
,
num_shards
,
name_prefix
,
max_lines_to_shuffle
=
1000
):
import
recordio
import
recordio
"""
"""
Convert data from reader to recordio format files.
Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved.
:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
"""
"""
assert
num_shards
>=
1
assert
line_count
>=
1
assert
max_lines_to_shuffle
>=
1
indx_f
=
0
def
open_writers
():
w
=
[]
for
i
in
range
(
0
,
num_shards
):
n
=
"%s/%s-%05d-of-%05d"
%
(
output_path
,
name_prefix
,
i
,
num_shards
-
1
)
w
.
append
(
recordio
.
writer
(
n
))
return
w
def
close_writers
(
w
):
for
i
in
range
(
0
,
num_shards
):
w
[
i
].
close
()
def
write_data
(
w
,
lines
):
def
write_data
(
indx_f
,
lines
):
random
.
shuffle
(
lines
)
random
.
shuffle
(
lines
)
for
i
,
d
in
enumerate
(
lines
):
filename
=
"%s/%s-%05d"
%
(
output_path
,
name_prefix
,
indx_f
)
writer
=
recordio
.
writer
(
filename
)
for
l
in
lines
:
# FIXME(Yancey1989):
# FIXME(Yancey1989):
# dumps with protocol: pickle.HIGHEST_PROTOCOL
# dumps with protocol: pickle.HIGHEST_PROTOCOL
o
=
pickle
.
dumps
(
d
)
writer
.
write
(
cPickle
.
dumps
(
l
)
)
w
[
i
%
num_shards
].
write
(
o
)
writer
.
close
(
)
w
=
open_writers
()
lines
=
[]
lines
=
[]
for
i
,
d
in
enumerate
(
reader
()):
for
i
,
d
in
enumerate
(
reader
()):
lines
.
append
(
d
)
lines
.
append
(
d
)
if
i
%
max_lines_to_shuffle
==
0
and
i
>=
max_lines_to_shuffle
:
if
i
%
line_count
==
0
and
i
>=
line_count
:
write_data
(
w
,
lines
)
write_data
(
indx_f
,
lines
)
lines
=
[]
lines
=
[]
indx_f
+=
1
continue
continue
write_data
(
w
,
lines
)
write_data
(
indx_f
,
lines
)
close_writers
(
w
)
python/paddle/v2/dataset/mq2007.py
浏览文件 @
4cc42171
...
@@ -242,9 +242,9 @@ def gen_list(querylist):
...
@@ -242,9 +242,9 @@ def gen_list(querylist):
if
not
isinstance
(
querylist
,
QueryList
):
if
not
isinstance
(
querylist
,
QueryList
):
querylist
=
QueryList
(
querylist
)
querylist
=
QueryList
(
querylist
)
querylist
.
_correct_ranking_
()
querylist
.
_correct_ranking_
()
relevance_score_list
=
[
query
.
relevance_score
for
query
in
querylist
]
relevance_score_list
=
[
[
query
.
relevance_score
]
for
query
in
querylist
]
feature_vector_list
=
[
query
.
feature_vector
for
query
in
querylist
]
feature_vector_list
=
[
query
.
feature_vector
for
query
in
querylist
]
yield
np
.
array
(
relevance_score_list
)
.
T
,
np
.
array
(
feature_vector_list
)
yield
np
.
array
(
relevance_score_list
),
np
.
array
(
feature_vector_list
)
def
query_filter
(
querylists
):
def
query_filter
(
querylists
):
...
...
python/paddle/v2/framework/create_op_creation_methods.py
浏览文件 @
4cc42171
...
@@ -220,6 +220,9 @@ def create_op_creation_method(op_proto):
...
@@ -220,6 +220,9 @@ def create_op_creation_method(op_proto):
__impl__
.
all_input_args
=
[
var
.
name
for
var
in
op_proto
.
inputs
]
__impl__
.
all_input_args
=
[
var
.
name
for
var
in
op_proto
.
inputs
]
__impl__
.
all_output_args
=
[
var
.
name
for
var
in
op_proto
.
outputs
]
__impl__
.
all_output_args
=
[
var
.
name
for
var
in
op_proto
.
outputs
]
__impl__
.
all_attr_args
=
[
attr
.
name
for
attr
in
op_proto
.
attrs
]
__impl__
.
all_attr_args
=
[
attr
.
name
for
attr
in
op_proto
.
attrs
]
__impl__
.
all_not_temp_output_args
=
[
var
.
name
for
var
in
op_proto
.
outputs
if
not
var
.
temporary
]
return
__impl__
return
__impl__
...
...
python/paddle/v2/framework/network.py
0 → 100644
浏览文件 @
4cc42171
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.create_op_creation_methods
import
op_creations
from
default_scope_funcs
import
create_var
,
get_var
,
get_cur_scope
__all__
=
[
'Network'
]
# Only expose Network
class
NetworkFunctor
(
object
):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def
__init__
(
self
,
func
,
net
):
self
.
func
=
func
self
.
net
=
net
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
!=
0
:
raise
ValueError
(
"Paddle must use keyword argument"
)
inputs
=
self
.
func
.
all_input_args
for
ipt
in
inputs
:
if
ipt
in
kwargs
:
var
=
kwargs
[
ipt
]
if
isinstance
(
var
,
basestring
):
var
=
create_var
(
var
)
if
not
isinstance
(
var
,
core
.
Variable
):
raise
TypeError
(
"Input of op creation must be string or variable"
)
kwargs
[
ipt
]
=
get_cur_scope
().
get_var_name
(
var
)
notemp_outputs
=
self
.
func
.
all_not_temp_output_args
for
name
in
notemp_outputs
:
if
name
not
in
kwargs
:
kwargs
[
name
]
=
self
.
func
.
__name__
+
"@OUT@%d"
%
core
.
unique_integer
(
)
outputs
=
self
.
func
.
all_output_args
for
opt
in
outputs
:
if
opt
in
kwargs
:
var
=
kwargs
[
opt
]
if
isinstance
(
var
,
basestring
):
var
=
create_var
(
var
)
if
not
isinstance
(
var
,
core
.
Variable
):
raise
TypeError
(
"Output of op creation must be string or variable"
)
kwargs
[
opt
]
=
get_cur_scope
().
get_var_name
(
var
)
op
=
self
.
func
(
**
kwargs
)
self
.
net
.
net
.
add_op
(
op
)
lst
=
[
get_var
(
kwargs
[
opt
])
for
opt
in
notemp_outputs
]
if
len
(
lst
)
==
1
:
return
lst
[
0
]
elif
len
(
lst
)
==
0
:
return
None
else
:
return
lst
class
Network
(
object
):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def
__init__
(
self
):
self
.
net
=
core
.
Net
.
create
()
funcs
=
(
func_name
for
func_name
in
dir
(
op_creations
)
if
not
func_name
.
startswith
(
"__"
))
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for
func_name
in
funcs
:
func
=
getattr
(
op_creations
,
func_name
)
impl
=
NetworkFunctor
(
func
,
self
)
setattr
(
self
,
func_name
,
impl
.
__call__
)
self
.
__complete_add_op__
=
False
def
infer_shape
(
self
):
self
.
complete_add_op
()
self
.
net
.
infer_shape
(
get_cur_scope
())
def
run
(
self
,
device_context
):
self
.
complete_add_op
()
self
.
net
.
run
(
get_cur_scope
(),
device_context
)
def
__str__
(
self
):
return
str
(
self
.
net
)
def
complete_add_op
(
self
):
if
not
self
.
__complete_add_op__
:
self
.
net
.
complete_add_op
()
self
.
__complete_add_op__
=
True
if
__name__
==
'__main__'
:
net
=
Network
()
out
=
net
.
add_two
(
X
=
"a"
,
Y
=
"b"
)
fc_out
=
net
.
fc
(
X
=
out
,
W
=
"fc.w"
,
b
=
"fc.b"
,
activation
=
"softmax"
)
net
.
complete_add_op
()
print
net
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
4cc42171
...
@@ -3,7 +3,7 @@ add_python_test(test_framework
...
@@ -3,7 +3,7 @@ add_python_test(test_framework
test_scope.py
test_scope.py
test_default_scope_funcs.py
test_default_scope_funcs.py
test_op_creation_methods.py
test_op_creation_methods.py
test_
plain_
net.py
test_net.py
test_tensor.py
test_tensor.py
test_fc_op.py
test_fc_op.py
test_add_two_op.py
test_add_two_op.py
...
@@ -12,4 +12,5 @@ add_python_test(test_framework
...
@@ -12,4 +12,5 @@ add_python_test(test_framework
test_mul_op.py
test_mul_op.py
test_sigmoid_op.py
test_sigmoid_op.py
test_softmax_op.py
test_softmax_op.py
test_rowwise_add_op.py
)
test_rowwise_add_op.py
test_network.py
)
python/paddle/v2/framework/tests/test_
plain_
net.py
→
python/paddle/v2/framework/tests/test_net.py
浏览文件 @
4cc42171
...
@@ -5,11 +5,11 @@ import unittest
...
@@ -5,11 +5,11 @@ import unittest
class
TestNet
(
unittest
.
TestCase
):
class
TestNet
(
unittest
.
TestCase
):
def
test_net_all
(
self
):
def
test_net_all
(
self
):
net
=
core
.
Plain
Net
.
create
()
net
=
core
.
Net
.
create
()
op1
=
op_creations
.
add_two
(
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
op1
=
op_creations
.
add_two
(
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
net
.
add_op
(
op1
)
net
.
add_op
(
op1
)
net2
=
core
.
Plain
Net
.
create
()
net2
=
core
.
Net
.
create
()
net2
.
add_op
(
op_creations
.
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
add_op
(
op_creations
.
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
complete_add_op
(
True
)
net2
.
complete_add_op
(
True
)
net
.
add_op
(
net2
)
net
.
add_op
(
net2
)
...
...
python/paddle/v2/framework/tests/test_network.py
0 → 100644
浏览文件 @
4cc42171
from
paddle.v2.framework.network
import
Network
import
paddle.v2.framework.core
as
core
import
unittest
class
TestNet
(
unittest
.
TestCase
):
def
test_net_all
(
self
):
net
=
Network
()
out
=
net
.
add_two
(
X
=
"X"
,
Y
=
"Y"
)
fc_out
=
net
.
fc
(
X
=
out
,
W
=
"w"
)
net
.
complete_add_op
()
self
.
assertTrue
(
isinstance
(
fc_out
,
core
.
Variable
))
self
.
assertEqual
(
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
'''
,
str
(
net
))
net2
=
Network
()
tmp
=
net2
.
add_two
(
X
=
"X"
,
Y
=
"Y"
)
self
.
assertTrue
(
isinstance
(
tmp
,
core
.
Variable
))
net2
.
complete_add_op
()
self
.
assertEqual
(
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
'''
,
str
(
net2
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_recurrent_op.py
0 → 100644
浏览文件 @
4cc42171
import
paddle.v2.framework.core
as
core
import
unittest
import
numpy
as
np
import
paddle.v2.framework.create_op_creation_methods
as
creation
ops
=
creation
.
op_creations
def
create_tensor
(
scope
,
name
,
shape
):
tensor
=
scope
.
create_var
(
name
).
get_tensor
()
tensor
.
set_dims
(
shape
)
tensor
.
alloc_float
()
tensor
.
set
(
np
.
random
.
random
(
shape
))
return
tensor
class
TestRNN
(
unittest
.
TestCase
):
'''
Test RNNOp
equation:
h_t = \sigma (W x_t + U h_{t-1})
weights:
- W
- U
vars:
- x
memories:
- h
outputs:
- h
'''
def
init
(
self
):
input_dim
=
30
batch_size
=
50
weight_dim
=
15
self
.
scope
=
core
.
Scope
(
None
)
# create vars
create_tensor
(
self
.
scope
,
"x"
,
[
batch_size
,
input_dim
])
create_tensor
(
self
.
scope
,
"W"
,
[
input_dim
,
weight_dim
])
create_tensor
(
self
.
scope
,
"U"
,
[
weight_dim
,
weight_dim
])
create_tensor
(
self
.
scope
,
"h_boot"
,
[
batch_size
,
weight_dim
])
x_alias
=
"x@alias"
y_alias
=
"y@alias"
memory
=
"h@alias"
prememory
=
"h@pre"
output
=
"rnn_out"
output_alias
=
"rnn_out@alias"
# create step net
stepnet_var
=
self
.
scope
.
create_var
(
"stepnet"
)
stepnet
=
stepnet_var
.
get_net
()
# stepnet = core.Net.create()
x_fc_op
=
ops
.
fc
(
X
=
x_alias
,
W
=
"W"
,
Y
=
"Wx"
)
h_fc_op
=
ops
.
fc
(
X
=
prememory
,
W
=
"U"
,
Y
=
"Uh"
)
sum_op
=
ops
.
add_two
(
X
=
"Wx"
,
Y
=
"Uh"
,
Out
=
"sum"
)
sig_op
=
ops
.
sigmoid
(
X
=
"sum"
,
Y
=
memory
)
stepnet
.
add_op
(
x_fc_op
)
stepnet
.
add_op
(
h_fc_op
)
stepnet
.
add_op
(
sum_op
)
stepnet
.
add_op
(
sig_op
)
stepnet
.
complete_add_op
(
True
)
# create RNNOp
rnnop
=
ops
.
recurrent_op
(
# inputs
inlinks
=
[
"x"
],
boot_memories
=
[
"h_boot"
],
step_net
=
"stepnet"
,
# outputs
outlinks
=
[
output
],
step_scopes
=
"step_scopes"
,
# attributes
inlink_alias
=
[
"x@alias"
],
outlink_alias
=
[
output_alias
],
pre_memories
=
[
prememory
],
memories
=
[
memory
])
ctx
=
core
.
DeviceContext
.
cpu_context
()
rnnop
.
infer_shape
(
self
.
scope
)
rnnop
.
run
(
self
.
scope
,
ctx
)
def
test_recurrent
(
self
):
self
.
init
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/inference.py
浏览文件 @
4cc42171
...
@@ -35,6 +35,13 @@ class Inference(object):
...
@@ -35,6 +35,13 @@ class Inference(object):
name
=
param
.
getName
()
name
=
param
.
getName
()
assert
isinstance
(
val
,
api
.
Vector
)
assert
isinstance
(
val
,
api
.
Vector
)
val
.
copyFromNumpyArray
(
parameters
.
get
(
name
).
flatten
())
val
.
copyFromNumpyArray
(
parameters
.
get
(
name
).
flatten
())
# the setValueUpdated function is called in randomize, zeroMem,
# load function in paddle/parameter/Parameter.cpp. But in the
# inference mode, the setValueUpdated is never called, it will
# cause the parameter will not be dispatched
# in MultiGradientMachine for multi-GPU. So setValueUpdated is
# called here, but it's better to call this function in one place.
param
.
setValueUpdated
()
self
.
__gradient_machine__
=
gm
self
.
__gradient_machine__
=
gm
self
.
__data_types__
=
topo
.
data_type
()
self
.
__data_types__
=
topo
.
data_type
()
...
...
python/paddle/v2/layer.py
浏览文件 @
4cc42171
...
@@ -324,6 +324,3 @@ def parse_network(output_layers, extra_layers=None):
...
@@ -324,6 +324,3 @@ def parse_network(output_layers, extra_layers=None):
def
get_layer
(
name
):
def
get_layer
(
name
):
return
config_base
.
__layer_map__
.
get
(
name
)
return
config_base
.
__layer_map__
.
get
(
name
)
cp
.
begin_parse
()
python/paddle/v2/master/client.py
浏览文件 @
4cc42171
...
@@ -49,7 +49,6 @@ class client(object):
...
@@ -49,7 +49,6 @@ class client(object):
def
set_dataset
(
self
,
paths
):
def
set_dataset
(
self
,
paths
):
holder_type
=
ctypes
.
c_char_p
*
len
(
paths
)
holder_type
=
ctypes
.
c_char_p
*
len
(
paths
)
holder
=
holder_type
()
holder
=
holder_type
()
print
paths
for
idx
,
path
in
enumerate
(
paths
):
for
idx
,
path
in
enumerate
(
paths
):
c_ptr
=
ctypes
.
c_char_p
(
path
)
c_ptr
=
ctypes
.
c_char_p
(
path
)
holder
[
idx
]
=
c_ptr
holder
[
idx
]
=
c_ptr
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录