diff --git a/go/.gitignore b/go/.gitignore deleted file mode 100644 index 398d70ca375ffceccdbfc82a4851a6830ca31264..0000000000000000000000000000000000000000 --- a/go/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -vendor/ -.glide/ -proto/*.go diff --git a/go/CMakeLists.txt b/go/CMakeLists.txt deleted file mode 100644 index f3a9296c2c66cd96419cae37c3ac2c93c2b033f5..0000000000000000000000000000000000000000 --- a/go/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -add_subdirectory(pserver/client/c) -add_subdirectory(cmd/pserver) -add_subdirectory(cmd/master) -add_subdirectory(master/c) -add_subdirectory(master) -add_subdirectory(pserver) -add_subdirectory(pserver/client) -add_subdirectory(utils/networkhelper) diff --git a/go/cmd/master/CMakeLists.txt b/go/cmd/master/CMakeLists.txt deleted file mode 100644 index fc99d8d3bd1ec1941b7a068cf8417f0663dea8c0..0000000000000000000000000000000000000000 --- a/go/cmd/master/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -go_binary(master SRC master.go) diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go deleted file mode 100644 index 537df59c860a3cb77ecd8287cd352397d7f7a4e4..0000000000000000000000000000000000000000 --- a/go/cmd/master/master.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "fmt" - "net" - "net/http" - "net/rpc" - "os" - "os/signal" - "strconv" - "strings" - "time" - - log "github.com/inconshreveable/log15" - "github.com/namsral/flag" - - "github.com/PaddlePaddle/Paddle/go/master" - "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" -) - -func main() { - port := flag.Int("port", 8080, "port of the master server.") - ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") - endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") - taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.") - taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.") - chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.") - logLevel := flag.String("log-level", "info", - "log level, possible values: debug, info, warn, error, crit") - flag.Parse() - - lvl, err := log.LvlFromString(*logLevel) - if err != nil { - panic(err) - } - - log.Root().SetHandler( - log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)), - ) - - if *endpoints == "" { - log.Warn("-endpoints not set, fault tolerance not be enabled.") - } - - var store master.Store - if *endpoints != "" { - eps := strings.Split(*endpoints, ",") - ip, err := networkhelper.GetExternalIP() - if err != nil { - log.Crit("get external ip error", log.Ctx{"error": err}) - panic(err) - } - - addr := fmt.Sprintf("%s:%d", ip, *port) - store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) - if err != nil { - log.Crit("error creating etcd client.", log.Ctx{"error": err}) - panic(err) - } - } else { - store = &master.InMemStore{} - } - - shutdown := func() { - log.Info("shutting down gracefully") - err := store.Shutdown() - if err != nil { - log.Error("shutdown error", log.Ctx{"error": err}) - } - } - - // Guaranteed to run even panic happens. - defer shutdown() - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - - s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) - if err != nil { - log.Crit("error creating new service.", log.Ctx{"error": err}) - panic(err) - } - - err = rpc.Register(s) - if err != nil { - log.Crit("error registering to etcd.", log.Ctx{"error": err}) - panic(err) - } - - rpc.HandleHTTP() - l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) - if err != nil { - log.Crit("error listing to port", log.Ctx{"error": err, "port": *port}) - panic(err) - } - - go func() { - err = http.Serve(l, nil) - if err != nil { - log.Crit("error serving HTTP", log.Ctx{"error": err}) - panic(err) - } - }() - - <-c -} diff --git a/go/cmd/pserver/.gitignore b/go/cmd/pserver/.gitignore deleted file mode 100644 index fffd9adc4fde9681ad2a58fcf594d20bdd86ab45..0000000000000000000000000000000000000000 --- a/go/cmd/pserver/.gitignore +++ /dev/null @@ -1 +0,0 @@ -pserver diff --git a/go/cmd/pserver/CMakeLists.txt b/go/cmd/pserver/CMakeLists.txt deleted file mode 100644 index 20d033c938648d1b1e5c5ed1b8a738a543c325cf..0000000000000000000000000000000000000000 --- a/go/cmd/pserver/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -go_binary(pserver SRCS pserver.go DEPS paddle_go_optimizer) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go deleted file mode 100644 index 271274cafc5c94a2c89ac211dba7a3a2bd232026..0000000000000000000000000000000000000000 --- a/go/cmd/pserver/pserver.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "net" - "net/http" - "net/rpc" - "os" - "os/signal" - "strconv" - "time" - - "github.com/namsral/flag" - "github.com/topicai/candy" - - "github.com/PaddlePaddle/Paddle/go/pserver" - log "github.com/inconshreveable/log15" -) - -func main() { - port := flag.Int("port", 8001, "port of the pserver") - index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry") - etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", - "comma separated endpoint string for pserver to connect to etcd") - dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout") - etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds") - numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") - checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") - checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") - logLevel := flag.String("log-level", "info", - "log level, possible values: debug, info, warn, error, crit") - flag.Parse() - - lvl, err := log.LvlFromString(*logLevel) - if err != nil { - panic(err) - } - - log.Root().SetHandler( - log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)), - ) - - var idx int - - var cp pserver.Checkpoint - var e *pserver.EtcdClient - if *index >= 0 { - idx = *index - } else { - e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL) - idx, err = e.Register(*port) - candy.Must(err) - - cp, err = pserver.LoadCheckpoint(e, idx) - if err != nil { - if err == pserver.ErrCheckpointNotFound { - log.Info("load checkpoint error", "error", err) - } else { - panic(err) - } - } - } - - shutdown := func() { - log.Info("shutting down gracefully") - sErr := e.Shutdown() - if sErr != nil { - log.Error("error shutting down", log.Ctx{"error": sErr}) - } - } - - // Guaranteed to run even panic happens. - defer shutdown() - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - - s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) - candy.Must(err) - - err = rpc.Register(s) - candy.Must(err) - - rpc.HandleHTTP() - l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) - candy.Must(err) - - go func() { - log.Info("serving pserver", log.Ctx{"port": *port}) - err = http.Serve(l, nil) - candy.Must(err) - }() - - <-c -} diff --git a/go/connection/conn.go b/go/connection/conn.go deleted file mode 100644 index b8353e8e18ed7b40bab057d6226637df1e6e569a..0000000000000000000000000000000000000000 --- a/go/connection/conn.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package connection - -import ( - "errors" - "net/rpc" - "sync" - - log "github.com/sirupsen/logrus" -) - -// TODO(helin): add TCP re-connect logic - -// Conn is a connection to a parameter server -type Conn struct { - mu sync.Mutex - client *rpc.Client - waitConn chan struct{} -} - -// New creates a new connection. -func New() *Conn { - c := &Conn{} - return c -} - -// Close closes the connection. -func (c *Conn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.client == nil { - return nil - } - - return c.client.Close() -} - -// Connect connects the connection to a address. -func (c *Conn) Connect(addr string) error { - c.mu.Lock() - if c.client != nil { - err := c.client.Close() - if err != nil { - c.mu.Unlock() - return err - } - - c.client = nil - } - c.mu.Unlock() - - client, err := rpc.DialHTTP("tcp", addr) - if err != nil { - return err - } - - c.mu.Lock() - defer c.mu.Unlock() - - if c.client == nil { - c.client = client - if c.waitConn != nil { - close(c.waitConn) - c.waitConn = nil - } - } else { - err := client.Close() - if err != nil { - log.Errorln(err) - } - - return errors.New("client already set from a concurrent goroutine") - } - - return nil -} - -// TODO(helin): refactor Call to be able to perform given retry -// policy. - -// Call make a RPC call. -// -// Call will be blocked until the connection to remote RPC service -// being established. -func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error { - c.mu.Lock() - client := c.client - var waitCh chan struct{} - if client == nil { - if c.waitConn != nil { - waitCh = c.waitConn - } else { - waitCh = make(chan struct{}) - c.waitConn = waitCh - } - } - c.mu.Unlock() - - if waitCh != nil { - // wait until new connection being established - <-waitCh - return c.Call(serviceMethod, args, reply) - } - - return client.Call(serviceMethod, args, reply) -} diff --git a/go/glide.lock b/go/glide.lock deleted file mode 100644 index d15fc934dbe511389cc92ce95cededa41ba32b4d..0000000000000000000000000000000000000000 --- a/go/glide.lock +++ /dev/null @@ -1,233 +0,0 @@ -hash: 107c058cf5c9163a75d40eef2273a793c36112683c25d72aa8288827fdde3a19 -updated: 2017-10-30T03:46:19.137696069Z -imports: -- name: github.com/alecthomas/gometalinter - version: bae2f1293d092fd8167939d5108d1b025eaef9de -- name: github.com/beorn7/perks - version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 - subpackages: - - quantile -- name: github.com/boltdb/bolt - version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9 -- name: github.com/cockroachdb/cmux - version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92 -- name: github.com/coreos/etcd - version: f1d7dd87da3e8feab4aaf675b8e29c6a5ed5f58b - subpackages: - - alarm - - auth - - auth/authpb - - client - - clientv3 - - clientv3/concurrency - - compactor - - discovery - - embed - - error - - etcdserver - - etcdserver/api - - etcdserver/api/etcdhttp - - etcdserver/api/v2http - - etcdserver/api/v2http/httptypes - - etcdserver/api/v3client - - etcdserver/api/v3election - - etcdserver/api/v3election/v3electionpb - - etcdserver/api/v3election/v3electionpb/gw - - etcdserver/api/v3lock - - etcdserver/api/v3lock/v3lockpb - - etcdserver/api/v3lock/v3lockpb/gw - - etcdserver/api/v3rpc - - etcdserver/api/v3rpc/rpctypes - - etcdserver/auth - - etcdserver/etcdserverpb - - etcdserver/etcdserverpb/gw - - etcdserver/membership - - etcdserver/stats - - lease - - lease/leasehttp - - lease/leasepb - - mvcc - - mvcc/backend - - mvcc/mvccpb - - pkg/adt - - pkg/contention - - pkg/cors - - pkg/cpuutil - - pkg/crc - - pkg/debugutil - - pkg/fileutil - - pkg/httputil - - pkg/idutil - - pkg/ioutil - - pkg/logutil - - pkg/monotime - - pkg/netutil - - pkg/pathutil - - pkg/pbutil - - pkg/runtime - - pkg/schedule - - pkg/srv - - pkg/tlsutil - - pkg/transport - - pkg/types - - pkg/wait - - proxy/grpcproxy/adapter - - raft - - raft/raftpb - - rafthttp - - snap - - snap/snappb - - store - - version - - wal - - wal/walpb -- name: github.com/coreos/go-semver - version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6 - subpackages: - - semver -- name: github.com/coreos/go-systemd - version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6 - subpackages: - - daemon - - journal - - util -- name: github.com/coreos/pkg - version: 3ac0863d7acf3bc44daf49afef8919af12f704ef - subpackages: - - capnslog -- name: github.com/dgrijalva/jwt-go - version: d2709f9f1f31ebcda9651b03077758c1f3a0018c -- name: github.com/ghodss/yaml - version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7 -- name: github.com/go-stack/stack - version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf -- name: github.com/gogo/protobuf - version: 909568be09de550ed094403c2bf8a261b5bb730a - subpackages: - - proto -- name: github.com/golang/protobuf - version: 4bd1920723d7b7c925de087aa32e2187708897f7 - subpackages: - - jsonpb - - proto -- name: github.com/golang/snappy - version: 553a641470496b2327abcac10b36396bd98e45c9 -- name: github.com/google/btree - version: 925471ac9e2131377a91e1595defec898166fe49 -- name: github.com/grpc-ecosystem/go-grpc-prometheus - version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0 -- name: github.com/grpc-ecosystem/grpc-gateway - version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676 - subpackages: - - runtime - - runtime/internal - - utilities -- name: github.com/inconshreveable/log15 - version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3 -- name: github.com/jonboulle/clockwork - version: 2eee05ed794112d45db504eb05aa693efd2b8b09 -- name: github.com/mattn/go-colorable - version: 5411d3eea5978e6cdc258b30de592b60df6aba96 -- name: github.com/mattn/go-isatty - version: 57fdcb988a5c543893cc61bce354a6e24ab70022 -- name: github.com/matttproud/golang_protobuf_extensions - version: c12348ce28de40eed0136aa2b644d0ee0650e56c - subpackages: - - pbutil -- name: github.com/namsral/flag - version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04 -- name: github.com/PaddlePaddle/recordio - version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81 -- name: github.com/prometheus/client_golang - version: c5b7fccd204277076155f10851dad72b76a49317 - subpackages: - - prometheus -- name: github.com/prometheus/client_model - version: 6f3806018612930941127f2a7c6c453ba2c527d2 - subpackages: - - go -- name: github.com/prometheus/common - version: 49fee292b27bfff7f354ee0f64e1bc4850462edf - subpackages: - - expfmt - - internal/bitbucket.org/ww/goautoneg - - model -- name: github.com/prometheus/procfs - version: a1dba9ce8baed984a2495b658c82687f8157b98f - subpackages: - - xfs -- name: github.com/satori/go.uuid - version: 879c5887cd475cd7864858769793b2ceb0d44feb -- name: github.com/sirupsen/logrus - version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e -- name: github.com/topicai/candy - version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc -- name: github.com/ugorji/go - version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74 - subpackages: - - codec -- name: github.com/xiang90/probing - version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2 -- name: golang.org/x/crypto - version: 9419663f5a44be8b34ca85f08abc5fe1be11f8a3 - repo: https://github.com/golang/crypto.git - vcs: git - subpackages: - - bcrypt - - blowfish - - ssh/terminal -- name: golang.org/x/net - version: c8c74377599bd978aee1cf3b9b63a8634051cec2 - subpackages: - - context - - http2 - - http2/hpack - - idna - - internal/timeseries - - lex/httplex - - trace -- name: golang.org/x/sys - version: e48874b42435b4347fc52bdee0424a52abc974d7 - repo: https://github.com/golang/sys.git - vcs: git - subpackages: - - unix - - windows -- name: golang.org/x/text - version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 - repo: https://github.com/golang/text.git - vcs: git - subpackages: - - secure/bidirule - - transform - - unicode/bidi - - unicode/norm -- name: google.golang.org/grpc - version: 8050b9cbc271307e5a716a9d782803d09b0d6f2d - subpackages: - - codes - - credentials - - grpclog - - internal - - keepalive - - metadata - - naming - - peer - - stats - - tap - - transport -- name: gopkg.in/yaml.v2 - version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b -testImports: -- name: github.com/davecgh/go-spew - version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 - subpackages: - - spew -- name: github.com/pmezard/go-difflib - version: d8ed2627bdf02c080bf22230dbb337003b7aba2d - subpackages: - - difflib -- name: github.com/stretchr/testify - version: 05e8a0eda380579888eb53c394909df027f06991 - subpackages: - - assert diff --git a/go/glide.yaml b/go/glide.yaml deleted file mode 100644 index c5d66694acd0f45de5002391a7953b7491eaf2bc..0000000000000000000000000000000000000000 --- a/go/glide.yaml +++ /dev/null @@ -1,33 +0,0 @@ -package: github.com/PaddlePaddle/Paddle/go -import: -- package: github.com/PaddlePaddle/recordio -- package: github.com/coreos/etcd - version: ^3.2.1 - subpackages: - - clientv3 - - clientv3/concurrency - - embed - - etcdserver -- package: github.com/namsral/flag - version: ^1.7.4-pre -- package: github.com/sirupsen/logrus - version: ^1.0.0 -- package: github.com/topicai/candy -- package: golang.org/x/crypto - repo: https://github.com/golang/crypto.git - vcs: git -- package: golang.org/x/sys - repo: https://github.com/golang/sys.git - vcs: git -- package: golang.org/x/text - repo: https://github.com/golang/text.git - vcs: git -- package: github.com/satori/go.uuid - version: v1.1.0 -- package: github.com/alecthomas/gometalinter - version: v1.2.1 -- package: github.com/inconshreveable/log15 - version: v2.13 -- package: github.com/go-stack/stack - version: v1.6.0 -- package: github.com/golang/protobuf diff --git a/go/master/CMakeLists.txt b/go/master/CMakeLists.txt deleted file mode 100644 index b5101c3479d708418dd662b84e09ad74af86adbe..0000000000000000000000000000000000000000 --- a/go/master/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -if(WITH_TESTING) - go_test(master_test) -endif() diff --git a/go/master/c/CMakeLists.txt b/go/master/c/CMakeLists.txt deleted file mode 100644 index 58b44e6445b63e12eb7d9bfdee93239cf1fab899..0000000000000000000000000000000000000000 --- a/go/master/c/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -go_library(paddle_master SHARED DEPS paddle_go_optimizer) diff --git a/go/master/c/client.go b/go/master/c/client.go deleted file mode 100644 index 42c176d00bd56f989b05e1d128b5ce030d220c77..0000000000000000000000000000000000000000 --- a/go/master/c/client.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -/* -#include -#include -#include -#define PADDLE_MASTER_OK 0 -#define PADDLE_MASTER_ERROR -1 - -#define PADDLE_SAVE_MODEL_OK 1 -#define PADDLE_SAVE_MODEL_SKIP 0 - -typedef int paddle_master_client; -*/ -import "C" - -import ( - "strings" - "sync" - "time" - "unsafe" - - "github.com/PaddlePaddle/Paddle/go/master" - log "github.com/inconshreveable/log15" -) - -var mu sync.Mutex -var handleMap = make(map[C.paddle_master_client]*master.Client) -var curHandle C.paddle_master_client - -func init() { - log.Root().SetHandler( - log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)), - ) -} - -func add(c *master.Client) C.paddle_master_client { - mu.Lock() - defer mu.Unlock() - client := curHandle - curHandle++ - handleMap[client] = c - return client -} - -func get(client C.paddle_master_client) *master.Client { - mu.Lock() - defer mu.Unlock() - return handleMap[client] -} - -func remove(client C.paddle_master_client) *master.Client { - mu.Lock() - defer mu.Unlock() - h := handleMap[client] - delete(handleMap, client) - return h -} - -//export paddle_new_etcd_master_client -// -// bufSize is the record buffer size. -func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { - p := C.GoString(etcdEndpoints) - endpoints := strings.Split(p, ",") - c, err := master.NewClient( - master.WithEtcd(endpoints, time.Duration(timeout)*time.Second), - master.WithBuffer(bufSize), - ) - if err != nil { - panic(err) - } - - return add(c) -} - -//export paddle_new_master_client -// -// bufSize is the record buffer size. -func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { - a := C.GoString(addr) - c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize)) - if err != nil { - panic(err) - } - - return add(c) -} - -//export paddle_release_master_client -func paddle_release_master_client(client C.paddle_master_client) { - remove(client) -} - -//export paddle_start_get_records -func paddle_start_get_records(client C.paddle_master_client, pass C.int) { - c := get(client) - c.StartGetRecords(int(pass)) -} - -//export paddle_set_dataset -func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { - c := get(client) - var paths []string - for i := 0; i < int(size); i++ { - ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path))) - str := C.GoString(*ptr) - paths = append(paths, str) - } - err := c.SetDataset(paths) - if err != nil { - log.Error("error set dataset", - log.Ctx{"error": err, "paths": paths}) - return C.PADDLE_MASTER_ERROR - } - - return C.PADDLE_MASTER_OK -} - -// paddle_next_record gets the nexts training record. -// -// returns number of bytes of the records if success, -1 if failed, -2 if pass end. -// -//export paddle_next_record -func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { - c := get(client) - r, err := c.NextRecord() - if err != nil { - // NOTE: use errors to indicate pass ends - if err.Error() == master.ErrAllTaskFailed.Error() || - err.Error() == master.ErrNoMoreAvailable.Error() || - err.Error() == master.ErrPassBefore.Error() { - return -2 - } - *record = (*C.uchar)(nil) - return -1 - } - - if len(r) == 0 { - // Empty record - *record = (*C.uchar)(nil) - return 0 - } - - size := C.size_t(len(r)) - *record = (*C.uchar)(C.malloc(size)) - C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size) - return C.int(size) -} - -// paddle_request_save_model requests the master server to approve the -// caller to save the model. -// -// returns 1 if the save the model request is approved, 0 if the -// request is rejected because other trainer is saving the model, -1 -// if error happened. -// -//export paddle_request_save_model -func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int { - c := get(client) - need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond) - if err != nil { - log.Error("error request save model", log.Ctx{"error": err}) - return C.PADDLE_MASTER_ERROR - } - - if need { - return C.PADDLE_SAVE_MODEL_OK - } - - return C.PADDLE_SAVE_MODEL_SKIP -} - -//export mem_free -func mem_free(p unsafe.Pointer) { - // "free" may be a better name for this function, but doing so - // will cause calling any function of this library from Python - // ctypes hanging. - C.free(p) -} - -func main() {} diff --git a/go/master/client.go b/go/master/client.go deleted file mode 100644 index e43903dd14e74047119d9dcea2adc431357305ee..0000000000000000000000000000000000000000 --- a/go/master/client.go +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master - -import ( - "os" - "time" - - "github.com/PaddlePaddle/Paddle/go/connection" - "github.com/PaddlePaddle/recordio" - "github.com/coreos/etcd/clientv3" - log "github.com/inconshreveable/log15" -) - -// Client is the client of the master server. -type Client struct { - conn *connection.Conn - ch chan record - bufSize int -} - -type record struct { - r []byte - err error -} - -// WithBuffer sets the client to buffer the training record. -// -// bufSize is the record buffer size. NextRecord will read from this -// buffer. -func WithBuffer(bufSize int) func(*Client) error { - return func(c *Client) error { - if bufSize <= 0 { - return nil - } - c.bufSize = bufSize - return nil - } -} - -// WithAddr sets the client to use fixed master address. -func WithAddr(addr string) func(c *Client) error { - return func(c *Client) error { - ch := make(chan string, 1) - ch <- addr - go c.monitorMaster(ch) - return nil - } -} - -// WithEtcd sets the client to use etcd for master discovery. -func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { - return func(c *Client) error { - var cli *clientv3.Client - f := func() error { - var err error - cli, err = clientv3.New(clientv3.Config{ - Endpoints: endpoints, - DialTimeout: timeout, - }) - return err - } - for { - err := f() - if err != nil { - log.Warn("create etcd client error", log.Ctx{"error": err}) - } else { - break - } - time.Sleep(time.Second) - } - - ch := make(chan string, 1) - a, err := GetKey(cli, DefaultAddrPath, timeout) - if err != nil { - return err - } - - if a != "" { - // Master is registered, send to the master address - // channel. - ch <- a - } - - go watchKey(cli, DefaultAddrPath, ch) - go c.monitorMaster(ch) - return nil - } -} - -// NewClient creates a new Client. -func NewClient(opts ...func(*Client) error) (*Client, error) { - c := &Client{} - c.conn = connection.New() - - for _, opt := range opts { - err := opt(c) - if err != nil { - return nil, err - } - } - c.ch = make(chan record, c.bufSize) - return c, nil -} - -// StartGetRecords must be called at beginning of each pass -func (c *Client) StartGetRecords(passID int) { - go c.getRecords(passID) -} - -func (c *Client) getRecords(passID int) { - i := 0 - for { - t, err := c.getTask(passID) - if err != nil { - if err.Error() == ErrPassBefore.Error() || - err.Error() == ErrNoMoreAvailable.Error() || - err.Error() == ErrAllTaskFailed.Error() { - c.ch <- record{nil, err} - break - } - - if i%60 == 0 { - log.Debug("getTask of passID error.", - log.Ctx{"error": err, "passID": passID}) - i = 0 - } - - // if err.Error() == ErrPassAfter.Error() - // wait util last pass finishes - // if other error such as network error - // wait to reconnect or task time out - time.Sleep(time.Second * 3) - i += 3 - continue - } - - for _, chunk := range t.Chunks { - f, e := os.Open(chunk.Path) - if e != nil { - log.Error("error open chunk", log.Ctx{"error": e}) - continue - } - - s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) - for s.Scan() { - c.ch <- record{s.Record(), nil} - } - - if s.Err() != nil { - c.ch <- record{nil, s.Err()} - log.Error( - "error scan chunk", - log.Ctx{"error": err, "path": chunk.Path}, - ) - } - - err = f.Close() - if err != nil { - log.Error("error close record file", log.Ctx{"error": err}) - } - } - - // We treat a task as finished whenever the last data - // instance of the task is read. This is not exactly - // correct, but a reasonable approximation. - err = c.taskFinished(t.Meta.ID) - if err != nil { - log.Error("task finish callback error.", log.Ctx{"error": err}) - } - } -} - -func (c *Client) monitorMaster(addrCh <-chan string) { - lastMaster := "" - for curMaster := range addrCh { - // connect to the new address once address changed. - if curMaster != lastMaster { - if curMaster == "" { - err := c.conn.Close() - if err != nil { - log.Error("close old master addr error", log.Ctx{"error": err}) - } - } else { - err := c.conn.Connect(curMaster) - if err != nil { - log.Error("connect to new master addr error", log.Ctx{"error": err}) - - // connect to addr failed, set - // to last known addr in order - // to retry next time. - curMaster = lastMaster - } - } - } - lastMaster = curMaster - } -} - -// SetDataset sets dataset to dispatch for the master server. -// -// SetDataset can be call multiple times at one pass. But only the first call -// will be honored. -// -// After all tasks are done, another call of SetDataset will start another pass. -func (c *Client) SetDataset(globPaths []string) error { - err := c.conn.Call("Service.SetDataset", globPaths, nil) - return err -} - -// getTask gets a new task from the master server. -func (c *Client) getTask(passID int) (Task, error) { - var t Task - err := c.conn.Call("Service.GetTask", passID, &t) - return t, err -} - -// TaskFinished tells the master server a task is finished. -func (c *Client) taskFinished(taskID int) error { - return c.conn.Call("Service.TaskFinished", taskID, nil) -} - -// TaskFailed tell the master server as task is failed. -func (c *Client) taskFailed(meta TaskMeta) error { - return c.conn.Call("Service.TaskFailed", meta, nil) -} - -// NextRecord returns next record in the dataset. -// -// NextRecord will block until the next record is available. It is -// thread-safe. -func (c *Client) NextRecord() ([]byte, error) { - r := <-c.ch - return r.r, r.err -} - -// RequestSaveModel requests the master server to approve the caller -// to save the model. -func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) { - var need bool - err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need) - return need, err -} diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go deleted file mode 100644 index 37028a9e1f884f6660bf1c5630980dccae2beb01..0000000000000000000000000000000000000000 --- a/go/master/client_internal_test.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master - -import ( - "fmt" - "net" - "net/http" - "net/rpc" - "os" - "strconv" - "strings" - "testing" - "time" - - "github.com/PaddlePaddle/Paddle/go/connection" - "github.com/PaddlePaddle/recordio" -) - -const ( - totalTask = 20 - chunkPerTask = 10 -) - -func TestGetFinishTask(t *testing.T) { - const path = "/tmp/master_client_test_0" - - l, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - - ss := strings.Split(l.Addr().String(), ":") - p, err := strconv.Atoi(ss[len(ss)-1]) - if err != nil { - panic(err) - } - go func(l net.Listener) { - s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) - if sErr != nil { - panic(sErr) - } - - server := rpc.NewServer() - sErr = server.Register(s) - if sErr != nil { - panic(sErr) - } - - mux := http.NewServeMux() - mux.Handle(rpc.DefaultRPCPath, server) - sErr = http.Serve(l, mux) - if sErr != nil { - panic(sErr) - } - }(l) - - f, err := os.Create(path) - if err != nil { - panic(err) - } - - for i := 0; i < totalTask*chunkPerTask; i++ { - w := recordio.NewWriter(f, -1, -1) - _, err = w.Write(nil) - if err != nil { - panic(err) - } - - // call Close to force RecordIO writing a chunk. - err = w.Close() - if err != nil { - panic(err) - } - } - err = f.Close() - if err != nil { - panic(err) - } - - // Manually intialize client to avoid calling c.getRecords() - c := &Client{} - c.conn = connection.New() - addr := fmt.Sprintf(":%d", p) - ch := make(chan string, 1) - ch <- addr - go c.monitorMaster(ch) - - err = c.SetDataset([]string{path}) - if err != nil { - panic(err) - } - - checkOnePass := func(i int) { - var tasks []Task - for idx := 0; idx < totalTask; idx++ { - task, cErr := c.getTask(i) - if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() { - t.Fatalf("error: %v, pass: %d\n", cErr, i) - } - tasks = append(tasks, task) - } - - // getting task before task finishes should return error - _, cErr := c.getTask(i) - if cErr == nil { - t.Fatalf("Should get error, pass: %d\n", i) - } - - cErr = c.taskFinished(tasks[0].Meta.ID) - if cErr != nil { - t.Fatalf("Error: %v, pass: %d\n", cErr, i) - } - // call taskFailed once won't put the task to failed queue, just ensure - // the call - cErr = c.taskFailed(tasks[0].Meta) - if cErr != nil { - t.Fatalf("Error: %v, pass: %d\n", cErr, i) - } - - tasks = tasks[1:] - _, cErr = c.getTask(i) - if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() { - t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr) - } - - for _, task := range tasks { - cErr = c.taskFinished(task.Meta.ID) - if cErr != nil { - t.Fatal(cErr) - } - } - } - - for i := 0; i < 10; i++ { - // init pass data - c.StartGetRecords(i) - checkOnePass(i) - } -} diff --git a/go/master/client_test.go b/go/master/client_test.go deleted file mode 100644 index 01ecad2deada7978e6fe030a6f2d25e533749568..0000000000000000000000000000000000000000 --- a/go/master/client_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master_test - -import ( - "fmt" - "net" - "net/http" - "net/rpc" - "os" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/PaddlePaddle/Paddle/go/master" - "github.com/PaddlePaddle/recordio" -) - -// tool function for testing output goroutine ids -func goid() int { - var buf [64]byte - n := runtime.Stack(buf[:], false) - idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] - id, err := strconv.Atoi(idField) - if err != nil { - panic(fmt.Sprintf("cannot get goroutine id: %v", err)) - } - return id -} - -func TestNextRecord(t *testing.T) { - const ( - path = "/tmp/master_client_TestFull" - total = 50 - ) - l, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - - ss := strings.Split(l.Addr().String(), ":") - p, err := strconv.Atoi(ss[len(ss)-1]) - if err != nil { - panic(err) - } - go func(l net.Listener) { - s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1) - if err != nil { - panic(err) - } - - server := rpc.NewServer() - err = server.Register(s) - if err != nil { - panic(err) - } - - mux := http.NewServeMux() - mux.Handle(rpc.DefaultRPCPath, server) - err = http.Serve(l, mux) - if err != nil { - panic(err) - } - }(l) - - f, err := os.Create(path) - if err != nil { - panic(err) - } - - w := recordio.NewWriter(f, 1, -1) - for i := 0; i < total; i++ { - _, err = w.Write([]byte{byte(i)}) - if err != nil { - panic(err) - } - } - - err = w.Close() - if err != nil { - panic(err) - } - - err = f.Close() - if err != nil { - panic(err) - } - - // start several client to test task fetching - var wg sync.WaitGroup - for i := 0; i < 4; i++ { - wg.Add(1) - // test for multiple concurrent clients - go func() { - defer wg.Done() - // each go-routine needs a single client connection instance - c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1)) - if e != nil { - t.Fatal(e) - } - e = c.SetDataset([]string{path}) - if e != nil { - panic(e) - } - - // test for n passes - for pass := 0; pass < 10; pass++ { - c.StartGetRecords(pass) - - received := make(map[byte]bool) - taskid := 0 - for { - r, e := c.NextRecord() - if e != nil { - // ErrorPassAfter will wait, else break for next pass - if e.Error() == master.ErrPassBefore.Error() || - e.Error() == master.ErrNoMoreAvailable.Error() { - break - } - t.Fatal(pass, taskid, "Read error:", e) - } - if len(r) != 1 { - t.Fatal(pass, taskid, "Length should be 1.", r) - } - if received[r[0]] { - t.Fatal(pass, taskid, "Received duplicate.", received, r) - } - taskid++ - received[r[0]] = true - } - } - }() - } - wg.Wait() -} diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go deleted file mode 100644 index 36fe61127443dc8f6386295acb1a711d6a93b11c..0000000000000000000000000000000000000000 --- a/go/master/etcd_client.go +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master - -import ( - "context" - "time" - - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/clientv3/concurrency" - log "github.com/inconshreveable/log15" -) - -const ( - // DefaultLockPath is the default etcd master lock path. - DefaultLockPath = "/master/lock" - // DefaultStatePath is the default etcd key for master state. - DefaultStatePath = "/master/state" - // DefaultAddrPath is the default etcd key for master address. - DefaultAddrPath = "/master/addr" -) - -// EtcdClient is the etcd client that the master uses for fault -// tolerance and service registry. -type EtcdClient struct { - lockPath string - statePath string - client *clientv3.Client - lock *concurrency.Mutex - sess *concurrency.Session -} - -// NewEtcdClient creates a new EtcdClient. -func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { - log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints}) - cli, err := clientv3.New(clientv3.Config{ - Endpoints: endpoints, - DialTimeout: dialTimeout, - }) - if err != nil { - return nil, err - } - - sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec)) - if err != nil { - return nil, err - } - - lock := concurrency.NewMutex(sess, lockPath) - // It's fine for the lock to get stuck, in this case we have - // multiple master servers running (only configured to have - // one master running, but split-brain problem may cause - // multiple master servers running), and the cluster management - // software will kill one of them. - log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath}) - err = lock.Lock(context.TODO()) - if err != nil { - return nil, err - } - log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath}) - - put := clientv3.OpPut(addrPath, addr) - resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() - if err != nil { - return nil, err - } - - if !resp.Succeeded { - log.Crit("No longer owns the master lock. Exiting.") - panic("No longer owns the master lock. Exiting.") - } - - e := &EtcdClient{ - lockPath: lockPath, - statePath: statePath, - client: cli, - lock: lock, - sess: sess, - } - - return e, nil -} - -// Save saves the state into the etcd. -func (e *EtcdClient) Save(state []byte) error { - ctx := context.TODO() - put := clientv3.OpPut(e.statePath, string(state)) - resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() - if err != nil { - return err - } - - if !resp.Succeeded { - log.Error("No longer owns the lock, trying to lock again") - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - err := e.lock.Lock(ctx) - cancel() - if err != nil { - // We lost the master lock and can not acquire - // it back, it means some other master is - // already started. We don't want cluster - // management system to kill the master server - // who is holding the lock and running - // correctly. So the most feasible solution is - // to kill current master server. The current - // state is not saved, but the trainer's RPC - // call will fail, so the trainer will retry. - log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err}) - panic("Could not acquire the lock at %s: %v. Exiting.") - } - log.Info("Successfully acquired lock at %s.", e.lockPath) - return e.Save(state) - } - - return nil -} - -// Load loads the state from etcd. -func (e *EtcdClient) Load() ([]byte, error) { - ctx := context.TODO() - get := clientv3.OpGet(e.statePath) - - resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit() - if err != nil { - return nil, err - } - - if !resp.Succeeded { - log.Error("No longer owns the lock, trying to lock and load again.") - err = e.lock.Lock(context.Background()) - if err != nil { - return nil, err - } - - return e.Load() - } - - kvs := resp.Responses[0].GetResponseRange().Kvs - if len(kvs) == 0 { - // No state exists - return nil, nil - } - - state := kvs[0].Value - return state, nil -} - -// Shutdown shuts down the etcd client gracefully. -func (e *EtcdClient) Shutdown() error { - err := e.sess.Close() - newErr := e.client.Close() - if newErr != nil { - if err == nil { - err = newErr - } else { - log.Error("shutdown error", log.Ctx{"error": newErr}) - } - } - - return err -} - -// GetKey gets the value by the specify key. -func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - resp, err := c.Get(ctx, key) - cancel() - if err != nil { - return "", err - } - kvs := resp.Kvs - if len(kvs) == 0 { - return "", nil - } - v := kvs[0].Value - return string(v), nil -} - -// watchKey watches the specify key and send to valChan if there is some event. -func watchKey(c *clientv3.Client, key string, valChan chan<- string) { - rch := c.Watch(context.Background(), key) - for wresp := range rch { - for _, ev := range wresp.Events { - // if received event is DELETE, the value will be an empty string - log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value}) - valChan <- string(ev.Kv.Value) - } - } -} diff --git a/go/master/inmem_store.go b/go/master/inmem_store.go deleted file mode 100644 index 33b4714317ff3f1ebbf312ac3a231cd9383bf224..0000000000000000000000000000000000000000 --- a/go/master/inmem_store.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master - -import "sync" - -// InMemStore is an in memory implementation of Store interface. -// -// It does not tolerate the fault that causes the program to crash. -type InMemStore struct { - mu sync.Mutex - buf []byte -} - -// Save saves the state into the in-memory store. -func (m *InMemStore) Save(state []byte) error { - m.mu.Lock() - defer m.mu.Unlock() - - m.buf = state - return nil -} - -// Load loads the state from the in-memory store. -func (m *InMemStore) Load() ([]byte, error) { - m.mu.Lock() - defer m.mu.Unlock() - - return m.buf, nil -} - -// Shutdown shuts down the in mem store. -func (m *InMemStore) Shutdown() error { - return nil -} diff --git a/go/master/service.go b/go/master/service.go deleted file mode 100644 index 39f746e528e0c91ecf0f3ccacb01520bab81e0a4..0000000000000000000000000000000000000000 --- a/go/master/service.go +++ /dev/null @@ -1,510 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master - -import ( - "bytes" - "compress/gzip" - "encoding/gob" - "errors" - "math/rand" - "os" - "path/filepath" - "sync" - "time" - - log "github.com/inconshreveable/log15" - - "github.com/PaddlePaddle/recordio" -) - -const ( - dialTimeout = 5 * time.Second -) - -// ErrAllTaskFailed occur when tasks are in done or failed state. -var ErrAllTaskFailed = errors.New("all task finished") - -// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail. -var ErrNoMoreAvailable = errors.New("no more available task") - -// ErrPassBefore client side pass number does not match with master counter. -var ErrPassBefore = errors.New("pass number smaller than master") - -// ErrPassAfter client side pass number does not match with master counter. -var ErrPassAfter = errors.New("pass number larger than master") - -// Store is the interface for save and load the master state. -type Store interface { - Save([]byte) error - Load() ([]byte, error) - Shutdown() error -} - -// Chunk is a chunk of data consisted of several data instances. -type Chunk struct { - Path string - Index recordio.Index // chunk index -} - -// TaskMeta is a struct which stores task's meta info. -type TaskMeta struct { - ID int - Epoch int -} - -// Task is the basic unit of data instances assigned to trainers. -type Task struct { - Meta TaskMeta - Chunks []Chunk -} - -type taskEntry struct { - Task Task - // A task fails if it's timeout or trainer reports it exits unnormally. - NumFailure int -} - -type masterState struct { - Todo []taskEntry - Pending map[int]taskEntry // map from task ID to task entry - Done []taskEntry - Failed []taskEntry - CurPass int -} - -// Service is the master server service. -type Service struct { - chunksPerTask int - timeoutDur time.Duration - failureMax int - store Store - - ready chan struct{} - initDone bool - - mu sync.Mutex - // State to be persisted to snapshot. - state masterState - // The trainer that is currently saving model. This state is - // transient, does not need to be persisted to snapshot. - savingTrainer string -} - -func partition(chunks []Chunk, chunksPerTask int) []taskEntry { - // generate uniq id across job using nanosecond + randint + counter - // FIXME(typhoonzero): this is a workaround, use uuid - randStart := rand.Int() - counter := 0 - timestamp := time.Now().Nanosecond() - id := timestamp + randStart + counter - if chunksPerTask <= 0 { - chunksPerTask = 1 - } - - var result []taskEntry - var cur taskEntry - for i, c := range chunks { - if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { - cur.Task.Meta.ID = id - counter++ - id = timestamp + randStart + counter - result = append(result, cur) - cur.Task.Chunks = nil - } - - cur.Task.Chunks = append(cur.Task.Chunks, c) - } - - if len(cur.Task.Chunks) > 0 { - cur.Task.Meta.ID = id - result = append(result, cur) - } - - return result -} - -// NewService creates a new service. -func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) { - s := &Service{} - s.chunksPerTask = chunksPerTask - s.timeoutDur = timeoutDur - s.failureMax = failureMax - s.state = masterState{} - s.state.Pending = make(map[int]taskEntry) - s.ready = make(chan struct{}) - s.store = store - recovered, err := s.recover() - if err != nil { - return nil, err - } - - if recovered { - // Recovered. Now the state is already initialized, - // and the master is ready. - s.initDone = true - close(s.ready) - log.Info("Master recovered from saved state.") - } - - return s, nil -} - -// recover recovers service state from etcd. -func (s *Service) recover() (bool, error) { - state, err := s.store.Load() - if err != nil { - return false, err - } - - if state == nil { - log.Info("No state exists, not recovered.") - return false, nil - } - - log.Info("Loaded snapshot.", log.Ctx{"size": len(state)}) - gr, err := gzip.NewReader(bytes.NewReader(state)) - if err != nil { - return false, err - } - - dec := gob.NewDecoder(gr) - var tqs masterState - err = dec.Decode(&tqs) - if err != nil { - return false, err - } - - err = gr.Close() - if err != nil { - // Only close failed, recover actually succeed, so - // just log error. - log.Error("error close recover file.", log.Ctx{"error": err}) - } - - s.state = tqs - log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx()) - for _, t := range s.state.Pending { - time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) - } - - return true, nil -} - -// snapshot *must* be called with s.mu being held. -func (s *Service) snapshot() error { - // TODO(helin): etcd request has a size limit, so the snapshot - // size is limited by the max request size. We should either - // divide the snapshot into smaller chunks and save under - // different keys, or configure the request size to be big - // enough: - // https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44 - var buf bytes.Buffer - gw := gzip.NewWriter(&buf) - enc := gob.NewEncoder(gw) - err := enc.Encode(s.state) - if err != nil { - return err - } - err = gw.Close() - if err != nil { - return err - } - - state := buf.Bytes() - log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)}) - return s.store.Save(state) -} - -func readChunks(globPaths []string) ([]Chunk, error) { - var chunks []Chunk - var paths []string - - for _, s := range globPaths { - match, err := filepath.Glob(s) - if err != nil { - return nil, err - } - paths = append(paths, match...) - } - - if len(paths) == 0 { - return nil, errors.New("no valid dataset specified") - } - - for _, path := range paths { - f, err := os.Open(path) - if err != nil { - return nil, err - } - - index, err := recordio.LoadIndex(f) - if err != nil { - return nil, err - } - err = f.Close() - if err != nil { - return nil, err - } - - count := index.NumChunks() - log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count}) - for i := 0; i < count; i++ { - chunk := Chunk{ - Path: path, - Index: *index.ChunkIndex(i), - } - chunks = append(chunks, chunk) - } - } - - return chunks, nil -} - -// SetDataset sets dataset to dispatch for the master server. -// -// SetDataset can be call multiple times. But only the first call will -// be honored. -func (s *Service) SetDataset(globPaths []string, _ *int) error { - if len(globPaths) == 0 { - return errors.New("no dataset specified") - } - - s.mu.Lock() - defer s.mu.Unlock() - if s.initDone { - // Already initialized. All trainer will call - // SetDataset, but we only handle the first one. Treat - // other calls as successful but do nothing. - return nil - } - - chunks, err := readChunks(globPaths) - if err != nil { - return err - } - - s.state.Todo = partition(chunks, s.chunksPerTask) - - err = s.snapshot() - if err != nil { - log.Error("snapshot error", log.Ctx{"error": err}) - return err - } - close(s.ready) - s.initDone = true - return nil -} - -// processFailedTask retry s.failureMax times for failed task. -// return true if all task are done or failed. -func (s *Service) processFailedTask(t taskEntry, epoch int) { - if t.Task.Meta.Epoch != epoch { - // new epoch, task launched after the - // schedule of this timeout check or failed status report. - return - } - - defer func() { - err := s.snapshot() - if err != nil { - log.Error("snapshot error", log.Ctx{"error": err}) - } - }() - - delete(s.state.Pending, t.Task.Meta.ID) - - t.NumFailure++ - if t.NumFailure > s.failureMax { - log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure}) - s.state.Failed = append(s.state.Failed, t) - return - } - - log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure}) - s.state.Todo = append(s.state.Todo, t) - return -} - -func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { - return func() { - s.mu.Lock() - defer s.mu.Unlock() - - t, ok := s.state.Pending[taskID] - if !ok { - return - } - - s.processFailedTask(t, epoch) - } -} - -// must be called with lock held. -func (s *Service) logCtx() log.Ctx { - return log.Ctx{ - "todoLen": len(s.state.Todo), - "pendingLen": len(s.state.Pending), - "doneLen": len(s.state.Done), - "failedLen": len(s.state.Failed), - "curPass": s.state.CurPass, - } -} - -// GetTask gets a new task from the service. -// passID is the client side pass count -func (s *Service) GetTask(passID int, task *Task) error { - select { - case <-s.ready: - } - - s.mu.Lock() - defer s.mu.Unlock() - if passID < s.state.CurPass { - return ErrPassBefore - } - if passID > s.state.CurPass { - // Client may get run to pass after master when one client faster than the - // other - return ErrPassAfter - } - - if len(s.state.Todo) == 0 { - if len(s.state.Done) == 0 && len(s.state.Pending) == 0 { - log.Warn("All tasks failed, may start next pass", s.logCtx()) - return ErrAllTaskFailed - } - log.Warn("No more available task.", s.logCtx()) - return ErrNoMoreAvailable - } - - t := s.state.Todo[0] - t.Task.Meta.Epoch++ - s.state.Todo = s.state.Todo[1:] - s.state.Pending[t.Task.Meta.ID] = t - err := s.snapshot() - if err != nil { - return err - } - - *task = t.Task - ctx := s.logCtx() - ctx["task meta"] = t.Task.Meta - log.Info("Task dispatched.", ctx) - time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) - return nil -} - -// TaskFinished tell the service that a task is finished. -func (s *Service) TaskFinished(taskID int, dummy *int) error { - select { - case <-s.ready: - } - - s.mu.Lock() - defer s.mu.Unlock() - - t, ok := s.state.Pending[taskID] - if !ok { - ctx := s.logCtx() - ctx["task id"] = taskID - log.Warn("Pending task not found.", ctx) - return nil - } - - // task finished, reset timeout - t.NumFailure = 0 - s.state.Done = append(s.state.Done, t) - delete(s.state.Pending, taskID) - - ctx := s.logCtx() - ctx["task id"] = taskID - log.Info("Task finished.", ctx) - if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 { - // increase master side pass count if all tasks finished - s.state.CurPass++ - s.state.Todo = append(s.state.Done, s.state.Failed...) - s.state.Done = []taskEntry{} - // TODO(typhoonzero): deal with failed tasks - s.state.Failed = []taskEntry{} - ctx := s.logCtx() - ctx["new pass"] = s.state.CurPass - log.Warn("all task finished, add new pass data.", ctx) - } - - err := s.snapshot() - if err != nil { - log.Error("snapshot error", log.Ctx{"error": err}) - } - return err -} - -// TaskFailed tells the service that a task is failed. -func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { - select { - case <-s.ready: - } - - s.mu.Lock() - defer s.mu.Unlock() - - t, ok := s.state.Pending[meta.ID] - if !ok { - log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta}) - return nil - } - - s.processFailedTask(t, meta.Epoch) - return nil -} - -// SaveModelRequest is the request for saving model -type SaveModelRequest struct { - TrainerID string - BlockDur time.Duration -} - -// RequestSaveModel requests the master server to approve the caller -// to save the model. -func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error { - s.mu.Lock() - defer s.mu.Unlock() - - if req.TrainerID == "" { - return errors.New("trainer id is empty") - } - - if s.savingTrainer == "" { - *need = true - } else { - if req.TrainerID == s.savingTrainer { - // save trainer asked to save model again - *need = true - } else { - *need = false - } - } - - if *need { - s.savingTrainer = req.TrainerID - time.AfterFunc(req.BlockDur, func() { - s.mu.Lock() - s.savingTrainer = "" - s.mu.Unlock() - }) - } - - return nil -} diff --git a/go/master/service_internal_test.go b/go/master/service_internal_test.go deleted file mode 100644 index dd22f3d548b99ac11250735c74bca3dfca46cf86..0000000000000000000000000000000000000000 --- a/go/master/service_internal_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package master - -import "testing" - -func TestPartitionCount(t *testing.T) { - cs := make([]Chunk, 100) - ts := partition(cs, 5) - if len(ts) != 20 { - t.Error(len(ts)) - } - - cs = make([]Chunk, 101) - ts = partition(cs, 5) - if len(ts) != 21 { - t.Error(len(ts)) - } - - ts = partition(cs, 1) - if len(ts) != 101 { - t.Error(len(ts)) - } - - ts = partition(cs, 0) - if len(ts) != 101 { - t.Error(len(ts)) - } -} - -func TestPartionIndex(t *testing.T) { - cs := make([]Chunk, 100) - ts := partition(cs, 20) - for i := range ts { - // test auto increament ids - if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 { - t.Error(ts[i], i) - } - } -} diff --git a/go/master/service_test.go b/go/master/service_test.go deleted file mode 100644 index 2d00c22d6feb7177da5c19c557fd16d7925ef6d1..0000000000000000000000000000000000000000 --- a/go/master/service_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package master_test - -import ( - "io/ioutil" - "net/url" - "os" - "strings" - "testing" - "time" - - "github.com/PaddlePaddle/Paddle/go/master" - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/embed" - "github.com/stretchr/testify/assert" -) - -func TestNewServiceWithEtcd(t *testing.T) { - // setup an embed etcd server - etcdDir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatal(err) - } - cfg := embed.NewConfig() - lpurl, _ := url.Parse("http://localhost:0") - lcurl, _ := url.Parse("http://localhost:0") - cfg.LPUrls = []url.URL{*lpurl} - cfg.LCUrls = []url.URL{*lcurl} - cfg.Dir = etcdDir - e, err := embed.StartEtcd(cfg) - if err != nil { - t.Fatal(err) - } - defer func() { - e.Close() - if err := os.RemoveAll(etcdDir); err != nil { - t.Fatal(err) - } - }() - - <-e.Server.ReadyNotify() - - port := strings.Split(e.Clients[0].Addr().String(), ":")[1] - endpoint := "127.0.0.1:" + port - - ep := []string{endpoint} - masterAddr := "127.0.0.1:3306" - store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30) - if err != nil { - t.Fatal(err) - } - - _, err = master.NewService(store, 10, 10, 3) - if err != nil { - t.Fatal(err) - } - cli, err := clientv3.New(clientv3.Config{ - Endpoints: ep, - DialTimeout: 3 * time.Second, - }) - if err != nil { - t.Fatal(err) - } - v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second) - if err != nil { - t.Fatal(err) - } - if err := cli.Close(); err != nil { - t.Fatal(err) - } - // test master process registry itself into etcd server. - assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.") -} diff --git a/go/proto/.gitignore b/go/proto/.gitignore deleted file mode 100644 index 5e7d2734cfc60289debf74293817c0a8f572ff32..0000000000000000000000000000000000000000 --- a/go/proto/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/go/pserver/CMakeLists.txt b/go/pserver/CMakeLists.txt deleted file mode 100644 index 32f3b2baba37238f0ca75e9177f9afa3dcfd4156..0000000000000000000000000000000000000000 --- a/go/pserver/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -if(WITH_TESTING) - go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go) -endif() diff --git a/go/pserver/client/CMakeLists.txt b/go/pserver/client/CMakeLists.txt deleted file mode 100644 index 1d6f45a6642fa8819050f8a21c212369b52d1112..0000000000000000000000000000000000000000 --- a/go/pserver/client/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -if(WITH_TESTING) - go_test(pserver_client_test DEPS paddle_go_optimizer) -endif() diff --git a/go/pserver/client/c/.gitignore b/go/pserver/client/c/.gitignore deleted file mode 100644 index 4bf05c85386dfcef83453a663dffc5d62efcbcc0..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/.gitignore +++ /dev/null @@ -1 +0,0 @@ -libpaddle_go_optimizer.a diff --git a/go/pserver/client/c/CMakeLists.txt b/go/pserver/client/c/CMakeLists.txt deleted file mode 100644 index 78776219dee06da09e8b6956cd7bc132fb28552b..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) -target_link_libraries(paddle_go_optimizer stdc++ m) - -# Copy library to the required place. -# See: go/pserver/optimizer.go: -# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm -add_custom_command(TARGET paddle_go_optimizer POST_BUILD - COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}" - ) - -go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) -if(WITH_TESTING) - # FIXME: this test requires pserver which is not managed by the test - # we need some kind of e2e testing machanism. - # add_subdirectory(test) -endif() diff --git a/go/pserver/client/c/cclient.go b/go/pserver/client/c/cclient.go deleted file mode 100644 index cddc28e46f48799f8643732283c94216e1f5cfb1..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/cclient.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -/* -#include -typedef enum { - PADDLE_ELEMENT_TYPE_INT32 = 0, - PADDLE_ELEMENT_TYPE_UINT32 = 1, - PADDLE_ELEMENT_TYPE_INT64 = 2, - PADDLE_ELEMENT_TYPE_UINT64 = 3, - PADDLE_ELEMENT_TYPE_FLOAT32 = 4, - PADDLE_ELEMENT_TYPE_FLOAT64 = 5, -} paddle_element_type; - -typedef struct { - char* name; - paddle_element_type element_type; - unsigned char* content; - int content_len; -} paddle_parameter, paddle_gradient; - -typedef int paddle_pserver_client; -#define PSERVER_ERROR -1 -#define PSERVER_OK 0 -*/ -import "C" - -import ( - "strings" - "sync" - "unsafe" - - "github.com/PaddlePaddle/Paddle/go/pserver" - "github.com/PaddlePaddle/Paddle/go/pserver/client" - log "github.com/inconshreveable/log15" -) - -func init() { - log.Root().SetHandler( - log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)), - ) -} - -var mu sync.Mutex -var handleMap = make(map[C.paddle_pserver_client]*client.Client) -var curHandle C.paddle_pserver_client - -func add(c *client.Client) C.paddle_pserver_client { - mu.Lock() - defer mu.Unlock() - cli := curHandle - curHandle++ - handleMap[cli] = c - return cli -} - -func get(client C.paddle_pserver_client) *client.Client { - mu.Lock() - defer mu.Unlock() - return handleMap[client] -} - -func remove(client C.paddle_pserver_client) *client.Client { - mu.Lock() - defer mu.Unlock() - h := handleMap[client] - delete(handleMap, client) - return h -} - -func cArrayToSlice(p unsafe.Pointer, len int) []byte { - if p == nil { - return nil - } - - // create a Go clice backed by a C array, reference: - // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices - // - // Go garbage collector will not interact with this data, need - // to be freed properly. - return (*[1 << 30]byte)(p)[:len:len] -} - -type selector bool - -func (s selector) Select() (bool, error) { - return bool(s), nil -} - -func (s selector) Done() error { - return nil -} - -type lister []client.Server - -func (l lister) List() []client.Server { - return l -} - -//export paddle_new_pserver_client -func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client { - a := C.GoString(addrs) - as := strings.Split(a, ",") - servers := make([]client.Server, len(as)) - for i := range as { - servers[i].Index = i - servers[i].Addr = as[i] - } - c := client.NewClient(lister(servers), len(as), selector(selected != 0)) - return add(c) -} - -//export paddle_new_etcd_pserver_client -func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client { - addr := C.GoString(etcdEndpoints) - etcdClient := client.NewEtcd(addr) - c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient) - return add(c) -} - -//export paddle_pserver_client_release -func paddle_pserver_client_release(client C.paddle_pserver_client) { - remove(client) -} - -// paddle_begin_init_params tells trainer if it needs to init the -// parameters. -// -// returns 1 if the trainer needs to init the parameters. 0 if the -// trainer does not need to init the parameters. -// -//export paddle_begin_init_params -func paddle_begin_init_params(client C.paddle_pserver_client) C.int { - c := get(client) - selected, err := c.BeginInitParams() - if err != nil { - panic(err) - } - - if selected { - return 1 - } - return 0 -} - -//export paddle_init_param -func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int { - et := pserver.ElementType(param.element_type) - name := C.GoString(param.name) - content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) - pc := pserver.ParameterWithConfig{ - Param: pserver.Parameter{Name: name, ElementType: et, Content: content}, - Config: cArrayToSlice(paramConfig, int(configLen)), - } - c := get(client) - err := c.InitParam(pc) - - if err != nil { - if err.Error() == pserver.AlreadyInitialized { - log.Warn( - "parameter already initialized, treat paddle_init_param as successful.", - log.Ctx{"parameter": name}, - ) - return C.PSERVER_OK - } - log.Error("error init param", log.Ctx{"error": err}) - return C.PSERVER_ERROR - } - - return C.PSERVER_OK -} - -//export paddle_finish_init_params -func paddle_finish_init_params(client C.paddle_pserver_client) C.int { - c := get(client) - err := c.FinishInitParams() - if err != nil { - if err.Error() == pserver.AlreadyInitialized { - log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.") - return C.PSERVER_OK - } - - log.Error("error finish init params", log.Ctx{"error": err}) - return C.PSERVER_ERROR - } - - return C.PSERVER_OK -} - -//export paddle_send_grads -func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient, total C.int) C.int { - var gs []pserver.Gradient - for i := 0; i < int(total); i++ { - grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) - et := pserver.ElementType(grad.element_type) - name := C.GoString(grad.name) - content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) - gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content}) - } - - c := get(client) - err := c.SendGrads(gs) - if err != nil { - log.Error("error send grads", log.Ctx{"error": err}) - return C.PSERVER_ERROR - } - - return C.PSERVER_OK -} - -//export paddle_get_params -func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int { - var ns []string - for i := 0; i < int(total); i++ { - param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) - ns = append(ns, C.GoString(param.name)) - } - c := get(client) - ps, err := c.GetParams(ns) - if err != nil { - log.Error("error get params", log.Ctx{"error": err}) - return C.PSERVER_ERROR - } - - if len(ps) != len(ns) { - pn := make([]string, len(ps)) - for i, p := range ps { - pn[i] = p.Name - } - log.Error( - "pserver returned wrong number of parameters.", - log.Ctx{ - "Requested": strings.Join(pn, ", "), - "Returned": strings.Join(ns, ", "), - }, - ) - return C.PSERVER_ERROR - } - - for i := range ps { - if ns[i] != ps[i].Name { - pn := make([]string, len(ps)) - for i, p := range ps { - pn[i] = p.Name - } - log.Error( - "pserver returned wrong parameters, or not in requested order.", - log.Ctx{ - "Requested": strings.Join(pn, ", "), - "Returned": strings.Join(ns, ", "), - }, - ) - return C.PSERVER_ERROR - } - } - - for i := 0; i < int(total); i++ { - p := ps[i] - param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) - - if unsafe.Pointer(param) == nil { - log.Error("must pre-allocate parameter.") - return C.PSERVER_ERROR - } - - if unsafe.Pointer(param.content) != nil { - if int(param.content_len) != len(p.Content) { - log.Error( - "the pre-allocated content len does not match parameter content len.", - log.Ctx{ - "Pre-allocated len": param.content_len, - "Returned len": len(p.Content), - }, - ) - return C.PSERVER_ERROR - } - } - - C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) - param.content_len = C.int(len(p.Content)) - param.element_type = C.paddle_element_type(p.ElementType) - } - - return C.PSERVER_OK -} - -func main() {} // Required but ignored diff --git a/go/pserver/client/c/test/CMakeLists.txt b/go/pserver/client/c/test/CMakeLists.txt deleted file mode 100644 index 4500b1f288372ed0e2d9d383234df97ae976c60b..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/test/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer) diff --git a/go/pserver/client/c/test/test_cclient.c b/go/pserver/client/c/test/test_cclient.c deleted file mode 100644 index 0116e42a0a67f757a786aa6dc9f8097af656d8b2..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/test/test_cclient.c +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "libpaddle_pserver_cclient.h" - -// TODO(helin): Fix: gtest using cmake is not working, using this -// hacky way for now. -#define fail() \ - fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \ - exit(-1); - -void sendGrads(paddle_pserver_client c) { - unsigned char grad_a[2000] = {2}; - unsigned char grad_b[3000] = {3}; - paddle_gradient grad1 = { - "param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000}; - paddle_gradient grad2 = { - "param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000}; - paddle_gradient *grads[2] = {&grad1, &grad2}; - if (paddle_send_grads(c, grads, 2)) { - fail(); - } -} - -void getParams(paddle_pserver_client c) { - paddle_parameter param_a; - paddle_parameter param_b; - char name_a[] = "param_a"; - char name_b[] = "param_b"; - // Must pre-allocate the prameter content before calling paddle_get_params. - unsigned char content_a[2000] = {}; - unsigned char content_b[3000] = {}; - param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; - param_a.name = name_a; - param_a.content = content_a; - param_a.content_len = 2000; - param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; - param_b.name = name_b; - param_b.content = content_b; - param_b.content_len = 3000; - - paddle_parameter *params[2] = {¶m_a, ¶m_b}; - if (paddle_get_params(c, params, 2)) { - fail(); - } -} - -int main() { - char addr[] = "localhost:3000"; - paddle_pserver_client c = paddle_new_pserver_client(addr, 1); - char *config_proto; - size_t config_proto_len = 0; - ssize_t nread; - FILE *fp = fopen("testdata/optimizer.pb", "r"); - if (!fp) { - fail(); - } - while ((nread = getline(&config_proto, &config_proto_len, fp)) != -1) { - printf("%s", config_proto); - } - fclose(fp); -retry: - if (paddle_begin_init_params(c)) { - paddle_parameter param; - char name_a[] = "param_a"; - char name_b[] = "param_b"; - unsigned char content_a[2000] = {1}; - unsigned char content_b[3000] = {0}; - param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; - param.name = name_a; - param.content = content_a; - param.content_len = 2000; - int error = - paddle_init_param(c, param, (void *)config_proto, config_proto_len); - if (error != 0) { - goto retry; - } - - param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; - param.name = name_b; - param.content = content_b; - param.content_len = 3000; - error = paddle_init_param(c, param, (void *)config_proto, config_proto_len); - if (error != 0) { - goto retry; - } - - error = paddle_finish_init_params(c); - if (error != 0) { - goto retry; - } - } - - int i; - for (i = 0; i < 100; i++) { - sendGrads(c); - getParams(c); - } - - return 0; -} diff --git a/go/pserver/client/c/test/test_mnist.py b/go/pserver/client/c/test/test_mnist.py deleted file mode 100644 index 97f63aeb6d4cdfc639b0d778d4df817525c51430..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/test/test_mnist.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.v2 as paddle -import gzip - - -def softmax_regression(img): - predict = paddle.layer.fc(input=img, - size=10, - act=paddle.activation.Softmax()) - return predict - - -def multilayer_perceptron(img): - # The first fully-connected layer - hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu()) - # The second fully-connected layer and the according activation function - hidden2 = paddle.layer.fc(input=hidden1, - size=64, - act=paddle.activation.Relu()) - # The thrid fully-connected layer, note that the hidden size should be 10, - # which is the number of unique digits - predict = paddle.layer.fc(input=hidden2, - size=10, - act=paddle.activation.Softmax()) - return predict - - -def convolutional_neural_network(img): - # first conv layer - conv_pool_1 = paddle.networks.simple_img_conv_pool( - input=img, - filter_size=5, - num_filters=20, - num_channel=1, - pool_size=2, - pool_stride=2, - act=paddle.activation.Tanh()) - # second conv layer - conv_pool_2 = paddle.networks.simple_img_conv_pool( - input=conv_pool_1, - filter_size=5, - num_filters=50, - num_channel=20, - pool_size=2, - pool_stride=2, - act=paddle.activation.Tanh()) - # The first fully-connected layer - fc1 = paddle.layer.fc(input=conv_pool_2, - size=128, - act=paddle.activation.Tanh()) - # The softmax layer, note that the hidden size should be 10, - # which is the number of unique digits - predict = paddle.layer.fc(input=fc1, - size=10, - act=paddle.activation.Softmax()) - return predict - - -def main(): - paddle.init(use_gpu=False, trainer_count=1) - - # define network topology - images = paddle.layer.data( - name='pixel', type=paddle.data_type.dense_vector(784)) - label = paddle.layer.data( - name='label', type=paddle.data_type.integer_value(10)) - - # Here we can build the prediction network in different ways. Please - # choose one by uncomment corresponding line. - predict = softmax_regression(images) - #predict = multilayer_perceptron(images) - #predict = convolutional_neural_network(images) - - cost = paddle.layer.classification_cost(input=predict, label=label) - parameters = paddle.parameters.create(cost) - - optimizer = paddle.optimizer.Momentum( - learning_rate=0.1 / 128.0, - momentum=0.9, - regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128)) - - trainer = paddle.trainer.SGD(cost=cost, - parameters=parameters, - update_equation=optimizer, - is_local=False, - pserver_spec="localhost:3000") - - lists = [] - - def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 1000 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - - elif isinstance(event, paddle.event.EndPass): - result = trainer.test(reader=paddle.batch( - paddle.dataset.mnist.test(), batch_size=128)) - print "Test with Pass %d, Cost %f, %s\n" % ( - event.pass_id, result.cost, result.metrics) - lists.append((event.pass_id, result.cost, - result.metrics['classification_error_evaluator'])) - - trainer.train( - reader=paddle.batch( - paddle.reader.shuffle( - paddle.dataset.mnist.train(), buf_size=8192), - batch_size=128), - event_handler=event_handler, - num_passes=100) - - # find the best pass - best = sorted(lists, key=lambda list: float(list[1]))[0] - print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1]) - print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100) - - test_creator = paddle.dataset.mnist.test() - test_data = [] - for item in test_creator(): - test_data.append((item[0], )) - if len(test_data) == 100: - break - - # output is a softmax layer. It returns probabilities. - # Shape should be (100, 10) - probs = paddle.infer( - output_layer=predict, parameters=parameters, input=test_data) - print probs.shape - - -if __name__ == '__main__': - main() diff --git a/go/pserver/client/c/test/test_train.py b/go/pserver/client/c/test/test_train.py deleted file mode 100644 index 2db5a0bf6a520b8fa29d13ea854b638ebcbbb7d9..0000000000000000000000000000000000000000 --- a/go/pserver/client/c/test/test_train.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.v2 as paddle -import paddle.v2.dataset.uci_housing as uci_housing -import paddle.v2.master as master -import os -import cPickle as pickle -from paddle.v2.reader.creator import cloud_reader - -etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") -etcd_endpoints = "http://" + etcd_ip + ":2379" -print "etcd endpoints: ", etcd_endpoints - - -def main(): - # init - paddle.init(use_gpu=False, trainer_count=1) - - # network config - x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) - y_predict = paddle.layer.fc(input=x, - param_attr=paddle.attr.Param(name='w'), - size=1, - act=paddle.activation.Linear(), - bias_attr=paddle.attr.Param(name='b')) - y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) - cost = paddle.layer.mse_cost(input=y_predict, label=y) - - # create parameters - parameters = paddle.parameters.create(cost) - - # create optimizer of new remote updater to pserver - optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3) - - trainer = paddle.trainer.SGD(cost=cost, - parameters=parameters, - update_equation=optimizer, - is_local=False, - pserver_spec=etcd_endpoints, - use_etcd=True) - - # event_handler to print training and testing info - def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - # FIXME: for cloud data reader, pass number is managed by master - # should print the server side pass number - if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f" % ( - event.pass_id, event.batch_id, event.cost) - - if isinstance(event, paddle.event.EndPass): - if (event.pass_id + 1) % 10 == 0: - result = trainer.test( - reader=paddle.batch( - uci_housing.test(), batch_size=2), - feeding={'x': 0, - 'y': 1}) - print "Test %d, %.2f" % (event.pass_id, result.cost) - - # training - # NOTE: use uci_housing.train() as reader for non-paddlecloud training - trainer.train( - reader=paddle.batch( - paddle.reader.shuffle( - cloud_reader( - ["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"], - etcd_endpoints), - buf_size=500), - batch_size=2), - feeding={'x': 0, - 'y': 1}, - event_handler=event_handler, - num_passes=30) - - -if __name__ == '__main__': - main() diff --git a/go/pserver/client/c/test/testdata/optimizer.pb b/go/pserver/client/c/test/testdata/optimizer.pb deleted file mode 100644 index 27dd3bc5f19e2964b4b674cff8860233cbdb445a..0000000000000000000000000000000000000000 Binary files a/go/pserver/client/c/test/testdata/optimizer.pb and /dev/null differ diff --git a/go/pserver/client/client.go b/go/pserver/client/client.go deleted file mode 100644 index 2a8f66a07c79906288c4179db5cab703cc2b8b61..0000000000000000000000000000000000000000 --- a/go/pserver/client/client.go +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "errors" - "hash/fnv" - "sort" - "time" - - "github.com/PaddlePaddle/Paddle/go/connection" - "github.com/PaddlePaddle/Paddle/go/pserver" - log "github.com/inconshreveable/log15" -) - -// TODO(helin): add RPC call retry logic - -// Selector selects if the client should initialize parameters and -// reports the initialization process done. -type Selector interface { - // Select selects if the client should initialize parameter servers. - Select() (bool, error) - // Done indicates the initialization process is done. - Done() error -} - -// Server is the identification of a parameter Server. -type Server struct { - Index int - Addr string -} - -// Lister lists currently available parameter servers. -type Lister interface { - List() []Server -} - -// Client is the client to parameter servers. -type Client struct { - sel Selector - pservers []*connection.Conn -} - -// NewClient creates a new client. -func NewClient(l Lister, pserverNum int, sel Selector) *Client { - c := &Client{sel: sel} - c.pservers = make([]*connection.Conn, pserverNum) - for i := 0; i < pserverNum; i++ { - c.pservers[i] = connection.New() - } - go c.monitorPservers(l, pserverNum) - return c -} - -// monitorPservers monitors pserver addresses, and updates connection -// when the address changes. -func (c *Client) monitorPservers(l Lister, pserverNum int) { - lastServers := make([]Server, pserverNum) - ticker := time.NewTicker(10 * time.Second) - monitor := func() { - curServers := make([]Server, pserverNum) - list := l.List() - for _, l := range list { - curServers[l.Index] = l - } - - for i := range lastServers { - if lastServers[i].Addr == curServers[i].Addr { - continue - } - - if curServers[i].Addr == "" { - err := c.pservers[i].Close() - if err != nil { - log.Error("error closing connection to pserver", log.Ctx{"error": err}) - } - - continue - } - - err := c.pservers[i].Connect(curServers[i].Addr) - if err != nil { - log.Error("error connecting to pserver", log.Ctx{"error": err}) - - // connect to addr failed, set - // to last known addr in order - // to retry next time. - curServers[i].Addr = lastServers[i].Addr - } - - } - - lastServers = curServers - } - - monitor() - for range ticker.C { - monitor() - } -} - -// BeginInitParams begins to initialize parameters on parameter -// servers. -// -// BeginInitParams will be called from multiple trainers, only one -// trainer will be selected to initialize the parameters on parameter -// servers. Other trainers will be blocked until the initialization is -// done, and they need to get the initialized parameters from -// parameter servers using GetParams. -func (c *Client) BeginInitParams() (bool, error) { - return c.sel.Select() -} - -// InitParam initializes the parameter on parameter servers. -func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error { - return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) -} - -// FinishInitParams tells parameter servers client has sent all -// parameters to parameter servers as initialization. -func (c *Client) FinishInitParams() error { - for _, p := range c.pservers { - err := p.Call("Service.FinishInitParams", 0, nil) - if err != nil { - return err - } - } - return c.sel.Done() -} - -// SendGrads sends gradients to parameter servers for updating -// parameters. -func (c *Client) SendGrads(grads []pserver.Gradient) error { - if len(grads) == 0 { - return errors.New("no gradient received") - } - errCh := make(chan error, len(grads)) - for _, g := range grads { - go func(g pserver.Gradient) { - err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) - errCh <- err - }(g) - } - - recv := 0 - for err := range errCh { - if err != nil { - return err - } - - recv++ - if recv == len(grads) { - break - } - } - return nil -} - -type result struct { - idx int - param pserver.Parameter - err error -} - -type results []result - -func (r results) Len() int { - return len(r) -} - -func (r results) Less(i int, j int) bool { - return r[i].idx < r[j].idx -} - -func (r results) Swap(i int, j int) { - r[i], r[j] = r[j], r[i] -} - -// GetParams gets parameters from parameter servers. -func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { - rCh := make(chan result, len(names)) - - for idx, name := range names { - go func(name string, idx int) { - var parameter pserver.Parameter - err := c.pservers[c.partition(name)].Call("Service.GetParam", name, ¶meter) - rCh <- result{idx: idx, param: parameter, err: err} - }(name, idx) - } - - var rs results - recv := 0 - for r := range rCh { - if r.err != nil { - return nil, r.err - } - rs = append(rs, r) - - recv++ - if recv == len(names) { - break - } - } - sort.Sort(rs) - - ps := make([]pserver.Parameter, len(rs)) - for i := range rs { - ps[i] = rs[i].param - } - - return ps, nil -} - -func strHash(s string) uint32 { - h := fnv.New32a() - _, _ = h.Write([]byte(s)) - return h.Sum32() -} - -// TODO(helin): now partition only select which parameter server to -// send the entire parameter. We need to partition a parameter into -// small blocks and send to different parameter servers. -func (c *Client) partition(key string) int { - return int(strHash(key) % uint32(len(c.pservers))) -} diff --git a/go/pserver/client/client_test.go b/go/pserver/client/client_test.go deleted file mode 100644 index 3a067ff5188fad8f6a13de88a2802e3e5866e59c..0000000000000000000000000000000000000000 --- a/go/pserver/client/client_test.go +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client_test - -import ( - "context" - "io/ioutil" - "math/rand" - "net" - "net/http" - "net/rpc" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/PaddlePaddle/Paddle/go/pserver" - "github.com/PaddlePaddle/Paddle/go/pserver/client" - "github.com/coreos/etcd/clientv3" - log "github.com/inconshreveable/log15" -) - -const ( - numPserver = 10 - etcdEndpoints = "127.0.0.1:2379" - timeout = 2 * time.Second -) - -var pserverClientPorts [numPserver]int - -// this function init pserver client and return their ports in an array. -func initClient() [numPserver]int { - var ports [numPserver]int - for i := 0; i < numPserver; i++ { - l, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - - ss := strings.Split(l.Addr().String(), ":") - p, err := strconv.Atoi(ss[len(ss)-1]) - if err != nil { - panic(err) - } - ports[i] = p - - go func(l net.Listener) { - var cp pserver.Checkpoint - s, err := pserver.NewService(0, time.Hour, "", nil, cp) - if err != nil { - panic(err) - } - server := rpc.NewServer() - err = server.Register(s) - if err != nil { - panic(err) - } - - mux := http.NewServeMux() - mux.Handle(rpc.DefaultRPCPath, server) - err = http.Serve(l, mux) - if err != nil { - panic(err) - } - }(l) - } - return ports -} - -func initNativeClient() { - pserverClientPorts = initClient() -} - -func initEtcdClient() { - client, err := clientv3.New(clientv3.Config{ - Endpoints: []string{etcdEndpoints}, - DialTimeout: time.Second * time.Duration(1), - }) - if err != nil { - log.Error("error init etcd client", log.Ctx{"error": err}) - } - ctx, cancel := context.WithTimeout(context.Background(), timeout) - _, err = client.Delete(ctx, pserver.PsDesired) - if err != nil { - panic(err) - } - - _, err = client.Delete(ctx, pserver.PsPath) - if err != nil { - panic(err) - } - - _, err = client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver)) - if err != nil { - panic(err) - } - - ports := initClient() - for i := 0; i < numPserver; i++ { - _, err = client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i])) - if err != nil { - panic(err) - } - } - cancel() - err = client.Close() - if err != nil { - panic(err) - } -} - -type selector bool - -func (s selector) Select() (bool, error) { - return bool(s), nil -} - -func (s selector) Done() error { - return nil -} - -type lister []client.Server - -func (l lister) List() []client.Server { - return l -} - -func testClient(t *testing.T, c *client.Client) { - selected, err := c.BeginInitParams() - if err != nil { - t.Fatal(err) - } - - if !selected { - t.Fatal("should be selected.") - } - - const numParameter = 1000 - config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb") - if err != nil { - t.Fatalf("read optimizer proto failed") - } - - var wg sync.WaitGroup - for i := 0; i < numParameter; i++ { - wg.Add(1) - go func(i int) { - var p pserver.Parameter - p.Name = "p_" + strconv.Itoa(i) - p.ElementType = pserver.Float32 - p.Content = make([]byte, (i+1)*100) - err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}) - if err != nil { - t.Fatal(err) - } - wg.Done() - }(i) - } - wg.Wait() - - err = c.FinishInitParams() - if err != nil { - t.Fatal(err) - } - - var grads []pserver.Gradient - for i := 0; i < numParameter; i++ { - var g pserver.Gradient - g.Name = "p_" + strconv.Itoa(i) - g.ElementType = pserver.Float32 - g.Content = make([]byte, (i+1)*100) - grads = append(grads, g) - } - - const paramPerGroup = 10 - const numGroups = numParameter / paramPerGroup - - // shuffle send grads order - for i := range grads { - j := rand.Intn(i + 1) - grads[i], grads[j] = grads[j], grads[i] - } - - for i := 0; i < numGroups; i++ { - var gs []pserver.Gradient - if i == numGroups-1 { - gs = grads[i*paramPerGroup:] - } else { - gs = grads[i*paramPerGroup : (i+1)*paramPerGroup] - } - - wg.Add(1) - go func(gs []pserver.Gradient) { - err := c.SendGrads(gs) - if err != nil { - t.Fatal(err) - } - wg.Done() - }(gs) - } - - names := make([]string, numParameter) - for i := 0; i < numParameter; i++ { - names[i] = "p_" + strconv.Itoa(i) - } - - for i := 0; i < numGroups; i++ { - var ns []string - if i == numGroups-1 { - ns = names[i*paramPerGroup:] - } else { - ns = names[i*paramPerGroup : (i+1)*paramPerGroup] - } - - wg.Add(1) - go func(ns []string) { - params, err := c.GetParams(ns) - if err != nil { - t.Fatal(err) - } - - if len(ns) != len(params) { - t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) - } - - for i := range params { - if ns[i] != params[i].Name { - t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name) - } - } - wg.Done() - }(ns) - } - - wg.Wait() -} - -func TestNativeClient(t *testing.T) { - initNativeClient() - servers := make([]client.Server, numPserver) - for i := 0; i < numPserver; i++ { - servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])} - } - c1 := client.NewClient(lister(servers), len(servers), selector(true)) - testClient(t, c1) -} - -// EtcdClient is a disabled test, since we have not embedded etcd into -// our test. -func EtcdClient(t *testing.T) { - initEtcdClient() - etcdClient := client.NewEtcd(etcdEndpoints) - c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true)) - testClient(t, c2) -} diff --git a/go/pserver/client/etcd_client.go b/go/pserver/client/etcd_client.go deleted file mode 100644 index 3fb835a6e165b7493df1a1fbb7440ee27109bbad..0000000000000000000000000000000000000000 --- a/go/pserver/client/etcd_client.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "context" - "errors" - "fmt" - "strconv" - "strings" - "time" - - "github.com/PaddlePaddle/Paddle/go/pserver" - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/clientv3/concurrency" - log "github.com/inconshreveable/log15" -) - -const ( - defaultEtcdTimeout time.Duration = 5 * time.Second - - initLockPath = "/init_ps/lock" - initDonePath = "/init_ps/done" - initDoneVal = "1" -) - -// Etcd is used by pserver client that is a part of trainer process. -// TODO: -// 1. add watcher to watch the change state of pservers. -type Etcd struct { - client *clientv3.Client - timeout time.Duration - endpoints []string - lock *concurrency.Mutex -} - -// Desired read ps desired number from etcd. -func (e *Etcd) Desired() int { - var psDesired int - for { - ctx, cancel := context.WithTimeout(context.Background(), e.timeout) - resp, err := e.client.Get(ctx, pserver.PsDesired) - cancel() - if err != nil { - log.Error( - "Get ps dresire number failed! reconnecting...", - log.Ctx{"error": err}, - ) - time.Sleep(e.timeout) - continue - } - - kvs := resp.Kvs - if len(kvs) == 0 { - log.Info("Waiting for ps desired registered ...") - time.Sleep(e.timeout) - continue - } - - psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) - if err != nil { - log.Error("atoi failed", log.Ctx{"error": err}) - time.Sleep(e.timeout) - continue - } - - log.Debug("Got psDesired", log.Ctx{"psDesired": psDesired}) - break - } - return psDesired -} - -// List return the pserver list read from etcd. -func (e *Etcd) List() []Server { - psDesired := e.Desired() - - servers := make([]Server, psDesired) - for { - for i := 0; i < psDesired; i++ { - ctx, cancel := context.WithTimeout(context.Background(), e.timeout) - psKey := pserver.PsPath + strconv.Itoa(i) - log.Debug("looking for pserver", log.Ctx{"ps key": psKey}) - resp, err := e.client.Get(ctx, psKey) - cancel() - if err != nil { - log.Info( - "Get psKey error", - log.Ctx{"ps key": psKey, "error": err}, - ) - time.Sleep(e.timeout) - continue - } - kvs := resp.Kvs - if len(kvs) == 0 { - log.Info("Waiting for ps addr registered ...") - time.Sleep(e.timeout) - continue - } - - psAddr := string(resp.Kvs[0].Value) - // TODO(Longfei) check the ps address - if psAddr == "" { - log.Info( - "Value under psKey is empty", - log.Ctx{"psKey": psKey}, - ) - time.Sleep(e.timeout) - continue - } - log.Debug( - "got psAddr given psKey", - log.Ctx{"psAddr": psAddr, "psKey": psKey}, - ) - servers[i].Index = i - servers[i].Addr = psAddr - } - break - } - return servers -} - -// NewEtcd create a etcd client to return the state of pserver on etcd. -func NewEtcd(endpoints string) *Etcd { - ep := strings.Split(endpoints, ",") - var cli *clientv3.Client - var err error - for { - cli, err = clientv3.New(clientv3.Config{ - Endpoints: ep, - DialTimeout: defaultEtcdTimeout, - }) - if err != nil { - log.Error("Init etcd connection failed", log.Ctx{"error": err}) - time.Sleep(defaultEtcdTimeout) - continue - } - break - } - log.Info("Connected to etcd endpoint", log.Ctx{"endpoint": endpoints}) - client := &Etcd{ - client: cli, - timeout: defaultEtcdTimeout, - endpoints: ep, - } - return client -} - -// Select indicates if the current trainer is selected to initialize -// the pserver parameters. -func (e *Etcd) Select() (bool, error) { - sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5)) - if err != nil { - return false, err - } - - lock := concurrency.NewMutex(sess, initLockPath) - log.Info("Trying to acquire lock", log.Ctx{"lock path": initLockPath}) - // Do not use timeout context here, since we don't know how - // long does it take for other trainers to initialize the - // parameters. - err = lock.Lock(context.Background()) - if err != nil { - return false, err - } - log.Info("Successfully acquired lock", log.Ctx{"lock path": initLockPath}) - - get := clientv3.OpGet(initDonePath) - ctx, cancel := context.WithTimeout(context.Background(), e.timeout) - tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit() - cancel() - if err != nil { - return false, err - } - - if !tresp.Succeeded { - return false, errors.New("no longer the owner of the lock") - } - - resp := tresp.Responses[0].GetResponseRange() - - if len(resp.Kvs) == 0 { - // Key value not set, select current trainer. - e.lock = lock - log.Info("Trainer selected.") - return true, nil - } - - if string(resp.Kvs[0].Value) == initDoneVal { - log.Info("Initialization is already done.") - ctx, cancel = context.WithTimeout(context.Background(), e.timeout) - err = lock.Unlock(ctx) - cancel() - if err != nil { - log.Error("error unlocking", log.Ctx{"error": err}) - } - return false, nil - } - - return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value) -} - -// Done indicates the parameter initialization process is done. -func (e *Etcd) Done() error { - if e.lock == nil { - return errors.New("lock is nil, Done called unexpectedly") - } - - put := clientv3.OpPut(initDonePath, initDoneVal) - ctx, cancel := context.WithTimeout(context.Background(), e.timeout) - tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() - cancel() - if err != nil { - return err - } - - if !tresp.Succeeded { - return errors.New("no longer the owner of the lock") - } - - ctx, cancel = context.WithTimeout(context.Background(), e.timeout) - err = e.lock.Unlock(ctx) - cancel() - if err != nil { - log.Error("error unlocking", log.Ctx{"error": err}) - } else { - e.lock = nil - } - - return nil -} - -// Close closes the etcd client. -func (e *Etcd) Close() error { - var err error - if e.lock != nil { - ctx, cancel := context.WithTimeout(context.Background(), e.timeout) - err = e.lock.Unlock(ctx) - cancel() - if err == nil { - e.lock = nil - } - } - - cErr := e.client.Close() - if cErr != nil { - if err != nil { - log.Error("error closing etcd client", log.Ctx{"error": cErr}) - return err - } - return cErr - } - - return err -} diff --git a/go/pserver/client/etcd_client_test.go b/go/pserver/client/etcd_client_test.go deleted file mode 100644 index 08742433e7a266fbd39e34f4b92ac4cc4caeb0fb..0000000000000000000000000000000000000000 --- a/go/pserver/client/etcd_client_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package client_test - -import ( - "io/ioutil" - "net/url" - "os" - "strings" - "sync" - "testing" - - "github.com/PaddlePaddle/Paddle/go/pserver/client" - "github.com/coreos/etcd/embed" -) - -func TestSelector(t *testing.T) { - etcdDir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatal(err) - } - cfg := embed.NewConfig() - lpurl, _ := url.Parse("http://localhost:0") - lcurl, _ := url.Parse("http://localhost:0") - cfg.LPUrls = []url.URL{*lpurl} - cfg.LCUrls = []url.URL{*lcurl} - cfg.Dir = etcdDir - e, err := embed.StartEtcd(cfg) - if err != nil { - t.Fatal(err) - } - - defer func() { - e.Close() - if err := os.RemoveAll(etcdDir); err != nil { - t.Fatal(err) - } - }() - - <-e.Server.ReadyNotify() - - port := strings.Split(e.Clients[0].Addr().String(), ":")[1] - endpoint := "127.0.0.1:" + port - - var mu sync.Mutex - selectedCount := 0 - var wg sync.WaitGroup - selectAndDone := func(c *client.Etcd) { - defer wg.Done() - - selected, err := c.Select() - if err != nil { - panic(err) - } - - if selected { - mu.Lock() - selectedCount++ - mu.Unlock() - err = c.Done() - if err != nil { - t.Fatal(err) - } - } - } - - c0 := client.NewEtcd(endpoint) - c1 := client.NewEtcd(endpoint) - c2 := client.NewEtcd(endpoint) - c3 := client.NewEtcd(endpoint) - wg.Add(3) - go selectAndDone(c0) - go selectAndDone(c1) - go selectAndDone(c2) - wg.Wait() - - // simulate trainer crashed and restarted after the - // initialization process. - wg.Add(1) - go selectAndDone(c3) - wg.Wait() - - mu.Lock() - if selectedCount != 1 { - t.Fatal("selected count wrong:", selectedCount) - } - mu.Unlock() - - err = c0.Close() - if err != nil { - t.Fatal(err) - } - - err = c1.Close() - if err != nil { - t.Fatal(err) - } - - err = c2.Close() - if err != nil { - t.Fatal(err) - } - - err = c3.Close() - if err != nil { - t.Fatal(err) - } -} diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go deleted file mode 100644 index 719013b1bb4e80ff1f3040394803706d97514516..0000000000000000000000000000000000000000 --- a/go/pserver/etcd_client.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pserver - -import ( - "context" - "errors" - "strconv" - "strings" - "time" - - "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/clientv3/concurrency" - log "github.com/inconshreveable/log15" -) - -const ( - // PsDesired is etcd path for store desired pserver count - PsDesired = "/ps_desired" - // PsPath is the base dir for pserver to store their addr - PsPath = "/ps/" - // PsCheckpoint is the etcd path for store checkpoints information - PsCheckpoint = "/checkpoints/" - - retryTimeout = 5 * time.Second -) - -// EtcdClient is the etcd client that the pserver uses for fault -// tolerance, service registry and coordination. -type EtcdClient struct { - numPservers int - endpoints string - client *clientv3.Client - sess *concurrency.Session - dialTimeout time.Duration - ttlSec int - // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. - externalIP string - // desired number of pservers in the job. - // assume desired will not change during one training job. - desired int -} - -// NewEtcdClient creates an EtcdClient -func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient { - return &EtcdClient{ - dialTimeout: dialtimeout, - ttlSec: ttlSec, - numPservers: numPservers, - endpoints: endpoints, - } -} - -// Register registers the pserver on etcd -// -// Register returns the index of the current pserver. -func (e *EtcdClient) Register(port int) (int, error) { - var err error - e.externalIP, err = networkhelper.GetExternalIP() - if err != nil { - return 0, err - } - - // initialize connection to etcd. - ep := strings.Split(e.endpoints, ",") - for { - cli, err := clientv3.New(clientv3.Config{ - Endpoints: ep, - DialTimeout: e.dialTimeout, - }) - if err != nil { - log.Error("connect to etcd error", log.Ctx{"error": err}) - time.Sleep(retryTimeout) - continue - } - e.client = cli - sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec)) - if err != nil { - log.Error("create etcd session error", log.Ctx{"error": err}) - time.Sleep(retryTimeout) - continue - } - e.sess = sess - log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints}) - break - } - // init /ps_desired using transaction, for multiple pservers may want to write - // it at the same time. - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := e.initDesiredPservers(ctx, e.numPservers) - cancel() - if err != nil { - log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers}) - time.Sleep(retryTimeout) - continue - } - break - } - // TODO: when implementing extending or reducing pservers, /ps_desired is - // changed, then we need to watch /ps_desired node for events. For now, just - // write once when init and read from it. - // wait and set s.desired init value - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := e.client.Get(ctx, PsDesired) - cancel() - if err != nil { - log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err}) - time.Sleep(retryTimeout) - continue - } - if len(resp.Kvs) != 0 { - e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) - if err != nil { - log.Error( - "psDesired atoi error", - log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)}, - ) - time.Sleep(retryTimeout) - // NOTE: wait util ps_desired value change - continue - } - break - } - } - - var pserverIdx int - // try register pserver node on etcd - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - var err error - pserverIdx, err = e.registerPserverEtcd(ctx, port) - cancel() - if err != nil { - log.Warn("register pserver on etcd error", log.Ctx{"error": err}) - time.Sleep(retryTimeout) - continue - } - break - } - - return pserverIdx, nil -} - -func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(e.client, func(c concurrency.STM) error { - dsStr := c.Get(PsDesired) - if dsStr == "" { - c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease())) - } - return nil - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - -// registerPserverEtcd registers pserver node on etcd using transaction. -func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) { - var idx int - _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error { - registered := false - for i := 0; i < e.desired; i++ { - psKey := PsPath + strconv.Itoa(i) - ps := c.Get(psKey) - log.Debug( - "register pserver got value", - log.Ctx{"value": ps, "key": psKey}, - ) - - if ps == "" { - // find the first id and write info - pserverAddr := e.externalIP + ":" + strconv.Itoa(port) - c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease())) - log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr}) - idx = i - registered = true - break - } - } - if registered { - return nil - } - return errors.New("not registered, may due to already have enough pservers") - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) - - if err != nil { - return 0, err - } - - return idx, nil -} - -// GetKey gets the value by the specified key -func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - resp, err := e.client.Get(ctx, key) - cancel() - if err != nil { - return []byte{}, err - } - - kvs := resp.Kvs - if len(kvs) == 0 { - return []byte{}, nil - } - v := kvs[0].Value - return v, nil -} - -// PutKey put into etcd with value by key specified -func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - var err error - if withLease { - _, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) - } else { - _, err = e.client.Put(ctx, key, string(value)) - } - cancel() - return err -} - -// Shutdown shuts down the etcd client gracefully. -func (e *EtcdClient) Shutdown() error { - var err error - if e.sess != nil { - err = e.sess.Close() - } - - if e.client != nil { - newErr := e.client.Close() - if newErr != nil { - if err != nil { - log.Error("shutdown error", log.Ctx{"error": newErr}) - } else { - err = newErr - } - } - } - return err -} diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go deleted file mode 100644 index eba0c47e195a80fc298f0fdd78c8d6345e963be8..0000000000000000000000000000000000000000 --- a/go/pserver/optimizer.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pserver - -// #cgo CFLAGS: -I ../../ -// #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm -// #include "paddle/legacy/optimizer/optimizer.h" -// #include -// #include -import "C" - -import ( - "fmt" - "unsafe" - - log "github.com/inconshreveable/log15" -) - -type optimizer struct { - opt *C.struct_paddle_optimizer - elementType ElementType - contentLen int - config []byte -} - -func cArrayToSlice(p unsafe.Pointer, len int) []byte { - if p == nil { - return nil - } - - // create a Go clice backed by a C array, reference: - // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices - // - // Go garbage collector will not interact with this data, need - // to be freed properly. - return (*[1 << 30]byte)(p)[:len:len] -} - -func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer { - o := &optimizer{} - o.elementType = paramWithConfigs.Param.ElementType - o.contentLen = len(paramWithConfigs.Param.Content) - p := paramWithConfigs.Param - c := paramWithConfigs.Config - s := State - paramBufferSize := C.size_t(len(p.Content)) - log.Info("New Optimizer Created with config", log.Ctx{ - "ElementType": p.ElementType, - "ParamSize": paramBufferSize, - "ConfigSize": len(c), - "StateSize": len(s), - }) - var cbuffer unsafe.Pointer - cbuffer = C.malloc(paramBufferSize) - - C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize) - var cstate unsafe.Pointer - if len(s) != 0 { - cstate = unsafe.Pointer(&s[0]) - } - - var cptr (*C.uchar) - if len(c) > 0 { - cptr = (*C.uchar)(&c[0]) - } else { - log.Error("empty config", "param name", paramWithConfigs.Param.Name) - } - o.config = c - o.opt = C.paddle_create_optimizer( - cptr, - C.int(len(c)), - C.paddle_element_type(p.ElementType), - cbuffer, - C.int(paramBufferSize), - (*C.char)(cstate), - C.int(len(s)), - ) - return o -} - -func (o *optimizer) GetWeights() []byte { - var buffer unsafe.Pointer - // we do not own the buffer, no need to free later. - bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer) - return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float) -} - -func (o *optimizer) GetStates() []byte { - var cbuffer *C.char - // we owns the state buffer, need to free later. - cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer) - buf := cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen)) - cpy := make([]byte, len(buf)) - copy(cpy, buf) - C.free(unsafe.Pointer(cbuffer)) - return cpy -} - -func (o *optimizer) UpdateParameter(g Gradient) error { - if o.elementType != g.ElementType { - return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) - } - - if o.contentLen != len(g.Content) { - return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content)) - } - - r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) - if r != 0 { - return fmt.Errorf("optimizer update returned error code: %d", r) - } - return nil -} - -func (o *optimizer) Cleanup() { - if unsafe.Pointer(o.opt) != nil { - C.paddle_release_optimizer(o.opt) - o.opt = (*C.struct_paddle_optimizer)(nil) - } -} diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go deleted file mode 100644 index 3b923879d5ec77675d707ccd40bf44a5148105fb..0000000000000000000000000000000000000000 --- a/go/pserver/optimizer_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pserver - -import ( - "encoding/binary" - "io/ioutil" - "math" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestOptimizerCreateRelease(t *testing.T) { - p := Parameter{ - Name: "a", - ElementType: Int32, - } - p.Content = []byte{1, 3} - config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb") - if err != nil { - t.Fatalf("read optimizer proto failed") - } - param := ParameterWithConfig{ - Param: p, - Config: config, - } - o := newOptimizer(param, nil) - o.Cleanup() -} - -func float32Bytes(float float32) []byte { - bits := math.Float32bits(float) - bytes := make([]byte, 4) - binary.LittleEndian.PutUint32(bytes, bits) - return bytes -} - -func TestOptimizerState(t *testing.T) { - p := Parameter{ - Name: "a", - ElementType: Int32, - } - weights := float32Bytes(100) - p.Content = weights - config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb") - if err != nil { - t.Fatalf("read optimizer proto failed") - } - param := ParameterWithConfig{ - Param: p, - Config: config, - } - o := newOptimizer(param, nil) - s := o.GetStates() - - // clear param content and check if the state is restored. - param.Param.Content = float32Bytes(300) - o1 := newOptimizer(param, s) - s1 := o1.GetStates() - assert.Equal(t, s, s1) - assert.Equal(t, weights, o.GetWeights()) - assert.Equal(t, weights, o1.GetWeights()) - o.Cleanup() - o1.Cleanup() -} diff --git a/go/pserver/service.go b/go/pserver/service.go deleted file mode 100644 index d6ead774af522ad78e9fe717f0d27bdf24d86246..0000000000000000000000000000000000000000 --- a/go/pserver/service.go +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pserver - -import ( - "bufio" - "bytes" - "encoding/binary" - "encoding/gob" - "encoding/json" - "errors" - "fmt" - "hash/crc32" - "io/ioutil" - "os" - "path" - "strconv" - "strings" - "sync" - "time" - - "github.com/golang/protobuf/proto" - uuid "github.com/satori/go.uuid" - - pb "github.com/PaddlePaddle/Paddle/go/proto" - - log "github.com/inconshreveable/log15" -) - -// ElementType is the type of elements of a Parameter. -type ElementType int - -// ErrCheckpointNotFound indicates that the pserver checkpoint could -// not be found. -var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd") - -// RPC error message. -const ( - AlreadyInitialized = "pserver already initialized" - Uninitialized = "pserver not fully initialized" - WrongChecksum = "checkpoint file checksum validation failed" -) - -// Supported element types. -const ( - Int32 ElementType = iota - UInt32 - Int64 - UInt64 - Float32 - Float64 -) - -// Parameter is a piece of data to sync with the parameter server. -type Parameter struct { - Name string - ElementType ElementType - Content []byte -} - -func float32ToString(b []byte) string { - f := make([]float32, len(b)/4) - buf := bytes.NewReader(b) - err := binary.Read(buf, binary.LittleEndian, &f) - if err != nil { - return "" - } - return fmt.Sprintf("%v", f) -} - -func float32ByteToString(c []byte) string { - var a []byte - var b []byte - if len(c) <= 80 { - a = c - } else { - a = c[0:40] - b = c[len(c)-40:] - } - - var s string - s = float32ToString(a) - - if b == nil { - return s - } - - s = strings.Replace(s, "]", "", -1) + "..." + strings.Replace(float32ToString(b), "[", "", -1) - return s -} - -func (p Parameter) String() string { - if p.ElementType != Float32 { - return fmt.Sprintf("name:%v ElementType:%v", - p.Name, p.ElementType) - } - - return float32ByteToString(p.Content) -} - -// ParameterWithConfig contains the parameter and the configuration. -type ParameterWithConfig struct { - Param Parameter - Config []byte // parameter configuration in Proto Buffer format -} - -// checkpointMeta saves checkpoint metadata -type checkpointMeta struct { - UUID string `json:"uuid"` - Path string `json:"path"` - CRC32 uint32 `json:"crc32"` - Timestamp int64 `json:"timestamp"` -} - -// Checkpoint is the pserver shard persist in file. -type Checkpoint []parameterCheckpoint - -// Gradient is the gradient of the parameter. -type Gradient Parameter - -// Service is the RPC service for pserver. -type Service struct { - initialized chan struct{} - idx int - checkpointInterval time.Duration - checkpointPath string - client KVStore - - mu sync.Mutex - optMap map[string]*optimizer -} - -// parameterCheckpoint saves parameter checkpoint. -type parameterCheckpoint struct { - ParameterWithConfig - State []byte -} - -type KVStore interface { - GetKey(key string, timeout time.Duration) ([]byte, error) - PutKey(key string, value []byte, timeout time.Duration, withLease bool) error -} - -func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) { - v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second) - if err != nil { - return - } - - if len(v) == 0 { - err = ErrCheckpointNotFound - return - } - - if err = json.Unmarshal(v, &meta); err != nil { - return - } - - return -} - -// LoadCheckpoint loads checkpoint from file. -func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) { - log.Info("Loading checkpoint", "pserver index", idx) - defer traceTime(time.Now(), "load checkpoint") - - cpMeta, err := loadMeta(e, idx) - if err != nil { - return nil, err - } - - content, err := ioutil.ReadFile(cpMeta.Path) - if err != nil { - return nil, err - } - - crc32 := crc32.ChecksumIEEE(content) - if crc32 != cpMeta.CRC32 { - return nil, errors.New(WrongChecksum) - } - - dec := gob.NewDecoder(bytes.NewReader(content)) - var cp Checkpoint - if err = dec.Decode(&cp); err != nil { - return nil, err - } - - return cp, nil -} - -// NewService creates a new service, will bypass etcd registration if no -// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. -func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) { - s := &Service{ - idx: idx, - checkpointInterval: interval, - checkpointPath: path, - client: client, - } - s.optMap = make(map[string]*optimizer) - s.initialized = make(chan struct{}) - - if cp != nil { - for _, item := range cp { - p := ParameterWithConfig{ - Param: item.Param, - Config: item.Config, - } - s.optMap[p.Param.Name] = newOptimizer(p, item.State) - } - close(s.initialized) - } - return s, nil -} - -// InitParam initializes a parameter. -func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error { - select { - case <-s.initialized: - log.Warn("init param called but parameters already initialized.") - return errors.New(AlreadyInitialized) - default: - } - - c := &pb.OptimizerConfig{} - proto.Unmarshal(paramWithConfigs.Config, c) - log.Debug(fmt.Sprintf("OptimizerConfig:%v", c)) - - s.mu.Lock() - defer s.mu.Unlock() - - // TODO(helin): check if paramWithConfigs.Param.Content is - // properly memory aligned, if not, make copy to a memory - // aligned region. - s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) - log.Info( - "init parameter", - "name", paramWithConfigs.Param.Name, - "config len", len(paramWithConfigs.Config), - "param len", len(paramWithConfigs.Param.Content), - "type", paramWithConfigs.Param.ElementType, - ) - return nil -} - -// FinishInitParams tells the parameter server that the parameter -// initialization has finished. -func (s *Service) FinishInitParams(_ int, _ *int) error { - select { - case <-s.initialized: - log.Warn("finished init param called but parameters already initialized.") - return errors.New(AlreadyInitialized) - default: - } - - close(s.initialized) - go func() { - t := time.Tick(s.checkpointInterval) - for range t { - err := s.checkpoint() - if err != nil { - log.Error("checkpoint error", log.Ctx{"error": err}) - } - } - }() - - log.Info("init parameter finished.") - return nil -} - -// SendGrad sends gradient to parameter servers for parameter -// optimization. -func (s *Service) SendGrad(g Gradient, _ *int) error { - select { - case <-s.initialized: - default: - log.Warn("received gradient before initialization.", - "name", g.Name, "size", len(g.Content), "type", g.ElementType) - return errors.New(Uninitialized) - } - - s.mu.Lock() - defer s.mu.Unlock() - - o, ok := s.optMap[g.Name] - if !ok { - log.Warn("received gradient but can't find name.", - "name", g.Name, "size", len(g.Content), "type", g.ElementType) - return fmt.Errorf("parameter: %s does not exist", g.Name) - } - - log.Debug(Parameter(g).String()) - log.Info("received gradient from trainer, updating gradient.", - "name", g.Name, "size", len(g.Content), "type", g.ElementType) - return o.UpdateParameter(g) -} - -// GetParam gets parameters from the parameter server. -func (s *Service) GetParam(name string, parameter *Parameter) error { - <-s.initialized - s.mu.Lock() - defer s.mu.Unlock() - - opt, ok := s.optMap[name] - if !ok { - log.Warn("trainer wants to get a parameter that does not exist.", "name", name) - return fmt.Errorf("parameter: %s does not exist", name) - } - - // The parameter content (a byte slice) may change - // during RPC serialization due to write from other - // goroutine, we allow it since mini-batch based deep - // learning optimization methods are stochastic in - // nature. This race condition is allowed deliberately - // to save the program from making a copy of the - // parameter content. - parameter.Name = name - parameter.ElementType = opt.elementType - parameter.Content = opt.GetWeights() - log.Debug(parameter.String()) - log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType) - return nil -} - -func traceTime(start time.Time, name string) { - elapsed := time.Since(start) - log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed}) -} - -// checkpoint saves checkpoint to disk. -// -// checkpoint should be only called after the parameters are -// initialized. -func (s *Service) checkpoint() (err error) { - log.Info("Begin save checkpoint.") - defer traceTime(time.Now(), "save checkpoint") - - s.mu.Lock() - cp := make([]parameterCheckpoint, len(s.optMap)) - index := 0 - // TODO(helin): write checkpoint incrementally to reduce memory - // footprint during checkpoint. - for name, opt := range s.optMap { - var pc parameterCheckpoint - pc.Param.Name = name - pc.Param.ElementType = opt.elementType - pc.Param.Content = opt.GetWeights() - pc.Config = opt.config - pc.State = opt.GetStates() - cp[index] = pc - index++ - } - s.mu.Unlock() - - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - err = encoder.Encode(cp) - if err != nil { - return - } - - if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) { - err = os.MkdirAll(s.checkpointPath, os.ModePerm) - if err != nil { - return - } - } - - id := uuid.NewV4().String() - p := path.Join(s.checkpointPath, id) - f, err := os.Create(p) - if err != nil { - return - } - - defer func() { - closeErr := f.Close() - if closeErr != nil { - if err != nil { - log.Error("error close checkpoint file", log.Ctx{"error": closeErr}) - } else { - // Set closeErr as return value. - err = closeErr - } - } - }() - - writer := bufio.NewWriter(f) - _, err = writer.Write(buf.Bytes()) - if err != nil { - return - } - - err = writer.Flush() - if err != nil { - return - } - - oldMeta, err := loadMeta(s.client, s.idx) - if err == ErrCheckpointNotFound { - log.Info("old meta not found, skip removing old meta") - err = nil - } else if err == nil { - log.Info("removing old meta") - if oldMeta.Path != "" { - rmErr := os.Remove(oldMeta.Path) - if rmErr != nil { - // log error, but still treat checkpoint as - // successful. - log.Error("remove old meta file error", log.Ctx{"error": rmErr}) - } - } - } - - if err != nil { - return - } - - crc32 := crc32.ChecksumIEEE(buf.Bytes()) - cpMeta := checkpointMeta{ - UUID: id, - Timestamp: time.Now().UnixNano(), - CRC32: crc32, - Path: p, - } - - json, err := json.Marshal(cpMeta) - if err != nil { - return - } - - err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false) - if err != nil { - return - } - - return -} diff --git a/go/pserver/service_internal_test.go b/go/pserver/service_internal_test.go deleted file mode 100644 index 36eca5112b3117cf295288de0de957c4af040f03..0000000000000000000000000000000000000000 --- a/go/pserver/service_internal_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package pserver - -import ( - "bytes" - "encoding/binary" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -const testDir = "./test_data" - -type myKV struct { - m map[string][]byte -} - -func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) { - if m.m == nil { - m.m = make(map[string][]byte) - } - return m.m[key], nil -} - -func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error { - if m.m == nil { - m.m = make(map[string][]byte) - } - m.m[key] = value - return nil -} - -func TestCheckpoint(t *testing.T) { - kv := &myKV{} - s, err := NewService(0, time.Hour, testDir, kv, nil) - assert.Nil(t, err) - err = s.checkpoint() - assert.Nil(t, err) - _, err = LoadCheckpoint(kv, 0) - assert.Nil(t, err) -} - -func float32ToByte(f float32) []byte { - var buf bytes.Buffer - err := binary.Write(&buf, binary.LittleEndian, f) - if err != nil { - fmt.Println("binary.Write failed:", err) - } - return buf.Bytes() -} - -func TestCheckpointWithData(t *testing.T) { - kv := &myKV{} - s, err := NewService(0, time.Hour, testDir, kv, nil) - assert.Nil(t, err) - - var content []byte - for i := 0; i < 50000; i++ { - content = append(content, float32ToByte(float32(i))...) - } - - p1 := Parameter{Name: "p1", ElementType: 1, Content: content} - err = s.InitParam(ParameterWithConfig{Param: p1}, nil) - assert.Nil(t, err) - - err = s.FinishInitParams(0, nil) - assert.Nil(t, err) - - var p2 Parameter - err = s.GetParam(p1.Name, &p2) - assert.Nil(t, err) - assert.Equal(t, p1, p2) - - err = s.checkpoint() - assert.Nil(t, err) - cp, err := LoadCheckpoint(kv, 0) - assert.Nil(t, err) - s1, err := NewService(0, time.Hour, testDir, kv, cp) - assert.Nil(t, err) - - var p3 Parameter - err = s1.GetParam(p1.Name, &p3) - assert.Nil(t, err) - assert.Equal(t, p1, p3) -} diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go deleted file mode 100644 index 6949348e933e74d53a99f3b6c8fb928b9b5140f5..0000000000000000000000000000000000000000 --- a/go/pserver/service_test.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pserver_test - -import ( - "fmt" - "io/ioutil" - "reflect" - "sync" - "testing" - "time" - - "github.com/PaddlePaddle/Paddle/go/pserver" -) - -const ( - OptimizerConfig = "./client/c/test/testdata/optimizer.pb" -) - -func TestServiceFull(t *testing.T) { - var cp pserver.Checkpoint - s, err := pserver.NewService(0, time.Hour, "", nil, cp) - if err != nil { - t.Error(err) - } - var p pserver.Parameter - p.Name = "param_a" - p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} - p.ElementType = pserver.Int32 - config, err := ioutil.ReadFile(OptimizerConfig) - if err != nil { - t.Fatalf("read optimizer proto failed") - } - - err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) - if err != nil { - t.Fatal(err) - } - - var p1 pserver.Parameter - p1.Name = "param_b" - p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - p1.ElementType = pserver.Float32 - err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) - if err != nil { - t.Fatal(err) - } - - err = s.FinishInitParams(0, nil) - if err != nil { - t.Fatal(err) - } - - var param pserver.Parameter - err = s.GetParam("param_b", ¶m) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(param, p1) { - t.Fatal("not equal:", param, p1) - } - - g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) - - err = s.SendGrad(g1, nil) - if err != nil { - t.Fatal(err) - } - err = s.SendGrad(g2, nil) - - if err != nil { - t.Fatal(err) - } - - var param1 pserver.Parameter - err = s.GetParam("param_a", ¶m1) - if err != nil { - t.Fatal(err) - } - - // don't compare content, since it's already changed by - // gradient update. - param1.Content = nil - p.Content = nil - - if !reflect.DeepEqual(param1, p) { - t.Fatal("not equal:", param1, p) - } -} - -func TestMultipleInit(t *testing.T) { - var cp pserver.Checkpoint - s, err := pserver.NewService(0, time.Hour, "", nil, cp) - if err != nil { - t.Fatal(err) - } - err = s.FinishInitParams(0, nil) - if err != nil { - t.Fatal(err) - } - - err = s.FinishInitParams(0, nil) - if err.Error() != pserver.AlreadyInitialized { - t.Fatal(err) - } -} - -func TestUninitialized(t *testing.T) { - var cp pserver.Checkpoint - s, err := pserver.NewService(0, time.Hour, "", nil, cp) - err = s.SendGrad(pserver.Gradient{}, nil) - if err.Error() != pserver.Uninitialized { - t.Fatal(err) - } -} - -func TestBlockUntilInitialized(t *testing.T) { - var cp pserver.Checkpoint - s, err := pserver.NewService(0, time.Hour, "", nil, cp) - if err != nil { - t.Error(err) - } - ch := make(chan struct{}, 2) - errCh := make(chan error, 2) - var wg sync.WaitGroup - wg.Add(1) - go func() { - var param pserver.Parameter - err := s.GetParam("param_a", ¶m) - if err != nil { - errCh <- err - } - wg.Done() - ch <- struct{}{} - }() - - time.Sleep(50 * time.Millisecond) - - select { - case <-ch: - // some function returned before initialization is completed. - t.FailNow() - case <-errCh: - t.FailNow() - default: - } - - var p pserver.Parameter - p.Name = "param_a" - p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} - p.ElementType = pserver.Int32 - config, err := ioutil.ReadFile(OptimizerConfig) - if err != nil { - t.Fatalf("read optimizer proto failed") - } - err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) - - if err != nil { - t.Fatal(err) - } - - err = s.FinishInitParams(0, nil) - if err != nil { - t.Fatal(err) - } - - wg.Wait() -} - -func TestGradientString(t *testing.T) { - g := pserver.Parameter{} - g.ElementType = pserver.Float32 - g.Content = []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40} - if g.String() != "[3.3702806e+12 2.142699 3.3702806e+12 2.142699]" { - t.Fatal("get float data error!") - } - - g.Content = []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, - 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40} - if g.String() != "[3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699...3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699]" { - t.Fatal("get float data error!", g.String()) - } - fmt.Println(g) -} diff --git a/go/utils/networkhelper/CMakeLists.txt b/go/utils/networkhelper/CMakeLists.txt deleted file mode 100644 index 3100f2b5a527720b5e8edfb4219b42a8a874f67a..0000000000000000000000000000000000000000 --- a/go/utils/networkhelper/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -if(WITH_TESTING) - go_test(network_helper_test) -endif() diff --git a/go/utils/networkhelper/helper.go b/go/utils/networkhelper/helper.go deleted file mode 100644 index d205b6c50202148c6634bb378a03adcca7b074a0..0000000000000000000000000000000000000000 --- a/go/utils/networkhelper/helper.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package networkhelper - -import ( - "errors" - "net" -) - -// GetExternalIP returns the ip address of local network interface, not the -// loopback device. -func GetExternalIP() (string, error) { - ifaces, err := net.Interfaces() - if err != nil { - return "", err - } - for _, iface := range ifaces { - if iface.Flags&net.FlagUp == 0 { - continue // interface down - } - if iface.Flags&net.FlagLoopback != 0 { - continue // loopback interface - } - addrs, err := iface.Addrs() - if err != nil { - return "", err - } - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - if ip == nil || ip.IsLoopback() { - continue - } - ip = ip.To4() - if ip == nil { - continue // not an ipv4 address - } - return ip.String(), nil - } - } - return "", errors.New("are you connected to the network?") -} diff --git a/go/utils/networkhelper/helper_test.go b/go/utils/networkhelper/helper_test.go deleted file mode 100644 index 60b520fae15484e024cccddf169c2c8072c2e990..0000000000000000000000000000000000000000 --- a/go/utils/networkhelper/helper_test.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package networkhelper - -import "testing" - -func TestGetIP(t *testing.T) { - _, err := GetExternalIP() - if err != nil { - t.Errorf("GetExternalIP returns error : %v\n", err) - } -} diff --git a/proto/.gitignore b/proto/.gitignore deleted file mode 100644 index a0f00082c8e5d428fcf98979e38e626b810213b7..0000000000000000000000000000000000000000 --- a/proto/.gitignore +++ /dev/null @@ -1 +0,0 @@ -CMakeLists.txt diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt deleted file mode 100644 index a075eeb83bda64133920f9ab0275eb6c0e0fb8c4..0000000000000000000000000000000000000000 --- a/proto/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -if (MOBILE_INFERENCE) - file(GLOB proto_filenames . ModelConfig.proto ParameterConfig.proto - TrainerConfig.proto DataConfig.proto) -else() - file(GLOB proto_filenames . *.proto) -endif() - -include_directories(${CMAKE_CURRENT_BINARY_DIR}) -proto_library(paddle_proto SRCS ${proto_filenames}) - -set(PROTO_GEN) -set(PROTO_GEN_PY) - -foreach(filename ${proto_filenames}) - get_filename_component(ABS_FIL ${filename} ABSOLUTE) - get_filename_component(FIL_WE ${filename} NAME_WE) - set(CUR_PROTO_GEN_PY - ${PADDLE_BINARY_DIR}/paddle/python/paddle/proto/${FIL_WE}_pb2.py) - set(PROTO_GEN_PY - ${CUR_PROTO_GEN_PY} - ${PROTO_GEN_PY}) - add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY} - COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/proto - COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} - ARGS "--python_out=${PADDLE_BINARY_DIR}/python/paddle/proto" - "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} - DEPENDS ${ABS_FIL} protoc) -endforeach() - -add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) - - -if (WITH_GOLANG) - add_custom_target(protoc-gen-go) - add_custom_command(TARGET protoc-gen-go - COMMAND go - ARGS "get" "-u" "github.com/golang/protobuf/protoc-gen-go") - - set(PROTO_GEN_GO) - file(GLOB proto_filenames . OptimizerConfig.proto) - foreach(filename ${proto_filenames}) - message(STATUS ${filename}) - get_filename_component(ABS_FIL ${filename} ABSOLUTE) - get_filename_component(FIL_WE ${filename} NAME_WE) - set(CUR_PROTO_GEN_GO - ${PADDLE_SOURCE_DIR}/paddle/go/proto/${FIL_WE}.pb.go) - set(PROTO_GEN_GO - ${CUR_PROTO_GEN_GO} - ${PROTO_GEN_GO}) - add_custom_command(OUTPUT ${CUR_PROTO_GEN_GO} - COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} - ARGS "--go_out=${PADDLE_SOURCE_DIR}/go/proto" - "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} - DEPENDS ${ABS_FIL} protoc protoc-gen-go) - endforeach() - add_custom_target(gen_proto_go ALL DEPENDS ${PROTO_GEN_GO}) -endif() diff --git a/proto/DataConfig.proto b/proto/DataConfig.proto deleted file mode 100644 index 1b2aa8e726d2c567afba8cb7375e44a56cedf228..0000000000000000000000000000000000000000 --- a/proto/DataConfig.proto +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -package paddle; - -message FileGroupConf { - optional uint32 queue_capacity = 1 [ default = 1 ]; - // how many files to load for a load file thread - optional int32 load_file_count = 2 [ default = 1 ]; - // how many threads to load files - // Setting to be 5~10 is appropriate when loading files by hadoop vfs - optional int32 load_thread_num = 3 [ default = 1 ]; -}; - -message DataConfig { - - required string type = 1; - - // name of a text file which contains a list of file names at each line - optional string files = 3; - - optional int32 feat_dim = 4; // feature dimension of one frame - repeated int32 slot_dims = 5; // feature slot dims - optional int32 context_len = 6; // max neibour frame numbers - optional uint64 buffer_capacity = 7; // the number of samples - - // part of data used in training - // if not -1, part of train data is used in training - optional int64 train_sample_num = 8 [ default = -1 ]; - - // The number of documents processed once - optional int32 file_load_num = 9 [ default = -1 ]; - optional bool async_load_data = 12 [ default = false ]; - /// Note the field number 10, 11 and 13 have been deprecated. - optional bool for_test = 14 - [ default = false ]; // whether this data is for test - optional FileGroupConf file_group_conf = 15; - repeated int32 float_slot_dims = 16; - - /// Note the field number 17, 18 and 19 have been deprecated. - - // a list of values which will be used to create additional one dimensional - // float - // values slots. These one dimensional slots can be used as the weight input - // for cost layers. - // Currently this is only supported by ProtoDataProvider. - repeated double constant_slots = 20; - - // for PyDataProvider. - // Specify the load data script module name, object name and user args - optional string load_data_module = 21; - optional string load_data_object = 22; - optional string load_data_args = 23; - - // for MultiDataProvider - repeated DataConfig sub_data_configs = 24; // sub dataproviders - /* - * the ratio of each sub dataproviders: - * e.g. sub dataprovider A's ratio is 1, B's ratio is 9, batch_size is 100, - * then each mini-batch is combined by 10 instance from A and 90 instances - * from B. - */ - optional int32 data_ratio = 25; - /* - * if one of the sub dataproviders is running out of data, then - * (1) it is "main data", then finish current pass. - * (2) it is not "main data", then reset it, and try getNextBatch again. - */ - optional bool is_main_data = 26 [ default = true ]; - - // the usage ratio of instances. Setting to 1.0 means the use of all - // instances. - optional double usage_ratio = 27 [ default = 1.0 ]; -}; diff --git a/proto/DataFormat.proto b/proto/DataFormat.proto deleted file mode 100644 index 46b1f58bdb805c06964476483966efc7817e2747..0000000000000000000000000000000000000000 --- a/proto/DataFormat.proto +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -package paddle; - -/* - If values is not empty and ids is empty, this is a dense vector. - If values is not empty and ids is not empty, this is a sparse vector. The - position of each value - is specified by ids. - If values is empty and ids is not empty, this is a sparse vector whose non-zero - values are 1. - The position of each 1 is specified by ids. -*/ -message VectorSlot { - repeated float values = 1 [ packed = true ]; - repeated uint32 ids = 2 [ packed = true ]; - /* For multidimensional data, for example "image width height depth" */ - repeated uint32 dims = 3 [ packed = true ]; - repeated string strs = 4; -}; - -/* - SubseqSlot use to record whether VectorSlot or any other slot in future has - subseq. - If not all VectorSlot have subseq, we only store the one who has subseq, and - use *slot_id* to record it. - One vector_slots has one sequence, and it may have N subseq, thus the number of - *lens* will be N too. -*/ -message SubseqSlot { - required uint32 slot_id = 1; // the id of slot who has subseq - repeated uint32 lens = 2; // lengths of sub-sequence in the slot -}; - -message SlotDef { - enum SlotType { - VECTOR_DENSE = 0; - VECTOR_SPARSE_NON_VALUE = 1; - VECTOR_SPARSE_VALUE = 2; - INDEX = 3; // This can be used as label, or word id, etc. - VAR_MDIM_DENSE = 4; - VAR_MDIM_INDEX = 5; - STRING = 6; - } - required SlotType type = 1; - required uint32 dim = - 2; // For INDEX slots, this means the maximal index plus 1. -}; - -message DataHeader { - // INDEX slot should be always after VECTOR slots. - repeated SlotDef slot_defs = 1; -}; - -message DataSample { - optional bool is_beginning = 1 - [ default = true ]; // is the beginning of a sequence - repeated VectorSlot vector_slots = 2; - repeated uint32 id_slots = 3 [ packed = true ]; - /* use ids of VectorSlot */ - repeated VectorSlot var_id_slots = 4; - repeated SubseqSlot subseq_slots = 5; -}; diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto deleted file mode 100644 index d78ee9c9f39ed09825dffdfa0a442c0ffac5958f..0000000000000000000000000000000000000000 --- a/proto/ModelConfig.proto +++ /dev/null @@ -1,698 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -import "ParameterConfig.proto"; - -package paddle; - -/** - * Various structs for the configuration of a neural network - */ - -message ExternalConfig { - repeated string layer_names = 1; - repeated string input_layer_names = 2; - repeated string output_layer_names = 3; -} - -message ActivationConfig { - // identity: f(x) = x - // sigmoid: f(x) = 1 / (1 + exp(-x)) - // logistic: f(x) = (1 - exp(-x)) / (1+ exp(-x)) - // softmax: y_i = f(x_i) = exp(x_i) / (\sum_i exp(x_i)) - // relu: y = max(0, x) - required string type = 1; -}; - -message ConvConfig { - // filter_size = 5, says that this layer will use - // filters of size 5x5 pixels. - required uint32 filter_size = 1; - - // The image data dimensionality. - // This value must be either 1, 2, 3, or a multiple of 4. - required uint32 channels = 2; - - // stride = 1, indicates that the distance between - // successive filter applications should be 1 pixel. - required uint32 stride = 3; - - // padding = 4, instructs the net to implicitly - // pad the images with a 4-pixel border of zeros. - required uint32 padding = 4; - - // If groups = 4 together with the filters = 32 parameter, - // they state that this convolutional layer is to have 4 - // groups of 32 filters. Each filter will connect to 8 - // input channels. - required uint32 groups = 5; - required uint32 filter_channels = 6; - - // The size of output feature map. - required uint32 output_x = 7; - - // The size of input feature map. - required uint32 img_size = 8; - - // caffe mode for output size coherence - required bool caffe_mode = 9 [ default = true ]; - - // if filter_size_y is set , this convolutional layer will use - // filters of size filter_size * filter_size_y pixels. - // if filter_size_y is not set, this convolutional layer will use - // filters of size filter_size * filter_size - required uint32 filter_size_y = 10; - required uint32 padding_y = 11; - required uint32 stride_y = 12; - - // if not set, use output_x - optional uint32 output_y = 13; - - // if not set, use img_size - optional uint32 img_size_y = 14; - - optional uint32 dilation = 15 [ default = 1 ]; - optional uint32 dilation_y = 16 [ default = 1 ]; - - optional uint32 filter_size_z = 17 [ default = 1 ]; - optional uint32 padding_z = 18 [ default = 1 ]; - optional uint32 stride_z = 19 [ default = 1 ]; - optional uint32 output_z = 20 [ default = 1 ]; - optional uint32 img_size_z = 21 [ default = 1 ]; -} - -message PoolConfig { - // max or avg pooling - required string pool_type = 1; - required uint32 channels = 2; - - // Defines the size of the pooling region in - // the x (equivalently, y) dimension. - required uint32 size_x = 3; - - // Tell the net where in the input image to start the pooling. - // start is deprecated now. - optional uint32 start = 4; - - // Defines the stride size between successive pooling squares. - required uint32 stride = 5 [ default = 1 ]; - - // The size of output feature map. - required uint32 output_x = 6; - - // The size of input feature map. - required uint32 img_size = 7; - - // padding = 4, instructs the net to implicitly - // pad the images with a 4-pixel border of zeros. - optional uint32 padding = 8 [ default = 0 ]; - - // if not set, use size_x - optional uint32 size_y = 9; - - // if not set, use stride - optional uint32 stride_y = 10; - - // if not set, use output_x - optional uint32 output_y = 11; - - // if not set, use img_size - optional uint32 img_size_y = 12; - - // if not set, use padding - optional uint32 padding_y = 13; - - optional uint32 size_z = 14 [ default = 1 ]; - optional uint32 stride_z = 15 [ default = 1 ]; - optional uint32 output_z = 16 [ default = 1 ]; - optional uint32 img_size_z = 17 [ default = 1 ]; - optional uint32 padding_z = 18 [ default = 1 ]; - - optional bool exclude_mode = 19; -} - -message SppConfig { - required ImageConfig image_conf = 1; - required string pool_type = 2; - required uint32 pyramid_height = 3; -} - -message NormConfig { - // rnorm or cmrnorm - required string norm_type = 1; - required uint32 channels = 2; - - // rnorm: this defines the size of the local regions - // used for response normalization. - // cmrnorm: The size parameter indicates how many - // nearby maps to use for normalization. - required uint32 size = 3; - - // the parameters for normalization - // u = u / (1+scale*sum(u^2 in window))^pow - required double scale = 4; - required double pow = 5; - - // The size of output feature map. - required uint32 output_x = 6; - - // The size of input feature map. - required uint32 img_size = 7; - - // normalize with fixed window or sliding window - // u = u / (1+scale*sum(u^2 in window))^pow - // fixed window: shared a fixed window for each value - // sliding window: have a different window for each value - optional bool blocked = 8; - - // if not set, use output_x - optional uint32 output_y = 9; - - // if not set, use img_size - optional uint32 img_size_y = 10; -} - -message BlockExpandConfig { - required uint32 channels = 1; - - required uint32 stride_x = 2; - required uint32 stride_y = 3; - - required uint32 padding_x = 4; - required uint32 padding_y = 5; - - required uint32 block_x = 6; - required uint32 block_y = 7; - - // The size of output feature map. - required uint32 output_x = 8; - required uint32 output_y = 9; - - // The size of input feature map. - required uint32 img_size_x = 10; - required uint32 img_size_y = 11; -} - -message MaxOutConfig { - required ImageConfig image_conf = 1; - required uint32 groups = 2; -} - -message RowConvConfig { required uint32 context_length = 1; } - -message SliceConfig { - required uint32 start = 1; - required uint32 end = 2; -} - -message ProjectionConfig { - required string type = 1; - required string name = 2; - required uint64 input_size = 3; - required uint64 output_size = 4; - - // For ShiftProjection - optional int32 context_start = 5; - optional int32 context_length = 6; - optional bool trainable_padding = 7 [ default = false ]; - - // For convolution - optional ConvConfig conv_conf = 8; - optional int32 num_filters = 9; - - // For IdentityOffsetProjection - optional uint64 offset = 11 [ default = 0 ]; - - // For pool - optional PoolConfig pool_conf = 12; - - // For slice - // Each slice output is the input[start, end) - repeated SliceConfig slices = 13; -} - -message OperatorConfig { - required string type = 1; - repeated int32 input_indices = 2; - repeated uint64 input_sizes = 3; - required uint64 output_size = 4; - - // For DotMulOperator - optional double dotmul_scale = 5 [ default = 1.0 ]; - - // For ConvOperator - optional ConvConfig conv_conf = 6; - optional int32 num_filters = 7; -} - -message BilinearInterpConfig { - // The size of input feature map. - required ImageConfig image_conf = 1; - // The size of output feature map. - required uint32 out_size_x = 2; - required uint32 out_size_y = 3; -} - -message ImageConfig { - // The image data dimensionality. - // This value must be either 1, 2, 3, or a multiple of 4. - required uint32 channels = 2; - - // The size of input feature map. - required uint32 img_size = 8; - optional uint32 img_size_y = 9; - optional uint32 img_size_z = 10 [ default = 1 ]; -} - -message PriorBoxConfig { - repeated uint32 min_size = 1; - repeated uint32 max_size = 2; - repeated float aspect_ratio = 3; - repeated float variance = 4; -} - -message PadConfig { - required ImageConfig image_conf = 1; - repeated uint32 pad_c = 2; - repeated uint32 pad_h = 3; - repeated uint32 pad_w = 4; -} - -message ReshapeConfig { - repeated uint32 height_axis = 1; - repeated uint32 width_axis = 2; -} - -message MultiBoxLossConfig { - required uint32 num_classes = 1; - required float overlap_threshold = 2; - required float neg_pos_ratio = 3; - required float neg_overlap = 4; - required uint32 background_id = 5; - required uint32 input_num = 6; - optional uint32 height = 7 [ default = 1 ]; - optional uint32 width = 8 [ default = 1 ]; -} - -message DetectionOutputConfig { - required uint32 num_classes = 1; - required float nms_threshold = 2; - required uint32 nms_top_k = 3; - required uint32 background_id = 4; - required uint32 input_num = 5; - required uint32 keep_top_k = 6; - required float confidence_threshold = 7; - optional uint32 height = 8 [ default = 1 ]; - optional uint32 width = 9 [ default = 1 ]; -} - -message ClipConfig { - required double min = 1; - required double max = 2; -} - -message UpsampleConfig { - required ImageConfig image_conf = 1; - optional uint32 scale = 2 [ default = 2 ]; - optional uint32 scale_y = 3 [ default = 2 ]; - optional bool pad_out_x = 4 [ default = false ]; - optional bool pad_out_y = 5 [ default = false ]; - optional uint32 upsample_size = 6; - optional uint32 upsample_size_y = 7; -} - -message ROIPoolConfig { - required uint32 pooled_width = 1; - required uint32 pooled_height = 2; - required float spatial_scale = 3; - optional uint32 height = 4 [ default = 1 ]; - optional uint32 width = 5 [ default = 1 ]; -} - -message ScaleSubRegionConfig { - required ImageConfig image_conf = 1; - required float value = 2; -} - -message LayerInputConfig { - required string input_layer_name = 1; - optional string input_parameter_name = 2; - optional ConvConfig conv_conf = 3; - optional PoolConfig pool_conf = 4; - optional NormConfig norm_conf = 5; - optional ProjectionConfig proj_conf = 6; - optional BlockExpandConfig block_expand_conf = 7; - optional ImageConfig image_conf = 8; - // If the input layer has multi-output. - // Set the argument name. - optional string input_layer_argument = 9; - optional BilinearInterpConfig bilinear_interp_conf = 10; - optional MaxOutConfig maxout_conf = 11; - optional SppConfig spp_conf = 12; - optional PriorBoxConfig priorbox_conf = 13; - optional PadConfig pad_conf = 14; - optional RowConvConfig row_conv_conf = 15; - optional MultiBoxLossConfig multibox_loss_conf = 16; - optional DetectionOutputConfig detection_output_conf = 17; - optional ClipConfig clip_conf = 18; - optional ScaleSubRegionConfig scale_sub_region_conf = 19; - optional ROIPoolConfig roi_pool_conf = 20; - optional UpsampleConfig upsample_conf = 21; -} - -message LayerConfig { - required string name = 1; - required string type = 2; - optional uint64 size = 3; - // optional ActivationConfig activation = 4; - optional string active_type = 4; - repeated LayerInputConfig inputs = 5; - optional string bias_parameter_name = 6; - - // This number must be a multiple of 16. - optional uint32 num_filters = 7; - - // indicates that the biases of every filter in this layer - // should be shared amongst all applications of that filter - // (which is how convnets are usually trained). Setting this to - // false will untie the biases, yielding a separate bias for - // every location at which the filter is applied. - optional bool shared_biases = 8 [ default = false ]; - - // Valid values are ones that divide the area of the output - // grid in this convolutional layer. For example if this layer - // produces 32-channel 20x20 output grid, valid values of - // partialSum are ones which divide 20*20 = 400. - // I'll update this comments when confirmed - optional uint32 partial_sum = 9; - - // for dropout - optional double drop_rate = 10; - - // for HierarchicalSoftmaxLayer and NCELayer - // the number of classes - optional uint32 num_classes = 11; - - // the gpu device which the Layer's data in. - // Only used by ParallelNeuralNetork. Ignored otherwise. - optional int32 device = 12 [ default = -1 ]; - - // for recurrent layer. If true, the recurrence runs from the end to the - // beginning. - optional bool reversed = 13 [ default = false ]; - - // for lstmemory layer. Different types of nodes have different activation - // type. - optional string active_gate_type = 14; - optional string active_state_type = 15; - - // For NCELayer - // The number of random negative labels for each sample - optional int32 num_neg_samples = 16 [ default = 10 ]; - - // For NCELayer - // The distribution for generating the random negative labels. - // A uniform distribution will be used if not provided - repeated double neg_sampling_dist = 17 [ packed = true ]; - - // For MaxLayer - // default: output VALUE of MaxLayer. set this flag to true for output INDEX - // INDEX will be put in Argument::value as double values. - optional bool output_max_index = 19 [ default = false ]; - - /// The filed number 20 have been deprecated. - - // For self-normalized estimation - optional double softmax_selfnorm_alpha = 21 [ default = 0.1 ]; - - /// The filed numbers 22 and 23 have been deprecated. - - // for MDLstmLayer - repeated bool directions = 24; - - // for CTCLayer - optional bool norm_by_times = 25; - - // for CostLayers - optional double coeff = 26 [ default = 1.0 ]; - - // for AverageLayer - // can be set to: 'average', 'sum' or 'squarerootn' - optional string average_strategy = 27; - - // for error clipping - optional double error_clipping_threshold = 28 [ default = 0.0 ]; - - // for operators used by mixed layer - repeated OperatorConfig operator_confs = 29; - - // for lambdaCost - optional int32 NDCG_num = 30; - optional int32 max_sort_size = 31; - - // for SlopeInterceptLayer - optional double slope = 32; - optional double intercept = 33; - - // for CosSimVecMatLayer and CosSimLayer - optional double cos_scale = 34; - - // for DataNormLayer - // can be set to: 'z-score', 'min-max' or 'decimal-scaling' - optional string data_norm_strategy = 36; - - // for bos/eos id - optional uint32 bos_id = 37; - optional uint32 eos_id = 38; - - // for max id layer - optional uint32 beam_size = 39; - - // for seqlastins layer, whether select first instead last - optional bool select_first = 40 [ default = false ]; - - // for seqlastins layer, AverageLayer, MaxLayer and ExpandLayer - // can be set to: 'non-seq','seq' - optional string trans_type = 41 [ default = 'non-seq' ]; - - // to indicate whether selective_fc layer - // is used in sequence generation or not - optional bool selective_fc_pass_generation = 42 [ default = false ]; - - // to indicate whether selective_fc layer take its last input to - // selected several columns and only compute the multiplications - // between the input matrices and the selected columns of - // the parameter matrices of this layer. - // if set false, selective_fc degrades into fc. - optional bool has_selected_colums = 43 [ default = true ]; - - // this parameter is for speed consideration. - // if number of the selected columns is less than - // sample number * selective_fc output size * selective_fc_mull_mull_ratio - // sparse multiplication is used, otherwise, using full multiplication. - optional double selective_fc_full_mul_ratio = 44 [ default = 0.02 ]; - - // to indicate how many threads selective_fc use to to accelate - // the plain_mul period - // leave empty or set to 0 to disable multi-thread accleleration - optional uint32 selective_fc_parallel_plain_mul_thread_num = 45 - [ default = 0 ]; - - // for batch normalization layer - // if set use_global_stats true, will use the loaded mean and variance. - optional bool use_global_stats = 46; - - // use to compute moving mean and variance. - optional double moving_average_fraction = 47 [ default = 0.9 ]; - - // bias size - optional uint32 bias_size = 48 [ default = 0 ]; - - // this parameter can be used as a user-defined parameter when necessary, - // without changing the proto file. - // e.g., when a new layer with a user-defined parameter is implemented, - // it can be used to pass that parameter, without modifying the proto file. - // string type is used for flexibility: different types can be converted - // to string and reinterpreted in the user's own layer implementation. - optional string user_arg = 49; - - // to indicate rectangle image data - optional uint64 height = 50; - optional uint64 width = 51; - - // blank label used in ctc loss - optional uint32 blank = 52 [ default = 0 ]; - - // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which - // controls the scope of pooling operation. can be set > 0. - // leave empty or set to -1 to disable this stride pooling. - optional int32 seq_pool_stride = 53 [ default = -1 ]; - - // for crop layer - optional int32 axis = 54 [ default = 2 ]; - repeated uint32 offset = 55; - repeated uint32 shape = 56; - - // for HuberRegressionLoss - optional double delta = 57 [ default = 1.0 ]; - - // for 3D data - optional uint64 depth = 58 [ default = 1 ]; - - // for switch order layer - optional ReshapeConfig reshape_conf = 59; - - // for batch normalization layer - // The small constant added to the variance to improve numeric stability. - optional double epsilon = 60 [ default = 0.00001 ]; - - // for factorization machine layer - optional uint32 factor_size = 61; -} - -message EvaluatorConfig { - required string name = 1; - required string type = 2; - repeated string input_layers = 3; - - // Used by ChunkEvaluator - // one of "IOB", "IOE", "IOBES" - optional string chunk_scheme = 4; - // number of chunk types other than "other" - optional int32 num_chunk_types = 5; - - // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator - // For multi binary labels: true if output > classification_threshold - optional double classification_threshold = 6 [ default = 0.5 ]; - // The positive label. -1 means average precision and recall - optional int32 positive_label = 7 [ default = -1 ]; - - // load dict from this file - optional string dict_file = 8; - - // dump result in this file - optional string result_file = 9; - - // top # results for max id printer - optional int32 num_results = 10 [ default = 1 ]; - - // whether to delimit the sequence in the seq_text_printer - optional bool delimited = 11 [ default = true ]; - - // Used by ChunkEvaluator - // chunk of these types are not counted - repeated int32 excluded_chunk_types = 12; - - // Used by ClassificationErrorEvaluator - // top # classification error - optional int32 top_k = 13 [ default = 1 ]; - - // Used by DetectionMAPEvaluator - optional double overlap_threshold = 14 [ default = 0.5 ]; - - optional int32 background_id = 15 [ default = 0 ]; - - optional bool evaluate_difficult = 16 [ default = false ]; - - optional string ap_type = 17 [ default = "11point" ]; -} - -message LinkConfig { - required string layer_name = 1; - required string link_name = 2; - // If true, this link has sub-sequence - optional bool has_subseq = 3 [ default = false ]; -} - -message MemoryConfig { - required string layer_name = 1; - required string link_name = 2; - - optional string boot_layer_name = 3; - optional string boot_bias_parameter_name = 4; - optional string boot_bias_active_type = 5; - optional uint32 boot_with_const_id = 7; - - // memory is a sequence, initailized by a sequence boot layer - optional bool is_sequence = 6 [ default = false ]; -} - -message GeneratorConfig { - required uint32 max_num_frames = 1; - required string eos_layer_name = 2; - optional int32 num_results_per_sample = 3 [ default = 1 ]; - - // for beam search - optional int32 beam_size = 4 [ default = 1 ]; - - optional bool log_prob = 5 [ default = true ]; -} - -message SubModelConfig { - required string name = 1; - repeated string layer_names = 2; // selected layers in sub model - repeated string input_layer_names = 3; - repeated string output_layer_names = 4; - repeated string evaluator_names = 5; - - optional bool is_recurrent_layer_group = 6 [ default = false ]; - - // If true, the recurrence runs from the end to the beginning. - optional bool reversed = 7 [ default = false ]; - - // name and link name of memory - repeated MemoryConfig memories = 8; - - // if use recurrent layer group, all layers in submodel will postfix by - // "_in_"+submodel.name, so we add a name pair to link between - // root model and layer group, - // note that these in/out layers are not input/output of the network. - repeated LinkConfig in_links = 9; - repeated LinkConfig out_links = 10; - - optional GeneratorConfig generator = 11; - - // the id of inlink which share info with outlinks, used in recurrent layer - // group - optional int32 target_inlinkid = 12; -} - -message ModelConfig { - // type of the model. - // Currently, "nn", "recurrent_nn" and "recursive_nn" are supported - required string type = 1 [ default = "nn" ]; - - // layers should be ordered in such a way that the forward propagation - // can be correctly executed by going from the first layer to the last layer - repeated LayerConfig layers = 2; - - repeated ParameterConfig parameters = 3; - - // Input layers should have the same order as the data streams provided - // by the data provider. The type of input layers should be "data" - repeated string input_layer_names = 4; - - // For training, the type of a output layer is usually cost layer. - // For prediction, they should be the actual output layers. - repeated string output_layer_names = 5; - - repeated EvaluatorConfig evaluators = 6; - - repeated SubModelConfig sub_models = 8; - - // For External Machine, defining how to split a neural network - // into multiple parts. - optional ExternalConfig external_config = 9; -}; diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto deleted file mode 100644 index e9ea1bfbcc66806e53a45623d0e8ec003ad9ed82..0000000000000000000000000000000000000000 --- a/proto/OptimizerConfig.proto +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -syntax = "proto2"; - -option optimize_for = LITE_RUNTIME; - -package paddle; - -message SGDConfig { - // SGD - // momentum: float >= 0. Parameter updates momentum. - // decay: float >= 0. Learning rate decay over each update. - // nesterov: boolean. Whether to apply Nesterov momentum. - optional double momentum = 21 [ default = 0.0 ]; - optional double decay = 23 [ default = 0.0 ]; - optional bool nesterov = 24 [ default = false ]; -} - -message AdadeltaConfig { - // Adadelta - // It is recommended to leave it at the default value. - // rho: float >= 0. - // epsilon: float >= 0. Fuzz factor. - // decay: float >= 0. Learning rate decay over each update. - - // reference : [Adadelta - an adaptive learning rate - // method](http://arxiv.org/abs/1212.5701) - optional double rho = 33 [ default = 0.90 ]; - optional double epsilon = 31 [ default = 1e-5 ]; - optional double decay = 32 [ default = 0.0 ]; -} - -message AdagradConfig { - // Adagrad - // epsilon: float >= 0. - // decay: float >= 0. Learning rate decay over each update. - - // reference : [Adaptive Subgradient Methods for Online Learning and - // Stochastic - // Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) - optional double epsilon = 41 [ default = 1e-5 ]; - optional double decay = 42 [ default = 0.0 ]; -} - -message AdamConfig { - // Adaj - // beta_1: float, 0 < beta < 1. Generally close to 1. - // beta_2: float, 0 < beta < 1. Generally close to 1. - // epsilon: float >= 0. Fuzz factor. - // decay: float >= 0. Learning rate decay over each update. - // reference : [Adam - A Method for Stochastic - // Optimization](http://arxiv.org/abs/1412.6980v8) - optional double beta_1 = 41; - optional double beta_2 = 42; - optional double epsilon = 43; - optional double decay = 44; -} - -message ConstLrConfig { - // learninRate Policy - optional double learning_rate = 1 [ default = 1.0 ]; -} - -message LinearLrConfig { - // learninRate Policy - optional double learning_rate = 1 [ default = 1.0 ]; - optional double lr_decay_a = 2; - optional double lr_decay_b = 3; -} - -message TensorProto { - enum DataType { - PADDLE_ELEMENT_TYPE_INT32 = 0; - PADDLE_ELEMENT_TYPE_UINT32 = 1; - PADDLE_ELEMENT_TYPE_INT64 = 2; - PADDLE_ELEMENT_TYPE_UINT64 = 3; - PADDLE_ELEMENT_TYPE_FLOAT32 = 4; - PADDLE_ELEMENT_TYPE_FLOAT64 = 5; - } - optional DataType data_type = 1; - repeated bytes content = 2; -} - -message LrPolicyState { - // learninRate Policy - optional double learning_rate = 1 [ default = 1.0 ]; - optional double lr_decay_a = 2; - optional double lr_decay_b = 3; -} - -message SGDOptimizerState { - optional LrPolicyState lr_state = 101; - optional double num_sample_passed = 104; - // state - optional TensorProto parameter = 1; - optional TensorProto momentums = 2; -} - -message AdadeltaOptimizerState { - // learning rate policy - optional LrPolicyState lr_state = 101; - optional double num_sample_passed = 104; - // state - optional TensorProto parameter = 1; - optional TensorProto accum_gradient = 2; - optional TensorProto accum_delta = 3; - optional TensorProto update_delta = 4; -} - -message AdagradOptimizerState { - optional LrPolicyState lr_state = 101; - optional double num_sample_passed = 104; - // state - optional TensorProto parameter = 1; - optional TensorProto accum_gradient = 2; -} - -message AdamOptimizerState { - optional LrPolicyState lr_state = 101; - optional double num_sample_passed = 104; - // state - optional TensorProto parameter = 1; - optional TensorProto momentums = 2; - optional TensorProto velocitys = 3; -} - -message OptimizerConfig { - enum Optimizer { - SGD = 1; - Adadelta = 2; - Adagrad = 3; - Adam = 4; - } - optional Optimizer optimizer = 1; - optional SGDConfig sgd = 3; - optional AdadeltaConfig adadelta = 4; - optional AdagradConfig adagrad = 5; - optional AdamConfig adam = 6; - - enum LrPolicy { - Const = 0; - Linear = 1; - } - optional LrPolicy lr_policy = 11; - optional ConstLrConfig const_lr = 12; - optional LinearLrConfig linear_lr = 13; - - // common config of optimizer - // gradient clip when L2 exceeding value - optional double clip_norm = 101; - // gradient clip when L1 exceeding value - optional double clip_value = 102; -} diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto deleted file mode 100644 index 6f8ba9d7605ef19ebcc32407d3f09d2fa7a266f8..0000000000000000000000000000000000000000 --- a/proto/ParameterConfig.proto +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -package paddle; - -/** - * Configuration structure for parameter - */ - -enum ParameterInitStrategy { - PARAMETER_INIT_NORMAL = 0; - PARAMETER_INIT_UNIFORM = 1; -} - -message ParameterUpdaterHookConfig { - // hook type such as 'pruning' - required string type = 1; - // this represents the ratio of zero element to be set by the Parameter - optional double sparsity_ratio = 2 [ default = 0.6 ]; -} - -message ParameterConfig { - required string name = 1; - required uint64 size = 2; - optional double learning_rate = 3 [ default = 1.0 ]; - optional double momentum = 4 [ default = 0.0 ]; - optional double initial_mean = 5 [ default = 0.0 ]; - optional double initial_std = 6 [ default = 0.01 ]; - // use L2-regularization if decay_rate set and decay_rate_l1 not set - optional double decay_rate = 7 [ default = 0.0 ]; - // use L1-regularization if decay_rate_l1 set - optional double decay_rate_l1 = 8 [ default = 0.0 ]; - // dims of Parameter, e.g. dims[0] as height, dims[1] as width.. - repeated uint64 dims = 9; - // the gpu device which the parameter in. - // Only used by ParallelNeuralNetork. Ignored otherwise. - optional int32 device = 10 [ default = -1 ]; - // how to init the parameter: 0 -> normal, 1 -> uniform - // 0: treat initial_mean as mean, intial_std as standard deviation - // 1: range is (initial_mean - initial_std) to (initial_mean + initial_std) - optional int32 initial_strategy = 11 [ default = 0 ]; - // define the variance when init the parameter, by height of the Matrix - optional bool initial_smart = 12 [ default = false ]; - // apply regularization every # batches - optional int32 num_batches_regularization = 13 [ default = 1 ]; - // if is_sparse is true, para is sparse, else para is dense - optional bool is_sparse = 14 [ default = false ]; - // if para is sparse, format should be "csc" or "csr", empty means is not - // sparse - optional string format = 15 [ default = "" ]; - // sparse remote update or not - optional bool sparse_remote_update = 16 [ default = false ]; - // gradient clipping threshold, no clipping by default - optional double gradient_clipping_threshold = 17 [ default = 0.0 ]; - // static parameters are fixed when training - optional bool is_static = 18 [ default = false ]; - // para_id should NOT be set by config_parser. It is for - // internal use. - optional uint64 para_id = 19; - - repeated ParameterUpdaterHookConfig update_hooks = 20; - // setup load mat -> csr - optional bool need_compact = 21 [ default = false ]; - // whether to do sparse update for this parameter - optional bool sparse_update = 22 [ default = false ]; - - // whether this parameter is shared or not. - optional bool is_shared = 23 [ default = false ]; - // parameter block size - optional uint64 parameter_block_size = 24 [ default = 0 ]; -} diff --git a/proto/ParameterServerConfig.proto b/proto/ParameterServerConfig.proto deleted file mode 100644 index 1404c8aa14327e89d7dde7d2668caac474ea9217..0000000000000000000000000000000000000000 --- a/proto/ParameterServerConfig.proto +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -package paddle; - -/** - * Configuration structure for ParameterClient2. - */ -message ParameterClientConfig { required int32 trainer_id = 1; } - -/** - * Configuration structure for ParameterServer2. - */ -message ParameterServerConfig { - // Number of ports for sending dense parameter, - // following ports on parameter server will be visited - // for sending dense parameter: [port, port+ports_num-1] - required int32 ports_num = 1 [ default = 1 ]; - // Number of ports for sending sparse parameter, - // following ports on parameter server will be visited - // for sending sparse parameter: - // [port+ports_num, port+ports_num+ports_num_for_sparse-1] - required int32 ports_num_for_sparse = 2 [ default = 0 ]; - // network device name for pservers - required string nics = 3 [ default = "xgbe0,xgbe1" ]; - required string rdma_tcp = 4 [ default = "tcp" ]; - // Listening port for pserver - required int32 port = 5 [ default = 20134 ]; - // number of gradient servers - required int32 num_gradient_servers = 6 [ default = 1 ]; - // number of threads for sync op exec - required int32 pserver_num_threads = 7 [ default = 1 ]; - // control config_.async_lagged_grad_discard_ratio() min value - required double async_lagged_ratio_min = 8 [ default = 1.0 ]; - // if async_lagged_grad_discard_ratio is not set in trainer_config.conf - // use it as defalut value - required double async_lagged_ratio_default = 9 [ default = 1.5 ]; -} diff --git a/proto/ParameterService.proto b/proto/ParameterService.proto deleted file mode 100644 index b56c1bfe7caa0ad1294ae07edd1d7fea8e1e9a27..0000000000000000000000000000000000000000 --- a/proto/ParameterService.proto +++ /dev/null @@ -1,351 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -import "ParameterConfig.proto"; -import "TrainerConfig.proto"; - -package paddle; - -/** - * Various structs for communicating with parameter server - */ -enum ParameterUpdateMode { - // Set parameter - PSERVER_UPDATE_MODE_SET_PARAM = 0; // use local param - PSERVER_UPDATE_MODE_SET_PARAM_ZERO = 1; // set zero param - - // Update parameter once a gradient is received - PSERVER_UPDATE_MODE_ASYNC_SGD = 2; - - // Accumulate gradient - PSERVER_UPDATE_MODE_ADD_GRADIENT = 3; - - // Average parameters - PSERVER_UPDATE_MODE_AVERAGE_PARAMETER = 4; - - // No update. Only get parameters back. - PSERVER_UPDATE_MODE_GET_PARAM = 5; - PSERVER_UPDATE_MODE_GET_PARAM_SPARSE = 6; // only get sparse rows -}; - -message ParameterBlock { - // it accurately means parameter id. - required uint64 para_id = 1; - // global sparse row or dense block for each block in parameter - required uint64 block_id = 2; - // offset in (local) storage - required uint64 begin_pos = 3; - // actual size of block, size for last block is [endDim -beginDim], - // others is parameter_block_size in ParameterConfig - required uint64 block_size = 4; -} - -enum PServerStatus { - PSERVER_STATUS_NOT_SET = 0; - PSERVER_STATUS_PARAMETER_READY = 1; -}; - -enum BatchStatus { - BATCH_START = 0; - BATCH_ON = 1; - BATCH_FINISH = 2; - BATCH_START_AND_FINISH = 3; -}; - -message SendParameterRequest { - required ParameterUpdateMode update_mode = 1; - repeated ParameterBlock blocks = 2; - required bool send_back_parameter = 3; - - // number of samples used for calculating this update - optional int64 num_samples = 4; - - // cost will be used to calculate global objective value - optional double cost = 5; - - required BatchStatus batch_status = 6; - - optional int32 trainer_id = 7; - - // send back parameter type on pserver, PARAMETER_VALUE by default - optional int32 send_back_parameter_type = 8 [ default = 0 ]; - - // forwardbackward time in usec - optional uint64 forwardbackward_time = 9; -} - -message WaitPassStartRequest {} - -message WaitPassStartResponse {} - -message WaitPassFinishRequest {} - -message WaitPassFinishResponse {} - -enum SyncObject { - SYNC_DEFAULT = 0; // wait for the synchronizeBarrier_ - SYNC_DATA = 1; // wait for the synchronizeDataBarrier_ -} - -message SynchronizeRequest { - required SyncObject sync_object_id = 1 [ default = SYNC_DEFAULT ]; - - optional int32 trainer_id = 2; -} - -message SynchronizeResponse {} - -message SendParameterResponse { repeated ParameterBlock blocks = 1; } - -message SetConfigRequest { - repeated ParameterConfig param_configs = 1; - required OptimizationConfig opt_config = 2; - required string save_dir = 4; - required int32 server_id = 5; - required bool is_sparse_server = 6; -} - -message SetConfigResponse {} - -message GetStatusRequest {} - -message GetStatusResponse { required PServerStatus status = 1; } - -message SetStatusRequest { required PServerStatus status = 1; } - -message SetStatusResponse {} - -// create a column vector. The size is the dimension of parameter -message CreateVectorRequest {} - -message CreateVectorResponse { - // error message. Empty if success - optional string return_message = 1; - - required int64 handle = 2; -} - -message ReleaseVectorRequest { required int64 handle = 1; } - -message ReleaseVectorResponse { - // error message. Empty if success - optional string return_message = 1; -} - -// Create a column major matrix. The number of rows is the dimension -// of parameter. The number of columns is specifed by num_cols -message CreateMatrixRequest { required int32 num_cols = 1; } - -message CreateMatrixResponse { - // error message. Empty if success - optional string return_message = 1; - - required int64 handle = 2; -} - -message ReleaseMatrixRequest { required int64 handle = 1; } - -message ReleaseMatrixResponse { - // error message. Empty if success - optional string return_message = 1; -} - -/** - * The operations are defined using the variables commented at Operation - * and OperationResult - */ -enum MatrixVectorOperation { - // r = u^T u - PSERVER_OP_utu = 0; - - // r = u^T v - PSERVER_OP_utv = 1; - - // u = a u - PSERVER_OP_au = 2; - - // v = a u + b v - PSERVER_OP_au_bv = 3; - - // u = a A x + b u - PSERVER_OP_aAx_bu = 4; - - // Stochastic gradient update - PSERVER_OP_SGD = 5; - - // u = a - PSERVER_OP_RESET = 6; - - // v = u - PSERVER_OP_COPY = 7; - - // w = a u + b v + c w - PSERVER_OP_au_bv_cw = 8; - - // owlqn: MakeSteepestDescDir - PSERVER_OP_MAKE_STEEPEST_DESC_DIR = 9; - - // owlqn: FixDirSigns - PSERVER_OP_FIX_DIR_SIGNS = 10; - - // owlqn: DirDeriv - PSERVER_OP_DIR_DERIV = 11; - - // owlqn: FixOmegaSigns - PSERVER_OP_FIX_OMEGA_SIGNS = 12; - - // Get overall cost - PSERVER_OP_COST = 13; - - // Pass control - PSERVER_OP_START_PASS = 14; - PSERVER_OP_FINISH_PASS = 15; - - // randomize value - PSERVER_OP_RANDOMIZE = 16; - - // call optimizer apply - PSERVER_OP_APPLY = 17; -} - -message ProtoVector { - required int64 dim = 1; - repeated double values = 2 [ packed = true ]; -} - -message ProtoMatrix { - required int64 num_rows = 1; - required int64 num_cols = 2; - repeated double values = 3 [ packed = true ]; -} - -message Operation { - required MatrixVectorOperation operation = 1; - - // vector handles created on the pserver - repeated int64 pvectors = 2; // u, v, w - - // matrix handles created on the pserver - repeated int64 pmatrices = 3; // A, B, C - - repeated double scalars = 4; // a, b, c - repeated ProtoVector vectors = 5; // x, y, z - repeated ProtoMatrix matrices = 6; // X, Y, Z -} - -message OperationResult { - // error message. Empty if success - optional string return_message = 1; - // - repeated double scalars = 2; // d, e, f - repeated ProtoVector vectors = 3; // p, q, r - repeated ProtoMatrix matrices = 4; // P, Q, R -} - -message DoOperationRequest { - repeated Operation operations = 1; - - // If true, wait for gradient to be ready before starting the operations - required bool wait_for_gradient = 2; - - // If true, send back the parameter to clients after the operations are - // finished - required bool send_back_parameter = 3; - - // If true, and if all clients call waitPassFinish, - // signal all clients finish the pass - required bool release_pass = 4; -} - -message DoOperationResponse { - // error message. Empty if success - optional string return_message = 1; - - repeated OperationResult results = 2; - - required bool pass_finish = 3; -} - -message LoadValueRequest { required string dir_name = 1; } - -message LoadValueResponse { - // error message. Empty if success - optional string return_message = 1; -} - -message SaveValueRequest { required string dir_name = 1; } - -message SaveValueResponse { - // error message. Empty if success - optional string return_message = 1; -} - -enum DataUpdateMode { - // Client send it's own data to pserver - DATA_UPDATE_MODE_SET_OWN = 0; - // Client get all user data from all pservers - DATA_UPDATE_MODE_GET_ALL = 1; - // Client send it's own ref feature to pserver - DATA_UPDATE_MODE_SET_REF = 2; - // Client get all ref featuers from all pservers - DATA_UPDATE_MODE_GET_REF = 3; - // Client send it's own ref label to pserver - DATA_UPDATE_MODE_SET_REF_LABEL = 4; - // Client get all ref labels from all pservers - DATA_UPDATE_MODE_GET_REF_LABEL = 5; - // Client send it's own ref grad to pserver - DATA_UPDATE_MODE_SET_REF_GRAD = 6; - // Client get all ref grad from all pservers - DATA_UPDATE_MODE_GET_REF_GRAD = 7; -} - -enum SendDataType { - DATA_REF = 0; - DATA_REFLABEL = 1; - DATA_REFGRAD = 2; - DATA_REDUCE_SUM = 3; -} - -enum TransDataType { - TRANS_INT32 = 0; - TRANS_UINT32_T = 1; - TRANS_INT64_T = 2; - TRANS_UINT64_T = 3; - TRANS_FLOAT = 5; - TRANS_DOUBLE = 6; -} - -message DataBlock { - // total byte size of this data blcok - required uint64 total_size = 1; - // byte size of one data type - required int32 data_size = 2; - // data_type - optional TransDataType data_type = 3 [ default = TRANS_DOUBLE ]; -} - -message SendDataRequest { - required SendDataType type = 1; - required DataUpdateMode update_mode = 2; - repeated DataBlock blocks = 3; - required uint64 client_id = 4; - required uint64 server_id = 5; -} - -message SendDataResponse { - required SendDataType type = 1; - repeated DataBlock blocks = 2; - required uint64 server_id = 3; -} diff --git a/proto/README.md b/proto/README.md deleted file mode 100644 index dda7ed7b3c8ea4b541eaafbd0fd239eea789b40e..0000000000000000000000000000000000000000 --- a/proto/README.md +++ /dev/null @@ -1,3 +0,0 @@ -## protos in this folder are legacy v2 protos. - -## Please refer to paddle/fluid for latest version. diff --git a/proto/TrainerConfig.proto b/proto/TrainerConfig.proto deleted file mode 100644 index 9cc20b4a3ef3faa1d9ffde69daa579a620de38d8..0000000000000000000000000000000000000000 --- a/proto/TrainerConfig.proto +++ /dev/null @@ -1,160 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -syntax = "proto2"; - -import "DataConfig.proto"; -import "ModelConfig.proto"; - -package paddle; - -message OptimizationConfig { - optional int32 batch_size = 3 [ default = 1 ]; - required string algorithm = 4 [ default = "async_sgd" ]; - optional int32 num_batches_per_send_parameter = 5 [ default = 1 ]; - optional int32 num_batches_per_get_parameter = 6 [ default = 1 ]; - - required double learning_rate = 7; - optional double learning_rate_decay_a = 8 [ default = 0 ]; - optional double learning_rate_decay_b = 9 [ default = 0 ]; - optional string learning_rate_schedule = 27 [ default = "constant" ]; - // learning rate will be scaled according to learning_rate_schedule - // 1), constant: - // lr = learning_rate - // 2), poly: - // lr = learning_rate * - // pow(1 + learning_rate_decay_a * num_samples_processed, - // -learning_rate_decay_b) - // 3), exp: - // lr = learning_rate * - // pow(learning_rate_decay_a, - // num_samples_processed / learning_rate_decay_b) - // 4), discexp: - // lr = learning_rate * - // pow(learning_rate_decay_a, - // floor(num_samples_processed / learning_rate_decay_b)) - // 5), linear: - // lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, - // learning_rate_decay_b) - - // owlqn related - // L1-regularization - optional double l1weight = 10 [ default = 0.1 ]; - // L2-regularization - optional double l2weight = 11 [ default = 0 ]; - // "c1" in wolfe condition: if (newobj <= oldobj + c1 * origDirDeriv * step) - // then accept the step - optional double c1 = 12 [ default = 0.0001 ]; - // multiply the step with "backoff", when wolfe condition doesn't satisfy - optional double backoff = 13 [ default = 0.5 ]; - // how many "s"s and "y"s are kept in owlqn - optional int32 owlqn_steps = 14 [ default = 10 ]; - // accept the step if encountered "max_backoff" times of "reduce the step" - optional int32 max_backoff = 15 [ default = 5 ]; - // L2-regularization coefficient is reduced linearly from iteration 0 to - // "l2weight_zero_iter", and set to 0 after "l2weight_zero_iter" - // iterations. set "l2weight_zero_iter" to 0 to disable this strategy. - optional int32 l2weight_zero_iter = 17 [ default = 0 ]; - - // averaged sgd - // About average_window * numBatchProcessed parameter are used - // for average. To be accurate, between average_window * numBatchProcessed - // and 2 * average_window * numBatchProcessed parameters are used for - // average. - optional double average_window = 18 [ default = 0 ]; - optional int64 max_average_window = 19 [ default = 0x7fffffffffffffff ]; - - ////////////////////////// - // Options Adaptive SGD // - ////////////////////////// - - // learning method for sgd/asgd, such as "momentum", "adagrad", "adadelta", - // "rmsprop" - // default learning method("momentum") use global decayed learning rate with - // momentum. - // "adagrad", "adadelta" and "rmsprop" can set momentum too. - optional string learning_method = 23 [ default = "momentum" ]; - optional double ada_epsilon = 24 [ default = 1e-6 ]; - optional double ada_rou = 26 [ default = 0.95 ]; - - // Force to do average in cpu in order to save gpu memory usage - optional bool do_average_in_cpu = 25 [ default = false ]; - - // delta add rate in pserver, used while num_batches_per_send_parameter>1 - // will be divided by #machines automatically. - optional double delta_add_rate = 28 [ default = 1.0 ]; - - // We split a large size into smaller mini-batches, whose sizes are - // determined by mini_batch_size. It only takes effect when there is - // an ExternalMachine. - optional int32 mini_batch_size = 29 [ default = 128 ]; - - // automatically set if any one of parameters set sparse remote update flag - optional bool use_sparse_remote_updater = 30 [ default = false ]; - - // how to update center parameter and feedback to local parameter, - // when use local sgd update in cluster training. - // A option is elastic_average, proposed by the paper: Deep learning with - // elastic averaging SGD. - // If use elastic_average method, every trainer node should sample from whole - // data sets. - optional string center_parameter_update_method = 31 [ default = "average" ]; - - // shrink sparse parameter value - // only works if parameter is remote sparse update and has L1 decay rate - optional double shrink_parameter_value = 32 [ default = 0 ]; - - //////////////////////////// - // Options Adam Optimizer // - //////////////////////////// - optional double adam_beta1 = 33 [ default = 0.9 ]; - optional double adam_beta2 = 34 [ default = 0.999 ]; - optional double adam_epsilon = 35 [ default = 1e-8 ]; - - // arguments for learning rate scheduler - // Format: num1:rate1,num2:rate2,...,numK:rateK - // For learning_rate_schedule="manual", num is the number of samples, - // For learning_rate_schedule="pass_manual", - // num is the number of passes (starting from 0) - optional string learning_rate_args = 36 [ default = "" ]; - - // for async sgd gradient commit control. - // when async_lagged_grad_discard_ratio * num_gradient_servers commit passed, - // current async gradient will be discard silently. - optional double async_lagged_grad_discard_ratio = 37 [ default = 1.5 ]; - - // global threshold for gradient clipping - optional double gradient_clipping_threshold = 38 [ default = 0.0 ]; -}; - -message TrainerConfig { - optional ModelConfig model_config = 1; - optional DataConfig data_config = 2; - required OptimizationConfig opt_config = 3; - optional DataConfig test_data_config = 4; - repeated string config_files = 5; - - // the directory to save/load model files for each training path - optional string save_dir = 6 [ default = "./output/model" ]; - - // Path of the initial model parameters. - // If it was set, start_pass will be ignored. - optional string init_model_path = 7; - - // Start training from this pass. - // Will load parameter from the previous pass. - optional int32 start_pass = 8 [ default = 0 ]; - - // file path to the trainer config file - optional string config_file = 9; -}