diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 980a97a07c996eca2e8c126a6ad5ab7f340fa1e5..2ca988c406ae2987e26ca37dbc17cc0a2af43743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,10 +17,14 @@ - id: detect-private-key files: (?!.*third_party)^.*$ | (?!.*book)^.*$ - id: end-of-file-fixer -- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29 +- repo: local hooks: - - id: clang-formater + - id: clang-format + name: clang-format + description: Format files with ClangFormat. + entry: clang-format -i + language: system + files: \.(c|cc|cxx|cpp|h|hpp|hxx)$ - repo: https://github.com/PaddlePaddle/pre-commit-golang sha: 8337620115c25ff8333f1b1a493bd031049bd7c0 hooks: diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 3e6cedbb0d718cfd4454f95dedf7e02a24f2981b..f7483f6be9169eb58f0148cd3a956a8c881e1fe3 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) ExternalProject_Add( extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} - # for latest version, please get from official website - # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" - # URL_MD5 "1a47e78efe365a97de0c022d127607c3" - - # for no-ssl http support, please get from bazel's mirror - # URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz" - # URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7" - - # get from github mirror GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" + GIT_TAG "master" PREFIX ${EIGEN_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 34fd348893058980964d723490d9cc220a157b5a..ef31c252038ce18655913c0f41343fe6dc7dbb86 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. -LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) +LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) if(CMAKE_BUILD_TYPE STREQUAL "Debug") diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 9a317d416c375b16c345a7ba38cba7b552fc3cab..45614f33e2a57c0a3a1486dffcb8e38c1c74b669 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -203,6 +203,10 @@ identity_projection .. autoclass:: paddle.v2.layer.identity_projection :noindex: +slice_projection +------------------- +.. autoclass:: paddle.v2.layer.slice_projection + :noindex: table_projection ---------------- diff --git a/doc/design/scope.md b/doc/design/scope.md index afe6bc028cafc5ee24b0041905857af58d3f5790..c9e0be716b606f6c7bf0373e0c6e632647e07a6f 100644 --- a/doc/design/scope.md +++ b/doc/design/scope.md @@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`. ```cpp class Scope { public: - Variable* CreateVariable(const std::string& name); - const Variable* GetVariable(const std::string& name) const; + Variable* NewVar(const std::string& name); + const Variable* FindVar(const std::string& name) const; private: std::unordered_map> vars_; @@ -58,12 +58,12 @@ class Scope { public: Scope(const std::shared_ptr& scope): parent_(scope) {} - Variable* GetVariable(const std::string& name) const { + Variable* FindVar(const std::string& name) const { auto it = vars_.find(name); if (it != vars_.end()) { return it->second.get(); } else if (parent_ != nullptr) { - return parent_->GetVariable(name); + return parent_->FindVar(name); } else { return nullptr; } @@ -95,10 +95,10 @@ class Scope { static std::shared_ptr Create(const std::shared_ptr& parent = nullptr); // return nullptr if not found. - Variable* GetVariable(const std::string& name) const; + Variable* FindVar(const std::string& name) const; // return if already contains same name variable. - Variable* CreateVariable(const std::string& name); + Variable* NewVar(const std::string& name); private: std::shared_ptr parent_; @@ -107,11 +107,11 @@ class Scope { ``` ## Only scope can create a variable -To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`. +To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `NewVar` can construct `Variable`. ## When scope destroyed, all variables inside this scope should be destroyed together -The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together. +The scope hold unique pointers for all variables. User can `FindVar` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together. ## Sharing a parent scope @@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar ## Orthogonal interface -`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily. +`FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `NewVar` will return a `Error` when there is a name conflict locally. Combine `FindVar` and `NewVar`, we can implement `NewVar` easily. diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 287da694915ca383dc29e6d33201dc701cb7de87..739c4c01e02b10f46c36b997f8c4700150da2a26 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -19,6 +19,8 @@ import ( "net" "net/http" "net/rpc" + "os" + "os/signal" "strconv" "strings" "time" @@ -68,6 +70,20 @@ func main() { store = &master.InMemStore{} } + shutdown := func() { + log.Infoln("shutting down gracefully") + err := store.Shutdown() + if err != nil { + log.Errorln(err) + } + } + + // Guaranteed to run even panic happens. + defer shutdown() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) if err != nil { log.Fatal(err) @@ -84,8 +100,12 @@ func main() { log.Fatal(err) } - err = http.Serve(l, nil) - if err != nil { - log.Fatal(err) - } + go func() { + err = http.Serve(l, nil) + if err != nil { + log.Fatal(err) + } + }() + + <-c } diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index aa81d0432b1d4f411644e0a5b703d7ea74d144b7..f9cd8f87e8f2e715c87834ee08482be0f511f681 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -18,6 +18,8 @@ import ( "net" "net/http" "net/rpc" + "os" + "os/signal" "strconv" "time" @@ -33,7 +35,8 @@ func main() { index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") - etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls") + dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout") + etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") @@ -53,7 +56,7 @@ func main() { if *index >= 0 { idx = *index } else { - e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) + e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL) idx, err = e.Register(*port) candy.Must(err) @@ -67,6 +70,20 @@ func main() { } } + shutdown := func() { + log.Infoln("shutting down gracefully") + sErr := e.Shutdown() + if sErr != nil { + log.Errorln(sErr) + } + } + + // Guaranteed to run even panic happens. + defer shutdown() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) candy.Must(err) @@ -77,7 +94,11 @@ func main() { l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) candy.Must(err) - log.Infof("start pserver at port %d", *port) - err = http.Serve(l, nil) - candy.Must(err) + go func() { + log.Infof("start pserver at port %d", *port) + err = http.Serve(l, nil) + candy.Must(err) + }() + + <-c } diff --git a/go/glide.lock b/go/glide.lock index f71ae643d68d29846611ec52d0ae7d67e4ced850..1f16abdf66422abcd0ab7987cab3499d02cf1b9c 100644 --- a/go/glide.lock +++ b/go/glide.lock @@ -1,15 +1,105 @@ -hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855 -updated: 2017-07-11T10:04:40.786745417+08:00 +hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c +updated: 2017-07-29T07:34:48.722757905+08:00 imports: +- name: github.com/beorn7/perks + version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 + subpackages: + - quantile +- name: github.com/boltdb/bolt + version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9 +- name: github.com/cockroachdb/cmux + version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92 - name: github.com/coreos/etcd - version: cb2a496c4ddd1c87a9f280e116649b599999ec79 + version: c31bec0f29facff13f7c3e3d948e55dd6689ed42 subpackages: + - alarm + - auth - auth/authpb + - client - clientv3 - clientv3/concurrency + - compactor + - discovery + - embed + - error + - etcdserver + - etcdserver/api + - etcdserver/api/v2http + - etcdserver/api/v2http/httptypes + - etcdserver/api/v3client + - etcdserver/api/v3election + - etcdserver/api/v3election/v3electionpb + - etcdserver/api/v3election/v3electionpb/gw + - etcdserver/api/v3lock + - etcdserver/api/v3lock/v3lockpb + - etcdserver/api/v3lock/v3lockpb/gw + - etcdserver/api/v3rpc - etcdserver/api/v3rpc/rpctypes + - etcdserver/auth - etcdserver/etcdserverpb + - etcdserver/etcdserverpb/gw + - etcdserver/membership + - etcdserver/stats + - lease + - lease/leasehttp + - lease/leasepb + - mvcc + - mvcc/backend - mvcc/mvccpb + - pkg/adt + - pkg/contention + - pkg/cors + - pkg/cpuutil + - pkg/crc + - pkg/debugutil + - pkg/fileutil + - pkg/httputil + - pkg/idutil + - pkg/ioutil + - pkg/logutil + - pkg/monotime + - pkg/netutil + - pkg/pathutil + - pkg/pbutil + - pkg/runtime + - pkg/schedule + - pkg/srv + - pkg/tlsutil + - pkg/transport + - pkg/types + - pkg/wait + - proxy/grpcproxy/adapter + - raft + - raft/raftpb + - rafthttp + - snap + - snap/snappb + - store + - version + - wal + - wal/walpb +- name: github.com/coreos/go-semver + version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6 + subpackages: + - semver +- name: github.com/coreos/go-systemd + version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6 + subpackages: + - daemon + - journal + - util +- name: github.com/coreos/pkg + version: 3ac0863d7acf3bc44daf49afef8919af12f704ef + subpackages: + - capnslog +- name: github.com/dgrijalva/jwt-go + version: d2709f9f1f31ebcda9651b03077758c1f3a0018c +- name: github.com/ghodss/yaml + version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7 +- name: github.com/gogo/protobuf + version: 909568be09de550ed094403c2bf8a261b5bb730a + subpackages: + - proto - name: github.com/golang/protobuf version: 4bd1920723d7b7c925de087aa32e2187708897f7 subpackages: @@ -17,14 +107,61 @@ imports: - proto - name: github.com/golang/snappy version: 553a641470496b2327abcac10b36396bd98e45c9 +- name: github.com/google/btree + version: 925471ac9e2131377a91e1595defec898166fe49 +- name: github.com/grpc-ecosystem/go-grpc-prometheus + version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0 +- name: github.com/grpc-ecosystem/grpc-gateway + version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676 + subpackages: + - runtime + - runtime/internal + - utilities +- name: github.com/jonboulle/clockwork + version: 2eee05ed794112d45db504eb05aa693efd2b8b09 +- name: github.com/matttproud/golang_protobuf_extensions + version: c12348ce28de40eed0136aa2b644d0ee0650e56c + subpackages: + - pbutil - name: github.com/namsral/flag version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04 - name: github.com/PaddlePaddle/recordio - version: edfb82af0739c84f241c87390ec5649c7b28c129 + version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81 +- name: github.com/prometheus/client_golang + version: c5b7fccd204277076155f10851dad72b76a49317 + subpackages: + - prometheus +- name: github.com/prometheus/client_model + version: 6f3806018612930941127f2a7c6c453ba2c527d2 + subpackages: + - go +- name: github.com/prometheus/common + version: 49fee292b27bfff7f354ee0f64e1bc4850462edf + subpackages: + - expfmt + - internal/bitbucket.org/ww/goautoneg + - model +- name: github.com/prometheus/procfs + version: a1dba9ce8baed984a2495b658c82687f8157b98f + subpackages: + - xfs - name: github.com/sirupsen/logrus - version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1 + version: a3f95b5c423586578a4e099b11a46c2479628cac - name: github.com/topicai/candy version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc +- name: github.com/ugorji/go + version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74 + subpackages: + - codec +- name: github.com/xiang90/probing + version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2 +- name: golang.org/x/crypto + version: 1351f936d976c60a0a48d728281922cf63eafb8d + repo: https://github.com/golang/crypto.git + vcs: git + subpackages: + - bcrypt + - blowfish - name: golang.org/x/net version: c8c74377599bd978aee1cf3b9b63a8634051cec2 subpackages: @@ -36,11 +173,15 @@ imports: - lex/httplex - trace - name: golang.org/x/sys - version: abf9c25f54453410d0c6668e519582a9e1115027 + version: 0f826bdd13b500be0f1d4004938ad978fcc6031e + repo: https://github.com/golang/sys.git + vcs: git subpackages: - unix - name: golang.org/x/text - version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa + version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 + repo: https://github.com/golang/text.git + vcs: git subpackages: - secure/bidirule - transform @@ -60,4 +201,23 @@ imports: - stats - tap - transport -testImports: [] +- name: gopkg.in/yaml.v2 + version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b +testImports: +- name: github.com/davecgh/go-spew + version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + subpackages: + - spew +- name: github.com/docker/docker + version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e + subpackages: + - pkg/ioutils + - pkg/longpath +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: 05e8a0eda380579888eb53c394909df027f06991 + subpackages: + - assert diff --git a/go/glide.yaml b/go/glide.yaml index ab472c7cda9755d0399bb8376b16589be8b53057..bc23fa6ebf2c3db61e2d63e5f7e7ddcb595dfed0 100644 --- a/go/glide.yaml +++ b/go/glide.yaml @@ -6,8 +6,19 @@ import: subpackages: - clientv3 - clientv3/concurrency + - embed + - etcdserver - package: github.com/namsral/flag version: ^1.7.4-pre - package: github.com/sirupsen/logrus version: ^1.0.0 - package: github.com/topicai/candy +- package: golang.org/x/crypto + vcs: git + repo: https://github.com/golang/crypto.git +- package: golang.org/x/sys + vcs: git + repo: https://github.com/golang/sys.git +- package: golang.org/x/text + vcs: git + repo: https://github.com/golang/text.git diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index ae6b6f776bec9ccaead4465ad233fc8ed6c3a418..94848d887e8bc4b055a7c8b89b9b7f26a39229d1 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -39,15 +39,12 @@ type EtcdClient struct { statePath string client *clientv3.Client lock *concurrency.Mutex + sess *concurrency.Session } // NewEtcdClient creates a new EtcdClient. func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { log.Debugf("Connecting to etcd at %v", endpoints) - // TODO(helin): gracefully shutdown etcd store. Because etcd - // store holds a etcd lock, even though the lock will expire - // when the lease timeout, we need to implement graceful - // shutdown to release the lock. cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: dialTimeout, @@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat // one master running, but split-brain problem may cause // multiple master servers running), and the cluster management // software will kill one of them. - log.Debugf("Trying to acquire lock at %s.", lockPath) + log.Infof("Trying to acquire lock at %s.", lockPath) err = lock.Lock(context.TODO()) if err != nil { return nil, err } - log.Debugf("Successfully acquired lock at %s.", lockPath) + log.Infof("Successfully acquired lock at %s.", lockPath) put := clientv3.OpPut(addrPath, addr) resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() @@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat statePath: statePath, client: cli, lock: lock, + sess: sess, } return e, nil @@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) { return state, nil } +// Shutdown shuts down the etcd client gracefully. +func (e *EtcdClient) Shutdown() error { + err := e.sess.Close() + newErr := e.client.Close() + if newErr != nil { + if err == nil { + err = newErr + } else { + log.Errorln(newErr) + } + } + + return err +} + // GetKey gets the value by the specify key. func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) diff --git a/go/master/inmem_store.go b/go/master/inmem_store.go index ffd663f7f0b25c29f0bab082d27b29dcfeb60826..a5bd2d4fe150cd34c699ccfae1f3d3e0fb2ef3d6 100644 --- a/go/master/inmem_store.go +++ b/go/master/inmem_store.go @@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) { return m.buf, nil } + +// Shutdown shuts down the in mem store. +func (m *InMemStore) Shutdown() error { + return nil +} diff --git a/go/master/service.go b/go/master/service.go index 1f2112ecfb925ee8bb9545f0bb1100efd3ad11ca..d30e9a33229c0aff354417771b5bf2ae6a781715 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -50,6 +50,7 @@ var ErrPassAfter = errors.New("pass number larger than master") type Store interface { Save([]byte) error Load() ([]byte, error) + Shutdown() error } // Chunk is a chunk of data consisted of several data instances. diff --git a/go/master/service_test.go b/go/master/service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5f91910ecc8cf32289e71e2e41e8b283acc115e6 --- /dev/null +++ b/go/master/service_test.go @@ -0,0 +1,68 @@ +package master_test + +import ( + "os" + "testing" + "time" + + "github.com/PaddlePaddle/Paddle/go/master" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/embed" + "github.com/docker/docker/pkg/ioutils" + "github.com/stretchr/testify/assert" +) + +func TestNewServiceWithEtcd(t *testing.T) { + // setup an embed etcd server + etcdDir, err := ioutils.TempDir("", "") + if err != nil { + t.Fatal(err) + } + cfg := embed.NewConfig() + cfg.Dir = etcdDir + e, err := embed.StartEtcd(cfg) + if err != nil { + t.Fatal(err) + } + defer func() { + e.Close() + if err := os.RemoveAll(etcdDir); err != nil { + t.Fatal(err) + } + }() + select { + case <-e.Server.ReadyNotify(): + t.Log("Server is ready!") + case <-time.After(60 * time.Second): + e.Server.Stop() // trigger a shutdown + t.Fatal("Server took too long to start!") + } + + ep := []string{"127.0.0.1:2379"} + masterAddr := "127.0.0.1:3306" + store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30) + if err != nil { + t.Fatal(err) + } + + _, err = master.NewService(store, 10, 10, 3) + if err != nil { + t.Fatal(err) + } + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: 3 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second) + if err != nil { + t.Fatal(err) + } + if err := cli.Close(); err != nil { + t.Fatal(err) + } + // test master process registry itself into etcd server. + assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.") +} diff --git a/go/pserver/client/c/cclient.go b/go/pserver/client/c/cclient.go index 0f7e20cdd8d20e37b586c22377a89fca4c3cf7ce..14ad0774550f6e5a5d8610d6007904cd2820432c 100644 --- a/go/pserver/client/c/cclient.go +++ b/go/pserver/client/c/cclient.go @@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client func add(c *client.Client) C.paddle_pserver_client { mu.Lock() defer mu.Unlock() - client := curHandle + cli := curHandle curHandle++ - handleMap[client] = c - return client + handleMap[cli] = c + return cli } func get(client C.paddle_pserver_client) *client.Client { diff --git a/go/pserver/client/c/test/test_train.py b/go/pserver/client/c/test/test_train.py index 85cb399590f7a5e7e73285ca87c49ea5f24afb32..572a61e4ccaa9ef3d03a60d916e80eab907c6d88 100644 --- a/go/pserver/client/c/test/test_train.py +++ b/go/pserver/client/c/test/test_train.py @@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing import paddle.v2.master as master import os import cPickle as pickle +from paddle.v2.reader.creator import cloud_reader etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") -etcd_endpoint = "http://" + etcd_ip + ":2379" -print "connecting to master, etcd endpoints: ", etcd_endpoint -master_client = master.client(etcd_endpoint, 5, 64) - - -def cloud_reader(): - global master_client - master_client.set_dataset( - ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30) - while 1: - r, e = master_client.next_record() - if not r: - if e != -2: # other errors - print "get record error:", e - break - yield pickle.loads(r) +etcd_endpoints = "http://" + etcd_ip + ":2379" +print "etcd endpoints: ", etcd_endpoints def main(): @@ -49,7 +36,7 @@ def main(): parameters=parameters, update_equation=optimizer, is_local=False, - pserver_spec=etcd_endpoint, + pserver_spec=etcd_endpoints, use_etcd=True) # event_handler to print training and testing info @@ -75,7 +62,11 @@ def main(): trainer.train( reader=paddle.batch( paddle.reader.shuffle( - cloud_reader, buf_size=500), batch_size=2), + cloud_reader( + ["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"], + etcd_endpoints), + buf_size=500), + batch_size=2), feeding={'x': 0, 'y': 1}, event_handler=event_handler, diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 98ff8ce827c7cfcd9122cb043f2a6226057cc95a..4fb26307667295ab825d07be6c3d1d4b33f6eb8b 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -34,16 +34,19 @@ const ( PsPath = "/ps/" // PsCheckpoint is the etcd path for store checkpoints information PsCheckpoint = "/checkpoints/" + + retryTimeout = 5 * time.Second ) // EtcdClient is the etcd client that the pserver uses for fault // tolerance, service registry and coordination. type EtcdClient struct { - numPservers int - etcdEndpoints string - etcdClient *clientv3.Client - // etcdTimeout is also used as retry intervals. - etcdTimeout time.Duration + numPservers int + endpoints string + client *clientv3.Client + sess *concurrency.Session + dialTimeout time.Duration + ttlSec int // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. externalIP string // desired number of pservers in the job. @@ -52,11 +55,12 @@ type EtcdClient struct { } // NewEtcdClient creates an EtcdClient -func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { +func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient { return &EtcdClient{ - etcdTimeout: timeout, - numPservers: numPservers, - etcdEndpoints: endpoints, + dialTimeout: dialtimeout, + ttlSec: ttlSec, + numPservers: numPservers, + endpoints: endpoints, } } @@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et // // Register returns the index of the current pserver. func (e *EtcdClient) Register(port int) (int, error) { - var err error e.externalIP, err = networkhelper.GetExternalIP() if err != nil { @@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) { } // initialize connection to etcd. - ep := strings.Split(e.etcdEndpoints, ",") + ep := strings.Split(e.endpoints, ",") for { cli, err := clientv3.New(clientv3.Config{ Endpoints: ep, - DialTimeout: e.etcdTimeout, + DialTimeout: e.dialTimeout, }) if err != nil { log.Errorf("connect to etcd error: %v", err) - time.Sleep(e.etcdTimeout) + time.Sleep(retryTimeout) + continue + } + e.client = cli + sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec)) + if err != nil { + log.Errorf("create etcd session error: %v", err) + time.Sleep(retryTimeout) continue } - e.etcdClient = cli - log.Debugf("inited client to %s", e.etcdEndpoints) + e.sess = sess + log.Debugf("inited client to %s", e.endpoints) break } // init /ps_desired using transaction, for multiple pservers may want to write @@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) { cancel() if err != nil { log.Warn(err) - time.Sleep(e.etcdTimeout) + time.Sleep(retryTimeout) continue } break @@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) { // wait and set s.desired init value for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := e.etcdClient.Get(ctx, PsDesired) + resp, err := e.client.Get(ctx, PsDesired) cancel() if err != nil { log.Errorf("getting %s error: %v", PsDesired, err) - time.Sleep(e.etcdTimeout) + time.Sleep(retryTimeout) continue } if len(resp.Kvs) != 0 { e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) if err != nil { log.Errorf("value of %s invalid %v\n", PsDesired, err) - time.Sleep(e.etcdTimeout) + time.Sleep(retryTimeout) // NOTE: wait util ps_desired value change continue } @@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) { cancel() if err != nil { log.Warn(err) - time.Sleep(e.etcdTimeout) + time.Sleep(retryTimeout) continue } break @@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) { } func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + return concurrency.NewSTM(e.client, func(c concurrency.STM) error { dsStr := c.Get(PsDesired) if dsStr == "" { - c.Put(PsDesired, strconv.Itoa(numPservers)) + c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease())) } return nil }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) @@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) ( // registerPserverEtcd registers pserver node on etcd using transaction. func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) { var idx int - _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error { registered := false for i := 0; i < e.desired; i++ { psKey := PsPath + strconv.Itoa(i) @@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er log.Debugf("got value (%s) for key: %s", ps, psKey) if ps == "" { - resp, err := e.etcdClient.Grant(context.TODO(), 5) - if err != nil { - log.Fatal(err) - } // find the first id and write info pserverAddr := e.externalIP + ":" + strconv.Itoa(port) - c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID)) + c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease())) log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) - ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) - if kaerr != nil { - log.Errorf("keepalive etcd node error: %v", kaerr) - return kaerr - } - - // Eat the keep alive message so etcd - // will not expire the lease. - go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { - ka := <-ch - log.Debugf("keepalive: %d\n", ka.TTL) - }(ch) log.Debug("register finished") idx = i registered = true @@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er // GetKey gets the value by the specified key func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) - resp, err := e.etcdClient.Get(ctx, key) + resp, err := e.client.Get(ctx, key) cancel() if err != nil { return []byte{}, err @@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { // PutKey put into etcd with value by key specified func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) - _, err := e.etcdClient.Put(ctx, key, string(value)) + _, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) cancel() return err } + +// Shutdown shuts down the etcd client gracefully. +func (e *EtcdClient) Shutdown() error { + var err error + if e.sess != nil { + err = e.sess.Close() + } + + if e.client != nil { + newErr := e.client.Close() + if newErr != nil { + if err != nil { + log.Errorln(newErr) + } else { + err = newErr + } + } + } + return err +} diff --git a/paddle/cuda/src/hl_cuda_sequence.cu b/paddle/cuda/src/hl_cuda_sequence.cu index 4f650ce03ccb2d14cc2997e9cd426acb91439539..eeee921db54e20ea6a017d2b83f2d7ca9e5e037e 100644 --- a/paddle/cuda/src/hl_cuda_sequence.cu +++ b/paddle/cuda/src/hl_cuda_sequence.cu @@ -269,8 +269,7 @@ void hl_sequence2batch_copy_padding(real* batch, int blockDimY = CUDA_BLOCK_SIZE / blockDimX; dim3 threads(blockDimX, blockDimY); - int gridDimX = (maxSequenceLength * blockDimX + CUDA_BLOCK_SIZE - 1) / - CUDA_BLOCK_SIZE; + int gridDimX = (maxSequenceLength + blockDimY - 1) / blockDimY; int gridDimY = numSequences; dim3 grid(gridDimX, gridDimY); diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 21cb7c7265e0052630b68954fa25f9189e641e7b..12a3a00bba35d476fca9c9fb47ac20b87e6f53f2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -8,7 +8,9 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(variable_test SRCS variable_test.cc) -cc_test(scope_test SRCS scope_test.cc) + +cc_library(scope SRCS scope.cc) +cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) @@ -16,7 +18,7 @@ proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) -cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) +cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) @@ -30,4 +32,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) cc_library(net SRCS net.cc DEPS op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net) + +cc_library(backward SRCS backward.cc DEPS net) +cc_test(backward_test SRCS backward_test.cc DEPS backward) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc new file mode 100644 index 0000000000000000000000000000000000000000..0da11b91a7fe4a98e0832f70095c3200956ff001 --- /dev/null +++ b/paddle/framework/backward.cc @@ -0,0 +1,178 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/backward.h" +#include +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { + +static bool AllInSet(const std::vector& names, + const std::string& suffix, + const std::unordered_set& set) { + for (auto& name : names) { + if (set.find(name + suffix) == set.end()) { + return false; + } + } + return true; +} + +static std::shared_ptr NOP() { + auto net_op = std::make_shared(); + net_op->type_ = "@NOP@"; + net_op->CompleteAddOp(); + return net_op; +} + +// Get backward operator from a forward operator, recursively implementation. +// +// no_grad_names the gradient variable names without gradient calculating. +// +// uniq_id is a unique index used inside recursively calling BackwardRecursive. +// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through +// recursive calling. +// +// returns The backward operator. For simple situation, it is a simple +// operator. For complex situation, it is a NetOp. +// +// See Backward.h for details +static std::shared_ptr BackwardRecursive( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, size_t& uniq_id); +std::shared_ptr BackwardRecursive( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, size_t& uniq_id) { + // If all input gradients of forwarding operator do not need to calculate, + // just return an NOP. Not return null ptr because NOP does not take + // too much time for calculation, but it is useful for simplifying logic. + if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), + no_grad_names)) { + return NOP(); + } + + // All output gradients of forwarding operator do not need to calculate. Then + // all input gradients cannot be computed at all, and we put them into + // `no_grad_names` set. Return an NOP. + if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), + no_grad_names)) { + for (auto& name : forwardOp.inputs_) { + // Mark all input is not need + no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + return NOP(); + } + + // Returned gradient network + auto net = std::make_shared(); + + if (forwardOp.IsNetOp()) { + // Because forwardOp is a net op, it can static_cast. + auto& forwardNet = static_cast(forwardOp); + + // Map from output gradient variable name to operator's indices in backward + // net. That operator generates that variable. + std::unordered_map> dup_output_ops; + + size_t local_op_id = 0; + // reversely travel forwardNet + for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); + ++it, ++local_op_id) { + auto fwd = *it; + auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); + net->AddOp(bwd); + for (auto& out : bwd->outputs_) { + dup_output_ops[out].emplace_back(local_op_id); + } + } + // Get unique ID for this method. + auto uid = uniq_id++; + // TODO(dzh): more comment + using Pos = std::pair>; + std::list insert_position; + for (auto& dup_output_op : dup_output_ops) { + const std::string& name = dup_output_op.first; + auto& dup_op = dup_output_op.second; + if (dup_op.size() == 1) continue; + std::vector dup_outputs; + + for (size_t i = 0; i < dup_op.size(); ++i) { + auto op_offset = dup_op[i]; + dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + + std::to_string(i)); + net->ops_[op_offset]->Rename(name, dup_outputs.back()); + } + insert_position.push_back( + {dup_op.back(), + OpRegistry::CreateOp( + "add", {dup_outputs}, {name}, + {{"input_format", + std::vector{0, static_cast(dup_outputs.size())}}})}); + } + + insert_position.sort( + [](const Pos& l, const Pos& r) { return l.first > r.first; }); + + for (auto& pos : insert_position) { + net->InsertOp(pos.first + 1, pos.second); + } + + } else { + std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + for (std::string& grad_input : grad_op->inputs_) { + if (no_grad_names.count(grad_input)) { + std::string prefix = grad_input.substr( + 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); + grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); + + // If part of input gradient of that operator is not calculated, fill + // zero variables to that input gradient. + net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, + {grad_input}, {})); + } + } + + for (std::string& grad_output : grad_op->outputs_) { + if (no_grad_names.count(grad_output)) { + grad_output = OperatorBase::EMPTY_VAR_NAME(); + } + } + + if (net->ops_.empty()) { // Current no aux op is added to network + return grad_op; + } + net->AddOp(grad_op); + } + net->type_ = "@GENERATED_BACKWARD@"; + net->CompleteAddOp(); + return net; +} + +// See header for comments +std::shared_ptr Backward( + const OperatorBase& forwardOp, + const std::unordered_set& no_grad_vars) { + std::unordered_set no_grad_names; + no_grad_names.reserve(no_grad_vars.size()); + + for (auto& name : no_grad_vars) { + no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + size_t uid = 0; + return BackwardRecursive(forwardOp, no_grad_names, uid); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h new file mode 100644 index 0000000000000000000000000000000000000000..c181919dc165cf0b49362f85e22ceb4131bbd387 --- /dev/null +++ b/paddle/framework/backward.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "operator.h" +namespace paddle { +namespace framework { + +// Create the backward operator from a forward operator. +// TODO(yuyang18): Add more API reference comment. +extern std::shared_ptr Backward( + const OperatorBase& forwardOp, + const std::unordered_set& no_grad_vars); +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md new file mode 100644 index 0000000000000000000000000000000000000000..74c001b06a9e7b2279abf998604f2acf1b1168e4 --- /dev/null +++ b/paddle/framework/backward.md @@ -0,0 +1,38 @@ +## Operator/expression 's Backward + +### Motivation + +In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass. + +### Implement : gradient operator registry + +| | forward operator | backward operator | +| ---------------------- | ---------------- | -------------------------------- | +| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | +| **Operator::outputs_** | Outputs | InputGradients | + +Inputs/Outputs means the input/output of the operator, InputGradients/OutputGradients is the gradient respect to forward opeartor. Forward operator and Backward operator are isomorphic, save their corresponding needs into member attribute. + +We use a global hash map record the gradient operators available, follow the philosophy of minimum core, make operator pluggable unit. Each gradient is an operator and it needs to regist itself. + +grad_op_builder(fengjiayi) + +### Implement : Backward network + +given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`. + +1. bla bla bla (yuyang) + +2. NetOp + + when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name. + + We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable. + + ![./images/duplicate_op]() + + Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead. + +![./images/duplicate_op2]() + +​ Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b095c2c3d5dbf21b5ea70e17475a4aaad9b1db44 --- /dev/null +++ b/paddle/framework/backward_test.cc @@ -0,0 +1,389 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/backward.h" + +#include +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { + +class EmptyOp : public OperatorBase { + public: + void InferShape(const Scope &scope) const override {} + void Run(const Scope &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { + public: + RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input X of Add").IgnoreGradient(); + AddInput("b", "Bias of Add").IgnoreGradient(); + AddOutput("Out", "Out of Add").IgnoreGradient(); + AddComment("Add Op"); + } +}; + +class MulOpMaker : public OpProtoAndCheckerMaker { + public: + MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("A", "A"); + AddInput("B", "B"); + AddOutput("Out", "Out"); + AddComment("Mul"); + } +}; + +class SigmoidOpMaker : public OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X"); + AddOutput("Y", "Y"); + AddComment("Sigmoid"); + } +}; + +class NoGradOpMaker : public OpProtoAndCheckerMaker { + public: + NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X input"); + AddOutput("Y", "Y output"); + AddComment("NoGradOp, same input output. no Grad"); + } +}; + +class FcOp : public NetOp { + public: + void Init() override { + AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, + {Output("mul_result")}, {})); + auto b_name = Input("b"); + std::string before_act = "mul_result"; + if (b_name != EMPTY_VAR_NAME()) { + AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, + {Output("add_result")}, {})); + before_act = "add_result"; + } else { + auto out_varname = Output("add_result"); + if (out_varname != EMPTY_VAR_NAME()) { + this->Rename(out_varname, EMPTY_VAR_NAME()); + } + } + + AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")}, + {})); + CompleteAddOp(false); + } +}; + +class FcOpMaker : public OpProtoAndCheckerMaker { + public: + FcOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x"); + AddInput("W", "w"); + AddInput("b", "b"); + AddOutput("mul_result", "").SetTemporary(); + AddOutput("add_result", "").SetTemporary(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class ManyOutputOpMaker : public OpProtoAndCheckerMaker { + public: + ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("y", "y"); + AddOutput("z", "z"); + AddComment(""); + } +}; + +class FillZeroOpMaker : public OpProtoAndCheckerMaker { + public: + FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("out", "out"); + AddComment(""); + } +}; + +class AddOpMaker : public OpProtoAndCheckerMaker { + public: + AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x").SetMultiple(); + AddOutput("Y", "y"); + AddComment(""); + } +}; +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; +using EnforceNotMet = paddle::platform::EnforceNotMet; +REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); +REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); +REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); +REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); +REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); +REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); +REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); +REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); +REGISTER_OP(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); +REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); + +TEST(Backward, simple_op_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); + auto gop = f::OpRegistry::CreateGradOp(*fwd); + ASSERT_EQ(1UL, gop->inputs_.size()); + ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); + ASSERT_EQ("rowwise_add_grad", gop->type_); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); + ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); + + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), + gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, simple_op_not_need_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); + auto gop = f::Backward(*fwd, {"X"}); + ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), + "X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + gop->outputs_.end()); + + auto no_input_gop = f::Backward(*fwd, {"X", "b"}); + ASSERT_NE(no_input_gop, nullptr); + ASSERT_TRUE(no_input_gop->IsNetOp()); + ASSERT_EQ(0UL, std::static_pointer_cast(no_input_gop)->ops_.size()); +} + +TEST(Backward, net_fc_backward_normal) { + std::shared_ptr fwd = f::OpRegistry::CreateOp( + "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(3UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_add = *net->ops_[1]; + ASSERT_EQ("rowwise_add_grad", d_add.type_); + + f::OperatorBase &d_mul = *net->ops_[2]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, net_fc_backward_not_have_b) { + std::shared_ptr fwd = f::OpRegistry::CreateOp( + "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, + {"mul_result", "add_result", "tmp"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(2UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_mul = *net->ops_[1]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, net_input_of_network_not_need_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, + {"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, + {"mul_tmp_1", "add_tmp_1", "hidden1"}, {})); + net.CompleteAddOp(); + auto bwd = Backward(net, {"X"}); // X@GRAD is not need. + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + + std::unordered_set all_output = std::unordered_set( + bwd_net->outputs_.begin(), bwd_net->outputs_.end()); + all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); + + for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { + ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + } + + // Not Generated X + ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + + ASSERT_EQ(2UL, bwd_net->ops_.size()); + ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); + auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); + ASSERT_EQ(3UL, first_fc_grad->ops_.size()); + ASSERT_EQ( + f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, net_shared_weight) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); + net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); + net.CompleteAddOp(); + + auto bwd = f::Backward(net, {}); + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + ASSERT_EQ(3UL, bwd_net->ops_.size()); + ASSERT_EQ("add", bwd_net->ops_[2]->type_); +} + +TEST(Backward, op_register_grad_not_for_network) { + auto fwd = f::OpRegistry::CreateOp( + "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"}, + {{"temporary_index", std::vector{0, 1}}}); + + ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); +} + +TEST(Backward, op_all_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"X", "b"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, op_all_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"Out"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, op_part_of_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); + auto backward = f::Backward(*fwd, {"Z"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 2UL); + + auto &fill_zero = *net->ops_[0]; + ASSERT_EQ("fill_zeros_like", fill_zero.type_); + ASSERT_EQ(1UL, fill_zero.inputs_.size()); + ASSERT_EQ("Z", fill_zero.inputs_[0]); + ASSERT_EQ(1UL, fill_zero.outputs_.size()); + ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]); + + auto &d_many_out = *net->ops_[1]; + ASSERT_EQ("many_output_op_grad", d_many_out.type_); + ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), + d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(), + d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), + d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, op_part_of_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto backward = f::Backward(*fwd, {"a"}); + auto &grad_mul = *backward; + ASSERT_EQ(grad_mul.type_, "mul_grad"); + ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("A"), "a"); + ASSERT_EQ(grad_mul.Input("B"), "b"); + ASSERT_EQ(grad_mul.Input("Out"), "out"); +} + +TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, + {"mul_out1", "add_out1", "out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, + {"mul_out2", "tmp_out2", "out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, + {"mul_out3", "tmp_out3", "out3"}, {})); + net.CompleteAddOp(); + auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); + ASSERT_TRUE(backward->IsNetOp()); + auto bwd_net = static_cast(backward.get()); + ASSERT_EQ(bwd_net->ops_.size(), 3UL); + auto &grad_fc = *bwd_net->ops_[0]; + EXPECT_EQ(grad_fc.inputs_.size(), + 3UL /* external input number */ + + 1UL /* external output number*/ + + 1UL /* number of gradient of external output*/ + - 1UL /*ignoreGradient varable number*/ + + 2U /* internal variable number*/); + EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ + + 2UL /* input number of rowwise_add */ + + 1UL /* input number of sigmod */); + EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); + + /* + EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + + EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ(grad_fc.Input("X"), "out2"); + EXPECT_EQ(grad_fc.Input("W"), "w3"); + EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3"); + EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3"); + EXPECT_EQ(grad_fc.Input("Out"), "out3"); + */ +} diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/detail/tensor-inl.h index 2acae1b0e20865e786137be09a3973b31b9fba25..e7ff09dd5c954378afeca299e901277c3ebdb96a 100644 --- a/paddle/framework/detail/tensor-inl.h +++ b/paddle/framework/detail/tensor-inl.h @@ -83,56 +83,38 @@ inline void Tensor::ShareDataWith(const Tensor& src) { template inline void Tensor::CopyFrom(const Tensor& src, - const platform::CPUDeviceContext& ctx) { + const platform::Place& dst_place) { src.check_memory_size(); Resize(src.dims()); auto src_place = src.holder_->place(); auto src_ptr = static_cast(src.data()); - auto dst_place = ctx.GetPlace(); auto dst_ptr = static_cast(mutable_data(dst_place)); auto size = product(src.dims_) * sizeof(T); - if (platform::is_cpu_place(src_place)) { + if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size); } #ifndef PADDLE_ONLY_CPU - else if (platform::is_gpu_place(src_place)) { + else if (platform::is_gpu_place(src_place) && + platform::is_cpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size, 0); - } -#endif -} - -#ifndef PADDLE_ONLY_CPU -template -inline void Tensor::CopyFrom(const Tensor& src, - const platform::CUDADeviceContext& ctx) { - src.check_memory_size(); - Resize(src.dims()); - - auto src_place = src.holder_->place(); - auto src_ptr = static_cast(src.data()); - - auto dst_place = ctx.GetPlace(); - auto dst_ptr = static_cast(mutable_data(dst_place)); - - auto size = product(src.dims_) * sizeof(T); - - if (platform::is_cpu_place(src_place)) { + } else if (platform::is_cpu_place(src_place) && + platform::is_gpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, - ctx.stream()); - } else if (platform::is_gpu_place(src_place)) { + boost::get(src_place), src_ptr, size, 0); + } else if (platform::is_gpu_place(src_place) && + platform::is_gpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, - ctx.stream()); + boost::get(src_place), src_ptr, size, 0); } -} + #endif +} template inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 5f3358c69b3fbbbfcd97a96ab50fde3d8b9efad0..a4667cc51fadfc020d3211b7a82356db386fced1 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -80,5 +80,21 @@ struct EigenVector : public EigenTensor { } }; +template +struct EigenScalar { + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + using Type = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + using ConstType = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + + static Type From(Tensor& tensor) { return Type(tensor.data()); } + + static ConstType From(const Tensor& tensor) { + return ConstType(tensor.data()); + } +}; + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index a9fa728e49a0dcc781e520a22c1ee5f921c4c733..dc1957691b1a202826e10e84c21ac8874df9e378 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -46,6 +46,17 @@ TEST(Eigen, Tensor) { } } +TEST(Eigen, ScalarFrom) { + Tensor t; + int* p = t.mutable_data(make_ddim({1}), platform::CPUPlace()); + *p = static_cast(100); + + EigenScalar::Type es = EigenScalar::From(t); + + ASSERT_EQ(0, es.dimension(0)); + ASSERT_EQ(100, es(0)); +} + TEST(Eigen, VectorFrom) { Tensor t; float* p = t.mutable_data(make_ddim({6}), platform::CPUPlace()); diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 6235be75f27dadb65de663ff1b3caf26a649f6cb..dd686cc78246f06cdc3ec7d013086863d7e8fac0 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,7 +20,7 @@ namespace framework { OperatorBase* GradOpBuilder::Build() { BuildOpInOutArgList(); - std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); + std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); grad_op->type_ = grad_op_type; CompleteGradOp(grad_op); @@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, } void GradOpBuilder::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_->type_); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); + const OpProto& op_proto = OpRegistry::protos().at(op_.type_); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); const std::vector& in_format = - op_->attrs_.count("input_format") - ? op_->GetAttr>("input_format") + op_.attrs_.count("input_format") + ? op_.GetAttr>("input_format") : std::vector(); const std::vector& out_format = - op_->attrs_.count("output_format") - ? op_->GetAttr>("output_format") + op_.attrs_.count("output_format") + ? op_.GetAttr>("output_format") : std::vector(); for (const auto& var : op_proto.inputs()) { arg_list_.emplace_back( @@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } (*varmap)[var_name] = idx++; size_t pre_sz = in_out.size(); - auto base_it = - arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::back_inserter(in_out)); if (is_grad) { @@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->attrs_ = op_->attrs_; + grad_op->attrs_ = op_.attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); VarIndexMap* grad_varmap = new VarIndexMap(); diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index 2ecf39479b4f4a51f89cd500caf851897df0e599..cc7a76f3726e00a08fbe06bca4c9b9f5bad466b4 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -29,7 +29,7 @@ class GradOpBuilder { using VarIndexMap = std::unordered_map; public: - GradOpBuilder(const OperatorBase* op) : op_(op) {} + GradOpBuilder(const OperatorBase& op) : op_(op) {} OperatorBase* Build(); private: @@ -40,7 +40,7 @@ class GradOpBuilder { std::vector& format, VarIndexMap* varmap, int& idx, bool is_grad) const; void CompleteGradOp(OperatorBase* grad_op) const; - const OperatorBase* op_; + const OperatorBase& op_; std::vector> arg_list_; }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 288a7841cd7c9212d8fa230e38d49dfc26e76256..e9cf3b9798db2cbfb8d26259ae9a6741fbae8278 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -11,7 +11,7 @@ namespace framework { TEST(GradOpBuilder, AddTwo) { std::shared_ptr add_op( OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); + std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); diff --git a/paddle/framework/images/duplicate_op.graffle b/paddle/framework/images/duplicate_op.graffle new file mode 100644 index 0000000000000000000000000000000000000000..5979f792e252f028a615729215529c2be42d9165 Binary files /dev/null and b/paddle/framework/images/duplicate_op.graffle differ diff --git a/paddle/framework/images/duplicate_op.png b/paddle/framework/images/duplicate_op.png new file mode 100644 index 0000000000000000000000000000000000000000..f299c5d37f260a1bb0daec886f0a4ee1c1f31c92 Binary files /dev/null and b/paddle/framework/images/duplicate_op.png differ diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle new file mode 100644 index 0000000000000000000000000000000000000000..2b658085d6a55d368c320051ba7f94ec2900f13c Binary files /dev/null and b/paddle/framework/images/duplicate_op2.graffle differ diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png new file mode 100644 index 0000000000000000000000000000000000000000..c5588015d1450fd8c1bda3580680d884494868bb Binary files /dev/null and b/paddle/framework/images/duplicate_op2.png differ diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 089c1355951f59d51db16d4b4bdce4282d6e5c25..acf1a69da9fd8adce1bd89367c882eade052e725 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -43,7 +43,7 @@ class NetOp : public OperatorBase { * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - void InferShape(const std::shared_ptr& scope) const override { + void InferShape(const Scope& scope) const override { for (auto& op : ops_) { op->InferShape(scope); } @@ -56,7 +56,7 @@ class NetOp : public OperatorBase { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - void Run(const std::shared_ptr& scope, + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { for (auto& op : ops_) { op->Run(scope, dev_ctx); @@ -68,9 +68,18 @@ class NetOp : public OperatorBase { */ void AddOp(const std::shared_ptr& op) { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); ops_.push_back(op); } + void InsertOp(size_t pos, const std::shared_ptr& op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot InsertOp when this network is sealed"); + PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); + PADDLE_ENFORCE(pos <= ops_.size(), "Out of range"); + ops_.insert(ops_.begin() + pos, op); + } + void CompleteAddOp(bool calculate = true); std::string DebugString() const override; diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 8048311fe54ee1827fb5b91577478a1d30803e43..f32e456e5d142bf8203f9ec03e8059772c4f5c99 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -3,11 +3,6 @@ #include #include -USE_OP(add_two); -USE_OP(mul); -USE_OP(sigmoid); -USE_OP(softmax); - namespace paddle { namespace framework { @@ -16,16 +11,22 @@ static int run_cnt = 0; class TestOp : public OperatorBase { public: - void InferShape( - const std::shared_ptr& scope) const override { + void InferShape(const framework::Scope& scope) const override { ++infer_shape_cnt; } - void Run(const std::shared_ptr& scope, + void Run(const framework::Scope& scope, const paddle::platform::DeviceContext& dev_ctx) const override { ++run_cnt; } }; +class EmptyOp : public OperatorBase { + public: + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + template void AssertSameVectorWithoutOrder(const std::vector& expected, const std::vector& actual) { @@ -62,7 +63,7 @@ TEST(OpKernel, all) { ASSERT_EQ(1UL, tmp_idx.size()); ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); - auto scope = std::make_shared(); + Scope scope; platform::CPUDeviceContext dev_ctx; net->InferShape(scope); @@ -72,20 +73,17 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); } -//! TODO(yuyang18): Refine Backward Op. -// TEST(AddBackwardOp, TestGradOp) { -// auto net = std::make_shared(); -// ASSERT_NE(net, nullptr); -// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); -// net->AddOp( -// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); -// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, -// {})); -// auto grad_ops = AddBackwardOp(net); -// for (auto& op : grad_ops->ops_) { -// op->DebugString(); -// } -//} +TEST(Net, insert_op) { + NetOp net; + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net.AddOp(op1); + net.InsertOp(0, op1); + ASSERT_EQ(2UL, net.ops_.size()); + net.InsertOp(2, op1); + ASSERT_EQ(3UL, net.ops_.size()); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 384f0f631dd9b9a4dd7c0c628340afe668bc248f..f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker { } protected: - void AddInput(const std::string& name, const std::string& comment, - bool multiple = false, bool ignore_gradient = false) { + struct VariableBuilder { + VarProto* var_; + std::function on_multiple_; + std::function on_temporary_; + + VariableBuilder& SetMultiple() { + var_->set_multiple(true); + on_multiple_(); + return *this; + } + + VariableBuilder& SetTemporary() { + PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); + var_->set_temporary(true); + on_temporary_(); + return *this; + } + + VariableBuilder& IgnoreGradient() { + var_->set_ignore_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, + const std::string& comment) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; - input->set_ignore_gradient(ignore_gradient); - input->set_multiple(multiple); - if (multiple) { - SetHasMultipleInput(); - } - } - - void AddInputs(const std::string& name, const std::string& comment, - bool ignore_gradient = false) { - AddInput(name, comment, true, ignore_gradient); + return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, + nullptr}; } - void AddOutput(const std::string& name, const std::string& comment, - bool temporary = false, bool multiple = false, - bool ignore_gradient = false) { + VariableBuilder AddOutput(const std::string& name, + const std::string& comment) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; - output->set_ignore_gradient(ignore_gradient); - output->set_multiple(multiple); - if (multiple) { - SetHasMultipleOutput(); - } - output->set_temporary(temporary); - if (temporary) { - SetHasTemporaryOutput(); - } - } - - void AddOutputs(const std::string& name, const std::string& comment, - bool temporary = false, bool ignore_gradient = false) { - AddOutput(name, comment, temporary, true, ignore_gradient); + return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, + [=] { this->SetHasTemporaryOutput(); }}; } template @@ -300,9 +303,10 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static std::shared_ptr CreateGradOp( - std::shared_ptr op) { - GradOpBuilder builder(op.get()); + static std::shared_ptr CreateGradOp(const OperatorBase& op) { + PADDLE_ENFORCE(!op.IsNetOp(), + "Use framework::Backward to get backward ops"); + GradOpBuilder builder(op); std::shared_ptr grad_op(builder.Build()); grad_op->Init(); return grad_op; diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 2ef781bf8672c8aa53ae32a44f1ea61973f3792c..9894928a7aa19bc6c7ad8b230562fb9a681cfebd 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,9 +7,9 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: - void Run(const std::shared_ptr& scope, + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} - void InferShape(const std::shared_ptr& scope) const override {} + void InferShape(const Scope& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} }; @@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("input", "input of cosine op"); - AddOutput("output", "output of cosine op", - /*temporary*/ true); + AddInput("input", "input of cosine op").SetMultiple(); + AddOutput("output", "output of cosine op").SetTemporary(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; @@ -69,7 +68,7 @@ TEST(OpRegistry, CreateOp) { std::shared_ptr op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto scope = std::make_shared(); + paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); float scale_get = op->GetAttr("scale"); @@ -111,7 +110,7 @@ TEST(OpRegistry, DefaultValue) { std::shared_ptr op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto scope = std::make_shared(); + paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); @@ -173,7 +172,7 @@ TEST(OpRegistry, CustomChecker) { SetInputFormat(&op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; - auto scope = std::make_shared(); + paddle::framework::Scope scope; op->Run(scope, dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 1e57e9a20f3eecfac266d67276347ad4b5b780f9..cfe9cba308556475ef64b45e7178dfc418761598 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* KernelContext::GetEigenDevice< +Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { return device_context_.get_eigen_device(); } @@ -28,28 +28,33 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice< #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice* -KernelContext::GetEigenDevice() const { +ExecutionContext::GetEigenDevice() const { return device_context_.get_eigen_device(); } #endif const std::string& OperatorBase::Input(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, + "Input Output Indices could not be nullptr"); auto it = in_out_idxs_->find(name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", name); - if (attrs_.count("input_format") == 0) { - return inputs_[it->second]; + return inputs_.at((size_t)it->second); } else { const auto& input_format = GetAttr>("input_format"); int idx = input_format[it->second]; - return inputs_.at(idx); + return inputs_.at((size_t)idx); } } std::vector OperatorBase::Inputs(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); + PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= + static_cast(inputs_.size()), + "Input Out Of Range"); return std::vector{ inputs_.begin() + input_format.at(offset), @@ -57,23 +62,26 @@ std::vector OperatorBase::Inputs(const std::string& name) const { } const std::string& OperatorBase::Output(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto it = in_out_idxs_->find(name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", name); - if (attrs_.count("output_format") == 0) { - return outputs_[it->second]; + return outputs_.at((size_t)it->second); } else { const auto& output_format = GetAttr>("output_format"); int idx = output_format[it->second]; - return outputs_.at(idx); + return outputs_.at((size_t)idx); } } std::vector OperatorBase::Outputs(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - + PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= + static_cast(outputs_.size()), + "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset + 1)}; @@ -99,5 +107,11 @@ std::string OperatorBase::DebugString() const { return ss.str(); } +void OperatorBase::Rename(const std::string& old_name, + const std::string& new_name) { + std::replace(inputs_.begin(), inputs_.end(), old_name, new_name); + std::replace(outputs_.begin(), outputs_.end(), old_name, new_name); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0a8c82ee47521713fa96cb423ceca4de858c260c..0832a663dd01fe2921366d70599bc867e73af47c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -31,22 +32,9 @@ limitations under the License. */ namespace paddle { namespace framework { -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class OperatorBase; +class InferShapeContext; +class ExecutionContext; /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -67,6 +55,9 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + /// Variables with this suffix are supposed to be filled up with zeros. + static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } + virtual ~OperatorBase() {} template @@ -84,16 +75,20 @@ class OperatorBase { /// InferShape infer the size of Variables used by this Operator with /// information inside scope - virtual void InferShape(const std::shared_ptr& scope) const = 0; + virtual void InferShape(const Scope& scope) const = 0; /// Net will call this function to Run an op. - virtual void Run(const std::shared_ptr& scope, + virtual void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const = 0; virtual bool IsNetOp() const { return false; } + /// rename inputs outputs name + void Rename(const std::string& old_name, const std::string& new_name); + //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; + //! Get a input which has multiple variables. //! TODO add a vector_view to prevent memory copy. std::vector Inputs(const std::string& name) const; @@ -105,53 +100,156 @@ class OperatorBase { public: std::string type_; + // NOTE: in case of OpGrad, inputs_ contains: + // I (Inputs) + // O (Outputs) + // OG (Output Gradients) std::vector inputs_; + // NOTE: in case of OpGrad, outputs_ contains + // IG (Inputs Gradients) std::vector outputs_; AttributeMap attrs_; // store the arguments' offset described in op_desc. std::shared_ptr> in_out_idxs_; }; -class KernelContext { +class OperatorContext { public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} + OperatorContext(const OperatorBase* op, const Scope& scope) + : op_(*op), scope_(scope) {} + + size_t InputSize() const { return op_.inputs_.size(); } - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); + size_t OutputSize() const { return op_.outputs_.size(); } + + const Variable* InputVar(const size_t index) const { + return scope_.FindVar(op_.inputs_.at(index)); } - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); + Variable* OutputVar(const size_t index) const { + return scope_.FindVar(op_.outputs_.at(index)); } - const Variable* Input(const std::string& name) const { - return scope_->GetVariable(op_.Input(name)); + const Variable* InputVar(const std::string& name) const { + return scope_.FindVar(op_.Input(name)); } - const Variable* Output(const std::string& name) const { - return scope_->GetVariable(op_.Output(name)); + Variable* OutputVar(const std::string& name) const { + return scope_.FindVar(op_.Output(name)); } - const std::vector Inputs(const std::string& name) const { + const std::vector MultiInputVar( + const std::string& name) const { auto names = op_.Inputs(name); std::vector res; + res.reserve(names.size()); std::transform( - names.begin(), names.end(), res.begin(), - [this](const std::string& name) { return scope_->GetVariable(name); }); + names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { return scope_.FindVar(name); }); return res; } - const std::vector Outputs(const std::string& name) const { + std::vector MultiOutputVar(const std::string& name) const { auto names = op_.Outputs(name); std::vector res; + res.reserve(names.size()); std::transform( - names.begin(), names.end(), res.begin(), - [this](const std::string& name) { return scope_->GetVariable(name); }); + names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { return scope_.FindVar(name); }); + return res; + } + + template + const T* Input(const size_t index) const { + auto var = InputVar(index); + PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); + return &var->Get(); + } + + template + T* Output(const size_t index) const { + auto var = OutputVar(index); + PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); + return var->GetMutable(); + } + + template + const T* Input(const std::string& name) const { + auto var = InputVar(name); + PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); + return &var->Get(); + } + + template + T* Output(const std::string& name) const { + auto var = OutputVar(name); + PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); + return var->GetMutable(); + } + + template + const std::vector MultiInput(const std::string& name) const { + auto names = op_.Inputs(name); + std::vector res; + res.reserve(names.size()); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiInput(%s:%s) should not be nullptr", + name, sub_name); + return &var->Get(); + }); + return res; + } + + template + std::vector MultiOutput(const std::string& name) const { + auto names = op_.Outputs(name); + std::vector res; + res.reserve(names.size()); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiOutput(%s:%s) should not be nullptr", + name, sub_name); + return var->GetMutable(); + }); return res; } + const OperatorBase& op_; + const Scope& scope_; +}; + +class InferShapeContext : public OperatorContext { + public: + InferShapeContext(const OperatorBase* op, const Scope& scope) + : OperatorContext(op, scope) {} +}; + +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + +class ExecutionContext : public OperatorContext { + public: + ExecutionContext(const OperatorBase* op, const Scope& scope, + const platform::DeviceContext& device_context) + : OperatorContext(op, scope), device_context_(device_context) {} + template ::EigenDeviceType> @@ -159,38 +257,23 @@ class KernelContext { platform::Place GetPlace() const { return device_context_.GetPlace(); } - const OperatorBase& op_; - const std::shared_ptr& scope_; const platform::DeviceContext& device_context_; }; class OpKernel { public: /** - * KernelContext is the only parameter of Kernel Run function. + * ExecutionContext is the only parameter of Kernel Run function. * Run will get input/output variables, state such as momentum and * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. + * ExecutionContext. User should construct it before run the Operator. */ - virtual void Compute(const KernelContext& context) const = 0; + virtual void Compute(const ExecutionContext& context) const = 0; virtual ~OpKernel() {} }; -template -struct VarToTensor {}; - -template <> -struct VarToTensor { - Tensor* operator()(Variable* var) { return var->GetMutable(); } -}; - -template <> -struct VarToTensor { - const Tensor* operator()(Variable* var) { return &var->Get(); } -}; - class OperatorWithKernel : public OperatorBase { public: struct OpKernelKey { @@ -216,10 +299,14 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - void Run(const std::shared_ptr& scope, + void InferShape(const Scope& scope) const { + InferShape(InferShapeContext(this, scope)); + } + + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(KernelContext(this, scope, dev_ctx)); + opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); } static std::unordered_map& @@ -228,34 +315,8 @@ class OperatorWithKernel : public OperatorBase { return g_all_op_kernels; } - void InferShape(const std::shared_ptr& scope) const final { - std::vector ins; - VarNamesToTensors(scope, inputs_, &ins); - std::vector outs; - VarNamesToTensors(scope, outputs_, &outs); - InferShape(ins, outs); - }; - - private: - template - void VarNamesToTensors(const std::shared_ptr& scope, - const std::vector& var_names, - std::vector* container) const { - container->reserve(var_names.size()); - VarToTensor convert; - for (auto& name : var_names) { - auto var = scope->GetVariable(name); - if (var != nullptr) { - container->push_back(convert(var)); - } else { - container->push_back(nullptr); - } - } - } - protected: - virtual void InferShape(const std::vector& inputs, - const std::vector& outputs) const = 0; + virtual void InferShape(const InferShapeContext& ctx) const = 0; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 3fae356c3e5d5b44271440b66d6923fd4994b937..6a6a802b7da05c37a317540030836baa28a89cd7 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -24,15 +24,15 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { op_run_num++; ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ((int)outputs_.size(), 1); - ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr); ASSERT_EQ(x, 1); - ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); + ASSERT_NE(scope.FindVar(outputs_[0]), nullptr); } public: @@ -68,11 +68,12 @@ TEST(OperatorBase, all) { attr->set_f(3.14); paddle::platform::CPUDeviceContext device_context; - auto scope = std::make_shared(); + paddle::framework::Scope scope; auto op = paddle::framework::OpRegistry::CreateOp(op_desc); - scope->CreateVariable("OUT1"); + scope.NewVar("OUT1"); ASSERT_EQ(paddle::framework::op_run_num, 0); + op->InferShape(scope); op->Run(scope, device_context); ASSERT_EQ(paddle::framework::op_run_num, 1); } @@ -97,14 +98,13 @@ static int cpu_kernel_run_num = 0; class OpWithKernelTest : public OperatorWithKernel { protected: - void InferShape(const std::vector& inputs, - const std::vector& outputs) const override {} + void InferShape(const framework::InferShapeContext& ctx) const override {} }; template class CPUKernelTest : public OpKernel { public: - void Compute(const KernelContext& ctx) const { + void Compute(const ExecutionContext& ctx) const { std::cout << "this is cpu kernel" << std::endl; std::cout << ctx.op_.DebugString() << std::endl; cpu_kernel_run_num++; @@ -117,12 +117,12 @@ class CPUKernelTest : public OpKernel { class OperatorMultiInputsTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { - ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr); ASSERT_EQ(x, 1); - ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); + ASSERT_NE(scope.FindVar(outputs_[0]), nullptr); ASSERT_EQ(Input("x"), "IN1"); ASSERT_EQ(Input("y"), "OUT1"); } @@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("xs", "inputs of test op"); + AddInput("xs", "inputs of test op").SetMultiple(); AddInput("k", "input of test op"); - AddOutputs("ys", "outputs of test op"); + AddOutput("ys", "outputs of test op").SetMultiple(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); @@ -149,13 +149,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker class CPUKernalMultiInputsTest : public OpKernel { public: - void Compute(const KernelContext& ctx) const { + void Compute(const ExecutionContext& ctx) const { auto xs = ctx.op_.Inputs("xs"); ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[1], "x1"); ASSERT_EQ(xs[2], "x2"); + auto inVar0 = ctx.MultiInputVar("xs"); + ASSERT_EQ(inVar0.size(), 3); + + auto intVar1 = ctx.InputVar("k"); + ASSERT_NE(intVar1, nullptr); + + auto outVar0 = ctx.MultiOutputVar("ys"); + ASSERT_EQ(outVar0.size(), 2); + + auto inTensor0 = ctx.MultiInput("xs"); + ASSERT_EQ(inTensor0.size(), 3); + + auto intTensor1 = ctx.Input("k"); + ASSERT_NE(intTensor1, nullptr); + + auto outTensor0 = ctx.MultiOutput("ys"); + ASSERT_EQ(outTensor0.size(), 2); + auto k = ctx.op_.Input("k"); ASSERT_EQ(k, "k0"); @@ -186,7 +204,7 @@ TEST(OpKernel, all) { attr->set_f(3.14); paddle::platform::CPUDeviceContext cpu_device_context; - auto scope = std::make_shared(); + paddle::framework::Scope scope; auto op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); @@ -232,7 +250,13 @@ TEST(OpKernel, multi_inputs) { output_format->Add(2); // y1 paddle::platform::CPUDeviceContext cpu_device_context; - auto scope = std::make_shared(); + paddle::framework::Scope scope; + scope.NewVar("x0")->GetMutable(); + scope.NewVar("x1")->GetMutable(); + scope.NewVar("x2")->GetMutable(); + scope.NewVar("k0")->GetMutable(); + scope.NewVar("y0")->GetMutable(); + scope.NewVar("y1")->GetMutable(); auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc new file mode 100644 index 0000000000000000000000000000000000000000..080b4ac621c1b8c0d4b4e7b26f394cf2be263894 --- /dev/null +++ b/paddle/framework/scope.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/scope.h" +#include "paddle/string/printf.h" + +namespace paddle { +namespace framework { + +Scope::~Scope() { + DropKids(); + for (auto& kv : vars_) delete kv.second; +} + +Scope& Scope::NewScope() const { + kids_.push_back(new Scope(this)); + return *kids_.back(); +} + +Variable* Scope::NewVar(const std::string& name) { + auto iter = vars_.find(name); + if (iter != vars_.end()) { + return iter->second; + } + Variable* v = new Variable(); + vars_[name] = v; + v->name_ = &(vars_.find(name)->first); + return v; +} + +Variable* Scope::NewVar() { + return NewVar(string::Sprintf("%p.%d", this, vars_.size())); +} + +Variable* Scope::FindVar(const std::string& name) const { + auto it = vars_.find(name); + if (it != vars_.end()) return it->second; + return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); +} + +const Scope* Scope::FindScope(const Variable* var) const { + for (auto& kv : vars_) { + if (kv.second == var) { + return this; + } + } + return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); +} +void Scope::DropKids() { + for (Scope* s : kids_) delete s; + kids_.clear(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 4faaf841440ba30b79c83d09fea977186bd0270a..2ba3f8ed355b48800cfa4180e4e8a94f2c9958a9 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once +#include #include #include -#include #include "paddle/framework/variable.h" @@ -35,73 +35,42 @@ class Scope; */ class Scope { public: - /** - * @brief Initialize s Scope without parent. - */ Scope() {} + ~Scope(); - /** - * @brief Initialize a Scope with parent. - */ - explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} - - /** - * @brief Create Variable - * - * Create Variable in this Scope. Return the exist one if Variable already - * been created. - */ - Variable* CreateVariable(const std::string& name) { - auto var = GetVariable(name); - if (var) { - return var; - } else { - auto ptr = new Variable(); - name_to_var_[name] = std::unique_ptr(ptr); - var_to_name_[ptr] = name; - return GetVariable(name); - } - } - - /** - * @brief Get Variable. - * - * Get Variable from this Scope, this function will recursive find Variable - * from it's parent scope. Return nullptr if not found. - */ - Variable* GetVariable(const std::string& name) const { - auto it = name_to_var_.find(name); - if (it != name_to_var_.end()) { - return it->second.get(); - } else if (parent_ != nullptr) { - return parent_->GetVariable(name); - } else { - return nullptr; - } - } - - /** - * @brief If this scope has a Var named name. - * - * Find if there is a Variable in this scope and it's parent scope - */ - bool HasVariable(const std::string& name) const { - return (name_to_var_.find(name) != name_to_var_.end() || - (parent_ && parent_->HasVariable(name))); - } - - std::string GetVariableName(Variable* const var) const { - try { - return var_to_name_.at(var); - } catch (...) { - return ""; - } - } + // Disable Copy, Assign, Move. + Scope(const Scope& other) = delete; + Scope& operator=(const Scope& other) = delete; + Scope(Scope&& other) = delete; + + /// Create a sub-scope. Returns a reference other than a pointer so + /// to prevent from manual deletion. + /// Mark it to const because that new kid scope cannot change parent scope. + Scope& NewScope() const; + + /// Create a variable with given name if it doesn't exist. + Variable* NewVar(const std::string& name); + + /// Create a variable with a scope-unique name. + Variable* NewVar(); + + /// Find a variable in the scope or any of its ancestors. Returns + /// nullptr if cannot find. + Variable* FindVar(const std::string& name) const; + + /// Find the scope or an ancestor scope that contains the given variable. + const Scope* FindScope(const Variable* var) const; + + /// Drop all kids scopes belonged to this scope. + void DropKids(); private: - std::unordered_map var_to_name_; - std::unordered_map> name_to_var_; - std::shared_ptr parent_{nullptr}; + // Call Scope::NewScope for a sub-scope. + explicit Scope(Scope const* parent) : parent_(parent) {} + + std::unordered_map vars_; + mutable std::list kids_; + Scope const* parent_{nullptr}; }; } // namespace framework diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index ff069c7be002e9bcfd63225c3d80aa958935ba14..9d51e355b0f6336d2f875ff2d77266b261baf5ac 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -15,49 +15,42 @@ limitations under the License. */ #include "paddle/framework/scope.h" #include "gtest/gtest.h" -TEST(Scope, Create) { - using paddle::framework::Scope; - using paddle::framework::Variable; +using paddle::framework::Scope; +using paddle::framework::Variable; - auto scope = std::make_shared(); +TEST(Scope, VarsShadowing) { + Scope s; + Scope& ss1 = s.NewScope(); + Scope& ss2 = s.NewScope(); - Variable* var0 = scope->CreateVariable(""); - EXPECT_NE(var0, nullptr); + Variable* v0 = s.NewVar("a"); + Variable* v1 = ss1.NewVar("a"); - /// GetVariable will return nullptr if not exist. - Variable* var1 = scope->GetVariable("a"); - EXPECT_EQ(var1, nullptr); + EXPECT_NE(v0, v1); - /// CreateVariable will return one. - Variable* var2 = scope->CreateVariable("a"); - EXPECT_NE(var2, nullptr); - - /// Get the created variable. - Variable* var3 = scope->GetVariable("a"); - EXPECT_EQ(var2, var3); + EXPECT_EQ(v0, s.FindVar("a")); + EXPECT_EQ(v1, ss1.FindVar("a")); + EXPECT_EQ(v0, ss2.FindVar("a")); +} - /// CreateVariable will just return the variable if it's - /// already exist. - Variable* var4 = scope->CreateVariable("a"); - EXPECT_EQ(var4, var2); +TEST(Scope, FindVar) { + Scope s; + Scope& ss = s.NewScope(); - EXPECT_EQ("a", scope->GetVariableName(var4)); - Scope scope2; - auto var = scope2.CreateVariable("tmp"); - EXPECT_EQ("", scope->GetVariableName(var)); -} + EXPECT_EQ(nullptr, s.FindVar("a")); + EXPECT_EQ(nullptr, ss.FindVar("a")); -TEST(Scope, Parent) { - using paddle::framework::Scope; - using paddle::framework::Variable; + ss.NewVar("a"); - auto parent_scope = std::make_shared(); - auto scope = std::make_shared(parent_scope); + EXPECT_EQ(nullptr, s.FindVar("a")); + EXPECT_NE(nullptr, ss.FindVar("a")); +} - Variable* var0 = parent_scope->CreateVariable("a"); - EXPECT_NE(var0, nullptr); +TEST(Scope, FindScope) { + Scope s; + Scope& ss = s.NewScope(); + Variable* v = s.NewVar("a"); - /// GetVariable will get Variable from parent scope if exist. - Variable* var1 = scope->GetVariable("a"); - EXPECT_EQ(var0, var1); + EXPECT_EQ(&s, s.FindScope(v)); + EXPECT_EQ(&s, ss.FindScope(v)); } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index d3f56b31cd350fac746b8fd5a37f278b26db3e7d..76070f636b0971f4a136042e056c59adb5dc2d40 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -94,14 +94,7 @@ class Tensor { * @note CopyFrom supports CPU <-> GPU, GPU <-> GPU. */ template - inline void CopyFrom(const Tensor& src, - const platform::CPUDeviceContext& ctx); - -#ifndef PADDLE_ONLY_CPU - template - inline void CopyFrom(const Tensor& src, - const platform::CUDADeviceContext& ctx); -#endif + inline void CopyFrom(const Tensor& src, const platform::Place& dst_place); /** * @brief Return the slice of the tensor. @@ -129,13 +122,16 @@ class Tensor { virtual platform::Place place() const = 0; }; - template + template struct PlaceholderImpl : public Placeholder { - PlaceholderImpl(PlaceType place, size_t size) + PlaceholderImpl(Place place, size_t size) : ptr_(static_cast(memory::Alloc(place, size)), - memory::PODDeleter(place)), + memory::PODDeleter(place)), place_(place), - size_(size) {} + size_(size) { + PADDLE_ENFORCE(ptr_ != nullptr, "Insufficient %s memory to allocation.", + is_cpu_place(place_) ? "CPU" : "GPU"); + } virtual size_t size() const { return size_; } virtual platform::Place place() const { return place_; } @@ -143,7 +139,7 @@ class Tensor { virtual std::type_index type() const { return std::type_index(typeid(T)); } /*! the pointer of memory block. */ - std::unique_ptr> ptr_; + std::unique_ptr> ptr_; /*! the place of memory block. */ platform::Place place_; diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index fd7143cfaa6ee9c9f1430dab743aa6b67fdd461e..ef1cc10b840896d9ab97f963fc12a4971cd74e1f 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -198,8 +198,8 @@ TEST(Tensor, CopyFrom) { int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; memcpy(src_ptr, arr, 9 * sizeof(int)); - auto* cpu_ctx = new paddle::platform::CPUDeviceContext(); - dst_tensor.CopyFrom(src_tensor, *cpu_ctx); + auto cpu_place = new paddle::platform::CPUPlace(); + dst_tensor.CopyFrom(src_tensor, *cpu_place); const int* dst_ptr = dst_tensor.data(); ASSERT_NE(src_ptr, dst_ptr); @@ -208,7 +208,7 @@ TEST(Tensor, CopyFrom) { } Tensor slice_tensor = src_tensor.Slice(1, 2); - dst_tensor.CopyFrom(slice_tensor, *cpu_ctx); + dst_tensor.CopyFrom(slice_tensor, *cpu_place); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); ASSERT_NE(dst_ptr, slice_ptr); @@ -228,12 +228,12 @@ TEST(Tensor, CopyFrom) { memcpy(src_ptr, arr, 9 * sizeof(int)); // CPU Tensor to GPU Tensor - auto gpu_ctx = new paddle::platform::CUDADeviceContext(0); - gpu_tensor.CopyFrom(src_tensor, *gpu_ctx); + auto gpu_place = new paddle::platform::GPUPlace(0); + gpu_tensor.CopyFrom(src_tensor, *gpu_place); // GPU Tensor to CPU Tensor - auto cpu_ctx = new paddle::platform::CPUDeviceContext(); - dst_tensor.CopyFrom(gpu_tensor, *cpu_ctx); + auto cpu_place = new paddle::platform::CPUPlace(); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); // Compare Tensors const int* dst_ptr = dst_tensor.data(); @@ -245,10 +245,10 @@ TEST(Tensor, CopyFrom) { Tensor slice_tensor = src_tensor.Slice(1, 2); // CPU Slice Tensor to GPU Tensor - gpu_tensor.CopyFrom(slice_tensor, *gpu_ctx); + gpu_tensor.CopyFrom(slice_tensor, *gpu_place); // GPU Tensor to CPU Tensor - dst_tensor.CopyFrom(gpu_tensor, *cpu_ctx); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); // Compare Slice Tensors const int* slice_ptr = slice_tensor.data(); diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 72c4a7a2a1d1cf93a784f24e687727ee8481484c..38fc2720a3023039aa113b32a394bda9c5def4c0 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -16,7 +16,7 @@ #include #include -#include "paddle/platform/assert.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { @@ -25,7 +25,7 @@ class Variable { public: template const T& Get() const { - PADDLE_ASSERT(IsType()); + PADDLE_ENFORCE(IsType(), "Variable must be type %s", typeid(T).name()); return *static_cast(holder_->Ptr()); } @@ -65,6 +65,17 @@ class Variable { std::unique_ptr holder_; // pointers to a PlaceholderImpl object indeed. + + // name_ is only meaningful with a Scope and accessible by it. + // + // NOTE: Please don't expose name_ by adding methods like + // Variable::Name or Scope::VarName! A variable could have a human + // readable name or an auto-generated scope-unique name. In the + // former case, the caller knows the name and doesn't need to access + // the name; in the latter case, the variable should be identified + // by its address but not the unreadable name. + friend class Scope; + const std::string* name_; }; } // namespace framework diff --git a/paddle/gserver/layers/SliceProjection.cpp b/paddle/gserver/layers/SliceProjection.cpp new file mode 100644 index 0000000000000000000000000000000000000000..267dd6154b1b21cc9b936384d438a2c3bdf0c246 --- /dev/null +++ b/paddle/gserver/layers/SliceProjection.cpp @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Projection.h" + +namespace paddle { + +/** + * SliceProjection can slice the input value into multiple parts, + * and then select some of them to merge into a new output. + * + * First, calculate the slices that need to be merged into the output. + * slices = input.slices().for_output() + * + * Second, merge each slice into the output. + * for(auto slice: slices) { + * out.addAtOffset(slice, offset); + * } + * + * Input slices as output: s0, s1, ...: + * ----------------------- + * |///| |//////| | + * |/s0| |//s1//| | + * |///| |//////| | + * ----------------------- + * Output, merge s0, s1, ... into one output: + * ---------------- + * |///|//////| | + * |/s0|//s1//|...| + * |///|//////| | + * ---------------- + * + * The config file api is slice_projection. + */ +class SliceProjection : public Projection { +public: + SliceProjection(const ProjectionConfig& config, + const ParameterPtr& parameter, + bool useGpu); + virtual void forward(); + virtual void backward(const UpdateCallback& callback); + +protected: + std::vector> slices_; +}; + +REGISTER_PROJECTION(slice, SliceProjection); + +/** + * Constructed function. + * @note SliceProjection should not have any parameter. + */ +SliceProjection::SliceProjection(const ProjectionConfig& config, + const ParameterPtr& parameter, + bool useGpu) + : Projection(config, parameter, useGpu) { + CHECK(!parameter) << "'slice' projection should not have any parameter"; + + slices_.reserve(config.slices_size()); + for (const auto& slice : config.slices()) { + slices_.push_back(std::make_pair(slice.start(), slice.end())); + } +} + +void SliceProjection::forward() { + size_t offset = 0; + for (auto& slice : slices_) { + auto slice_out = in_->value->subColMatrix(slice.first, slice.second); + out_->value->addAtOffset(*slice_out, offset); + offset += slice_out->getWidth(); + } +} + +void SliceProjection::backward(const UpdateCallback& callback) { + if (in_->grad) { + size_t offset = 0; + for (auto& slice : slices_) { + auto slice_out = in_->grad->subColMatrix(slice.first, slice.second); + slice_out->addAtOffset(*out_->grad, offset); + offset += slice_out->getWidth(); + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/concat_slice_a.conf b/paddle/gserver/tests/concat_slice_a.conf new file mode 100644 index 0000000000000000000000000000000000000000..dccf911089e16f4f97b1470ee39d192d4557d4bd --- /dev/null +++ b/paddle/gserver/tests/concat_slice_a.conf @@ -0,0 +1,41 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) + +data = data_layer(name ="input", size=8*16*16) + +conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) +conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + +proj1 = slice_projection(input=conv1, slices=[(0, 4), (4, 12)]) + +proj2 = slice_projection(input=conv2, slices=[(1, 5), (5, 15)]) + +concat = concat_layer(input=[proj1, proj2]) + +outputs(concat) + diff --git a/paddle/gserver/tests/concat_slice_b.conf b/paddle/gserver/tests/concat_slice_b.conf new file mode 100644 index 0000000000000000000000000000000000000000..29686ef2810370af3f84b60b2450d5c7d2e7663d --- /dev/null +++ b/paddle/gserver/tests/concat_slice_b.conf @@ -0,0 +1,41 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) + +data = data_layer(name ="input", size=8*16*16) + +conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) +conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + +proj1 = slice_projection(input=conv1, slices=[(0, 12)]) + +proj2 = slice_projection(input=conv2, slices=[(1, 15)]) + +concat = concat_layer(input=[proj1, proj2]) + +outputs(concat) + diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0d8789e0a2ed2d3544e63734d439db74b77868c5..24c802ca8fda14a8d4deabd3fe7a0f3d7e5e9b90 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -152,6 +152,26 @@ TEST(Projection, identity) { } } +TEST(Projection, slice) { + ProjectionConfig conf; + conf.set_type("slice"); + conf.set_input_size(100); + SliceConfig& slice1 = *conf.add_slices(); + slice1.set_start(10); + slice1.set_end(20); + SliceConfig& slice2 = *conf.add_slices(); + slice2.set_start(50); + slice2.set_end(70); + conf.set_output_size(30); + for (auto useGpu : {false, true}) { + testProjectionGrad(conf, + INPUT_DATA, + /* parameterSize */ 0, + /* batchSize */ 10, + useGpu); + } +} + TEST(Projection, scaling) { ProjectionConfig conf; conf.set_type("scaling"); diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index 40e662b22bac0a2d22aea31fe99b11695bac3f57..f930c72fde3f5e0a6a45cb6bfd3507a4f48028fc 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -237,6 +237,12 @@ TEST(Compare, concat_table) { compareNetwork(config_file_a, config_file_b); } +TEST(Compare, concat_slice) { + std::string config_file_a = "./gserver/tests/concat_slice_a.conf"; + std::string config_file_b = "./gserver/tests/concat_slice_b.conf"; + compareNetwork(config_file_a, config_file_b); +} + #ifndef PADDLE_ONLY_CPU TEST(Compare, img_pool) { std::string config_file_a = "./gserver/tests/img_pool_a.conf"; diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 354f58df39365410ff9aec2576c768e58db9e0d2..4980208e659233d50cd464dfeb213adfd2be3f38 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1141,4 +1141,64 @@ TEST(CpuMatrix, copyFrom) { TensorCheckEqual(cpu, copy); } +void testBatch2seqPadding(int batchSize, int inputDim) { + MatrixPtr cpuInput = std::make_shared(batchSize, inputDim); + MatrixPtr gpuInput = std::make_shared(batchSize, inputDim); + cpuInput->randomizeUniform(); + gpuInput->copyFrom(*cpuInput); + + IVectorPtr cpuSequence; + generateSequenceStartPositions(batchSize, cpuSequence); + IVectorPtr gpuSequence = IVector::create(cpuSequence->getSize(), true); + gpuSequence->copyFrom(*cpuSequence); + + size_t numSeq = cpuSequence->getSize() - 1; + size_t maxSeqLen = *std::max_element(cpuSequence->getData(), + cpuSequence->getData() + numSeq); + + MatrixPtr cBatch = std::make_shared(numSeq * maxSeqLen, inputDim); + MatrixPtr gBatch = std::make_shared(numSeq * maxSeqLen, inputDim); + MatrixPtr cCheck = std::make_shared(numSeq * maxSeqLen, inputDim); + + hl_sequence2batch_copy_padding(gBatch->getData(), + gpuInput->getData(), + cpuSequence->getData(), + inputDim, + maxSeqLen, + numSeq, + false, + true); + cCheck->copyFrom(*gBatch); + + int* seqStart = cpuSequence->getData(); + float* batchData = cBatch->getData(); + float* seqData = cpuInput->getData(); + for (size_t i = 0; i < maxSeqLen; i++) { + for (size_t j = 0; j < numSeq; j++) { + size_t sequenceStart = seqStart[j]; + size_t sequenceLength = seqStart[j + 1] - seqStart[j]; + if (i < sequenceLength) { + memcpy(batchData + (i * numSeq + j) * inputDim, + seqData + (sequenceStart + i) * inputDim, + inputDim * sizeof(real)); + } else { + memset(batchData + (i * numSeq + j) * inputDim, + 0, + inputDim * sizeof(real)); + } + } + } + + TensorCheckErr(*cBatch, *cCheck); +} + +TEST(Matrix, warpCTC) { + for (auto batchSize : {51, 526, 2884}) { + for (auto inputDim : {32, 512, 2026}) { + VLOG(3) << " batchSize=" << batchSize << " inputDim=" << inputDim; + testBatch2seqPadding(batchSize, inputDim); + } + } +} + #endif diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 098931c887479ce6f1afc8b90e4003758d88c018..aaab1142ca18d3319469a4d685fde9d30929113f 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -35,7 +35,7 @@ void Copy(platform::CPUPlace dst_place, platform::GPUPlace src_place, const void* src, size_t num, cudaStream_t stream) { - platform::GPUPlaceGuard g(src_place.device); + platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } @@ -45,7 +45,7 @@ void Copy(platform::GPUPlace dst_place, platform::CPUPlace src_place, const void* src, size_t num, cudaStream_t stream) { - platform::GPUPlaceGuard g(dst_place.device); + platform::SetDeviceId(dst_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); } @@ -56,7 +56,7 @@ void Copy(platform::GPUPlace dst_place, const void* src, size_t num, cudaStream_t stream) { if (dst_place == src_place) { - platform::GPUPlaceGuard g(src_place.device); + platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); } else { platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 99b1c2e1c3e5ae4facaeb4fd0b773a7531448f03..2b9c0eada6e8406fc81baec7f331a8dd5b8b0ec1 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -20,13 +20,39 @@ limitations under the License. */ namespace paddle { namespace memory { +/** + * \brief Copy memory from one place to another place. + * + * \param[in] DstPlace Destination allocation place (CPU). + * \param[in] dst Destination memory address. + * \param[in] SrcPlace Source allocation place (CPU). + * \param[in] src Source memory address. + * \param[in] num memory size in bytes to copy. + * + */ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); #ifndef PADDLE_ONLY_CPU + +/** + * \brief Copy memory from one place to another place. + * + * \param[in] DstPlace Destination allocation place (CPU or GPU). + * \param[in] dst Destination memory address. + * \param[in] SrcPlace Source allocation place (CPU or GPU). + * \param[in] src Source memory address. + * \param[in] num memory size in bytes to copy. + * \param[in] stream CUDA stream. + * + * \note For GPU memory copy, CUDA stream need to be specified + * for asynchronously memory copy. + * + */ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); + #endif // PADDLE_ONLY_CPU } // namespace memory diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index c2e046926fafd8f4cfc4cd81d8f32e3882ff02ec..207025f9b1c64f0f8943f9fae5edefc9328a1d26 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -60,6 +60,7 @@ detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { platform::GpuMaxChunkSize()); } } + platform::SetDeviceId(gpu_id); return as[gpu_id]; } diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index fd4d5e7082c06e481e74515c9cd7f7c13f1cff4b..44f567caf9c19775f17988b5142b7693b41a126d 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -20,15 +20,49 @@ limitations under the License. */ namespace paddle { namespace memory { +/** + * \brief Allocate memory block in one place. + * + * \param[in] place Allocation place (CPU or GPU). + * \param[in] size Allocation size. + * + * \return Allocated memory block address. + * + * \note If return nullptr, it indicates memory allocation failed + * because insufficient memory in current system. When Alloc + * function is invoked, you must check the returned memory + * address is valid or not. + */ template -void* Alloc(Place, size_t); +void* Alloc(Place place, size_t size); +/** + * \brief Free memory block in one place. + * + * \param[in] place Allocation place (CPU or GPU). + * \param[in] ptr Memory block address to free. + * + */ template -void Free(Place, void*); +void Free(Place place, void* ptr); +/** + * \brief Total size of used memory in one place. + * + * \param[in] place Allocation place (CPU or GPU). + * + */ template -size_t Used(Place); +size_t Used(Place place); +/** + * \brief Free memory block in one place. + * + * \note In some cases, custom deleter is used to + * deallocate the memory automatically for + * std::unique_ptr in tensor.h. + * + */ template class PODDeleter { static_assert(std::is_pod::value, "T must be POD"); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0a14dc21144153f9a45d5227e54102983c6c2659..b910bee836ed488aeb34f28d0503b5efba396583 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -44,13 +44,26 @@ endfunction() op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) +op_library(mean_op SRCS mean_op.cc mean_op.cu) +cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op) + op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) -op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) + +op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) - -op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op - softmax_op net) +op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) + +op_library(fc_op + SRCS fc_op.cc + DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) + +op_library(recurrent_network_op + SRCS recurrent_network_op.cc + DEPS op_desc tensor net) +cc_test(recurrent_network_op_test + SRCS recurrent_network_op_test.cc + DEPS recurrent_network_op mul_op add_op) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 1424b0284372d8dfe9eb93ee251b121a48b19b0b..3a43dbfbada87e458109d8ca22effdb4407b4c1d 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -19,16 +19,16 @@ namespace operators { class AddOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two"); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); - PADDLE_ENFORCE( - inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr, - "Inputs/Outputs of AddOp must all be set"); - PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, + "Inputs of AddOp must all be set"); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, + "Outputs of AddOp must all be set"); + PADDLE_ENFORCE(ctx.Input(0)->dims() == ctx.Input(1)->dims(), "Two input of Add Op's dimension must be same."); - outputs[0]->Resize(inputs[0]->dims()); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -49,8 +49,7 @@ The equation is: Out = X + Y class AddOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "AddOpGrad"; return ""; diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 0c39433788e1e07e30aaadc4766028219b05bfa5..d2b649fcbd1e5cac1c8cfcfd4e522e41135f7d1f 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -21,16 +21,17 @@ namespace operators { template class AddKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input0 = context.Input(0)->Get(); - auto input1 = context.Input(1)->Get(); - auto output = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto input0 = context.Input(0); + auto input1 = context.Input(1); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - EigenVector::Flatten(input0) + EigenVector::Flatten(input1); + framework::EigenVector::Flatten(*input0) + + framework::EigenVector::Flatten(*input1); } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 46c88d4d1a28eeedd02eb699562244651ead6d68..4f5b935fde4d5b0d9efae66554cf890291e26941 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -19,20 +19,20 @@ namespace operators { class OnehotCrossEntropyOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of OnehotCrossEntropyOp must be two"); - PADDLE_ENFORCE(outputs.size() == 1, + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of OnehotCrossEntropyOp must be one"); - PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr, + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, "Inputs of OnehotCrossEntropyOp must all be set"); - PADDLE_ENFORCE(outputs[0] != nullptr, + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "Outputs of OnehotCrossEntropyOp must all be set"); - PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); - PADDLE_ENFORCE(outputs[0]->dims().size() == 1, + PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, + "X's dimension must be 2."); + PADDLE_ENFORCE(ctx.Output(0)->dims().size() == 1, "label's dimension must be 1."); - outputs[0]->Resize({inputs[0]->dims()[0]}); + ctx.Output(0)->Resize({ctx.Input(0)->dims()[0]}); } }; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 0383df46be3a3cea7dde8f1b45857e64d5a2f2d8..c3a3728149950a5c7f2195122e8e0ff728492bdb 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel { public: constexpr T LOG_THRESHOLD() const { return static_cast(1e-20); } - void Compute(const KernelContext& context) const override { - auto X = context.Input(0)->Get(); - const T* X_data = X.data(); - const int* label_data = context.Input(1)->Get().data(); - auto* Y = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& ctx) const override { + auto X = ctx.Input(0); + const T* X_data = X->data(); + const int* label_data = ctx.Input(1)->data(); + auto Y = ctx.Output(0); - Y->mutable_data(context.GetPlace()); + Y->mutable_data(ctx.GetPlace()); T* Y_data = Y->data(); - int batch_size = X.dims()[0]; - int class_num = X.dims()[1]; + int batch_size = X->dims()[0]; + int class_num = X->dims()[1]; // Y[i] = -log(X[i][j]) for (int i = 0; i < batch_size; ++i) { diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc index c4a9f5937f4fa8c60989bea1726cedbb73330156..71ceda958770796693265c08cb1fcae27e79bcd9 100644 --- a/paddle/operators/fc_op.cc +++ b/paddle/operators/fc_op.cc @@ -50,8 +50,8 @@ public: AddInput("b", "the bias of fc operator"); AddOutput("Y", "the output of fc operator"); - AddOutput( - "before_act", "the before activation output of fc operator", true); + AddOutput("before_act", "the before activation output of fc operator") + .SetTemporary(); AddAttr("activation", "The activation key for fc layer") .SetDefault("sigmoid") .InEnum({"sigmoid", "softmax"}); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..79a0e3d7e911b728a7a96ceff573976ba2b2e37f --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_zeros_like_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class FillZerosLikeOp : public framework::OperatorWithKernel { +protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1UL, + "Input size of FillZerosLikeOp must be one."); + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, + "Output size of AddOp must be one."); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, + "Input of FillZerosLikeOp must be set."); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, + "Output of FillZerosLikeOp must be set."); + ctx.Output(0)->Resize( + ctx.Input(0)->dims()); + } +}; + +class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FillZerosLikeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Src", "The input of fill-zeros-like op."); + AddOutput("Dst", "The varibale will be filled up with zeros."); + AddComment(R"DOC( +Fill up a vriable with zeros. + +The output will have the same size with input. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(fill_zeros_like, + paddle::operators::FillZerosLikeOp, + paddle::operators::FillZerosLikeOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernel); diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..55ad58f4f17cd4a3e737c01b001675d2690d273e --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cu @@ -0,0 +1,6 @@ +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_zeros_like_op.h" + +REGISTER_OP_GPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernel); \ No newline at end of file diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h new file mode 100644 index 0000000000000000000000000000000000000000..05272964abd43bdc2bd5c3cae8b128099e1c888c --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class FillZerosLikeKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output = context.Output(0); + output->mutable_data(context.GetPlace()); + framework::EigenVector::Flatten(*output).setZero(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe34d6ad4015620cac520146850e10563d4c50e0 --- /dev/null +++ b/paddle/operators/mean_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/mean_op.h" + +namespace paddle { +namespace operators { + +class MeanOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr, + "Input/Output of MeanOp must be initialized."); + ctx.Output(0)->Resize(framework::make_ddim({1})); + } +}; + +class MeanOpMaker : public OpProtoAndCheckerMaker { +public: + MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of mean op"); + AddOutput("Out", "The output of mean op"); + AddComment("Mean Operator"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); +REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..740157cbc57a64cafcf109186c630691620f542b --- /dev/null +++ b/paddle/operators/mean_op.cu @@ -0,0 +1,5 @@ +#define EIGEN_USE_GPU + +#include "paddle/operators/mean_op.h" + +REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel); diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5f7d443751d1cdd7de3b67b0de2758ba1d566fb3 --- /dev/null +++ b/paddle/operators/mean_op.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/type_alias.h" + +namespace paddle { +namespace operators { + +template +class MeanKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); + + output->mutable_data(context.GetPlace()); + + EigenScalar::From(*output).device(*(context.GetEigenDevice())) = + EigenVector::Flatten(*input).mean(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/mean_op_test.cc b/paddle/operators/mean_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..375dcd50e130355c60f82b9d39d1b94fb2c911b0 --- /dev/null +++ b/paddle/operators/mean_op_test.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include + +USE_OP(mean); + +TEST(MeanOp, GetOpProto) { + auto& protos = paddle::framework::OpRegistry::protos(); + auto it = protos.find("mean"); + ASSERT_NE(it, protos.end()); +} diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 22c1b78005358a934c57d487f5b0cff133f61f0c..d127f3a302a340fe7558f918d6eeb2ea0a3fafe7 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -19,18 +19,17 @@ namespace operators { class MulOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs"); - auto dim0 = inputs[0]->dims(); - auto dim1 = inputs[1]->dims(); + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); + auto dim0 = ctx.Input(0)->dims(); + auto dim1 = ctx.Input(1)->dims(); PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2, "The input of mul op must be matrix"); PADDLE_ENFORCE( dim0[1] == dim1[0], "First matrix's width must be equal with second matrix's height."); - PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); - outputs[0]->Resize({dim0[0], dim1[1]}); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "The mul op must take one output"); + ctx.Output(0)->Resize({dim0[0], dim1[1]}); } }; @@ -51,8 +50,7 @@ The equation is: Out = X * Y class MulOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "MulGrad"; return ""; diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 467975044638a3f034ceec84173e8d3fed43cc0c..eef72ab293e13a9d05ce0013be41ec4bb75d6077 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -22,19 +22,17 @@ namespace operators { template class MulKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { + void Compute(const ExecutionContext& context) const override { Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; - auto input0 = context.Input(0)->Get(); - auto input1 = context.Input(1)->Get(); - auto* output = context.Output(0)->GetMutable(); - + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenMatrix::From(*output).device(*(context.GetEigenDevice())) = - EigenMatrix::From(input0).contract(EigenMatrix::From(input1), - dim_pair); + EigenMatrix::From(*context.Input("X")) + .contract(EigenMatrix::From(*context.Input("Y")), + dim_pair); } }; } // namespace operators diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_network_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..60d065fc4789f76370840328870165579aa73b67 --- /dev/null +++ b/paddle/operators/recurrent_network_op.cc @@ -0,0 +1,412 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/recurrent_network_op.h" + +#include +#include +#include + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { + +namespace rnn { + +void SegmentInputs(const std::vector& step_scopes, + const std::vector& inlinks, + const size_t seq_len) { + PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); + for (size_t i = 0; i < inlinks.size(); ++i) { + Tensor* input = + step_scopes[0]->FindVar(inlinks[i].external)->GetMutable(); + DDim dims = input->dims(); + PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, + "all the inlinks must have same length"); + DDim step_dims = slice_ddim(dims, 1, dims.size()); + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_input = + step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable(); + *step_input = input->Slice(j, j + 1); + step_input->Resize(step_dims); + } + } +} + +void ConcatOutputs(const std::vector& step_scopes, + const std::vector& outlinks, + const size_t seq_len) { + for (size_t i = 0; i < outlinks.size(); i++) { + Tensor* output = + step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); + + // TODO(qingiqng) remove following code after adding + // InferShape in RecurrentGradientOp + DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + dims_vec.insert(dims_vec.begin(), seq_len); + output->mutable_data(make_ddim(dims_vec), platform::CPUPlace()); + + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_output = + step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); + // TODO(luotao02) data type and platform::DeviceContext() should set + // correctly + (output->Slice(j, j + 1)) + .CopyFrom(*step_output, platform::CPUPlace()); + } + } +} + +void LinkMemories(const std::vector& scopes, + const std::vector& memories, + size_t step_id, + int offset) { + PADDLE_ENFORCE(step_id < scopes.size(), + "step [%d] is out of range of step scopes' size [%d]", + step_id, + scopes.size()); + PADDLE_ENFORCE(static_cast(step_id) + offset >= 0, + "offset [%d] must be large than -[%d]", + offset, + step_id); + PADDLE_ENFORCE(step_id + offset < scopes.size(), + "offset [%d] is out of range, it must be less than (%d - %d)", + offset, + scopes.size(), + step_id); + auto scope = scopes[step_id]; + auto linked_scope = scopes[step_id + offset]; + for (auto& attr : memories) { + auto mem = scope->NewVar(attr.pre_var)->GetMutable(); + // maybe share variable is better? + auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); + mem->ShareDataWith(*linked_mem); + + // TODO(qingqing) remove following code + // the memory of current step should be allocated in step net + auto m = scope->NewVar(attr.var)->GetMutable(); + // for unit test, as addOp and mulOp are null currently, if not + // mutable_data, mem.data() in output will be error. We will + // remove this line after merge the correct addOp and mulOp. + m->mutable_data(mem->dims(), platform::CPUPlace()); + } +} + +void InitArgument(const ArgumentName& name, + Argument* arg, + const OperatorBase& op) { + arg->step_net = op.Input(name.step_net); + arg->step_scopes = op.Output(name.step_scopes); + + auto inlinks = op.Inputs(name.inlinks); + auto inlink_alias = op.GetAttr>(name.inlink_alias); + PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(), + "the size of inlinks and inlink_alias don't match:%d,%d", + inlinks.size(), + inlink_alias.size()); + for (size_t i = 0; i < inlinks.size(); ++i) { + rnn::Link link; + link.external = inlinks[i]; + link.internal = inlink_alias[i]; + (arg->inlinks).push_back(link); + } + + auto outlinks = op.Outputs(name.outlinks); + auto outlink_alias = op.GetAttr>(name.outlink_alias); + PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(), + "the size of outlinks and outlink_alias don't match:%d,%d", + outlinks.size(), + outlink_alias.size()); + for (size_t i = 0; i < outlinks.size(); ++i) { + rnn::Link link; + link.external = outlinks[i]; + link.internal = outlink_alias[i]; + (arg->outlinks).push_back(link); + } + + auto boot_memories = op.Inputs(name.boot_memories); + + // attributes + auto memories = op.GetAttr>(name.memories); + auto pre_memories = op.GetAttr>(name.pre_memories); + + PADDLE_ENFORCE(memories.size() == boot_memories.size(), + "the size of memories, boot_memories don't match:%d,%d", + memories.size(), + boot_memories.size()); + PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(), + "the size of pre_memories, boot_memories don't match:%d,%d", + pre_memories.size(), + boot_memories.size()); + PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set"); + + for (size_t i = 0; i < memories.size(); ++i) { + rnn::MemoryAttr mem_attr; + mem_attr.var = memories[i]; + mem_attr.pre_var = pre_memories[i]; + mem_attr.boot_var = boot_memories[i]; + (arg->memories).push_back(mem_attr); + } +} + +} // namespace rnn + +void RecurrentAlgorithm::InferShape(const Scope& scope) const { + seq_len_ = scope.FindVar((arg_->inlinks[0]).external) + ->GetMutable() + ->dims()[0]; + CreateScopes(scope); + auto step_scopes = GetStepScopes(scope); + + // SegmentInputs is called in InferShape. The input must hold memory in + // SegmentInputs. But the other op only set dimension for the output in + // InferShape. That's a problem. Wether the RNN op needs InferShape or not? + // Wether the following functions (SegmentInputs, InitMemories, ...) need + // to rewrite for RNN op? + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + + InitMemories(step_scopes[0]); + + PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, + "stepnet [%s] is not in scope.", + arg_->step_net); + Variable* net = scope.FindVar(arg_->step_net); + PADDLE_ENFORCE(net != nullptr, "failed to get step net"); + // If the InferShape is called in OperatorBase's run function, + // the rnn op only needs to do InferShape for the first time step + for (size_t i = 0; i < seq_len_; i++) { + if (i > 0) { + rnn::LinkMemories(step_scopes, arg_->memories, i, -1); + } + net->GetMutable()->InferShape(*step_scopes[i]); + } + + auto outlinks = arg_->outlinks; + for (size_t i = 0; i < outlinks.size(); i++) { + DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + // now only support fixed length + dims_vec.insert(dims_vec.begin(), seq_len_); + Tensor* output = + step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); + output->Resize(make_ddim(dims_vec)); + } +} + +void RecurrentAlgorithm::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + auto step_scopes = GetStepScopes(scope); + + Variable* net = scope.FindVar(arg_->step_net); + for (size_t step_id = 0; step_id < seq_len_; step_id++) { + // the link memory is done in InferShape + // maybe remove following code after testing + if (step_id > 0) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); + } + net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + } + + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); +} + +void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { + // TODO(xxx) Only two scopes are needed for inference, this case will be + // supported later. + auto step_scopes = + scope.FindVar(arg_->step_scopes)->GetMutable>(); + + if (seq_len_ > step_scopes->size()) { + for (size_t i = step_scopes->size(); i < seq_len_; ++i) { + auto& step_scope = scope.NewScope(); + + // Now all variables in scope must be created outside of op. + auto net_op = scope.FindVar(arg_->step_net)->GetMutable(); + for (auto& input : net_op->inputs_) { + if (!step_scope.FindVar(input)) step_scope.NewVar(input); + } + for (auto& output : net_op->outputs_) { + step_scope.NewVar(output); + } + + step_scopes->emplace_back(&step_scope); + } + } +} + +void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { + for (auto& attr : arg_->memories) { + Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable(); + PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, + "memory [%s]'s boot variable [%s] not exists", + attr.var, + attr.boot_var); + Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable(); + pre_mem->ShareDataWith(*boot_mem); + + // TODO(qingqing) remove following code + // the memory of current step should be allocated in step net + // here for unit test + auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable(); + cur_step_mem->mutable_data(boot_mem->dims(), platform::CPUPlace()); + } +} + +const rnn::ArgumentName RecurrentOp::kArgName{"step_net", + "step_scopes", + "inlinks", + "outlinks", + "inlink_alias", + "outlink_alias", + "memories", + "pre_memories", + "boot_memories"}; + +const rnn::ArgumentName RecurrentGradientOp::kArgName{"step_net", + "step_scopes", + "outlink@grad", + "inlink@grad", + "inlink_alias", + "outlink_alias", + "memories", + "pre_memories", + "boot_memories@grad"}; + +void RecurrentOp::Init() { + OperatorBase::Init(); + std::unique_ptr arg(new rnn::Argument()); + rnn::InitArgument(kArgName, arg.get(), *this); + alg_.Init(std::move(arg)); +} + +class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker { +public: + RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto, + OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + const auto& name = RecurrentOp::kArgName; + // inputs and outputs stored in proto + AddInput(name.inlinks, "the input that need to be segmented for each step.") + .SetMultiple(); + AddInput(name.boot_memories, "variables to initialize memories.") + .SetMultiple(); + AddInput(name.step_net, "network shared by all steps."); + + AddOutput(name.outlinks, "the output that need to concated for all steps.") + .SetMultiple(); + AddOutput(name.step_scopes, "step scopes"); + + // Attributes stored in AttributeMap + AddAttr>(name.inlink_alias, "alias of inlinks"); + AddAttr>(name.outlink_alias, "alias of outlinks"); + AddAttr>(name.pre_memories, + "names of pre-memories"); + AddAttr>(name.memories, "names of memories"); + + AddComment("This is a recurrent group operator."); + } +}; + +void RecurrentGradientAlgorithm::Run( + const Scope& scope, const platform::DeviceContext& dev_ctx) const { + auto step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, + "step net is not in scope."); + Variable* net = scope.FindVar(arg_->step_net); + PADDLE_ENFORCE(net != nullptr, "failed to get step net"); + for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { + if (static_cast(step_id) != seq_len_ - 1) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + } + net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + } + LinkBootMemoryGradients(step_scopes[0]); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); +} + +void RecurrentGradientAlgorithm::LinkBootMemoryGradients( + Scope* step_scope) const { + for (auto& attr : arg_->memories) { + Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); + PADDLE_ENFORCE(mem_grad != nullptr, + "boot_tensor should be retrieved before"); + PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, + "memory [%s]'s boot variable [%s] not exists", + attr.var, + attr.boot_var); + Tensor* boot_mem_grad = + step_scope->NewVar(attr.boot_var)->GetMutable(); + boot_mem_grad->ShareDataWith(*mem_grad); + } +} + +void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { + seq_len_ = scope.FindVar((arg_->inlinks[0]).external) + ->GetMutable() + ->dims()[0]; + auto step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + + PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, + "step net is not in scope."); + Variable* net = scope.FindVar(arg_->step_net); + PADDLE_ENFORCE(net != nullptr, "failed to get step net"); + + for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { + if (static_cast(step_id) != seq_len_ - 1) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + } + net->GetMutable()->InferShape(*step_scopes[step_id]); + } + + auto outlinks = arg_->outlinks; + for (size_t i = 0; i < outlinks.size(); i++) { + DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + // now only support fixed length + dims_vec.insert(dims_vec.begin(), seq_len_); + Tensor* output = + step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); + output->Resize(make_ddim(dims_vec)); + } + LinkBootMemoryGradients(step_scopes[0]); +} + +void RecurrentGradientOp::Init() { + OperatorBase::Init(); + std::unique_ptr arg(new rnn::Argument()); + rnn::InitArgument(kArgName, arg.get(), *this); + alg_.Init(std::move(arg)); +} + +} // namespace operators +} // namespace paddle + +REGISTER_OP(recurrent_op, + paddle::operators::RecurrentOp, + paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); diff --git a/paddle/operators/recurrent_network_op.h b/paddle/operators/recurrent_network_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d57a1a2e51cbed22549ab6ebce79223e2d4e3bcf --- /dev/null +++ b/paddle/operators/recurrent_network_op.h @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +using namespace paddle::framework; + +namespace rnn { + +/** + * Memory of a RNN (same as the role of `Momory` in PaddlePaddle). + * + * Memory attributes cached by this op, dims will be infered from + * boot memories in father scope. Other attributes are copied from Op's proto + * attributes. + */ +struct MemoryAttr { + // name of current state variable + std::string var; + // name of previous step's state variable + std::string pre_var; + // name of the variables to init this memory (same role of `boot_layer` in + // PaddlePaddle), which is store in father's scope. + std::string boot_var; +}; + +struct Link { + // input or output links name. + std::string internal; + // alias to avoid duplicate keys in scopes. + std::string external; +}; + +struct Argument { + std::string step_net; + std::string step_scopes; + std::vector inlinks; + std::vector outlinks; + std::vector memories; +}; + +struct ArgumentName { + std::string step_net; + std::string step_scopes; + std::string inlinks; + std::string outlinks; + std::string inlink_alias; // the alias of inlinks in step net. + std::string outlink_alias; // the alias of outlinks in step net. + std::string memories; // the memory name + std::string pre_memories; // the previous memory name + std::string boot_memories; // the boot memory name +}; + +/** + * Prepare inputs for each step net. + */ +void SegmentInputs(const std::vector& step_scopes, + const std::vector& inlinks, + const size_t seq_len); + +/** + * Process outputs of step nets and merge to variables. + */ +void ConcatOutputs(const std::vector& step_scopes, + const std::vector& outlinks, + const size_t seq_len); + +void LinkMemories(const std::vector& step_scopes, + const std::vector& memories, + size_t step_id, + int offset); + +void InitArgument(const ArgumentName& name, Argument* arg); + +}; // namespace rnn + +// The sequence format in RecurrentOp is Tensor now. +// TODO: +// 1. No-padding computing for sequences with indifinite length in one batch. +// 2. Hierarchical RNN for sequence with sub-sequence. +// 3. Internal Memory. +// 4. More Complex RNN architecture, such as Gated Feedback RNN. +// Refer to: https://arxiv.org/pdf/1502.02367.pdf + +class RecurrentAlgorithm { +public: + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; + + void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + + /** + * InferShape must be called before Run. + */ + void InferShape(const Scope& scope) const; + +protected: + /* + * The step scopes will be stored in the father scope as a variable. + * + * NOTE the scopes are reused in both the forward and backward, so just + * create once and expand its size if more steps need. + */ + void CreateScopes(const Scope& scope) const; + + const std::vector& GetStepScopes(const Scope& scope) const { + return *scope.FindVar(arg_->step_scopes)->GetMutable>(); + } + + void InitMemories(Scope* step_scopes) const; + +private: + std::unique_ptr arg_; + mutable size_t seq_len_; +}; + +class RecurrentGradientAlgorithm { + /** + * RNN's backward alogorithm. + * + * To accelerate the development of RecurrentGradientOp, we decouple RNN's + * algorithm and `OperatorBase`'s implementation, the former contains the core + * implementation of a RNN, and will keep stable even if the framework changes + * a + * lot, and the latter is a wrapper acts like an dapter for it to make RNN an + * operator. + */ +public: + void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; + + void LinkBootMemoryGradients(Scope* step_scopes) const; + + /** + * InferShape must be called before Run. + */ + void InferShape(const Scope& scope) const; + +protected: + inline const std::vector& GetStepScopes(const Scope& scope) const { + return *scope.FindVar(arg_->step_scopes)->GetMutable>(); + } + +private: + std::unique_ptr arg_; + mutable size_t seq_len_; +}; + +class RecurrentOp final : public OperatorBase { +public: + void Init() override; + + /** + * InferShape must be called before Run. + */ + virtual void InferShape(const Scope& scope) const override { + alg_.InferShape(scope); + } + + virtual void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + alg_.Run(scope, dev_ctx); + } + + static const rnn::ArgumentName kArgName; + +private: + RecurrentAlgorithm alg_; +}; + +class RecurrentGradientOp final : public OperatorBase { +public: + void Init() override; + + /** + * InferShape must be called before Run. + */ + virtual void InferShape(const Scope& scope) const override { + alg_.InferShape(scope); + } + + virtual void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + alg_.Run(scope, dev_ctx); + } + + static const rnn::ArgumentName kArgName; + +private: + RecurrentGradientAlgorithm alg_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/recurrent_network_op_test.cc b/paddle/operators/recurrent_network_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0e61fbee611744adb85b498b1c3540f059afc8c --- /dev/null +++ b/paddle/operators/recurrent_network_op_test.cc @@ -0,0 +1,399 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" +#include "paddle/operators/recurrent_network_op.h" + +namespace paddle { +namespace operators { + +class RecurrentOpTest : public ::testing::Test { +protected: + virtual void SetUp() override { + CreateGlobalVariables(); + CreateStepNet(); + CreateRNNOp(); + } + + virtual void TearDown() override {} + + void CreateGlobalVariables() { + // create input, and init content + LOG(INFO) << "create global variable x"; + for (auto inlink : std::vector{"x", "x0", "x1", "h"}) { + Variable* x = scope_.NewVar(inlink); + DDim dims = make_ddim(std::vector{ + 10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + } + // create output alias just for test + for (auto inlink : std::vector{"h@alias"}) { + Variable* x = scope_.NewVar(inlink); + DDim dims = + make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + } + + LOG(INFO) << "create global variable w"; + Variable* w = scope_.NewVar("rnn/w"); + w->GetMutable()->mutable_data( + make_ddim(std::vector{30, 30}), platform::CPUPlace()); + + for (auto boot : std::vector{"x_boot", "h_boot"}) { + LOG(INFO) << "create global variable " << boot; + Variable* h_boot = scope_.NewVar(boot); + h_boot->GetMutable()->mutable_data( + make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}), + platform::CPUPlace()); + } + + LOG(INFO) << "create variable step_scopes"; + scope_.NewVar("step_scopes"); + + LOG(INFO) << "create variable h"; + scope_.NewVar("h"); + } + + void CreateRNNOp() { + OpDesc op_desc; + + op_desc.set_type("recurrent_op"); + // inlinks 0 + op_desc.add_inputs("x"); + op_desc.add_inputs("x0"); + op_desc.add_inputs("x1"); + // boot_memories 3 + op_desc.add_inputs("x_boot"); + op_desc.add_inputs("h_boot"); + // step net 5 + op_desc.add_inputs("step_net"); + // outlinks 6 + op_desc.add_outputs("h"); + // step scopes 7 + op_desc.add_outputs("step_scopes"); + + auto _input_format = std::vector{ + 0, // in_link + 3, // memories + 5 // step_net + }; + auto input_format = op_desc.add_attrs(); + input_format->set_name("input_format"); + input_format->set_type(paddle::framework::AttrType::INTS); + for (auto i : _input_format) { + input_format->add_ints(i); + } + + auto output_format = op_desc.add_attrs(); + output_format->set_name("output_format"); + output_format->set_type(paddle::framework::AttrType::INTS); + for (auto i : std::vector{0, 1, 2}) { + output_format->add_ints(i); + } + + auto inlink_alias = op_desc.add_attrs(); + inlink_alias->set_name("inlink_alias"); + inlink_alias->set_type(paddle::framework::AttrType::STRINGS); + + auto outlink_alias = op_desc.add_attrs(); + outlink_alias->set_name("outlink_alias"); + outlink_alias->set_type(paddle::framework::AttrType::STRINGS); + + auto pre_memories = op_desc.add_attrs(); + pre_memories->set_name("pre_memories"); + pre_memories->set_type(paddle::framework::AttrType::STRINGS); + + auto memories = op_desc.add_attrs(); + memories->set_name("memories"); + memories->set_type(paddle::framework::AttrType::STRINGS); + + // create inlink_alias + for (const auto& item : + std::vector{"x@alias", "x0@alias", "x1@alias"}) { + inlink_alias->add_strings(item); + } + // pre memories + for (const auto& item : + std::vector{"rnn/x@pre", "rnn/h@pre"}) { + pre_memories->add_strings(item); + } + // memories + for (const auto& item : std::vector{"rnn/x", "rnn/h"}) { + memories->add_strings(item); + } + // output alias + for (const auto& item : std::vector{"h@alias"}) { + outlink_alias->add_strings(item); + } + + rnn_op_ = OpRegistry::CreateOp(op_desc); + + LOG(INFO) << "rnn_op finish init"; + } + + void CreateStepNet() { + LOG(INFO) << "create variable step_net"; + Variable* var = scope_.NewVar("step_net"); + auto net = var->GetMutable(); + // rnn/s is net's input or output? + net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"}; + net->inputs_ = {"rnn/s", "rnn/h"}; + net->AddOp( + OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); + + net->AddOp( + OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); + net->CompleteAddOp(); + } + + // father scope + Scope scope_; + std::shared_ptr rnn_op_; +}; + +TEST_F(RecurrentOpTest, Run) { + platform::CPUDeviceContext ctx; + rnn_op_->InferShape(scope_); + rnn_op_->Run(scope_, ctx); +} + +class RecurrentGradientAlgorithmTest : public ::testing::Test { +protected: + virtual void SetUp() override { + CreateGlobalVariables(); + CreateStepScopes(); + CreateStepNet(); + CreateRNNGradientAlgorithm(); + + // segment inputs + SegmentInputs(); + // link forward memories + LinkeMemories(); + } + + virtual void TearDown() override {} + + void CreateGlobalVariables() { + // inputs: x + LOG(INFO) << "create global variable x"; + Variable* x = scope_.NewVar("x"); + DDim dims = + make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + // inputs: h_boot + LOG(INFO) << "create global variable h_boot"; + Variable* h_boot = scope_.NewVar("h_boot"); + h_boot->GetMutable()->mutable_data( + make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace()); + // inputs: w + LOG(INFO) << "create global variable w"; + Variable* w = scope_.NewVar("rnn/w"); + w->GetMutable()->mutable_data(make_ddim({30, 30}), + platform::CPUPlace()); + // inputs: h_grad + LOG(INFO) << "create variable h_grad"; + Variable* dh = scope_.NewVar("h_grad"); + dh->GetMutable()->mutable_data(make_ddim({10, 20, 30}), + platform::CPUPlace()); + // inputs: step_scopes + LOG(INFO) << "create variable step_scopes"; + scope_.NewVar("step_scopes"); + // inputs: step_net + LOG(INFO) << "create variable step_net"; + scope_.NewVar("step_net"); + // outputs: w_grad + LOG(INFO) << "create global variable w_grad"; + scope_.NewVar("rnn/w_grad"); + // outputs: x_grad + LOG(INFO) << "create global variable x_grad"; + scope_.NewVar("x_grad"); + // outputs: h_boot_grad + LOG(INFO) << "create global variable h_boot_grad"; + scope_.NewVar("h_boot_grad"); + } + + void CreateStepScopes() { + auto step_scopes = + scope_.FindVar("step_scopes")->GetMutable>(); + for (int i = 0; i < 10; ++i) { + auto& scope = scope_.NewScope(); + auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable(); + pre_t->mutable_data({20, 30}, platform::CPUPlace()); + auto tensor = scope.NewVar("rnn/h")->GetMutable(); + tensor->mutable_data({20, 30}, platform::CPUPlace()); + + // for unit test of ConcatOutputs + auto xg = scope.NewVar("rnn/x_grad")->GetMutable(); + xg->mutable_data({20, 30}, platform::CPUPlace()); + + step_scopes->emplace_back(&scope); + } + + // last time step + auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable(); + g->mutable_data({20, 30}, platform::CPUPlace()); + } + + void CreateRNNGradientAlgorithm() { + std::unique_ptr arg(new rnn::Argument()); + arg->step_net = "step_net"; + arg->step_scopes = "step_scopes"; + rnn::Link inlink; + inlink.external = "h_grad"; + inlink.internal = "rnn/h_grad"; + arg->inlinks = std::vector{inlink}; + + rnn::Link outlink; + outlink.external = "x_grad"; + outlink.internal = "rnn/x_grad"; + arg->outlinks = std::vector{outlink}; + + rnn::MemoryAttr mem_attr; + mem_attr.pre_var = "rnn/h_pre_grad"; + mem_attr.var = "rnn/h_grad"; + mem_attr.boot_var = "h_boot_grad"; + arg->memories = std::vector{mem_attr}; + + rnn_grad_algo_.Init(std::move(arg)); + } + + void CreateStepNet() { + LOG(INFO) << "create variable step_net"; + Variable* var = scope_.NewVar("step_net"); + auto net = var->GetMutable(); + net->AddOp(OpRegistry::CreateOp("mul", + {"rnn/h_pre", "rnn/w", "rnn/s_grad"}, + {"rnn/h_pre_grad", "rnn/w_grad"}, + {})); + + net->AddOp(OpRegistry::CreateOp( + "add_two", {"rnn/h_grad"}, {"rnn/x_grad", "rnn/s_grad"}, {})); + net->CompleteAddOp(); + } + + void SegmentInputs() { + LOG(INFO) << "segment inputs"; + std::vector inlinks = {"x"}; + std::vector inlinks_alias = {"rnn/x"}; + + rnn::Link inlink; + inlink.external = "x"; + inlink.internal = "rnn/x"; + auto step_scopes = + scope_.FindVar("step_scopes")->GetMutable>(); + rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10); + } + + void LinkeMemories() { + LOG(INFO) << "link memories"; + rnn::MemoryAttr mem_attr; + mem_attr.pre_var = "rnn/h_pre"; + mem_attr.var = "rnn/h"; + mem_attr.boot_var = "boot_h"; + std::vector memories; + memories.push_back(mem_attr); + auto step_scopes = + scope_.FindVar("step_scopes")->GetMutable>(); + for (int i = 1; i < 10; ++i) { + rnn::LinkMemories(*step_scopes, memories, i, -1); + } + } + + Scope scope_; + RecurrentGradientAlgorithm rnn_grad_algo_; +}; + +// TEST_F(RecurrentGradientAlgorithmTest, Run) { +// platform::CPUDeviceContext ctx; +// rnn_grad_algo_.Run(scope_, ctx); +// } + +} // namespace operators +} // namespace paddle + +TEST(RecurrentOp, LinkMemories) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + // create and init step scopes + int len = 10; + std::vector step_scopes; + for (int i = 0; i < len; ++i) { + auto scope = new Scope(); + scope->NewVar("pre_h"); + auto tensor = scope->NewVar("h")->GetMutable(); + float* data = tensor->mutable_data({15, 20}, CPUPlace()); + for (int j = 0; j < 15 * 20; ++j) { + data[j] = rand() * (1. / (double)RAND_MAX); + } + step_scopes.push_back(scope); + } + + // create MemoryAttr + rnn::MemoryAttr mem_attr; + mem_attr.pre_var = "pre_h"; + mem_attr.var = "h"; + mem_attr.boot_var = "boot_h"; + std::vector memories; + memories.push_back(mem_attr); + + for (int i = 1; i < len; ++i) { + rnn::LinkMemories(step_scopes, memories, i, -1); + } + // check + for (int i = 0; i < len - 1; ++i) { + const float* a = + step_scopes[i]->FindVar("h")->GetMutable()->data(); + const float* b = step_scopes[i + 1] + ->FindVar("pre_h") + ->GetMutable() + ->data(); + for (size_t i = 0; i < 15 * 20; ++i) { + ASSERT_FLOAT_EQ(a[i], b[i]); + } + } + + for (int i = len - 2; i >= 0; --i) { + rnn::LinkMemories(step_scopes, memories, i, 1); + } + // check + for (int i = len - 2; i >= 0; --i) { + const float* a = + step_scopes[i]->FindVar("pre_h")->GetMutable()->data(); + const float* b = + step_scopes[i + 1]->FindVar("h")->GetMutable()->data(); + for (size_t i = 0; i < 15 * 20; ++i) { + ASSERT_FLOAT_EQ(a[i], b[i]); + } + } + + for (auto s : step_scopes) { + delete s; + } +} + +USE_OP(add_two); +USE_OP(mul); + +// int main() { +// //! TODO(yuyang18): Temporary disable this unit-test because implementation +// //! error. +// return 0; +//} \ No newline at end of file diff --git a/paddle/operators/rnn_design.md b/paddle/operators/rnn_design.md new file mode 100644 index 0000000000000000000000000000000000000000..3d38b9a0ad225fd8e0c1bb037474b292b1887f5b --- /dev/null +++ b/paddle/operators/rnn_design.md @@ -0,0 +1,239 @@ +# RNN 变长输入设计 +对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式, +即将一个mini-batch内不同长度的序列补0到固定长度参与计算。 + +现有Paddle包括 `RecurrentLayerGroup` 在内的RNN均实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。 + +## 背景介绍 +由于tensor必须有明确的shape,因此基于tensor 的主流框架在存储变长序列时, +必须用zero-padding的方式将变长序列补全为固定shape的tensor。 + +由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在, +因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。 + +由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来进行优化[1][2], +但不管是padding还是bucket,对于用户都是额外的使用负担。 + +因此,**paddle原生支持变长序列的方式,能直接满足用户对变长序列的最直接的需求,在当前主流平台中可以算是一大优势**。 + +但对变长序列的支持,需要对目前框架做一些修改,下面讨论如何在最小修改下支持变长序列。 + +## 多层序列数据格式 `LODTensor` +目前 Paddle 会将一个mini-batch内的数据存储在一维的内存上, +额外使用 `Argument.sequenceStartPositions` 来存储每个句子的信息。 + +Paddle里使用 `Argument.subSequenceStartPositions` 来存储2层的序列信息,更高维度的序列则无法直接支持; + +为了支持 `N-level` 序列的存储,本文将序列信息定义成如下数据结构: + +```c++ +std::shared_ptr>> lod_start_pos_; +``` + +或者更明确的定义 + +```c++ +typedef std::vector level_t; +std::vector lod_start_pos; +``` + +这里的每一个 `level_t` 存储一个粒度(level)的偏移信息,和paddle目前做法一致。 + +为了更透明地传递序列信息,我们引入了一种新的tensor 称为 `LODTensor`[4], +其关于tensor相关的接口都直接继承自 `Tensor`,但另外添加了序列相关接口。 +如此,在操作一个 `LODTensor` 时,普通 `Op` 直接当成 `Tensor` 使用, +而操作序列的 `Op` 会额外操作 `LODTensor` 的变长序列操作的相关接口。 + +`LODTensor` 具体定义如下: + +```c++ +class LODTensor : public Tensor { +public: + size_t Levels() const { return seq_start_positions_.size(); } + size_t Elements(int level = 0) const { + return seq_start_positions_[level].size(); + } + // slice of level[elem_begin: elem_end] + // NOTE low performance in slice seq_start_positions_. + // TODO should call Tensor's Slice. + LODTensor LODSlice(int level, int elem_begin, int elem_end) const; + + // slice with tensor's data shared with this. + LODTensor LODSliceShared(int level, int elem_begin, int elem_end) const; + + // copy other's lod_start_pos_, to share LOD info. + // NOTE the LOD info sould not be changed. + void ShareConstLODFrom(const LODTensor &other) { + lod_start_pos_ = other.lod_start_pos_; + } + // copy other's lod_start_pos_'s content, free to mutate. + void ShareMutableLODFrom(const LODTensor &other) { + lod_start_pos_ = std::make_shared < + std::vector>(other.lod_start_pos_.begin(), + other.lod_start_pos_.end()); + } + +private: + std::shared_ptr>> lod_start_pos_; +}; +``` + +其中, `lod_start_pos_` 使用了 `shared_ptr` 来减少存储和复制的代价, +可以认为 `LODTensor` 是 `Tensor` 的扩展,几乎完全兼容原始 `Tensor` 的使用。 + +## 框架支持 +### 框架现有的 `Tensor` 调用替换为 `LODTensor` +为了实现 `LODTensor` 的传递,框架里很多 `Tensor` 都需要变成 `LODTensor`, +简单实现,直接 **把之前所有的`Tensor` 全部替换成 `LODTensor`,这里可以直接修改 `pybind.cc` 里面创建`Tensor`的接口**。 + +此外,用户有可能需要感知序列的存在(比如序列的可视化需要解析模型中输出的序列),因此一些序列操作的API也需要暴露到 python 层。 + +### `lod_start_pos` 随着Op调用链传递 +框架需要支持下列特性,以实现`lod_start_pos`的传递: + +1. 以 `shared_ptr` 的方式实现传递 + - 不修改 `lod_start_pos` 内容的作为 consumer + - 修改 `lod_start_pos` 的作为 producer + - 约定 consumer 只需要复制传递过来的 `shared_ptr` + - producer 需要创建自己的独立的内存,以存储自己独立的修改,并暴露 `shared_ptr` 给后续 consumer + - 由于传递过程是以复制`shared_ptr`的方式实现,因此框架只需要传递一次 `lod_start_pos` + +2. 对于不感知 `lod_start_pos` 的Op足够透明 +3. 需要修改 `lod_start_pos` 的producer Op可以在 `Run` 时更新自己的 `lod_start_pos` 数据 + +具体的设计分为以下3小节 + +#### `load_start_pos` 的传递 + +- 对于不需要修改 `lod_start_pos` 的情况,调用 LODTensor的 `ShareConstLODFrom` 接口实现复制 +- 需要修改的,调用`ShareMutableLODFrom` 接口自己分配内存以存储修改 + +#### 框架透明 +传递这一步需要加入到网络跑之前的初始化操作中,并且只需要初始化一次,基于当前框架设计的初步方案如下 + +- 在 Op 的 `attrs` 中添加一项 `do_mutate_lod_info` 的属性,默认为 `false` + - 有需要修改 `lod_start_pos` 的Op需要在定义 `OpProto` 时设置为 `true` +- `OperatorBase` 的 `InferShape` 中会读取 `do_mutate_lod_info` ,并且调用 `LODTensor` 相关的方法实现 `lod_start_pos` 的复制。 +- `OperatorBase` 中添加一个 member `is_lod_inited{false}` 来保证传递只进行一次 + +一些逻辑如下 + +```c++ +class OperatorBase { +public: + // ... + void InferShape() { + if (!is_load_inited) { + bool do_mutate_lod_info = GetAttr("do_mutate_load_info"); + // find a input having LOD to copy + auto lod_input = ValidLODInput(); + for (auto &output : outputs) { + if (do_mutate_load_info) { + output.ShareMutableLODFrom(lod_input); + } else { + output.ShareConstLODFrom(load_input); + } + } + is_pod_inited = true; + } + + // call op's InferShape + // ... + } + +private: + // ... + bool is_lod_inited{false}; +}; +``` + +如此,`lod_start_pos` 的信息的传递对非OLD的Op的实现是完全透明的。 + +#### `lod_start_pos` 的更新 +上一小节介绍到,对于需要修改 `load_start_pos` 的Op,`OperatorBase` 会分配一块自己的内存以存储修改, +Op在 `Run` 的实现中,操作更新自己的 `load_start_pos` , +而所有依赖其 outputs 的 op 会通过共享的指针自动获取到其更新。 + +## 根据长度排序 +按照长度排序后,从前往后的时间步的batch size会自然地递减,可以直接塞入 Net 做batch计算 + +比如原始的输入: + +``` +origin: +xxxx +xx +xxx + +-> sorted: +xxxx +xxx +xx +``` + +经过 `SegmentInputs` 之后,每个会有4个时间步,每个时间步的输入如下(纵向排列) + +``` +0 1 2 3 +x x x x +x x x +x x +``` + +为了追踪排序前后序列的变化,这里用 +```c++ +struct SortedSeqItem { + void *start{nullptr}; + void *end{nullptr}; +}; + +std::vector sorted_seqs; +``` +来追踪序列排序后的位置,并添加一个新的接口 + +```c++ +std::vector SortBySeqLen(const LODTensor& tensor); +``` + +由于输入序列的顺序变化,以下现有的接口需要针对性地修改: + +- InitMemories, memory需要根据 `sorted_seqs` 重新排列 +- SetmentInputs +- ConcatOutputs + +此外,由于 `sorted_seqs` 需要被 `RecurrentGradientOp` 复用,因此会变成 `RecurrentOp` 一个新的output输出, +之后作为 `RecurrentGradientOp` 的一个输入传入。 + +## InitMemories +由于序列顺序的变化,`boot_memories` 的batch上的element的顺序也需要对应重新排列。 + +## SegmentInputs +`SegmentInputs` 会依赖 `sorted_seqs` 的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。 + +即下面的转变: +``` +origin: +xxxx +xx +xxx + + | + | + \ / + ! +0 1 2 3 +x x x x +x x x +x x +``` +## ConcatOutputs +`ConcatOutputs` 需要 + +- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱) +- 将每个序列concat 为规则的mini-batch表示 + +## 参考文献 +1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing) +2. [mxnet Bucketing](http://mxnet.io/how_to/bucketing.html) +3. [variable length input in RNN scenario](https://discuss.pytorch.org/t/about-the-variable-length-input-in-rnn-scenario/345/5) +4. [Level of details](https://en.wikipedia.org/wiki/Level_of_detail) diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 4129422fa744b2a7cf135b681efa73ffb2ebcdcc..2ad2b66c8f385c858eb34c7ea766f168de9c817e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -18,17 +18,17 @@ namespace operators { class RowWiseAddOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add"); - auto dim0 = inputs[0]->dims(); - auto dim1 = inputs[1]->dims(); + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2UL, + "Two inputs is needed by rowwise add"); + auto dim0 = ctx.Input(0)->dims(); + auto dim1 = ctx.Input(1)->dims(); PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); - PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); - outputs[0]->Resize(inputs[0]->dims()); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 4596925e9322f373c822608fd9aa6ecee6144d4c..b86dd5463436bf521f9939b1c421b39f11102769 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -21,14 +21,12 @@ namespace operators { template class RowWiseAddKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto in0 = context.Input(0)->Get(); - auto in1 = context.Input(1)->Get(); - auto* out = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto out = context.Output(0); out->mutable_data(context.GetPlace()); - auto input = EigenMatrix::From(in0); - auto bias = EigenVector::From(in1); + auto input = EigenMatrix::From(*context.Input(0)); + auto bias = EigenVector::From(*context.Input(1)); auto output = EigenMatrix::From(*out); const int bias_size = bias.dimension(0); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index f6c654a9e7083704e353c276e0abc975f4e61ef9..9a84dc8af3b3e649b776ca8a97dedba1fa3ff48d 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -19,16 +19,15 @@ namespace operators { class SGDOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); - PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); - PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); - PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); - PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set"); + PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set"); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set"); + PADDLE_ENFORCE(ctx.Input(0)->dims() == ctx.Input(1)->dims(), "Two input of SGD Op's dimension must be same."); - outputs[0]->Resize(inputs[0]->dims()); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 65179d323bd991b8b4e196c069a11cd901c62082..af1dfdd756ceb9991bee6b85c3281c05f0fb5a9f 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -21,16 +21,16 @@ namespace operators { template class SGDOpKernel : public OpKernel { public: - void Compute(const KernelContext& ctx) const override { - auto param = ctx.Input("param")->Get(); - auto grad = ctx.Input("grad")->Get(); - auto* param_out = ctx.Output(0)->GetMutable(); + void Compute(const ExecutionContext& ctx) const override { + auto param = ctx.Input("param"); + auto grad = ctx.Input("grad"); + auto param_out = ctx.Output(0); float lr = ctx.op_.GetAttr("learning_rate"); param_out->mutable_data(ctx.GetPlace()); EigenVector::Flatten(*param_out).device(*(ctx.GetEigenDevice())) = - EigenVector::Flatten(param) - lr * EigenVector::Flatten(grad); + EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); } }; diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 716f1d9c4dbc45e2d5569f8d634b06fd988a149c..a81ab262cc6fe7bdff0045259e0030f3d46f503f 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -18,11 +18,10 @@ namespace operators { class SigmoidOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); - PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); - outputs[0]->Resize(inputs[0]->dims()); + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -38,8 +37,7 @@ public: class SigmoidOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "SigmoidGrad"; return ""; diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 896a6f5d83e0f96de50e3aaae6f545172bf5da14..3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -22,15 +22,14 @@ namespace operators { template class SigmoidKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input = context.Input(0)->Get(); - auto* output = context.Output(0)->GetMutable(); - + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(input)).exp()); + 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index df60b62fa6ac8d67c9dadc40ec49aaedab92bc88..5b59fad7d5f9729b0862f8cd78cb32f94f87f513 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -18,14 +18,13 @@ namespace operators { class SoftmaxOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); - PADDLE_ENFORCE(inputs[0]->dims().size() == 2, + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax"); + PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, "The input of softmax op must be matrix"); - PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); - - outputs[0]->Resize(inputs[0]->dims()); + PADDLE_ENFORCE(ctx.OutputSize() == 1, + "Only one output is need for softmax"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -41,8 +40,7 @@ public: class SoftmaxOpGrad : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override {} + void InferShape(const InferShapeContext &ctx) const override {} std::string DebugString() const override { LOG(INFO) << "SoftmaxOpGrad"; return ""; diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 625a87b58560231572c1cca2a21bd0c47c8cb296..a5c19c5fc7c6f5909dbb355aff09bf15405b6957 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -22,12 +22,12 @@ namespace operators { template class SoftmaxKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input = context.Input(0)->Get(); - auto* output = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); - auto logits = EigenMatrix::From(input); + auto logits = EigenMatrix::From(*input); auto softmax = EigenMatrix::From(*output); const int kBatchDim = 0; diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index b712e457ff60e8b30b87c0d549693d53e9f05d59..93b62cddc819e0d1fd48323e474a294ff0d327e1 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -22,7 +22,13 @@ namespace paddle { namespace operators { using OpKernel = framework::OpKernel; -using KernelContext = framework::KernelContext; +using InferShapeContext = framework::InferShapeContext; +using ExecutionContext = framework::ExecutionContext; +using Variable = framework::Variable; +template +using EigenScalar = framework::EigenScalar; template diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 9c1d94e9e703caf2db92ca4a8eac975317e6b945..a928e097787db9deebe1c6eab263190caacac7eb 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -20,12 +20,101 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device() return reinterpret_cast(this)->eigen_device(); } +CPUDeviceContext::CPUDeviceContext() { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +CPUDeviceContext::CPUDeviceContext(CPUPlace place) { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { + return eigen_device_.get(); +} + +Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } + #ifndef PADDLE_ONLY_CPU + template <> Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } -#endif + +CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { + SetDeviceId(place_.device); + // TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly + // here will cause segment fault. We must implement a class derived from + // Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id + // later. Please refer to the implementation of class EigenCudaStreamDevice + // in TensorFlow. + // + // We find that CUDA 7 introduces a new option, the per-thread default stream, + // that has two effects. Please refer to https://devblogs.nvidia.com/ + // parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/ + // + // So, we decide to use default stream and add –default-stream per-thread nvcc + // flag. Than, two threads with two CUDADeviceContexts will run parallelly. + eigen_stream_.reset(new Eigen::CudaStreamDevice()); + eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); +} + +CUDADeviceContext::~CUDADeviceContext() { + SetDeviceId(place_.device); + Wait(); + if (cublas_handle_) { + PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + } + + if (cudnn_handle_) { + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); + } + + if (curand_generator_) { + PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); + } + eigen_stream_.reset(); + eigen_device_.reset(); +} + +Place CUDADeviceContext::GetPlace() const { return place_; } + +void CUDADeviceContext::Wait() const { + PADDLE_ENFORCE(cudaStreamSynchronize(0)); +} + +Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { + return eigen_device_.get(); +} + +cublasHandle_t CUDADeviceContext::cublas_handle() { + if (!cublas_handle_) { + SetDeviceId(place_.device); + PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); + } + return cublas_handle_; +} + +cudnnHandle_t CUDADeviceContext::cudnn_handle() { + if (!cudnn_handle_) { + SetDeviceId(place_.device); + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + } + return cudnn_handle_; +} + +curandGenerator_t CUDADeviceContext::curand_generator() { + if (!curand_generator_) { + SetDeviceId(place_.device); + PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, + CURAND_RNG_PSEUDO_DEFAULT)); + PADDLE_ENFORCE( + dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + } + return curand_generator_; +} + +#endif // PADDLE_ONLY_CPU } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 5a366dccdc080ce61cf45ea2e22e23b703313682..2038fafe2e15ec2631726643695ac6cbc317fed9 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -39,14 +39,13 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: - CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } + CPUDeviceContext(); + CPUDeviceContext(CPUPlace); + virtual ~CPUDeviceContext() {} - Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } + Eigen::DefaultDevice* eigen_device() const; - Place GetPlace() const override { - Place retv = CPUPlace(); - return retv; - } + Place GetPlace() const override; private: std::unique_ptr eigen_device_; @@ -54,119 +53,46 @@ class CPUDeviceContext : public DeviceContext { #ifndef PADDLE_ONLY_CPU -class GPUPlaceGuard { +class CUDADeviceContext : public DeviceContext { public: - explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { - if (previous_ != new_place) { - paddle::platform::SetDeviceId(new_place.device); - } - } + explicit CUDADeviceContext(GPUPlace); + virtual ~CUDADeviceContext(); - ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } + /*! \brief Wait for all operations completion in the stream. */ + void Wait() const; - private: - GPUPlace previous_; -}; + /*! \brief Return place in the device context. */ + Place GetPlace() const override; -class CUDADeviceContext : public DeviceContext { - public: - explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); - eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); - eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - } - - Place GetPlace() const override { - Place retv = GPUPlace(); - return retv; - } - - void Wait() { - PADDLE_ENFORCE(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); - } - - cudaStream_t stream() const { return stream_; } - - Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); } - - cublasHandle_t cublas_handle() { - if (!blas_handle_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_), - "cublasCreate failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::cublasSetStream(blas_handle_, stream_), - "cublasSetStream failed"); - } - return blas_handle_; - } - - cudnnHandle_t cudnn_handle() { - if (!dnn_handle_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_), - "cudnnCreate failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_), - "cudnnSetStream failed"); - } - return dnn_handle_; - } - - curandGenerator_t curand_generator() { - if (!rand_generator_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_), - "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::curandSetStream(rand_generator_, stream_), - "curandSetStream failed"); - } - return rand_generator_; - } - - ~CUDADeviceContext() { - Wait(); - if (blas_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_), - "cublasDestroy failed"); - } - - if (dnn_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_), - "cudnnDestroy failed"); - } - - if (rand_generator_) { - PADDLE_ENFORCE( - paddle::platform::dynload::curandDestroyGenerator(rand_generator_), - "curandDestroyGenerator failed"); - } - eigen_stream_.reset(); - eigen_device_.reset(); - PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); - } + /*! \brief Return eigen device in the device context. */ + Eigen::GpuDevice* eigen_device() const; + + // clang-format off + /*! \brief Return cublas handle in the device context. */ + cublasHandle_t cublas_handle (); + + /*! \brief Return cudnn handle in the device context. */ + cudnnHandle_t cudnn_handle (); + + /*! \brief Return curand handle in the device context. */ + curandGenerator_t curand_generator(); + // clang-format on private: - GPUPlace gpu_place_; - cudaStream_t stream_; + GPUPlace place_; - std::unique_ptr eigen_stream_; + private: std::unique_ptr eigen_device_; + std::unique_ptr eigen_stream_; - cublasHandle_t blas_handle_{nullptr}; - - cudnnHandle_t dnn_handle_{nullptr}; + private: + uint64_t seed_; - int random_seed_; - curandGenerator_t rand_generator_{nullptr}; + // clang-format off + cudnnHandle_t cudnn_handle_ = nullptr; + cublasHandle_t cublas_handle_ = nullptr; + curandGenerator_t curand_generator_ = nullptr; + // clang-format on }; #endif diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index a3a10fc07fed1cb245dc721d18ea927ea25393a1..26c8eb78e614a68ec9728aad727d8fe3e08547ae 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once +#include #include +#include #include #include #include @@ -39,12 +41,22 @@ namespace platform { struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; - EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) { + static constexpr int TRACE_STACK_LIMIT = 100; try { std::rethrow_exception(exp_); } catch (const std::exception& exp) { - err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l); + std::ostringstream sout; + sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; + sout << "Call Stacks: " << std::endl; + void* call_stack[TRACE_STACK_LIMIT]; + int sz = backtrace(call_stack, TRACE_STACK_LIMIT); + auto line = backtrace_symbols(call_stack, sz); + for (int i = 0; i < sz; ++i) { + sout << line[i] << std::endl; + } + free(line); + err_str_ = sout.str(); } } @@ -58,11 +70,6 @@ struct EnforceNotMet : public std::exception { // For more details, please check https://stackoverflow.com/a/43870188/724872. #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) -template -inline void throw_on_error(T e) { - throw_on_error(e, ""); -} - template inline typename std::enable_if::type throw_on_error( int stat, const Args&... args) { @@ -132,6 +139,11 @@ inline typename std::enable_if::type throw_on_error( #endif // PADDLE_ONLY_CPU +template +inline void throw_on_error(T e) { + throw_on_error(e, ""); +} + #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index fd1a142b40e19d257505f0465ce6c7a62e5cbc35..845589dcb1997b662b5175e5cce320eec4be4a8d 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,9 @@ -cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op fc_op sgd_op cross_entropy_op) +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_network_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index ccefcd2511ca0f132d127463166f9f40779a1d85..801ef50e577d563f4534f33e49aa7b72ab840d89 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -33,9 +33,11 @@ USE_OP(onehot_cross_entropy); USE_OP_WITHOUT_KERNEL(fc); USE_OP(sgd); USE_OP(mul); +USE_OP(mean); USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); +USE_OP_WITHOUT_KERNEL(recurrent_op); template void ExposeOperator(ClassType& m) { @@ -94,17 +96,25 @@ All parameter, weight, gradient are variables in Paddle. [](pd::Variable& self) -> pd::Tensor* { return self.GetMutable(); }, + py::return_value_policy::reference) + .def("get_net", + [](pd::Variable& self) -> pd::NetOp* { + return self.GetMutable(); + }, py::return_value_policy::reference); - py::class_>(m, "Scope") - .def(py::init&>()) - .def("get_var", - &pd::Scope::GetVariable, + py::class_(m, "Scope", "") + .def("new_var", + [](pd::Scope& self, const std::string& name) -> pd::Variable* { + return self.NewVar(name); + }, py::return_value_policy::reference) - .def("create_var", - &pd::Scope::CreateVariable, + .def("find_var", &pd::Scope::FindVar, py::return_value_policy::reference) + .def(py::init<>()) + .def("new_scope", + [](pd::Scope& self) -> pd::Scope* { return &self.NewScope(); }, py::return_value_policy::reference) - .def("get_var_name", &pd::Scope::GetVariableName); + .def("drop_kids", &pd::Scope::DropKids); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 83f72c137bdf5e55f28be908321bd2ccd6c906fe..3bee5b572ae42750332b69e28af980ae325532da 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -198,6 +198,11 @@ message RowConvConfig { required uint32 context_length = 1; } +message SliceConfig { + required uint32 start = 1; + required uint32 end = 2; +} + message ProjectionConfig { required string type = 1; required string name = 2; @@ -218,6 +223,10 @@ message ProjectionConfig { // For pool optional PoolConfig pool_conf = 12; + + // For slice + // Each slice output is the input[start, end) + repeated SliceConfig slices = 13; } message OperatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 3587ea1752a0cdd024543c25715448c874011f43..53a9c40619b37492895388552058943435885657 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -565,6 +565,35 @@ class IdentityOffsetProjection(Projection): return [] +@config_class +class SliceProjection(Projection): + type = 'slice' + + def __init__(self, input_layer_name, slices, **xargs): + super(SliceProjection, self).__init__(input_layer_name, **xargs) + input = g_layer_map[input_layer_name] + if input.type in ["exconv", "cudnn_conv"]: + # the slice operator is for the channel dimension + assert input.num_filters is not None + channels = input.num_filters + image_size = input.size / channels + assert slices[len(slices) - 1][1] <= channels + for i in xrange(len(slices)): + slice = self.proj_conf.slices.add() + slice.start = slices[i][0] * image_size + slice.end = slices[i][1] * image_size + self.size += slice.end - slice.start + else: + config_assert(False, + 'Currently the input should be convolution layer') + + def calc_parameter_size(self, input_size, output_size): + return 0 + + def calc_parameter_dims(self, input_size, output_size): + return [] + + # DotMulProjection performs element-wise multiplication with weight @config_class class DotMulProjection(Projection): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 9985a290a547eabb0d2c55d105be7132981599b3..31cef4acfb8d948171a1ed021f572bb18331ade7 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -129,6 +129,7 @@ __all__ = [ 'prelu_layer', 'gated_unit_layer', 'crop_layer', + 'slice_projection', ] @@ -538,6 +539,45 @@ def identity_projection(input, offset=None, size=None): return proj +def slice_projection(input, slices): + """ + slice_projection can slice the input value into multiple parts, + and then select some of them to merge into a new output. + + .. math:: + output = [input.slices()] + + The example usage is: + + .. code-block:: python + + proj = slice_projection(input=layer, slices=[(0, 10), (20, 30)]) + + Note that slice_projection should not have any parameter. + + :param input: Input Layer. + :type input: LayerOutput + :param slices: An array of slice parameters. + Each slice contains the start and end offsets based + on the input. + :type slices: pair of int + :return: A SliceProjection object + :rtype: SliceProjection + """ + assert len(slices) >= 1 + start = 0 + for i in xrange(len(slices)): + assert len(slices[i]) == 2 + # The start position of the next slice needs to be greater than + # or equal to the end position of the previous slice. + assert slices[i][0] >= start + assert slices[i][1] >= slices[i][0] + start = slices[i][1] + proj = SliceProjection(input_layer_name=input.name, slices=slices) + proj.origin = input + return proj + + @wrap_param_attr_default() def scaling_projection(input, param_attr=None): """ diff --git a/python/paddle/v2/framework/default_scope_funcs.py b/python/paddle/v2/framework/default_scope_funcs.py index 4e772326c94b7ee44906c71f13e9420e078a1917..1b5580c8b30f69016f187b1d8710a57b5f7cfa9f 100644 --- a/python/paddle/v2/framework/default_scope_funcs.py +++ b/python/paddle/v2/framework/default_scope_funcs.py @@ -5,7 +5,7 @@ Default scope function. thread-local stack of Scope. Top of that stack is current scope, the bottom of that stack is all scopes' parent. -Invoking `create_var/get_var` can `create/get` variable in current scope. +Invoking `new_var/find_var` can `new/find` variable in current scope. Invoking `enter_local_scope/leave_local_scope` can create or destroy local scope. @@ -19,8 +19,8 @@ import threading __tl_scope__ = threading.local() __all__ = [ - 'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var', - 'get_var', 'scoped_function' + 'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'new_var', + 'find_var', 'scoped_function' ] @@ -33,7 +33,7 @@ def get_cur_scope(): if cur_scope_stack is None: __tl_scope__.cur_scope = list() if len(__tl_scope__.cur_scope) == 0: - __tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None)) + __tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope()) return __tl_scope__.cur_scope[-1] @@ -42,7 +42,7 @@ def enter_local_scope(): Enter a new local scope """ cur_scope = get_cur_scope() - new_scope = paddle.v2.framework.core.Scope(cur_scope) + new_scope = cur_scope.new_scope() __tl_scope__.cur_scope.append(new_scope) @@ -51,20 +51,21 @@ def leave_local_scope(): Leave local scope """ __tl_scope__.cur_scope.pop() + get_cur_scope().drop_kids() -def create_var(name): +def new_var(name): """ create variable in current scope. """ - return get_cur_scope().create_var(name) + return get_cur_scope().new_var(name) -def get_var(name): +def find_var(name): """ get variable in current scope. """ - return get_cur_scope().get_var(name) + return get_cur_scope().find_var(name) def scoped_function(func): diff --git a/python/paddle/v2/framework/network.py b/python/paddle/v2/framework/network.py index c85e87413ef45f40755709e134a277b8d8d1e233..cfeb0e3dec0fd2c6ad4d2d2501f97932495fdd41 100644 --- a/python/paddle/v2/framework/network.py +++ b/python/paddle/v2/framework/network.py @@ -1,6 +1,6 @@ import paddle.v2.framework.core as core from paddle.v2.framework.create_op_creation_methods import op_creations -from default_scope_funcs import create_var, get_var, get_cur_scope +from default_scope_funcs import new_var, find_var, get_cur_scope __all__ = ['Network'] # Only expose Network @@ -29,12 +29,15 @@ class NetworkFunctor(object): if ipt in kwargs: var = kwargs[ipt] if isinstance(var, basestring): - var = create_var(var) + tmp = new_var(var) + self.net.var_names[tmp] = var + var = tmp + if not isinstance(var, core.Variable): raise TypeError( "Input of op creation must be string or variable") - kwargs[ipt] = get_cur_scope().get_var_name(var) + kwargs[ipt] = self.net.var_names[var] notemp_outputs = self.func.all_not_temp_output_args @@ -49,17 +52,20 @@ class NetworkFunctor(object): if opt in kwargs: var = kwargs[opt] if isinstance(var, basestring): - var = create_var(var) + tmp = new_var(var) + self.net.var_names[tmp] = var + var = tmp + if not isinstance(var, core.Variable): raise TypeError( "Output of op creation must be string or variable") - kwargs[opt] = get_cur_scope().get_var_name(var) + kwargs[opt] = self.net.var_names[var] op = self.func(**kwargs) self.net.net.add_op(op) - lst = [get_var(kwargs[opt]) for opt in notemp_outputs] + lst = [find_var(kwargs[opt]) for opt in notemp_outputs] if len(lst) == 1: return lst[0] elif len(lst) == 0: @@ -89,6 +95,7 @@ class Network(object): self.net = core.Net.create() funcs = (func_name for func_name in dir(op_creations) if not func_name.startswith("__")) + self.var_names = dict() # TODO(yuyang18): This code can work, but do not generate a good # docstring, try to give a better way generate function in runtime diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index cdaaa60674937c68c38656a5046bcb29f44d6c8b..540636a0e8100fbf97231bd548dbc1176b07daca 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -10,6 +10,7 @@ add_python_test(test_framework test_sgd_op.py test_cross_entropy_op.py test_mul_op.py + test_mean_op.py test_sigmoid_op.py test_softmax_op.py test_rowwise_add_op.py diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 7b62313f8aca5e9f515d1a9e6df3bb6f51b974fb..99085c367221150c8386a24e8d90d58fd63894c4 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -24,13 +24,13 @@ class OpTestMeta(type): func = getattr(creation.op_creations, self.type, None) self.assertIsNotNone(func) - scope = core.Scope(None) + scope = core.Scope() kwargs = dict() for in_name in func.all_input_args: if hasattr(self, in_name): kwargs[in_name] = in_name - var = scope.create_var(in_name).get_tensor() + var = scope.new_var(in_name).get_tensor() arr = getattr(self, in_name) var.set_dims(arr.shape) var.set(arr) @@ -40,7 +40,7 @@ class OpTestMeta(type): for out_name in func.all_output_args: if hasattr(self, out_name): kwargs[out_name] = out_name - scope.create_var(out_name).get_tensor() + scope.new_var(out_name).get_tensor() for attr_name in func.all_attr_args: if hasattr(self, attr_name): @@ -54,7 +54,7 @@ class OpTestMeta(type): op.run(scope, ctx) for out_name in func.all_output_args: - actual = numpy.array(scope.get_var(out_name).get_tensor()) + actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = getattr(self, out_name) # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul # has some diff, and could not pass unittest. So I set decimal 3 here. diff --git a/python/paddle/v2/framework/tests/test_default_scope_funcs.py b/python/paddle/v2/framework/tests/test_default_scope_funcs.py index 81033deb1546c81e2566ec5474f45dc56781644a..495863c4562b5a2d6755fb02e21a6b0c845fd7b6 100644 --- a/python/paddle/v2/framework/tests/test_default_scope_funcs.py +++ b/python/paddle/v2/framework/tests/test_default_scope_funcs.py @@ -7,19 +7,19 @@ class TestDefaultScopeFuncs(unittest.TestCase): self.assertIsNotNone(get_cur_scope()) def test_none_variable(self): - self.assertIsNone(get_var("test")) + self.assertIsNone(find_var("test")) def test_create_var_get_var(self): - var_a = create_var("var_a") + var_a = new_var("var_a") self.assertIsNotNone(var_a) - self.assertIsNotNone(get_cur_scope().get_var('var_a')) + self.assertIsNotNone(get_cur_scope().find_var('var_a')) enter_local_scope() - self.assertIsNotNone(get_cur_scope().get_var('var_a')) + self.assertIsNotNone(get_cur_scope().find_var('var_a')) leave_local_scope() def test_var_get_int(self): def __new_scope__(): - i = create_var("var_i") + i = new_var("var_i") self.assertFalse(i.is_int()) i.set_int(10) self.assertTrue(i.is_int()) diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py index 59e7e61249e2a7d49a17e5d87209f03b8f35f730..43931aac406cd93beede008066aa1c0c00eba6ea 100644 --- a/python/paddle/v2/framework/tests/test_fc_op.py +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -6,13 +6,13 @@ import paddle.v2.framework.create_op_creation_methods as creation class TestFc(unittest.TestCase): def test_fc(self): - scope = core.Scope(None) - x = scope.create_var("X") + scope = core.Scope() + x = scope.new_var("X") x_tensor = x.get_tensor() x_tensor.set_dims([1000, 784]) x_tensor.alloc_float() - w = scope.create_var("W") + w = scope.new_var("W") w_tensor = w.get_tensor() w_tensor.set_dims([784, 100]) w_tensor.alloc_float() @@ -25,10 +25,10 @@ class TestFc(unittest.TestCase): op = creation.op_creations.fc(X="X", Y="Y", W="W") for out in op.outputs(): - if scope.get_var(out) is None: - scope.create_var(out).get_tensor() + if scope.find_var(out) is None: + scope.new_var(out).get_tensor() - tensor = scope.get_var("Y").get_tensor() + tensor = scope.find_var("Y").get_tensor() op.infer_shape(scope) self.assertEqual([1000, 100], tensor.shape()) diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py new file mode 100644 index 0000000000000000000000000000000000000000..78fff1eeff998109a51ea662f963a102eff49d3a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -0,0 +1,16 @@ +import unittest +from op_test_util import OpTestMeta +import numpy as np + + +class TestMeanOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "mean" + self.X = np.random.random((32, 784)).astype("float32") + self.Out = np.mean(self.X) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0457e3f16a709140180ce433c1d56d146f0b6974 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -0,0 +1,92 @@ +import paddle.v2.framework.core as core +import unittest +import numpy as np +import paddle.v2.framework.create_op_creation_methods as creation + +ops = creation.op_creations + + +def create_tensor(scope, name, shape): + tensor = scope.create_var(name).get_tensor() + tensor.set_dims(shape) + tensor.alloc_float() + tensor.set(np.random.random(shape)) + return tensor + + +class TestRNN(unittest.TestCase): + ''' + Test RNNOp + + equation: + h_t = \sigma (W x_t + U h_{t-1}) + weights: + - W + - U + vars: + - x + memories: + - h + outputs: + - h + ''' + + def init(self): + input_dim = 30 + batch_size = 50 + weight_dim = 15 + + self.scope = core.Scope(None) + + # create vars + create_tensor(self.scope, "x", [batch_size, input_dim]) + create_tensor(self.scope, "W", [input_dim, weight_dim]) + create_tensor(self.scope, "U", [weight_dim, weight_dim]) + create_tensor(self.scope, "h_boot", [batch_size, weight_dim]) + + x_alias = "x@alias" + y_alias = "y@alias" + memory = "h@alias" + prememory = "h@pre" + output = "rnn_out" + output_alias = "rnn_out@alias" + + # create step net + stepnet_var = self.scope.create_var("stepnet") + stepnet = stepnet_var.get_net() + # stepnet = core.Net.create() + x_fc_op = ops.fc(X=x_alias, W="W", Y="Wx") + h_fc_op = ops.fc(X=prememory, W="U", Y="Uh") + sum_op = ops.add_two(X="Wx", Y="Uh", Out="sum") + sig_op = ops.sigmoid(X="sum", Y=memory) + stepnet.add_op(x_fc_op) + stepnet.add_op(h_fc_op) + stepnet.add_op(sum_op) + stepnet.add_op(sig_op) + stepnet.complete_add_op(True) + + # create RNNOp + rnnop = ops.recurrent_op( + # inputs + inlinks=["x"], + boot_memories=["h_boot"], + step_net="stepnet", + # outputs + outlinks=[output], + step_scopes="step_scopes", + # attributes + inlink_alias=["x@alias"], + outlink_alias=[output_alias], + pre_memories=[prememory], + memories=[memory]) + + ctx = core.DeviceContext.cpu_context() + rnnop.infer_shape(self.scope) + rnnop.run(self.scope, ctx) + + def test_recurrent(self): + self.init() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_scope.py b/python/paddle/v2/framework/tests/test_scope.py index f0ee45cfc75e486c693a00d92a97ac0970195581..1ce9454067f91f39f01d9eb4c912857464a3c1cb 100644 --- a/python/paddle/v2/framework/tests/test_scope.py +++ b/python/paddle/v2/framework/tests/test_scope.py @@ -5,29 +5,29 @@ import unittest class TestScope(unittest.TestCase): def test_create_destroy(self): paddle_c = paddle.v2.framework.core - scope = paddle_c.Scope(None) + scope = paddle_c.Scope() self.assertIsNotNone(scope) - scope_with_parent = paddle_c.Scope(scope) + scope_with_parent = scope.new_scope() self.assertIsNotNone(scope_with_parent) def test_none_variable(self): paddle_c = paddle.v2.framework.core - scope = paddle_c.Scope(None) - self.assertIsNone(scope.get_var("test")) + scope = paddle_c.Scope() + self.assertIsNone(scope.find_var("test")) def test_create_var_get_var(self): paddle_c = paddle.v2.framework.core - scope = paddle_c.Scope(None) - var_a = scope.create_var("var_a") + scope = paddle_c.Scope() + var_a = scope.new_var("var_a") self.assertIsNotNone(var_a) - self.assertIsNotNone(scope.get_var('var_a')) - scope2 = paddle_c.Scope(scope) - self.assertIsNotNone(scope2.get_var('var_a')) + self.assertIsNotNone(scope.find_var('var_a')) + scope2 = scope.new_scope() + self.assertIsNotNone(scope2.find_var('var_a')) def test_var_get_int(self): paddle_c = paddle.v2.framework.core - scope = paddle_c.Scope(None) - var = scope.create_var("test_int") + scope = paddle_c.Scope() + var = scope.new_var("test_int") var.set_int(10) self.assertTrue(var.is_int()) self.assertEqual(10, var.get_int()) diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/framework/tests/test_tensor.py index b72aff3b9cd16595c7e81856642196b2bb61a790..6d59863cea29832f648139e07a134050e22bfa21 100644 --- a/python/paddle/v2/framework/tests/test_tensor.py +++ b/python/paddle/v2/framework/tests/test_tensor.py @@ -5,8 +5,8 @@ import numpy class TestScope(unittest.TestCase): def test_int_tensor(self): - scope = core.Scope(None) - var = scope.create_var("test_tensor") + scope = core.Scope() + var = scope.new_var("test_tensor") tensor = var.get_tensor() tensor.set_dims([1000, 784]) @@ -23,8 +23,8 @@ class TestScope(unittest.TestCase): self.assertEqual(2.0, tensor_array_2[19, 11]) def test_float_tensor(self): - scope = core.Scope(None) - var = scope.create_var("test_tensor") + scope = core.Scope() + var = scope.new_var("test_tensor") tensor = var.get_tensor() tensor.set_dims([1000, 784]) diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index b658a81630733fea3976b812afe819d76de4cb25..fc718f031e2267e737adbc340226e145bf614bf2 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -76,3 +76,6 @@ class client(object): # Memory created from C should be freed. get_c_lib().mem_free(ret.contents) return record, 0 + + def paddle_start_get_records(self, pass_id): + get_c_lib().paddle_start_get_records(self.c, pass_id) diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 55a0fcdf56af7a8c9bee3255ea6f1d1ae1b34893..d0f18e4b6611fa56654e7f2a0144758339cb9e19 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could be used in user program. """ -__all__ = ['np_array', 'text_file', "recordio"] +__all__ = ['np_array', 'text_file', "cloud_reader"] def np_array(x): @@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100): return dec.buffered(reader, buf_size) -def recordio(paths, buf_size=100): +pass_num = 0 + + +def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): """ - Creates a data reader that outputs record one one by one - from given local or cloud recordio path. + Create a data reader that yield a record one by one from + the paths: :path: path of recordio files. + :etcd_endpoints: the endpoints for etcd cluster :returns: data reader of recordio files. + + .. code-block:: python + from paddle.v2.reader.creator import cloud_reader + etcd_endpoints = "http://127.0.0.1:2379" + trainer.train.( + reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints), + ) """ import os - import paddle.v2.master.client as cloud - - if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): - return recordio_local(paths) - - host_name = "MASTER_SERVICE_HOST" - if host_name not in os.environ.keys(): - raise Exception('not find ' + host_name + ' in environment variable.') - - addr = os.environ(host) + import cPickle as pickle + import paddle.v2.master as master + c = master.client(etcd_endpoints, timeout_sec, buf_size) + c.set_dataset(paths) def reader(): - c = cloud(addr, buf_size) - c.set_dataset(paths) + global pass_num + c.paddle_start_get_records(pass_num) + pass_num += 1 while True: - r, err = client.next_record() - if err < 0: + r, e = c.next_record() + if not r: + if e != -2: + print "get record error: ", e break - yield r - - c.release() + yield pickle.loads(r) return reader diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index b42d273ecfe6c4bc5706ec52617960b83496d70d..359f3eeefbe8efeb343cc875c707c9251a7087fb 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase): self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) -class TestRecordIO(unittest.TestCase): - def test_recordio(self): - path = os.path.join( - os.path.dirname(__file__), "test_recordio_creator.dat") - reader = paddle.v2.reader.creator.recordio([path]) - for idx, r in enumerate(reader()): - self.assertSequenceEqual(r, str(idx)) - - if __name__ == '__main__': unittest.main()