提交 4cc42171 编写于 作者: Q qijun

merge baidu/develop

...@@ -22,9 +22,11 @@ ...@@ -22,9 +22,11 @@
hooks: hooks:
- id: clang-formater - id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang - repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281 sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks: hooks:
- id: go-fmt - id: go-fmt
types: [go] types:
- go
- id: gometalinter - id: gometalinter
types: [go] types:
- go
...@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) ...@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" GIT_TAG "master"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) ...@@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
......
...@@ -19,6 +19,8 @@ import ( ...@@ -19,6 +19,8 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -68,6 +70,20 @@ func main() { ...@@ -68,6 +70,20 @@ func main() {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() {
log.Infoln("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Errorln(err)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
...@@ -84,8 +100,12 @@ func main() { ...@@ -84,8 +100,12 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
go func() {
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
}()
<-c
} }
...@@ -18,6 +18,8 @@ import ( ...@@ -18,6 +18,8 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"time" "time"
...@@ -33,7 +35,8 @@ func main() { ...@@ -33,7 +35,8 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
...@@ -53,7 +56,7 @@ func main() { ...@@ -53,7 +56,7 @@ func main() {
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
idx, err = e.Register(*port) idx, err = e.Register(*port)
candy.Must(err) candy.Must(err)
...@@ -67,6 +70,20 @@ func main() { ...@@ -67,6 +70,20 @@ func main() {
} }
} }
shutdown := func() {
log.Infoln("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
log.Errorln(sErr)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err) candy.Must(err)
...@@ -77,7 +94,11 @@ func main() { ...@@ -77,7 +94,11 @@ func main() {
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
candy.Must(err) candy.Must(err)
go func() {
log.Infof("start pserver at port %d", *port) log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err) candy.Must(err)
}()
<-c
} }
hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855 hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-07-11T10:04:40.786745417+08:00 updated: 2017-07-29T07:34:48.722757905+08:00
imports: imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: cb2a496c4ddd1c87a9f280e116649b599999ec79 version: c31bec0f29facff13f7c3e3d948e55dd6689ed42
subpackages: subpackages:
- alarm
- auth
- auth/authpb - auth/authpb
- client
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes - etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb - etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb - mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf - name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7 version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages: subpackages:
...@@ -17,14 +107,61 @@ imports: ...@@ -17,14 +107,61 @@ imports:
- proto - proto
- name: github.com/golang/snappy - name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9 version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag - name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04 version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1 version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy - name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
...@@ -36,11 +173,15 @@ imports: ...@@ -36,11 +173,15 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: abf9c25f54453410d0c6668e519582a9e1115027 version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform
...@@ -60,4 +201,23 @@ imports: ...@@ -60,4 +201,23 @@ imports:
- stats - stats
- tap - tap
- transport - transport
testImports: [] - name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
...@@ -6,8 +6,19 @@ import: ...@@ -6,8 +6,19 @@ import:
subpackages: subpackages:
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag - package: github.com/namsral/flag
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy - package: github.com/topicai/candy
- package: golang.org/x/crypto
vcs: git
repo: https://github.com/golang/crypto.git
- package: golang.org/x/sys
vcs: git
repo: https://github.com/golang/sys.git
- package: golang.org/x/text
vcs: git
repo: https://github.com/golang/text.git
...@@ -18,7 +18,6 @@ package main ...@@ -18,7 +18,6 @@ package main
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1 #define PADDLE_MASTER_ERROR -1
...@@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) { ...@@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client) remove(client)
} }
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset //export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client) c := get(client)
...@@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// paddle_next_record gets the nexts training record. // paddle_next_record gets the nexts training record.
// //
// returns number of bytes of the records if success, -1 if failed. // returns number of bytes of the records if success, -1 if failed, -2 if pass end.
// //
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r, err := c.NextRecord() r, err := c.NextRecord()
if err != nil { if err != nil {
// Error // NOTE: use errors to indicate pass ends
// TODO: return the type of error? if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil) *record = (*C.uchar)(nil)
return -1 return -1
} }
......
...@@ -16,7 +16,6 @@ package master ...@@ -16,7 +16,6 @@ package master
import ( import (
"os" "os"
"sync"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
...@@ -29,7 +28,7 @@ import ( ...@@ -29,7 +28,7 @@ import (
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan record ch chan record
initChOnce sync.Once bufSize int
} }
type record struct { type record struct {
...@@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error { ...@@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if bufSize <= 0 { if bufSize <= 0 {
return nil return nil
} }
c.bufSize = bufSize
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
return nil return nil
} }
} }
...@@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) { ...@@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil return c, nil
} }
func (c *Client) getRecords() { // StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords(passID int) {
for { for {
t, err := c.getTask() t, err := c.getTask(passID)
if err != nil { if err != nil {
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err) if err.Error() == ErrPassBefore.Error() ||
time.Sleep(3 * time.Second) err.Error() == ErrNoMoreAvailable.Error() ||
err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue continue
} }
log.Errorf("getTask error: %s", err)
}
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if err != nil { if e != nil {
log.Errorln(err) log.Errorln(e)
continue continue
} }
...@@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) { ...@@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
} }
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
// //
// SetDataset can be call multiple times from different nodes. But // After all tasks are done, another call of SetDataset will start another pass.
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error { func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil) err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
} }
// getTask gets a new task from the master server. // getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) { func (c *Client) getTask(passID int) (Task, error) {
var t Task var t Task
err := c.conn.Call("Service.GetTask", 0, &t) err := c.conn.Call("Service.GetTask", passID, &t)
return t, err return t, err
} }
...@@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error { ...@@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() ([]byte, error) { func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch r := <-c.ch
return r.r, r.err return r.r, r.err
} }
......
...@@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) { ...@@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
server := rpc.NewServer() server := rpc.NewServer()
err = server.Register(s) sErr = server.Register(s)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server) mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux) sErr = http.Serve(l, mux)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
}(l) }(l)
...@@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) { ...@@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1) ch := make(chan string, 1)
ch <- addr ch <- addr
go c.monitorMaster(ch) go c.monitorMaster(ch)
err = c.SetDataset([]string{path}) err = c.SetDataset([]string{path})
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) { ...@@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask() task, cErr := c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("error: %v, pass: %d\n", cErr, i)
} }
tasks = append(tasks, task) tasks = append(tasks, task)
} }
_, err = c.getTask() // getting task before task finishes should return error
if err == nil { _, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].Meta.ID) cErr = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
// call taskFailed once won't put the task to failed queue, just ensure
err = c.taskFailed(tasks[0].Meta) // the call
if err != nil { cErr = c.taskFailed(tasks[0].Meta)
t.Fatalf("Error: %v, pass: %d\n", err, i) if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
tasks = tasks[1:] tasks = tasks[1:]
task, err := c.getTask() _, cErr = c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatal(err) t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
} }
tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.Meta.ID) cErr = c.taskFinished(task.Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatal(cErr)
} }
} }
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i) checkOnePass(i)
} }
} }
...@@ -20,8 +20,10 @@ import ( ...@@ -20,8 +20,10 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -29,6 +31,18 @@ import ( ...@@ -29,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) { func TestNextRecord(t *testing.T) {
const ( const (
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
...@@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) { ...@@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) { ...@@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
w := recordio.NewWriter(f, -1, -1) w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
_, err = w.Write([]byte{byte(i)}) _, err = w.Write([]byte{byte(i)})
if err != nil { if err != nil {
...@@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) { ...@@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10)) // start several client to test task fetching
if err != nil { var wg sync.WaitGroup
panic(err) for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
} }
e = c.SetDataset([]string{path})
err = c.SetDataset([]string{path}) if e != nil {
if err != nil { panic(e)
panic(err)
} }
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { taskid := 0
r, err := c.NextRecord() for {
if err != nil { r, e := c.NextRecord()
t.Fatal(pass, i, "Read error:", err) if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
} }
if len(r) != 1 { if len(r) != 1 {
t.Fatal(pass, i, "Length should be 1.", r) t.Fatal(pass, taskid, "Length should be 1.", r)
} }
if received[r[0]] { if received[r[0]] {
t.Fatal(pass, i, "Received duplicate.", received, r) t.Fatal(pass, taskid, "Received duplicate.", received, r)
} }
taskid++
received[r[0]] = true received[r[0]] = true
} }
} }
}()
}
wg.Wait()
} }
...@@ -39,15 +39,12 @@ type EtcdClient struct { ...@@ -39,15 +39,12 @@ type EtcdClient struct {
statePath string statePath string
client *clientv3.Client client *clientv3.Client
lock *concurrency.Mutex lock *concurrency.Mutex
sess *concurrency.Session
} }
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Because etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
...@@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath) log.Infof("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("Successfully acquired lock at %s.", lockPath) log.Infof("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, addr) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
...@@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath: statePath, statePath: statePath,
client: cli, client: cli,
lock: lock, lock: lock,
sess: sess,
} }
return e, nil return e, nil
...@@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) {
return state, nil return state, nil
} }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Errorln(newErr)
}
}
return err
}
// GetKey gets the value by the specify key. // GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) { func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
......
...@@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) { ...@@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return m.buf, nil return m.buf, nil
} }
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"compress/gzip" "compress/gzip"
"encoding/gob" "encoding/gob"
"errors" "errors"
"math/rand"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
...@@ -33,10 +34,23 @@ const ( ...@@ -33,10 +34,23 @@ const (
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
) )
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state. // Store is the interface for save and load the master state.
type Store interface { type Store interface {
Save([]byte) error Save([]byte) error
Load() ([]byte, error) Load() ([]byte, error)
Shutdown() error
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
...@@ -75,17 +89,26 @@ type Service struct { ...@@ -75,17 +89,26 @@ type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
failureMax int failureMax int
ready chan struct{}
store Store store Store
mu sync.Mutex ready chan struct{}
initDone bool initDone bool
mu sync.Mutex
taskQueues taskQueues taskQueues taskQueues
currPass int
jobTasks []taskEntry
savingTrainer string savingTrainer string
} }
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 // generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 { if chunksPerTask <= 0 {
chunksPerTask = 1 chunksPerTask = 1
} }
...@@ -95,7 +118,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -95,7 +118,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id cur.Task.Meta.ID = id
id++ counter++
id = timestamp + randStart + counter
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
} }
...@@ -266,19 +290,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error { ...@@ -266,19 +290,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return err return err
} }
s.taskQueues.Todo = partition(chunks, s.chunksPerTask) s.jobTasks = partition(chunks, s.chunksPerTask)
s.taskQueues.Todo = s.jobTasks
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
return err return err
} }
close(s.ready) close(s.ready)
s.initDone = true s.initDone = true
return nil return nil
} }
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) { func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch { if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the // new epoch, task launched after the
...@@ -302,8 +328,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -302,8 +328,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
return return
} }
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
} }
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
...@@ -331,37 +358,30 @@ func (s *Service) logFields() log.Fields { ...@@ -331,37 +358,30 @@ func (s *Service) logFields() log.Fields {
} }
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(_ int, task *Task) error { // passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select { select {
case <-s.ready: case <-s.ready:
} }
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if passID < s.currPass {
return ErrPassBefore
}
if passID > s.currPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 { if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 { log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
err := errors.New("all task failed") return ErrAllTaskFailed
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
} }
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.") log.WithFields(s.logFields()).Warningln("No more available task.")
return err return ErrNoMoreAvailable
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
...@@ -381,7 +401,7 @@ func (s *Service) GetTask(_ int, task *Task) error { ...@@ -381,7 +401,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
} }
// TaskFinished tell the service that a task is finished. // TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, _ *int) error { func (s *Service) TaskFinished(taskID int, dummy *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -401,11 +421,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error { ...@@ -401,11 +421,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { // increase master side pass count if all tasks finished
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") s.currPass++
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Todo = s.jobTasks
s.taskQueues.Done = nil s.taskQueues.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.taskQueues.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass)
} }
err := s.snapshot() err := s.snapshot()
...@@ -416,7 +439,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error { ...@@ -416,7 +439,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
} }
// TaskFailed tells the service that a task is failed. // TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, _ *int) error { func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
......
...@@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) { ...@@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.Meta.ID != i { // test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }
......
package master_test
import (
"os"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}
...@@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client ...@@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client
func add(c *client.Client) C.paddle_pserver_client { func add(c *client.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle cli := curHandle
curHandle++ curHandle++
handleMap[client] = c handleMap[cli] = c
return client return cli
} }
func get(client C.paddle_pserver_client) *client.Client { func get(client C.paddle_pserver_client) *client.Client {
......
...@@ -6,16 +6,19 @@ import cPickle as pickle ...@@ -6,16 +6,19 @@ import cPickle as pickle
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379" etcd_endpoint = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader(): def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint global master_client
master_client = master.client(etcd_endpoint, 5, 64)
master_client.set_dataset( master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"]) ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1: while 1:
r, e = master_client.next_record() r, e = master_client.next_record()
if not r: if not r:
if e != -2: # other errors
print "get record error:", e
break break
yield pickle.loads(r) yield pickle.loads(r)
...@@ -27,10 +30,12 @@ def main(): ...@@ -27,10 +30,12 @@ def main():
# network config # network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x, y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'), param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
size=1, size=1,
act=paddle.activation.Linear(), act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b')) bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y) cost = paddle.layer.mse_cost(input=y_predict, label=y)
...@@ -38,9 +43,8 @@ def main(): ...@@ -38,9 +43,8 @@ def main():
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
...@@ -51,6 +55,8 @@ def main(): ...@@ -51,6 +55,8 @@ def main():
# event_handler to print training and testing info # event_handler to print training and testing info
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % ( print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)
......
...@@ -34,16 +34,19 @@ const ( ...@@ -34,16 +34,19 @@ const (
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information // PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/" PsCheckpoint = "/checkpoints/"
retryTimeout = 5 * time.Second
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination. // tolerance, service registry and coordination.
type EtcdClient struct { type EtcdClient struct {
numPservers int numPservers int
etcdEndpoints string endpoints string
etcdClient *clientv3.Client client *clientv3.Client
// etcdTimeout is also used as retry intervals. sess *concurrency.Session
etcdTimeout time.Duration dialTimeout time.Duration
ttlSec int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string externalIP string
// desired number of pservers in the job. // desired number of pservers in the job.
...@@ -52,11 +55,12 @@ type EtcdClient struct { ...@@ -52,11 +55,12 @@ type EtcdClient struct {
} }
// NewEtcdClient creates an EtcdClient // NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
return &EtcdClient{ return &EtcdClient{
etcdTimeout: timeout, dialTimeout: dialtimeout,
ttlSec: ttlSec,
numPservers: numPservers, numPservers: numPservers,
etcdEndpoints: endpoints, endpoints: endpoints,
} }
} }
...@@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et ...@@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
// //
// Register returns the index of the current pserver. // Register returns the index of the current pserver.
func (e *EtcdClient) Register(port int) (int, error) { func (e *EtcdClient) Register(port int) (int, error) {
var err error var err error
e.externalIP, err = networkhelper.GetExternalIP() e.externalIP, err = networkhelper.GetExternalIP()
if err != nil { if err != nil {
...@@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) {
} }
// initialize connection to etcd. // initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",") ep := strings.Split(e.endpoints, ",")
for { for {
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: ep, Endpoints: ep,
DialTimeout: e.etcdTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue
}
e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
log.Errorf("create etcd session error: %v", err)
time.Sleep(retryTimeout)
continue continue
} }
e.etcdClient = cli e.sess = sess
log.Debugf("inited client to %s", e.etcdEndpoints) log.Debugf("inited client to %s", e.endpoints)
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
...@@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
...@@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) {
// wait and set s.desired init value // wait and set s.desired init value
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
} }
...@@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
...@@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) {
} }
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired) dsStr := c.Get(PsDesired)
if dsStr == "" { if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers)) c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
} }
return nil return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
...@@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) ( ...@@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
// registerPserverEtcd registers pserver node on etcd using transaction. // registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) { func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
...@@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er ...@@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" { if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info // find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID)) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished") log.Debug("register finished")
idx = i idx = i
registered = true registered = true
...@@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er ...@@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
// GetKey gets the value by the specified key // GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key) resp, err := e.client.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
...@@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ...@@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value)) _, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
cancel() cancel()
return err return err
} }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
var err error
if e.sess != nil {
err = e.sess.Close()
}
if e.client != nil {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
log.Errorln(newErr)
} else {
err = newErr
}
}
}
return err
}
...@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const { ...@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double Evaluator::getValue(const std::string name) const { double Evaluator::getValue(const std::string name) const {
paddle::Error err; paddle::Error err;
double v = m->rawPtr->getValue(name, &err); double v = m->rawPtr->getValue(name, &err);
if (err) { if (!err.isOK()) {
throw std::runtime_error(err.msg()); throw std::runtime_error(err.msg());
} }
return v; return v;
......
...@@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3) ...@@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory) cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
...@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. ...@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) cc_library(net SRCS net.cc DEPS op_registry)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
template <typename T>
inline const T* Tensor::data() const {
check_memory_size<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
template <typename T>
inline T* Tensor::data() {
check_memory_size<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place);
}
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */
size_t size = product(dims_) * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size));
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size));
}
#endif
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline void Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size<T>();
*this = src;
}
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::Place& dst_place) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
#endif
}
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
inline void Tensor::Resize(const DDim& dims) { dims_ = dims; }
inline const DDim& Tensor::dims() const { return dims_; }
} // namespace framework
} // namespace paddle
...@@ -20,17 +20,7 @@ ...@@ -20,17 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) { void NetOp::CompleteAddOp(bool calc) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
...@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) { ...@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index; attrs_["temporary_index"] = tmp_index;
} }
std::string PlainNet::DebugString() const { std::string NetOp::DebugString() const {
std::ostringstream os; std::ostringstream os;
os << OperatorBase::DebugString() << std::endl; os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const { ...@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str(); return os.str();
} }
bool NetOp::IsNetOp() const { return true; }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -37,21 +37,7 @@ namespace framework { ...@@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs * This is the base class of network, all the networks should implement the APIs
* it defines. * it defines.
*/ */
class Net : public OperatorBase { class NetOp : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
using NetPtr = std::shared_ptr<Net>;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
public: public:
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
...@@ -80,15 +66,17 @@ class PlainNet : public Net { ...@@ -80,15 +66,17 @@ class PlainNet : public Net {
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) override { void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op); ops_.push_back(op);
} }
void CompleteAddOp(bool calculate = true) override; void CompleteAddOp(bool calculate = true);
std::string DebugString() const override; std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
...@@ -100,7 +88,5 @@ class PlainNet : public Net { ...@@ -100,7 +88,5 @@ class PlainNet : public Net {
} }
}; };
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected, ...@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
} }
TEST(OpKernel, all) { TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>(); auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::make_shared<TestOp>();
...@@ -69,30 +69,23 @@ TEST(OpKernel, all) { ...@@ -69,30 +69,23 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
} }
// TODO(zhihong): add fc grad without registering. //! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestNoGradOp) { // TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<PlainNet>(); // auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr); // ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"}, // net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { // net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString(); // op->DebugString();
// } // }
// } //}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
syntax="proto2";
package paddle.framework;
import "op_proto.proto";
message NetDesc {
// network identification
optional string name = 1;
// operator contains in network
repeated OpProto operators = 2;
// network type to run with. e.g "plainNet", "DAG"
optional string net_type = 3;
// num worker always
optional int32 num_workers = 4;
}
...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif #endif
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr,
"Input Output Indices could not be nullptr");
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
if (attrs_.count("input_format") == 0) { if (attrs_.count("input_format") == 0) {
return inputs_[it->second]; return inputs_.at((size_t)it->second);
} else { } else {
const auto& input_format = GetAttr<std::vector<int>>("input_format"); const auto& input_format = GetAttr<std::vector<int>>("input_format");
int idx = input_format[it->second]; int idx = input_format[it->second];
return inputs_.at(idx); return inputs_.at((size_t)idx);
} }
} }
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(),
"Input Out Of Range");
return std::vector<std::string>{ return std::vector<std::string>{
inputs_.begin() + input_format.at(offset), inputs_.begin() + input_format.at(offset),
...@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
} }
const std::string& OperatorBase::Output(const std::string& name) const { const std::string& OperatorBase::Output(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
if (attrs_.count("output_format") == 0) { if (attrs_.count("output_format") == 0) {
return outputs_[it->second]; return outputs_.at((size_t)it->second);
} else { } else {
const auto& output_format = GetAttr<std::vector<int>>("output_format"); const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second]; int idx = output_format[it->second];
return outputs_.at(idx); return outputs_.at((size_t)idx);
} }
} }
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(),
"Output Out of Range");
return std::vector<std::string>{ return std::vector<std::string>{
outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)}; outputs_.begin() + output_format.at(offset + 1)};
......
...@@ -90,15 +90,17 @@ class OperatorBase { ...@@ -90,15 +90,17 @@ class OperatorBase {
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto` virtual bool IsNetOp() const { return false; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables. //! Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const; std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const; const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables. //! Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const; std::vector<std::string> Outputs(const std::string& name) const;
public: public:
...@@ -199,8 +201,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -199,8 +201,6 @@ class OperatorWithKernel : public OperatorBase {
place_ = dev_ctx.GetPlace(); place_ = dev_ctx.GetPlace();
} }
// bool operator==(const OpKernelKey& o) const { return place_ == o.place_;
// }
bool operator==(const OpKernelKey& o) const { bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_); return platform::places_are_same_class(place_, o.place_);
} }
......
...@@ -56,7 +56,9 @@ class Scope { ...@@ -56,7 +56,9 @@ class Scope {
if (var) { if (var) {
return var; return var;
} else { } else {
vars_[name] = std::unique_ptr<Variable>(new Variable()); auto ptr = new Variable();
name_to_var_[name] = std::unique_ptr<Variable>(ptr);
var_to_name_[ptr] = name;
return GetVariable(name); return GetVariable(name);
} }
} }
...@@ -68,8 +70,8 @@ class Scope { ...@@ -68,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found. * from it's parent scope. Return nullptr if not found.
*/ */
Variable* GetVariable(const std::string& name) const { Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name); auto it = name_to_var_.find(name);
if (it != vars_.end()) { if (it != name_to_var_.end()) {
return it->second.get(); return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->GetVariable(name);
...@@ -84,12 +86,21 @@ class Scope { ...@@ -84,12 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope * Find if there is a Variable in this scope and it's parent scope
*/ */
bool HasVariable(const std::string& name) const { bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() || return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name))); (parent_ && parent_->HasVariable(name)));
} }
std::string GetVariableName(Variable* const var) const {
try {
return var_to_name_.at(var);
} catch (...) {
return "";
}
}
private: private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<Variable*, std::string> var_to_name_;
std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_;
std::shared_ptr<Scope> parent_{nullptr}; std::shared_ptr<Scope> parent_{nullptr};
}; };
......
...@@ -40,6 +40,11 @@ TEST(Scope, Create) { ...@@ -40,6 +40,11 @@ TEST(Scope, Create) {
/// already exist. /// already exist.
Variable* var4 = scope->CreateVariable("a"); Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2); EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4));
Scope scope2;
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
} }
TEST(Scope, Parent) { TEST(Scope, Parent) {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/tensor.h> #include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace framework {} namespace framework {}
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
...@@ -32,9 +33,11 @@ template <bool less, size_t i, typename... args> ...@@ -32,9 +33,11 @@ template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl; struct CastToPyBufferImpl;
} // namespace details } // namespace details
} // namespace pybind } // namespace pybind
namespace framework { namespace framework {
class Tensor { class Tensor {
public:
template <bool less, size_t i, typename... args> template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl; friend struct paddle::pybind::details::CastToPyBufferImpl;
...@@ -47,151 +50,123 @@ class Tensor { ...@@ -47,151 +50,123 @@ class Tensor {
public: public:
Tensor() : offset_(0) {} Tensor() : offset_(0) {}
/*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
const T* data() const { inline T* data();
EnforceSufficientMemory<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
/*! Return a pointer to constant memory block. */
template <typename T> template <typename T>
T* data() { inline const T* data() const;
EnforceSufficientMemory<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T, // must be POD types /**
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> * @brief Return a pointer to mutable memory block.
T* mutable_data(DDim dims, platform::Place place) { * @note If not exist, then allocation.
Resize(dims); */
return mutable_data<T>(place); template <typename T>
} inline T* mutable_data(platform::Place place);
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
*
* @note If not exist, then allocation.
*/
template <typename T>
inline T* mutable_data(DDim dims, platform::Place place);
template <typename T, // must be POD types /*! Return the dimensions of the memory block. */
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> inline const DDim& dims() const;
T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
if (holder_ == nullptr ||
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < product(dims_) * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
#endif
} else {
PADDLE_THROW("Unknown 'place'.");
}
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T> /*! Resize the dimensions of the memory block. */
void ShareDataWith(const Tensor& src) { inline void Resize(const DDim& dims);
src.EnforceSufficientMemory<T>();
*this = src;
}
/*! The internal of two tensors share the same memory block. */
template <typename T> template <typename T>
void CopyFrom(const Tensor& src, platform::Place dst_place) { inline void ShareDataWith(const Tensor& src);
PADDLE_ENFORCE(platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support dst CPU now."); /**
size_t size = product(src.dims_) * sizeof(T); * @brief Copy the content of external tensor to a new place.
Resize(src.dims()); *
const void* src_ptr = static_cast<const void*>(src.data<T>()); * @param[in] src The external tensor.
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place)); * @param[in] ctx The device context contains place where to store.
if (paddle::platform::is_cpu_place(holder_->place())) { *
std::memcpy(dst_ptr, src_ptr, size); * @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
} else if (paddle::platform::is_gpu_place(holder_->place())) { */
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
platform::GpuMemcpySync(dst_ptr, src_ptr, size, cudaMemcpyDeviceToHost);
#endif
}
}
template <typename T> template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const { inline void CopyFrom(const Tensor& src, const platform::Place& dst_place);
EnforceSufficientMemory<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
void Resize(const DDim& dims) { dims_ = dims; } /**
* @brief Return the slice of the tensor.
*
* @param[in] begin_idx The begin index of the slice.
* @param[in] end_idx The end index of the slice.
*/
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
const DDim& dims() const { return dims_; } private:
template <typename T>
inline void check_memory_size() const;
paddle::platform::Place place() const { return holder_->place(); } paddle::platform::Place place() const { return holder_->place(); }
private: private:
// Placeholder hides type T, so it doesn't appear as a template /**
// parameter of Variable. * @note Placeholder hides type T, so it doesn't appear as a template
* parameter of Variable.
*/
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual void* ptr() const = 0; virtual void* ptr() const = 0;
virtual platform::Place place() const = 0;
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual std::type_index type() const = 0; virtual std::type_index type() const = 0;
virtual platform::Place place() const = 0;
}; };
template <typename T, typename PlaceType> template <typename T, typename Place>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(PlaceType place, size_t size) PlaceholderImpl(Place place, size_t size)
: ptr_(static_cast<T*>(memory::Alloc(place, size)), : ptr_(static_cast<T*>(memory::Alloc(place, size)),
memory::PODDeleter<T, PlaceType>(place)), memory::PODDeleter<T, Place>(place)),
place_(place), place_(place),
size_(size) {} size_(size) {
PADDLE_ENFORCE(ptr_ != nullptr, "Insufficient %s memory to allocation.",
is_cpu_place(place_) ? "CPU" : "GPU");
}
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t size() const { return size_; } virtual size_t size() const { return size_; }
virtual paddle::platform::Place place() const { return place_; } virtual platform::Place place() const { return place_; }
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual std::type_index type() const { return std::type_index(typeid(T)); } virtual std::type_index type() const { return std::type_index(typeid(T)); }
std::unique_ptr<T, memory::PODDeleter<T, PlaceType>> ptr_; /*! the pointer of memory block. */
platform::Place place_; // record the place of ptr_. std::unique_ptr<T, memory::PODDeleter<T, Place>> ptr_;
size_t size_; // size of the memory block.
/*! the place of memory block. */
platform::Place place_;
/*! the size of memory block. */
size_t size_;
}; };
template <typename T> /*! holds the memory block if allocated. */
inline void EnforceSufficientMemory() const { std::shared_ptr<Placeholder> holder_;
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated. /*! points to dimensions of memory block. */
DDim dims_; DDim dims_;
// A PlaceHolder may be shared by more than one tensor. Some of them may be
// slices of the others. So the offset_ is introduced here to indicate the /**
// byte offset between PlaceHolder::ptr_ and where tensor's data really * @brief A PlaceHolder may be shared by more than one tensor.
// begins. *
* @note Some of them may be slices of the others. So the offset_
* is introduced here to indicate the byte offset between
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t offset_; size_t offset_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/framework/detail/tensor-inl.h"
...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { ...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false; bool caught = false;
try { try {
src_tensor.data<double>(); src_tensor.data<double>();
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
...@@ -72,7 +72,8 @@ TEST(Tensor, MutableData) { ...@@ -72,7 +72,8 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace()); p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
#ifdef __CUDACC__
#ifndef PADDLE_ONLY_CPU
{ {
Tensor src_tensor; Tensor src_tensor;
float* p1 = nullptr; float* p1 = nullptr;
...@@ -107,7 +108,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -107,7 +108,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false; bool caught = false;
try { try {
dst_tensor.ShareDataWith<float>(src_tensor); dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
...@@ -123,7 +124,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -123,7 +124,7 @@ TEST(Tensor, ShareDataWith) {
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
#ifdef __CUDACC__ #ifndef PADDLE_ONLY_CPU
{ {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
...@@ -160,7 +161,7 @@ TEST(Tensor, Slice) { ...@@ -160,7 +161,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address); EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
} }
#ifdef __CUDACC__ #ifndef PADDLE_ONLY_CPU
{ {
Tensor src_tensor; Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace()); src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
...@@ -188,13 +189,53 @@ TEST(Tensor, Slice) { ...@@ -188,13 +189,53 @@ TEST(Tensor, Slice) {
TEST(Tensor, CopyFrom) { TEST(Tensor, CopyFrom) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
{
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor;
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace()); int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int)); memcpy(src_ptr, arr, 9 * sizeof(int));
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom<int>(src_tensor, *cpu_place);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
dst_tensor.CopyFrom<int>(slice_tensor, *cpu_place);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#ifndef PADDLE_ONLY_CPU
{
Tensor src_tensor;
Tensor gpu_tensor;
Tensor dst_tensor; Tensor dst_tensor;
dst_tensor.CopyFrom<int>(src_tensor, CPUPlace());
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
auto gpu_place = new paddle::platform::GPUPlace(0);
gpu_tensor.CopyFrom<int>(src_tensor, *gpu_place);
// GPU Tensor to CPU Tensor
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
// Compare Tensors
const int* dst_ptr = dst_tensor.data<int>(); const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr); ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) { for (size_t i = 0; i < 9; ++i) {
...@@ -202,11 +243,20 @@ TEST(Tensor, CopyFrom) { ...@@ -202,11 +243,20 @@ TEST(Tensor, CopyFrom) {
} }
Tensor slice_tensor = src_tensor.Slice<int>(1, 2); Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
dst_tensor.CopyFrom<int>(slice_tensor, CPUPlace());
// CPU Slice Tensor to GPU Tensor
gpu_tensor.CopyFrom<int>(slice_tensor, *gpu_place);
// GPU Tensor to CPU Tensor
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
// Compare Slice Tensors
const int* slice_ptr = slice_tensor.data<int>(); const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>(); dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr); ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
}
#endif
} }
...@@ -207,8 +207,8 @@ Error __must_check backward(Argument& act) { ...@@ -207,8 +207,8 @@ Error __must_check backward(Argument& act) {
argument_.value->setData(act.value->getData() + offset, 1UL, size); argument_.value->setData(act.value->getData() + offset, 1UL, size);
argument_.grad->setData(act.grad->getData() + offset, 1UL, size); argument_.grad->setData(act.grad->getData() + offset, 1UL, size);
Error status = softmax_.backward(argument_); Error err = softmax_.backward(argument_);
if (!status) return status; if (!err.isOK()) return err;
} }
return Error(); return Error();
} }
......
...@@ -27,12 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator, ...@@ -27,12 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
system_allocator_(std::move(system_allocator)) {} system_allocator_(std::move(system_allocator)) {}
BuddyAllocator::~BuddyAllocator() { BuddyAllocator::~BuddyAllocator() {
DLOG(INFO) << "BuddyAllocator Disconstructor makes sure that all of these " VLOG(3) << "BuddyAllocator Disconstructor makes sure that all of these "
"have actually been freed"; "have actually been freed";
while (!pool_.empty()) { while (!pool_.empty()) {
auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin())); auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin()));
DLOG(INFO) << "Free from block (" << block << ", " << max_chunk_size_ VLOG(3) << "Free from block (" << block << ", " << max_chunk_size_ << ")";
<< ")";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
cache_.invalidate(block); cache_.invalidate(block);
...@@ -52,12 +51,11 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) { ...@@ -52,12 +51,11 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
// acquire the allocator lock // acquire the allocator lock
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
DLOG(INFO) << "Allocate " << unaligned_size << " bytes from chunk size " VLOG(3) << "Allocate " << unaligned_size << " bytes from chunk size " << size;
<< size;
// if the allocation is huge, send directly to the system allocator // if the allocation is huge, send directly to the system allocator
if (size > max_chunk_size_) { if (size > max_chunk_size_) {
DLOG(INFO) << "Allocate from system allocator."; VLOG(3) << "Allocate from system allocator.";
return SystemAlloc(size); return SystemAlloc(size);
} }
...@@ -72,7 +70,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) { ...@@ -72,7 +70,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
return nullptr; return nullptr;
} }
} else { } else {
DLOG(INFO) << "Allocation from existing memory block " << std::get<2>(*it) VLOG(3) << "Allocation from existing memory block " << std::get<2>(*it)
<< " at address " << " at address "
<< reinterpret_cast<MemoryBlock*>(std::get<2>(*it))->data(); << reinterpret_cast<MemoryBlock*>(std::get<2>(*it))->data();
} }
...@@ -91,10 +89,10 @@ void BuddyAllocator::Free(void* p) { ...@@ -91,10 +89,10 @@ void BuddyAllocator::Free(void* p) {
// Acquire the allocator lock // Acquire the allocator lock
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
DLOG(INFO) << "Free from address " << block; VLOG(3) << "Free from address " << block;
if (block->type(cache_) == MemoryBlock::HUGE_CHUNK) { if (block->type(cache_) == MemoryBlock::HUGE_CHUNK) {
DLOG(INFO) << "Free directly from system allocator"; VLOG(3) << "Free directly from system allocator";
system_allocator_->Free(block, block->total_size(cache_), system_allocator_->Free(block, block->total_size(cache_),
block->index(cache_)); block->index(cache_));
...@@ -111,7 +109,7 @@ void BuddyAllocator::Free(void* p) { ...@@ -111,7 +109,7 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the right buddy // Trying to merge the right buddy
if (block->has_right_buddy(cache_)) { if (block->has_right_buddy(cache_)) {
DLOG(INFO) << "Merging this block " << block << " with its right buddy " VLOG(3) << "Merging this block " << block << " with its right buddy "
<< block->right_buddy(cache_); << block->right_buddy(cache_);
auto right_buddy = block->right_buddy(cache_); auto right_buddy = block->right_buddy(cache_);
...@@ -129,7 +127,7 @@ void BuddyAllocator::Free(void* p) { ...@@ -129,7 +127,7 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the left buddy // Trying to merge the left buddy
if (block->has_left_buddy(cache_)) { if (block->has_left_buddy(cache_)) {
DLOG(INFO) << "Merging this block " << block << " with its left buddy " VLOG(3) << "Merging this block " << block << " with its left buddy "
<< block->left_buddy(cache_); << block->left_buddy(cache_);
auto left_buddy = block->left_buddy(cache_); auto left_buddy = block->left_buddy(cache_);
...@@ -146,7 +144,7 @@ void BuddyAllocator::Free(void* p) { ...@@ -146,7 +144,7 @@ void BuddyAllocator::Free(void* p) {
} }
// Dumping this block into pool // Dumping this block into pool
DLOG(INFO) << "Inserting free block (" << block << ", " VLOG(3) << "Inserting free block (" << block << ", "
<< block->total_size(cache_) << ")"; << block->total_size(cache_) << ")";
pool_.insert( pool_.insert(
IndexSizeAddress(block->index(cache_), block->total_size(cache_), block)); IndexSizeAddress(block->index(cache_), block->total_size(cache_), block));
...@@ -166,7 +164,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) { ...@@ -166,7 +164,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
size_t index = 0; size_t index = 0;
void* p = system_allocator_->Alloc(index, size); void* p = system_allocator_->Alloc(index, size);
DLOG(INFO) << "Allocated " << p << " from system allocator."; VLOG(3) << "Allocated " << p << " from system allocator.";
if (p == nullptr) return nullptr; if (p == nullptr) return nullptr;
...@@ -192,7 +190,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { ...@@ -192,7 +190,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
if (p == nullptr) return pool_.end(); if (p == nullptr) return pool_.end();
DLOG(INFO) << "Creating and inserting new block " << p VLOG(3) << "Creating and inserting new block " << p
<< " from system allocator"; << " from system allocator";
static_cast<MemoryBlock*>(p)->init(cache_, MemoryBlock::FREE_CHUNK, index, static_cast<MemoryBlock*>(p)->init(cache_, MemoryBlock::FREE_CHUNK, index,
...@@ -237,18 +235,18 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it, ...@@ -237,18 +235,18 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
auto block = static_cast<MemoryBlock*>(std::get<2>(*it)); auto block = static_cast<MemoryBlock*>(std::get<2>(*it));
pool_.erase(it); pool_.erase(it);
DLOG(INFO) << "Split block (" << block << ", " << block->total_size(cache_) VLOG(3) << "Split block (" << block << ", " << block->total_size(cache_)
<< ") into"; << ") into";
block->split(cache_, size); block->split(cache_, size);
DLOG(INFO) << "Left block (" << block << ", " << block->total_size(cache_) VLOG(3) << "Left block (" << block << ", " << block->total_size(cache_)
<< ")"; << ")";
block->set_type(cache_, MemoryBlock::ARENA_CHUNK); block->set_type(cache_, MemoryBlock::ARENA_CHUNK);
// the rest of memory if exist // the rest of memory if exist
if (block->has_right_buddy(cache_)) { if (block->has_right_buddy(cache_)) {
if (block->right_buddy(cache_)->type(cache_) == MemoryBlock::FREE_CHUNK) { if (block->right_buddy(cache_)->type(cache_) == MemoryBlock::FREE_CHUNK) {
DLOG(INFO) << "Insert right block (" << block->right_buddy(cache_) << ", " VLOG(3) << "Insert right block (" << block->right_buddy(cache_) << ", "
<< block->right_buddy(cache_)->total_size(cache_) << ")"; << block->right_buddy(cache_)->total_size(cache_) << ")";
pool_.insert( pool_.insert(
...@@ -276,7 +274,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() { ...@@ -276,7 +274,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
return; return;
} }
DLOG(INFO) << "Return block " << block << " to fallback allocator."; VLOG(3) << "Return block " << block << " to fallback allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
cache_.invalidate(block); cache_.invalidate(block);
...@@ -312,7 +310,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() { ...@@ -312,7 +310,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
MemoryBlock* block = static_cast<MemoryBlock*>(std::get<2>(*pool)); MemoryBlock* block = static_cast<MemoryBlock*>(std::get<2>(*pool));
DLOG(INFO) << "Return block " << block << " to base allocator."; VLOG(3) << "Return block " << block << " to base allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
cache_.invalidate(block); cache_.invalidate(block);
......
...@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place, ...@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform::GPUPlace src_place, platform::GPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream) { cudaStream_t stream) {
platform::GPUPlaceGuard g(src_place.device); platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
} }
...@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place, ...@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform::CPUPlace src_place, platform::CPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream) { cudaStream_t stream) {
platform::GPUPlaceGuard g(dst_place.device); platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
} }
...@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place, ...@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream) { cudaStream_t stream) {
if (dst_place == src_place) { if (dst_place == src_place) {
platform::GPUPlaceGuard g(src_place.device); platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else { } else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
......
...@@ -20,13 +20,39 @@ limitations under the License. */ ...@@ -20,13 +20,39 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace memory { namespace memory {
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
*
*/
template <typename DstPlace, typename SrcPlace> template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
/**
* \brief Copy memory from one place to another place.
*
* \param[in] DstPlace Destination allocation place (CPU or GPU).
* \param[in] dst Destination memory address.
* \param[in] SrcPlace Source allocation place (CPU or GPU).
* \param[in] src Source memory address.
* \param[in] num memory size in bytes to copy.
* \param[in] stream CUDA stream.
*
* \note For GPU memory copy, CUDA stream need to be specified
* for asynchronously memory copy.
*
*/
template <typename DstPlace, typename SrcPlace> template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream); cudaStream_t stream);
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} // namespace memory } // namespace memory
......
...@@ -60,6 +60,7 @@ detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { ...@@ -60,6 +60,7 @@ detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
platform::GpuMaxChunkSize()); platform::GpuMaxChunkSize());
} }
} }
platform::SetDeviceId(gpu_id);
return as[gpu_id]; return as[gpu_id];
} }
......
...@@ -20,19 +20,53 @@ limitations under the License. */ ...@@ -20,19 +20,53 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace memory { namespace memory {
/**
* \brief Allocate memory block in one place.
*
* \param[in] place Allocation place (CPU or GPU).
* \param[in] size Allocation size.
*
* \return Allocated memory block address.
*
* \note If return nullptr, it indicates memory allocation failed
* because insufficient memory in current system. When Alloc
* function is invoked, you must check the returned memory
* address is valid or not.
*/
template <typename Place> template <typename Place>
void* Alloc(Place, size_t); void* Alloc(Place place, size_t size);
/**
* \brief Free memory block in one place.
*
* \param[in] place Allocation place (CPU or GPU).
* \param[in] ptr Memory block address to free.
*
*/
template <typename Place> template <typename Place>
void Free(Place, void*); void Free(Place place, void* ptr);
/**
* \brief Total size of used memory in one place.
*
* \param[in] place Allocation place (CPU or GPU).
*
*/
template <typename Place> template <typename Place>
size_t Used(Place); size_t Used(Place place);
template <typename T, /* must be POD types */ /**
typename Place /* platform::GPUPlace or platform::CPUPlace */, * \brief Free memory block in one place.
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> *
* \note In some cases, custom deleter is used to
* deallocate the memory automatically for
* std::unique_ptr<T> in tensor.h.
*
*/
template <typename T, typename Place>
class PODDeleter { class PODDeleter {
static_assert(std::is_pod<T>::value, "T must be POD");
public: public:
PODDeleter(Place place) : place_(place) {} PODDeleter(Place place) : place_(place) {}
void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); } void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); }
......
...@@ -54,3 +54,8 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op ...@@ -54,3 +54,8 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net) softmax_op net)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc
tensor op_registry operator net)
cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS
recurrent_network_op gtest mul_op add_op)
...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/add_op.h" #include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class AddOp : public framework::OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two"); PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -35,10 +32,10 @@ protected: ...@@ -35,10 +32,10 @@ protected:
} }
}; };
class AddOpMaker : public framework::OpProtoAndCheckerMaker { class AddOpMaker : public OpProtoAndCheckerMaker {
public: public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op"); AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op"); AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op"); AddOutput("Out", "The output of add op");
...@@ -50,11 +47,10 @@ The equation is: Out = X + Y ...@@ -50,11 +47,10 @@ The equation is: Out = X + Y
} }
}; };
class AddOpGrad : public framework::OperatorWithKernel { class AddOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "AddOpGrad"; LOG(INFO) << "AddOpGrad";
return ""; return "";
...@@ -64,7 +60,6 @@ protected: ...@@ -64,7 +60,6 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad); REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL(add_two, REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
...@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and ...@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public framework::OpKernel { class AddKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>(); auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1);
framework::EigenVector<T>::Flatten(input1);
} }
}; };
......
...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, PADDLE_ENFORCE(outputs.size() == 1,
...@@ -35,15 +32,14 @@ protected: ...@@ -35,15 +32,14 @@ protected:
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1, PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1."); "label's dimension must be 1.");
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]})); outputs[0]->Resize({inputs[0]->dims()[0]});
} }
}; };
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public: public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto, OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp"); AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp"); AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp"); AddOutput("Y", "The output of OnehotCrossEntropyOp");
...@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator. ...@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
} // namespace paddle } // namespace paddle
REGISTER_OP(onehot_cross_entropy, REGISTER_OP(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOp,
paddle::operators::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
paddle::operators::OnehotCrossEntropyOpKernel<::paddle::platform::CPUPlace,
float>);
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel< ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
::paddle::platform::GPUPlace, float>); \ No newline at end of file
\ No newline at end of file
...@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and ...@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel { class OnehotCrossEntropyOpKernel : public OpKernel {
public: public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); } constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto X = context.Input(0)->Get<framework::Tensor>(); auto X = context.Input(0)->Get<Tensor>();
const T* X_data = X.data<T>(); const T* X_data = X.data<T>();
const int* label_data = const int* label_data = context.Input(1)->Get<Tensor>().data<int>();
context.Input(1)->Get<framework::Tensor>().data<int>(); auto* Y = context.Output(0)->GetMutable<Tensor>();
auto* Y = context.Output(0)->GetMutable<framework::Tensor>();
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(context.GetPlace());
......
...@@ -12,41 +12,38 @@ ...@@ -12,41 +12,38 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/net.h" #include "type_alias.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FullyConnectedOp : public framework::PlainNet { class FullyConnectedOp : public NetOp {
public: public:
void Init() override { void Init() override {
AddOp(framework::OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
{ {
Input("X"), Input("W"), Input("X"), Input("W"),
}, },
{Output("before_act")}, {Output("before_act")},
{})); {}));
auto b = Input("b"); auto b = Input("b");
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { if (b != EMPTY_VAR_NAME()) {
AddOp(framework::OpRegistry::CreateOp("rowwise_add", AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")}, {Output("before_act"), Input("b")},
{Output("before_act")}, {Output("before_act")},
{})); {}));
} }
auto activation = GetAttr<std::string>("activation"); auto activation = GetAttr<std::string>("activation");
AddOp(framework::OpRegistry::CreateOp( AddOp(OpRegistry::CreateOp(
activation, {Output("before_act")}, {Output("Y")}, {})); activation, {Output("before_act")}, {Output("Y")}, {}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
public: public:
FullyConnectedOpMaker(framework::OpProto *proto, FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator"); AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator"); AddInput("W", "the weight of fc operator");
...@@ -71,6 +68,4 @@ USE_OP(rowwise_add); ...@@ -71,6 +68,4 @@ USE_OP(rowwise_add);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
REGISTER_OP(fc, REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
paddle::operators::FullyConnectedOp,
paddle::operators::FullyConnectedOpMaker);
...@@ -13,17 +13,14 @@ ...@@ -13,17 +13,14 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MulOp : public framework::OperatorWithKernel { class MulOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs"); PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs");
auto dim0 = inputs[0]->dims(); auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims(); auto dim1 = inputs[1]->dims();
...@@ -37,10 +34,10 @@ protected: ...@@ -37,10 +34,10 @@ protected:
} }
}; };
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op"); AddOutput("Out", "The output of mul op");
...@@ -52,11 +49,10 @@ The equation is: Out = X * Y ...@@ -52,11 +49,10 @@ The equation is: Out = X * Y
} }
}; };
class MulOpGrad : public framework::OperatorWithKernel { class MulOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "MulGrad"; LOG(INFO) << "MulGrad";
return ""; return "";
...@@ -66,8 +62,7 @@ protected: ...@@ -66,8 +62,7 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad); REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>);
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -13,8 +13,5 @@ ...@@ -13,8 +13,5 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(mul, REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
paddle::operators::MulKernel<paddle::platform \ No newline at end of file
::GPUPlace, float>);
\ No newline at end of file
...@@ -14,30 +14,27 @@ ...@@ -14,30 +14,27 @@
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input(0)->Get<framework::Tensor>(); auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenMatrix<T>::From(*output).device( EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1),
framework::EigenMatrix<T>::From(input0).contract( dim_pair);
framework::EigenMatrix<T>::From(input1), dim_pair);
} }
}; };
} // namespace operators } // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/recurrent_network_op.h"
#include <glog/logging.h>
#include <cstring>
#include <sstream>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace rnn {
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
Tensor* input =
step_scopes[0]->GetVariable(inlinks[i].external)->GetMutable<Tensor>();
DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = step_scopes[j]
->CreateVariable(inlinks[i].internal)
->GetMutable<Tensor>();
*step_input = input->Slice<float>(j, j + 1);
step_input->Resize(step_dims);
}
}
}
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len) {
for (size_t i = 0; i < outlinks.size(); i++) {
Tensor* output =
step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>();
// TODO(qingiqng) remove following code after adding
// InferShape in RecurrentGradientOp
DDim step_dims = step_scopes[0]
->GetVariable(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->mutable_data<float>(make_ddim(dims_vec), platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output = step_scopes[j]
->GetVariable(outlinks[i].internal)
->GetMutable<Tensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
}
}
void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
const std::vector<rnn::MemoryAttr>& memories,
size_t step_id,
int offset) {
PADDLE_ENFORCE(step_id < scopes.size(),
"step [%d] is out of range of step scopes' size [%d]",
step_id,
scopes.size());
PADDLE_ENFORCE(static_cast<int>(step_id) + offset >= 0,
"offset [%d] must be large than -[%d]",
offset,
step_id);
PADDLE_ENFORCE(step_id + offset < scopes.size(),
"offset [%d] is out of range, it must be less than (%d - %d)",
offset,
scopes.size(),
step_id);
std::shared_ptr<Scope> scope = scopes[step_id];
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
auto mem = scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better?
auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>();
mem->ShareDataWith<float>(*linked_mem);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
auto m = scope->CreateVariable(attr.var)->GetMutable<Tensor>();
// for unit test, as addOp and mulOp are null currently, if not
// mutable_data, mem.data() in output will be error. We will
// remove this line after merge the correct addOp and mulOp.
m->mutable_data<float>(mem->dims(), platform::CPUPlace());
}
}
void InitArgument(const ArgumentName& name,
Argument* arg,
const OperatorBase& op) {
arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks);
auto inlink_alias = op.GetAttr<std::vector<std::string>>(name.inlink_alias);
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
"the size of inlinks and inlink_alias don't match:%d,%d",
inlinks.size(),
inlink_alias.size());
for (size_t i = 0; i < inlinks.size(); ++i) {
rnn::Link link;
link.external = inlinks[i];
link.internal = inlink_alias[i];
(arg->inlinks).push_back(link);
}
auto outlinks = op.Outputs(name.outlinks);
auto outlink_alias = op.GetAttr<std::vector<std::string>>(name.outlink_alias);
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
"the size of outlinks and outlink_alias don't match:%d,%d",
outlinks.size(),
outlink_alias.size());
for (size_t i = 0; i < outlinks.size(); ++i) {
rnn::Link link;
link.external = outlinks[i];
link.internal = outlink_alias[i];
(arg->outlinks).push_back(link);
}
auto boot_memories = op.Inputs(name.boot_memories);
// attributes
auto memories = op.GetAttr<std::vector<std::string>>(name.memories);
auto pre_memories = op.GetAttr<std::vector<std::string>>(name.pre_memories);
PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of memories, boot_memories don't match:%d,%d",
memories.size(),
boot_memories.size());
PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(),
"the size of pre_memories, boot_memories don't match:%d,%d",
pre_memories.size(),
boot_memories.size());
PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set");
for (size_t i = 0; i < memories.size(); ++i) {
rnn::MemoryAttr mem_attr;
mem_attr.var = memories[i];
mem_attr.pre_var = pre_memories[i];
mem_attr.boot_var = boot_memories[i];
(arg->memories).push_back(mem_attr);
}
}
} // namespace rnn
void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const {
seq_len_ = scope->GetVariable((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
->dims()[0];
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
// SegmentInputs is called in InferShape. The input must hold memory in
// SegmentInputs. But the other op only set dimension for the output in
// InferShape. That's a problem. Wether the RNN op needs InferShape or not?
// Wether the following functions (SegmentInputs, InitMemories, ...) need
// to rewrite for RNN op?
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
InitMemories(step_scopes[0]);
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
"stepnet [%s] is not in scope.",
arg_->step_net);
Variable* net = scope->GetVariable(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
// If the InferShape is called in OperatorBase's run function,
// the rnn op only needs to do InferShape for the first time step
for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1);
}
net->GetMutable<NetOp>()->InferShape(step_scopes[i]);
}
auto outlinks = arg_->outlinks;
for (size_t i = 0; i < outlinks.size(); i++) {
DDim step_dims = step_scopes[0]
->GetVariable(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
// now only support fixed length
dims_vec.insert(dims_vec.begin(), seq_len_);
Tensor* output =
step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
}
void RecurrentAlgorithm::Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
Variable* net = scope->GetVariable(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// the link memory is done in InferShape
// maybe remove following code after testing
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
}
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
}
void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
// TODO(xxx) Only two scopes are needed for inference, this case will be
// supported later.
auto step_scopes = scope->GetVariable(arg_->step_scopes)
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
std::shared_ptr<Scope> step_scope = std::make_shared<Scope>(scope);
// Now all variables in scope must be created outside of op.
auto net_op = scope->GetVariable(arg_->step_net)->GetMutable<NetOp>();
for (auto& input : net_op->inputs_) {
step_scope->CreateVariable(input);
}
for (auto& output : net_op->outputs_) {
step_scope->CreateVariable(output);
}
step_scopes->push_back(std::make_shared<Scope>(step_scope));
}
}
}
void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope) const {
for (auto& attr : arg_->memories) {
Tensor* pre_mem =
step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var),
"memory [%s]'s boot variable [%s] not exists",
attr.var,
attr.boot_var);
Tensor* boot_mem =
step_scope->GetVariable(attr.boot_var)->GetMutable<Tensor>();
pre_mem->ShareDataWith<float>(*boot_mem);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
// here for unit test
auto cur_step_mem =
step_scope->CreateVariable(attr.var)->GetMutable<Tensor>();
cur_step_mem->mutable_data<float>(boot_mem->dims(), platform::CPUPlace());
}
}
const rnn::ArgumentName RecurrentOp::kArgName{"step_net",
"step_scopes",
"inlinks",
"outlinks",
"inlink_alias",
"outlink_alias",
"memories",
"pre_memories",
"boot_memories"};
const rnn::ArgumentName RecurrentGradientOp::kArgName{"step_net",
"step_scopes",
"outlink@grad",
"inlink@grad",
"inlink_alias",
"outlink_alias",
"memories",
"pre_memories",
"boot_memories@grad"};
void RecurrentOp::Init() {
OperatorBase::Init();
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
}
class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInputs(name.inlinks,
"the input that need to be segmented for each step.");
AddInputs(name.boot_memories, "variables to initialize memories.");
AddInput(name.step_net, "network shared by all steps.");
AddOutputs(name.outlinks,
"the output that need to concated for all steps.");
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.inlink_alias, "alias of inlinks");
AddAttr<std::vector<std::string>>(name.outlink_alias, "alias of outlinks");
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
AddComment("This is a recurrent group operator.");
}
};
void RecurrentGradientAlgorithm::Run(
const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
"step net is not in scope.");
Variable* net = scope->GetVariable(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
}
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0]);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
}
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
std::shared_ptr<Scope> step_scope) const {
for (auto& attr : arg_->memories) {
Tensor* mem_grad =
step_scope->CreateVariable(attr.var)->GetMutable<Tensor>();
PADDLE_ENFORCE(mem_grad != nullptr,
"boot_tensor should be retrieved before");
PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var),
"memory [%s]'s boot variable [%s] not exists",
attr.var,
attr.boot_var);
Tensor* boot_mem_grad =
step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>();
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
}
void RecurrentGradientAlgorithm::InferShape(
const std::shared_ptr<Scope>& scope) const {
seq_len_ = scope->GetVariable((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
"step net is not in scope.");
Variable* net = scope->GetVariable(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
}
net->GetMutable<NetOp>()->InferShape(step_scopes[step_id]);
}
auto outlinks = arg_->outlinks;
for (size_t i = 0; i < outlinks.size(); i++) {
DDim step_dims = step_scopes[0]
->GetVariable(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
// now only support fixed length
dims_vec.insert(dims_vec.begin(), seq_len_);
Tensor* output =
step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
LinkBootMemoryGradients(step_scopes[0]);
}
void RecurrentGradientOp::Init() {
OperatorBase::Init();
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
}
} // namespace operators
} // namespace paddle
REGISTER_OP(recurrent_op,
paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
using namespace paddle::framework;
namespace rnn {
/**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
*
* Memory attributes cached by this op, dims will be infered from
* boot memories in father scope. Other attributes are copied from Op's proto
* attributes.
*/
struct MemoryAttr {
// name of current state variable
std::string var;
// name of previous step's state variable
std::string pre_var;
// name of the variables to init this memory (same role of `boot_layer` in
// PaddlePaddle), which is store in father's scope.
std::string boot_var;
};
struct Link {
// input or output links name.
std::string internal;
// alias to avoid duplicate keys in scopes.
std::string external;
};
struct Argument {
std::string step_net;
std::string step_scopes;
std::vector<Link> inlinks;
std::vector<Link> outlinks;
std::vector<rnn::MemoryAttr> memories;
};
struct ArgumentName {
std::string step_net;
std::string step_scopes;
std::string inlinks;
std::string outlinks;
std::string inlink_alias; // the alias of inlinks in step net.
std::string outlink_alias; // the alias of outlinks in step net.
std::string memories; // the memory name
std::string pre_memories; // the previous memory name
std::string boot_memories; // the boot memory name
};
/**
* Prepare inputs for each step net.
*/
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len);
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len);
void LinkMemories(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<MemoryAttr>& memories,
size_t step_id,
int offset);
void InitArgument(const ArgumentName& name, Argument* arg);
}; // namespace rnn
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
// TODO:
// 1. No-padding computing for sequences with indifinite length in one batch.
// 2. Hierarchical RNN for sequence with sub-sequence.
// 3. Internal Memory.
// 4. More Complex RNN architecture, such as Gated Feedback RNN.
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
class RecurrentAlgorithm {
public:
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
/**
* InferShape must be called before Run.
*/
void InferShape(const std::shared_ptr<Scope>& scope) const;
protected:
/*
* The step scopes will be stored in the father scope as a variable.
*
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void CreateScopes(std::shared_ptr<Scope> scope) const;
inline const std::vector<std::shared_ptr<Scope>>& GetStepScopes(
std::shared_ptr<Scope> scope) const {
return *(scope->GetVariable(arg_->step_scopes))
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
}
void InitMemories(std::shared_ptr<Scope> step_scopes) const;
private:
std::unique_ptr<rnn::Argument> arg_;
mutable size_t seq_len_;
};
class RecurrentGradientAlgorithm {
/**
* RNN's backward alogorithm.
*
* To accelerate the development of RecurrentGradientOp, we decouple RNN's
* algorithm and `OperatorBase`'s implementation, the former contains the core
* implementation of a RNN, and will keep stable even if the framework changes
* a
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* operator.
*/
public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(std::shared_ptr<Scope> step_scopes) const;
/**
* InferShape must be called before Run.
*/
void InferShape(const std::shared_ptr<Scope>& scope) const;
protected:
inline const std::vector<std::shared_ptr<Scope>>& GetStepScopes(
std::shared_ptr<Scope> scope) const {
return *(scope->GetVariable(arg_->step_scopes))
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
}
private:
std::unique_ptr<rnn::Argument> arg_;
mutable size_t seq_len_;
};
class RecurrentOp final : public OperatorBase {
public:
void Init() override;
/**
* InferShape must be called before Run.
*/
virtual void InferShape(const std::shared_ptr<Scope>& scope) const override {
alg_.InferShape(scope);
}
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx);
}
static const rnn::ArgumentName kArgName;
private:
RecurrentAlgorithm alg_;
};
class RecurrentGradientOp final : public OperatorBase {
public:
void Init() override;
/**
* InferShape must be called before Run.
*/
virtual void InferShape(const std::shared_ptr<Scope>& scope) const override {
alg_.InferShape(scope);
}
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx);
}
static const rnn::ArgumentName kArgName;
private:
RecurrentGradientAlgorithm alg_;
};
} // namespace operators
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h"
namespace paddle {
namespace operators {
class RecurrentOpTest : public ::testing::Test {
protected:
virtual void SetUp() override {
CreateGlobalVariables();
CreateStepNet();
CreateRNNOp();
}
virtual void TearDown() override {}
void CreateGlobalVariables() {
scope_ = std::make_shared<Scope>();
// create input, and init content
LOG(INFO) << "create global variable x";
for (auto inlink : std::vector<std::string>{"x", "x0", "x1", "h"}) {
Variable* x = scope_->CreateVariable(inlink);
DDim dims = make_ddim(std::vector<int>{
10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
}
// create output alias just for test
for (auto inlink : std::vector<std::string>{"h@alias"}) {
Variable* x = scope_->CreateVariable(inlink);
DDim dims =
make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
}
LOG(INFO) << "create global variable w";
Variable* w = scope_->CreateVariable("rnn/w");
w->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) {
LOG(INFO) << "create global variable " << boot;
Variable* h_boot = scope_->CreateVariable(boot);
h_boot->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/}),
platform::CPUPlace());
}
LOG(INFO) << "create variable step_scopes";
scope_->CreateVariable("step_scopes");
LOG(INFO) << "create variable h";
scope_->CreateVariable("h");
}
void CreateRNNOp() {
OpDesc op_desc;
op_desc.set_type("recurrent_op");
// inlinks 0
op_desc.add_inputs("x");
op_desc.add_inputs("x0");
op_desc.add_inputs("x1");
// boot_memories 3
op_desc.add_inputs("x_boot");
op_desc.add_inputs("h_boot");
// step net 5
op_desc.add_inputs("step_net");
// outlinks 6
op_desc.add_outputs("h");
// step scopes 7
op_desc.add_outputs("step_scopes");
auto _input_format = std::vector<int>{
0, // in_link
3, // memories
5 // step_net
};
auto input_format = op_desc.add_attrs();
input_format->set_name("input_format");
input_format->set_type(paddle::framework::AttrType::INTS);
for (auto i : _input_format) {
input_format->add_ints(i);
}
auto output_format = op_desc.add_attrs();
output_format->set_name("output_format");
output_format->set_type(paddle::framework::AttrType::INTS);
for (auto i : std::vector<int>{0, 1, 2}) {
output_format->add_ints(i);
}
auto inlink_alias = op_desc.add_attrs();
inlink_alias->set_name("inlink_alias");
inlink_alias->set_type(paddle::framework::AttrType::STRINGS);
auto outlink_alias = op_desc.add_attrs();
outlink_alias->set_name("outlink_alias");
outlink_alias->set_type(paddle::framework::AttrType::STRINGS);
auto pre_memories = op_desc.add_attrs();
pre_memories->set_name("pre_memories");
pre_memories->set_type(paddle::framework::AttrType::STRINGS);
auto memories = op_desc.add_attrs();
memories->set_name("memories");
memories->set_type(paddle::framework::AttrType::STRINGS);
// create inlink_alias
for (const auto& item :
std::vector<std::string>{"x@alias", "x0@alias", "x1@alias"}) {
inlink_alias->add_strings(item);
}
// pre memories
for (const auto& item :
std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) {
pre_memories->add_strings(item);
}
// memories
for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) {
memories->add_strings(item);
}
// output alias
for (const auto& item : std::vector<std::string>{"h@alias"}) {
outlink_alias->add_strings(item);
}
rnn_op_ = OpRegistry::CreateOp(op_desc);
LOG(INFO) << "rnn_op finish init";
}
void CreateStepNet() {
LOG(INFO) << "create variable step_net";
Variable* var = scope_->CreateVariable("step_net");
auto net = var->GetMutable<NetOp>();
// rnn/s is net's input or output?
net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
net->inputs_ = {"rnn/s", "rnn/h"};
net->AddOp(
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
net->AddOp(
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {}));
net->CompleteAddOp();
}
// father scope
std::shared_ptr<Scope> scope_;
std::shared_ptr<OperatorBase> rnn_op_;
};
TEST_F(RecurrentOpTest, Run) {
platform::CPUDeviceContext ctx;
rnn_op_->InferShape(scope_);
rnn_op_->Run(scope_, ctx);
}
class RecurrentGradientAlgorithmTest : public ::testing::Test {
protected:
virtual void SetUp() override {
CreateGlobalVariables();
CreateStepScopes();
CreateStepNet();
CreateRNNGradientAlgorithm();
// segment inputs
SegmentInputs();
// link forward memories
LinkeMemories();
}
virtual void TearDown() override {}
void CreateGlobalVariables() {
scope_ = std::make_shared<Scope>();
// inputs: x
LOG(INFO) << "create global variable x";
Variable* x = scope_->CreateVariable("x");
DDim dims =
make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
// inputs: h_boot
LOG(INFO) << "create global variable h_boot";
Variable* h_boot = scope_->CreateVariable("h_boot");
h_boot->GetMutable<Tensor>()->mutable_data<float>(
make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace());
// inputs: w
LOG(INFO) << "create global variable w";
Variable* w = scope_->CreateVariable("rnn/w");
w->GetMutable<Tensor>()->mutable_data<float>(make_ddim({30, 30}),
platform::CPUPlace());
// inputs: h_grad
LOG(INFO) << "create variable h_grad";
Variable* dh = scope_->CreateVariable("h_grad");
dh->GetMutable<Tensor>()->mutable_data<float>(make_ddim({10, 20, 30}),
platform::CPUPlace());
// inputs: step_scopes
LOG(INFO) << "create variable step_scopes";
scope_->CreateVariable("step_scopes");
// inputs: step_net
LOG(INFO) << "create variable step_net";
scope_->CreateVariable("step_net");
// outputs: w_grad
LOG(INFO) << "create global variable w_grad";
scope_->CreateVariable("rnn/w_grad");
// outputs: x_grad
LOG(INFO) << "create global variable x_grad";
scope_->CreateVariable("x_grad");
// outputs: h_boot_grad
LOG(INFO) << "create global variable h_boot_grad";
scope_->CreateVariable("h_boot_grad");
}
void CreateStepScopes() {
std::vector<std::shared_ptr<Scope>>* step_scopes =
scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 0; i < 10; ++i) {
auto scope = std::make_shared<Scope>(scope_);
auto pre_t = scope->CreateVariable("rnn/pre_h")->GetMutable<Tensor>();
pre_t->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace());
auto tensor = scope->CreateVariable("rnn/h")->GetMutable<Tensor>();
tensor->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace());
// for unit test of ConcatOutputs
auto xg = scope->CreateVariable("rnn/x_grad")->GetMutable<Tensor>();
xg->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace());
step_scopes->push_back(scope);
}
// last time step
auto g = (*step_scopes)[9]
->CreateVariable("rnn/h_pre_grad")
->GetMutable<Tensor>();
g->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace());
}
void CreateRNNGradientAlgorithm() {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
arg->step_net = "step_net";
arg->step_scopes = "step_scopes";
rnn::Link inlink;
inlink.external = "h_grad";
inlink.internal = "rnn/h_grad";
arg->inlinks = std::vector<rnn::Link>{inlink};
rnn::Link outlink;
outlink.external = "x_grad";
outlink.internal = "rnn/x_grad";
arg->outlinks = std::vector<rnn::Link>{outlink};
rnn::MemoryAttr mem_attr;
mem_attr.pre_var = "rnn/h_pre_grad";
mem_attr.var = "rnn/h_grad";
mem_attr.boot_var = "h_boot_grad";
arg->memories = std::vector<rnn::MemoryAttr>{mem_attr};
rnn_grad_algo_.Init(std::move(arg));
}
void CreateStepNet() {
LOG(INFO) << "create variable step_net";
Variable* var = scope_->CreateVariable("step_net");
auto net = var->GetMutable<NetOp>();
net->AddOp(OpRegistry::CreateOp("mul",
{"rnn/h_pre", "rnn/w", "rnn/s_grad"},
{"rnn/h_pre_grad", "rnn/w_grad"},
{}));
net->AddOp(OpRegistry::CreateOp(
"add_two", {"rnn/h_grad"}, {"rnn/x_grad", "rnn/s_grad"}, {}));
net->CompleteAddOp();
}
void SegmentInputs() {
LOG(INFO) << "segment inputs";
std::vector<std::string> inlinks = {"x"};
std::vector<std::string> inlinks_alias = {"rnn/x"};
rnn::Link inlink;
inlink.external = "x";
inlink.internal = "rnn/x";
std::vector<std::shared_ptr<Scope>>* step_scopes =
scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10);
}
void LinkeMemories() {
LOG(INFO) << "link memories";
rnn::MemoryAttr mem_attr;
mem_attr.pre_var = "rnn/h_pre";
mem_attr.var = "rnn/h";
mem_attr.boot_var = "boot_h";
std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr);
std::vector<std::shared_ptr<Scope>>* step_scopes =
scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1);
}
}
std::shared_ptr<Scope> scope_;
RecurrentGradientAlgorithm rnn_grad_algo_;
};
// TEST_F(RecurrentGradientAlgorithmTest, Run) {
// platform::CPUDeviceContext ctx;
// rnn_grad_algo_.Run(scope_, ctx);
// }
} // namespace operators
} // namespace paddle
TEST(RecurrentOp, LinkMemories) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators;
// create and init step scopes
int len = 10;
std::vector<std::shared_ptr<Scope>> step_scopes;
for (int i = 0; i < len; ++i) {
auto scope = std::make_shared<Scope>();
scope->CreateVariable("pre_h");
auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace());
for (int i = 0; i < 15 * 20; ++i) {
data[i] = rand() * (1. / (double)RAND_MAX);
}
step_scopes.push_back(scope);
}
// create MemoryAttr
rnn::MemoryAttr mem_attr;
mem_attr.pre_var = "pre_h";
mem_attr.var = "h";
mem_attr.boot_var = "boot_h";
std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr);
for (int i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1);
}
// check
for (int i = 0; i < len - 1; ++i) {
const float* a =
step_scopes[i]->GetVariable("h")->GetMutable<Tensor>()->data<float>();
const float* b = step_scopes[i + 1]
->GetVariable("pre_h")
->GetMutable<Tensor>()
->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) {
ASSERT_FLOAT_EQ(a[i], b[i]);
}
}
for (int i = len - 2; i >= 0; --i) {
rnn::LinkMemories(step_scopes, memories, i, 1);
}
// check
for (int i = len - 2; i >= 0; --i) {
const float* a = step_scopes[i]
->GetVariable("pre_h")
->GetMutable<Tensor>()
->data<float>();
const float* b = step_scopes[i + 1]
->GetVariable("h")
->GetMutable<Tensor>()
->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) {
ASSERT_FLOAT_EQ(a[i], b[i]);
}
}
}
USE_OP(add_two);
USE_OP(mul);
# RNN 变长输入设计
对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式,
即将一个mini-batch内不同长度的序列补0到固定长度参与计算。
现有Paddle包括 `RecurrentLayerGroup` 在内的RNN均实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。
## 背景介绍
由于tensor必须有明确的shape,因此基于tensor 的主流框架在存储变长序列时,
必须用zero-padding的方式将变长序列补全为固定shape的tensor。
由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在,
因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。
由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来进行优化[1][2]
但不管是padding还是bucket,对于用户都是额外的使用负担。
因此,**paddle原生支持变长序列的方式,能直接满足用户对变长序列的最直接的需求,在当前主流平台中可以算是一大优势**
但对变长序列的支持,需要对目前框架做一些修改,下面讨论如何在最小修改下支持变长序列。
## 多层序列数据格式 `LODTensor`
目前 Paddle 会将一个mini-batch内的数据存储在一维的内存上,
额外使用 `Argument.sequenceStartPositions` 来存储每个句子的信息。
Paddle里使用 `Argument.subSequenceStartPositions` 来存储2层的序列信息,更高维度的序列则无法直接支持;
为了支持 `N-level` 序列的存储,本文将序列信息定义成如下数据结构:
```c++
std::shared_ptr<std::vector<std::vector<int>>> lod_start_pos_;
```
或者更明确的定义
```c++
typedef std::vector<int> level_t;
std::vector<level_t> lod_start_pos;
```
这里的每一个 `level_t` 存储一个粒度(level)的偏移信息,和paddle目前做法一致。
为了更透明地传递序列信息,我们引入了一种新的tensor 称为 `LODTensor`[4],
其关于tensor相关的接口都直接继承自 `Tensor`,但另外添加了序列相关接口。
如此,在操作一个 `LODTensor` 时,普通 `Op` 直接当成 `Tensor` 使用,
而操作序列的 `Op` 会额外操作 `LODTensor` 的变长序列操作的相关接口。
`LODTensor` 具体定义如下:
```c++
class LODTensor : public Tensor {
public:
size_t Levels() const { return seq_start_positions_.size(); }
size_t Elements(int level = 0) const {
return seq_start_positions_[level].size();
}
// slice of level[elem_begin: elem_end]
// NOTE low performance in slice seq_start_positions_.
// TODO should call Tensor's Slice.
LODTensor LODSlice(int level, int elem_begin, int elem_end) const;
// slice with tensor's data shared with this.
LODTensor LODSliceShared(int level, int elem_begin, int elem_end) const;
// copy other's lod_start_pos_, to share LOD info.
// NOTE the LOD info sould not be changed.
void ShareConstLODFrom(const LODTensor &other) {
lod_start_pos_ = other.lod_start_pos_;
}
// copy other's lod_start_pos_'s content, free to mutate.
void ShareMutableLODFrom(const LODTensor &other) {
lod_start_pos_ = std::make_shared <
std::vector<std::vector<int>>(other.lod_start_pos_.begin(),
other.lod_start_pos_.end());
}
private:
std::shared_ptr<std::vector<std::vector<int>>> lod_start_pos_;
};
```
其中, `lod_start_pos_` 使用了 `shared_ptr` 来减少存储和复制的代价,
可以认为 `LODTensor``Tensor` 的扩展,几乎完全兼容原始 `Tensor` 的使用。
## 框架支持
### 框架现有的 `Tensor` 调用替换为 `LODTensor`
为了实现 `LODTensor` 的传递,框架里很多 `Tensor` 都需要变成 `LODTensor`
简单实现,直接 **把之前所有的`Tensor` 全部替换成 `LODTensor`,这里可以直接修改 `pybind.cc` 里面创建`Tensor`的接口**
此外,用户有可能需要感知序列的存在(比如序列的可视化需要解析模型中输出的序列),因此一些序列操作的API也需要暴露到 python 层。
### `lod_start_pos` 随着Op调用链传递
框架需要支持下列特性,以实现`lod_start_pos`的传递:
1.`shared_ptr` 的方式实现传递
- 不修改 `lod_start_pos` 内容的作为 consumer
- 修改 `lod_start_pos` 的作为 producer
- 约定 consumer 只需要复制传递过来的 `shared_ptr`
- producer 需要创建自己的独立的内存,以存储自己独立的修改,并暴露 `shared_ptr` 给后续 consumer
- 由于传递过程是以复制`shared_ptr`的方式实现,因此框架只需要传递一次 `lod_start_pos`
2. 对于不感知 `lod_start_pos` 的Op足够透明
3. 需要修改 `lod_start_pos` 的producer Op可以在 `Run` 时更新自己的 `lod_start_pos` 数据
具体的设计分为以下3小节
#### `load_start_pos` 的传递
- 对于不需要修改 `lod_start_pos` 的情况,调用 LODTensor的 `ShareConstLODFrom` 接口实现复制
- 需要修改的,调用`ShareMutableLODFrom` 接口自己分配内存以存储修改
#### 框架透明
传递这一步需要加入到网络跑之前的初始化操作中,并且只需要初始化一次,基于当前框架设计的初步方案如下
- 在 Op 的 `attrs` 中添加一项 `do_mutate_lod_info` 的属性,默认为 `false`
- 有需要修改 `lod_start_pos` 的Op需要在定义 `OpProto` 时设置为 `true`
- `OperatorBase``InferShape` 中会读取 `do_mutate_lod_info` ,并且调用 `LODTensor` 相关的方法实现 `lod_start_pos` 的复制。
- `OperatorBase` 中添加一个 member `is_lod_inited{false}` 来保证传递只进行一次
一些逻辑如下
```c++
class OperatorBase {
public:
// ...
void InferShape() {
if (!is_load_inited) {
bool do_mutate_lod_info = GetAttr<bool>("do_mutate_load_info");
// find a input having LOD to copy
auto lod_input = ValidLODInput();
for (auto &output : outputs) {
if (do_mutate_load_info) {
output.ShareMutableLODFrom(lod_input);
} else {
output.ShareConstLODFrom(load_input);
}
}
is_pod_inited = true;
}
// call op's InferShape
// ...
}
private:
// ...
bool is_lod_inited{false};
};
```
如此,`lod_start_pos` 的信息的传递对非OLD的Op的实现是完全透明的。
#### `lod_start_pos` 的更新
上一小节介绍到,对于需要修改 `load_start_pos` 的Op,`OperatorBase` 会分配一块自己的内存以存储修改,
Op在 `Run` 的实现中,操作更新自己的 `load_start_pos`
而所有依赖其 outputs 的 op 会通过共享的指针自动获取到其更新。
## 根据长度排序
按照长度排序后,从前往后的时间步的batch size会自然地递减,可以直接塞入 Net 做batch计算
比如原始的输入:
```
origin:
xxxx
xx
xxx
-> sorted:
xxxx
xxx
xx
```
经过 `SegmentInputs` 之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)
```
0 1 2 3
x x x x
x x x
x x
```
为了追踪排序前后序列的变化,这里用
```c++
struct SortedSeqItem {
void *start{nullptr};
void *end{nullptr};
};
std::vector<SortedSeqItem> sorted_seqs;
```
来追踪序列排序后的位置,并添加一个新的接口
```c++
std::vector<SortedSeqItem> SortBySeqLen(const LODTensor& tensor);
```
由于输入序列的顺序变化,以下现有的接口需要针对性地修改:
- InitMemories, memory需要根据 `sorted_seqs` 重新排列
- SetmentInputs
- ConcatOutputs
此外,由于 `sorted_seqs` 需要被 `RecurrentGradientOp` 复用,因此会变成 `RecurrentOp` 一个新的output输出,
之后作为 `RecurrentGradientOp` 的一个输入传入。
## InitMemories
由于序列顺序的变化,`boot_memories` 的batch上的element的顺序也需要对应重新排列。
## SegmentInputs
`SegmentInputs` 会依赖 `sorted_seqs` 的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。
即下面的转变:
```
origin:
xxxx
xx
xxx
|
|
\ /
!
0 1 2 3
x x x x
x x x
x x
```
## ConcatOutputs
`ConcatOutputs` 需要
- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
- 将每个序列concat 为规则的mini-batch表示
## 参考文献
1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing)
2. [mxnet Bucketing](http://mxnet.io/how_to/bucketing.html)
3. [variable length input in RNN scenario](https://discuss.pytorch.org/t/about-the-variable-length-input-in-rnn-scenario/345/5)
4. [Level of details](https://en.wikipedia.org/wiki/Level_of_detail)
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel { class RowWiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add"); PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add");
auto dim0 = inputs[0]->dims(); auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims(); auto dim1 = inputs[1]->dims();
...@@ -34,11 +32,10 @@ protected: ...@@ -34,11 +32,10 @@ protected:
} }
}; };
class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(framework::OpProto *proto, RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector"); AddInput("b", "The right input of row-wise add op, must be vector");
AddOutput("Out", "The output of row-wise add op"); AddOutput("Out", "The output of row-wise add op");
...@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]): ...@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]):
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(rowwise_add, REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
paddle::operators::RowWiseAddOp, REGISTER_OP_CPU_KERNEL(rowwise_add,
paddle::operators::RowWiseAddOpMaker); ops::RowWiseAddKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rowwise_add,
paddle::operators::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
#include "paddle/framework/op_registry.h"
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(rowwise_add,
rowwise_add, ops::RowWiseAddKernel<ops::GPUPlace, float>);
paddle::operators::RowWiseAddKernel<paddle::platform ::GPUPlace, float>);
...@@ -13,25 +13,23 @@ ...@@ -13,25 +13,23 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public framework::OpKernel { class RowWiseAddKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto in0 = context.Input(0)->Get<framework::Tensor>(); auto in0 = context.Input(0)->Get<Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>(); auto in1 = context.Input(1)->Get<Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>(); auto* out = context.Output(0)->GetMutable<Tensor>();
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto input = framework::EigenMatrix<T>::From(in0); auto input = EigenMatrix<T>::From(in0);
auto bias = framework::EigenVector<T>::From(in1); auto bias = EigenVector<T>::From(in1);
auto output = framework::EigenMatrix<T>::From(*out); auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size; const int rest_size = input.size() / bias_size;
......
...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SGDOp : public framework::OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
...@@ -35,10 +32,10 @@ protected: ...@@ -35,10 +32,10 @@ protected:
} }
}; };
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("param", "input parameter");
AddInput("grad", "input gradient"); AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter"); AddOutput("param_out", "output parameter");
...@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad; ...@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker); REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker);
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float> REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<ops::CPUPlace, float>);
SGDOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float; REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float); \ No newline at end of file
\ No newline at end of file
...@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and ...@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel { class SGDOpKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& ctx) const override { void Compute(const KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<framework::Tensor>(); auto param = ctx.Input("param")->Get<Tensor>();
auto grad = ctx.Input("grad")->Get<framework::Tensor>(); auto grad = ctx.Input("grad")->Get<Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>(); auto* param_out = ctx.Output(0)->GetMutable<Tensor>();
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
framework::EigenVector<T>::Flatten(*param_out) EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
.device(*(ctx.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(param) - lr * EigenVector<T>::Flatten(grad);
framework::EigenVector<T>::Flatten(param) -
lr * framework::EigenVector<T>::Flatten(grad);
} }
}; };
......
...@@ -13,37 +13,33 @@ ...@@ -13,37 +13,33 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public framework::OperatorWithKernel { class SigmoidOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output");
outputs[0]->Resize(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public: public:
SigmoidOpMaker(framework::OpProto *proto, SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input"); AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output"); AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function"); AddComment("Sigmoid function");
} }
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad"; LOG(INFO) << "SigmoidGrad";
return ""; return "";
...@@ -53,11 +49,7 @@ protected: ...@@ -53,11 +49,7 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(sigmoid, REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker);
paddle::operators::SigmoidOp, REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad);
paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>);
sigmoid,
paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>);
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
sigmoid, paddle::operators::SigmoidKernel<paddle::platform::GPUPlace, float>);
...@@ -14,25 +14,23 @@ ...@@ -14,25 +14,23 @@
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel { class SigmoidKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * framework::EigenVector<T>::Flatten(input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(input)).exp());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, PADDLE_ENFORCE(inputs[0]->dims().size() == 2,
"The input of softmax op must be matrix"); "The input of softmax op must be matrix");
...@@ -31,10 +29,9 @@ protected: ...@@ -31,10 +29,9 @@ protected:
} }
}; };
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(framework::OpProto *proto, SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input of softmax"); AddInput("X", "input of softmax");
AddOutput("Y", "output of softmax"); AddOutput("Y", "output of softmax");
...@@ -42,11 +39,10 @@ public: ...@@ -42,11 +39,10 @@ public:
} }
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad"; LOG(INFO) << "SoftmaxOpGrad";
return ""; return "";
...@@ -56,9 +52,6 @@ protected: ...@@ -56,9 +52,6 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad); REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<ops::CPUPlace, float>);
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel<ops::GPUPlace, float>);
softmax, paddle::operators::SoftmaxKernel<paddle::platform::GPUPlace, float>);
...@@ -14,23 +14,21 @@ ...@@ -14,23 +14,21 @@
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel { class SoftmaxKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto logits = framework::EigenMatrix<T>::From(input); auto logits = EigenMatrix<T>::From(input);
auto softmax = framework::EigenMatrix<T>::From(*output); auto softmax = EigenMatrix<T>::From(*output);
const int kBatchDim = 0; const int kBatchDim = 0;
const int kClassDim = 1; const int kClassDim = 1;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using OperatorWithKernel = framework::OperatorWithKernel;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
...@@ -20,12 +20,101 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() ...@@ -20,12 +20,101 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
#endif
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
// later. Please refer to the implementation of class EigenCudaStreamDevice
// in TensorFlow.
//
// We find that CUDA 7 introduces a new option, the per-thread default stream,
// that has two effects. Please refer to https://devblogs.nvidia.com/
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
//
// So, we decide to use default stream and add –default-stream per-thread nvcc
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
eigen_stream_.reset(new Eigen::CudaStreamDevice());
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}
if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset();
eigen_device_.reset();
}
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(0));
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
}
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
}
return cudnn_handle_;
}
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -39,14 +39,13 @@ class DeviceContext { ...@@ -39,14 +39,13 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } CPUDeviceContext();
CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Eigen::DefaultDevice* eigen_device() const;
Place GetPlace() const override { Place GetPlace() const override;
Place retv = CPUPlace();
return retv;
}
private: private:
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
...@@ -54,119 +53,46 @@ class CPUDeviceContext : public DeviceContext { ...@@ -54,119 +53,46 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext {
public: public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { explicit CUDADeviceContext(GPUPlace);
if (previous_ != new_place) { virtual ~CUDADeviceContext();
paddle::platform::SetDeviceId(new_place.device);
}
}
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } /*! \brief Wait for all operations completion in the stream. */
void Wait() const;
private: /*! \brief Return place in the device context. */
GPUPlace previous_; Place GetPlace() const override;
};
class CUDADeviceContext : public DeviceContext { /*! \brief Return eigen device in the device context. */
public: Eigen::GpuDevice* eigen_device() const;
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_); // clang-format off
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); /*! \brief Return cublas handle in the device context. */
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); cublasHandle_t cublas_handle ();
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
} /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle ();
Place GetPlace() const override {
Place retv = GPUPlace(); /*! \brief Return curand handle in the device context. */
return retv; curandGenerator_t curand_generator();
} // clang-format on
void Wait() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
"cublasCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed");
}
return blas_handle_;
}
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
"cudnnCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed");
}
return dnn_handle_;
}
curandGenerator_t curand_generator() {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed");
}
return rand_generator_;
}
~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed");
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
}
private: private:
GPUPlace gpu_place_; GPUPlace place_;
cudaStream_t stream_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
cublasHandle_t blas_handle_{nullptr}; private:
uint64_t seed_;
cudnnHandle_t dnn_handle_{nullptr};
int random_seed_; // clang-format off
curandGenerator_t rand_generator_{nullptr}; cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
curandGenerator_t curand_generator_ = nullptr;
// clang-format on
}; };
#endif #endif
......
...@@ -36,6 +36,21 @@ limitations under the License. */ ...@@ -36,6 +36,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;
EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
try {
std::rethrow_exception(exp_);
} catch (const std::exception& exp) {
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l);
}
}
const char* what() const noexcept { return err_str_.c_str(); }
};
// Because most enforce conditions would evaluate to true, we can use // Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that // __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true. // always forces branch prediction of true.
...@@ -43,18 +58,11 @@ namespace platform { ...@@ -43,18 +58,11 @@ namespace platform {
// For more details, please check https://stackoverflow.com/a/43870188/724872. // For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) { int stat, const Args&... args) {
if (UNLIKELY(!(stat))) { if (UNLIKELY(!(stat))) {
throw std::runtime_error( throw std::runtime_error(string::Sprintf(args...));
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
} }
...@@ -64,12 +72,8 @@ template <typename... Args> ...@@ -64,12 +72,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) { cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) { if (UNLIKELY(e)) {
// clang-format off throw thrust::system_error(e, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
e, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -77,12 +81,8 @@ template <typename... Args> ...@@ -77,12 +81,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) { curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) { if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -92,12 +92,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -92,12 +92,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == CUDNN_STATUS_SUCCESS) { if (stat == CUDNN_STATUS_SUCCESS) {
return; return;
} else { } else {
// clang-format off throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
throw std::runtime_error( string::Sprintf(args...));
platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -126,22 +122,32 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -126,22 +122,32 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, "; err = "CUBLAS: license error, ";
} }
throw std::runtime_error(err + string::Sprintf(args...) + throw std::runtime_error(err + string::Sprintf(args...));
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw std::runtime_error( \ throw ::paddle::platform::EnforceNotMet( \
string::Sprintf(__VA_ARGS__) + \ std::make_exception_ptr( \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
do { \ do { \
try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \ ::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0) } while (0)
} // namespace platform } // namespace platform
......
...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) { ...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool in_catch = false; bool in_catch = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (const std::runtime_error& error) { } catch (paddle::platform::EnforceNotMet error) {
// your error handling code here // your error handling code here
in_catch = true; in_catch = true;
std::string msg = "Enforce is not ok 123 at all"; std::string msg = "Enforce is not ok 123 at all";
......
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op fc_op sgd_op cross_entropy_op) add_op fc_op sgd_op cross_entropy_op recurrent_network_op)
...@@ -38,6 +38,7 @@ USE_OP(mul); ...@@ -38,6 +38,7 @@ USE_OP(mul);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
USE_OP(rowwise_add); USE_OP(rowwise_add);
USE_OP_WITHOUT_KERNEL(recurrent_op);
template <typename ClassType> template <typename ClassType>
void ExposeOperator(ClassType& m) { void ExposeOperator(ClassType& m) {
...@@ -50,6 +51,11 @@ void ExposeOperator(ClassType& m) { ...@@ -50,6 +51,11 @@ void ExposeOperator(ClassType& m) {
.def("__str__", &ClassType::type::DebugString); .def("__str__", &ClassType::type::DebugString);
} }
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
...@@ -103,6 +109,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -103,6 +109,11 @@ All parameter, weight, gradient are variables in Paddle.
[](pd::Variable& self) -> pd::Tensor* { [](pd::Variable& self) -> pd::Tensor* {
return self.GetMutable<pd::Tensor>(); return self.GetMutable<pd::Tensor>();
}, },
py::return_value_policy::reference)
.def("get_net",
[](pd::Variable& self) -> pd::NetOp* {
return self.GetMutable<pd::NetOp>();
},
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<pd::Scope, std::shared_ptr<pd::Scope>>(m, "Scope") py::class_<pd::Scope, std::shared_ptr<pd::Scope>>(m, "Scope")
...@@ -112,7 +123,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -112,7 +123,8 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference) py::return_value_policy::reference)
.def("create_var", .def("create_var",
&pd::Scope::CreateVariable, &pd::Scope::CreateVariable,
py::return_value_policy::reference); py::return_value_policy::reference)
.def("get_var_name", &pd::Scope::GetVariableName);
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
...@@ -166,24 +178,25 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -166,24 +178,25 @@ All parameter, weight, gradient are variables in Paddle.
}); });
ExposeOperator(operator_base); ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>; py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
py::class_<pd::PlainNet, PlainNetPtr> plain_net(m, "PlainNet");
plain_net net.def_static("create",
.def_static("create", []() -> std::shared_ptr<pd::NetOp> {
[]() -> std::shared_ptr<pd::PlainNet> { auto retv = std::make_shared<pd::NetOp>();
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "plain_net"; retv->type_ = "plain_net";
return retv; return retv;
}) })
.def("add_op", &pd::PlainNet::AddOp) .def("add_op", &pd::NetOp::AddOp)
.def("add_op", .def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& plain_net) -> void { [](pd::NetOp& self, const std::shared_ptr<pd::NetOp>& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(plain_net)); self.AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
}) })
.def("complete_add_op", &pd::PlainNet::CompleteAddOp) .def("complete_add_op", &pd::NetOp::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); .def("complete_add_op",
ExposeOperator(plain_net); [](std::shared_ptr<pd::NetOp>& self) { self->CompleteAddOp(); });
ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
return m.ptr(); return m.ptr();
} }
...@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init( ...@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init(
sgdConfigV2->set_decay(paramConfig.decay_rate()); sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const); optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr(); auto constlr = optimizeConfigV2.mutable_const_lr();
if (paramConfig.has_learning_rate()) {
constlr->set_learning_rate(paramConfig.learning_rate()); constlr->set_learning_rate(paramConfig.learning_rate());
} else {
constlr->set_learning_rate(trainerConfig_.learning_rate());
}
if (trainerConfig_.algorithm() == "sgd") { if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD); optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms // FIXME: config all algorithms
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册