提交 1777017a 编写于 作者: T Tao Luo

remove legace go and proto code

上级 ef038743
vendor/
.glide/
proto/*.go
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(master SRC master.go)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"strings"
"time"
log "github.com/inconshreveable/log15"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
)
func main() {
port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warn, error, crit")
flag.Parse()
lvl, err := log.LvlFromString(*logLevel)
if err != nil {
panic(err)
}
log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
if *endpoints == "" {
log.Warn("-endpoints not set, fault tolerance not be enabled.")
}
var store master.Store
if *endpoints != "" {
eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP()
if err != nil {
log.Crit("get external ip error", log.Ctx{"error": err})
panic(err)
}
addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil {
log.Crit("error creating etcd client.", log.Ctx{"error": err})
panic(err)
}
} else {
store = &master.InMemStore{}
}
shutdown := func() {
log.Info("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Error("shutdown error", log.Ctx{"error": err})
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil {
log.Crit("error creating new service.", log.Ctx{"error": err})
panic(err)
}
err = rpc.Register(s)
if err != nil {
log.Crit("error registering to etcd.", log.Ctx{"error": err})
panic(err)
}
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
panic(err)
}
go func() {
err = http.Serve(l, nil)
if err != nil {
log.Crit("error serving HTTP", log.Ctx{"error": err})
panic(err)
}
}()
<-c
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(pserver SRCS pserver.go DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"time"
"github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/inconshreveable/log15"
)
func main() {
port := flag.Int("port", 8001, "port of the pserver")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warn, error, crit")
flag.Parse()
lvl, err := log.LvlFromString(*logLevel)
if err != nil {
panic(err)
}
log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
} else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
idx, err = e.Register(*port)
candy.Must(err)
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
}
}
shutdown := func() {
log.Info("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
log.Error("error shutting down", log.Ctx{"error": sErr})
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err)
err = rpc.Register(s)
candy.Must(err)
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
candy.Must(err)
go func() {
log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
<-c
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection
import (
"errors"
"net/rpc"
"sync"
log "github.com/sirupsen/logrus"
)
// TODO(helin): add TCP re-connect logic
// Conn is a connection to a parameter server
type Conn struct {
mu sync.Mutex
client *rpc.Client
waitConn chan struct{}
}
// New creates a new connection.
func New() *Conn {
c := &Conn{}
return c
}
// Close closes the connection.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil
}
return c.client.Close()
}
// Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error {
c.mu.Lock()
if c.client != nil {
err := c.client.Close()
if err != nil {
c.mu.Unlock()
return err
}
c.client = nil
}
c.mu.Unlock()
client, err := rpc.DialHTTP("tcp", addr)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
c.client = client
if c.waitConn != nil {
close(c.waitConn)
c.waitConn = nil
}
} else {
err := client.Close()
if err != nil {
log.Errorln(err)
}
return errors.New("client already set from a concurrent goroutine")
}
return nil
}
// TODO(helin): refactor Call to be able to perform given retry
// policy.
// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
// being established.
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
c.mu.Lock()
client := c.client
var waitCh chan struct{}
if client == nil {
if c.waitConn != nil {
waitCh = c.waitConn
} else {
waitCh = make(chan struct{})
c.waitConn = waitCh
}
}
c.mu.Unlock()
if waitCh != nil {
// wait until new connection being established
<-waitCh
return c.Call(serviceMethod, args, reply)
}
return client.Call(serviceMethod, args, reply)
}
hash: 107c058cf5c9163a75d40eef2273a793c36112683c25d72aa8288827fdde3a19
updated: 2017-10-30T03:46:19.137696069Z
imports:
- name: github.com/alecthomas/gometalinter
version: bae2f1293d092fd8167939d5108d1b025eaef9de
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
version: f1d7dd87da3e8feab4aaf675b8e29c6a5ed5f58b
subpackages:
- alarm
- auth
- auth/authpb
- client
- clientv3
- clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/go-stack/stack
version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages:
- jsonpb
- proto
- name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/inconshreveable/log15
version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/mattn/go-colorable
version: 5411d3eea5978e6cdc258b30de592b60df6aba96
- name: github.com/mattn/go-isatty
version: 57fdcb988a5c543893cc61bce354a6e24ab70022
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio
version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus
version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 9419663f5a44be8b34ca85f08abc5fe1be11f8a3
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- ssh/terminal
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
- context
- http2
- http2/hpack
- idna
- internal/timeseries
- lex/httplex
- trace
- name: golang.org/x/sys
version: e48874b42435b4347fc52bdee0424a52abc974d7
repo: https://github.com/golang/sys.git
vcs: git
subpackages:
- unix
- windows
- name: golang.org/x/text
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- secure/bidirule
- transform
- unicode/bidi
- unicode/norm
- name: google.golang.org/grpc
version: 8050b9cbc271307e5a716a9d782803d09b0d6f2d
subpackages:
- codes
- credentials
- grpclog
- internal
- keepalive
- metadata
- naming
- peer
- stats
- tap
- transport
- name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
package: github.com/PaddlePaddle/Paddle/go
import:
- package: github.com/PaddlePaddle/recordio
- package: github.com/coreos/etcd
version: ^3.2.1
subpackages:
- clientv3
- clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag
version: ^1.7.4-pre
- package: github.com/sirupsen/logrus
version: ^1.0.0
- package: github.com/topicai/candy
- package: golang.org/x/crypto
repo: https://github.com/golang/crypto.git
vcs: git
- package: golang.org/x/sys
repo: https://github.com/golang/sys.git
vcs: git
- package: golang.org/x/text
repo: https://github.com/golang/text.git
vcs: git
- package: github.com/satori/go.uuid
version: v1.1.0
- package: github.com/alecthomas/gometalinter
version: v1.2.1
- package: github.com/inconshreveable/log15
version: v2.13
- package: github.com/go-stack/stack
version: v1.6.0
- package: github.com/golang/protobuf
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(master_test)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
*/
import "C"
import (
"strings"
"sync"
"time"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
log "github.com/inconshreveable/log15"
)
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}
func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
endpoints := strings.Split(p, ",")
c, err := master.NewClient(
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
master.WithBuffer(bufSize),
)
if err != nil {
panic(err)
}
return add(c)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
if err != nil {
panic(err)
}
return add(c)
}
//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Error("error set dataset",
log.Ctx{"error": err, "paths": paths})
return C.PADDLE_MASTER_ERROR
}
return C.PADDLE_MASTER_OK
}
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}
if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nil)
return 0
}
size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Error("error request save model", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}
func main() {}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/inconshreveable/log15"
)
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan record
bufSize int
}
type record struct {
r []byte
err error
}
// WithBuffer sets the client to buffer the training record.
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.bufSize = bufSize
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
var cli *clientv3.Client
f := func() error {
var err error
cli, err = clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
return err
}
for {
err := f()
if err != nil {
log.Warn("create etcd client error", log.Ctx{"error": err})
} else {
break
}
time.Sleep(time.Second)
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{}
c.conn = connection.New()
for _, opt := range opts {
err := opt(c)
if err != nil {
return nil, err
}
}
c.ch = make(chan record, c.bufSize)
return c, nil
}
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords(passID int) {
i := 0
for {
t, err := c.getTask(passID)
if err != nil {
if err.Error() == ErrPassBefore.Error() ||
err.Error() == ErrNoMoreAvailable.Error() ||
err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if i%60 == 0 {
log.Debug("getTask of passID error.",
log.Ctx{"error": err, "passID": passID})
i = 0
}
// if err.Error() == ErrPassAfter.Error()
// wait util last pass finishes
// if other error such as network error
// wait to reconnect or task time out
time.Sleep(time.Second * 3)
i += 3
continue
}
for _, chunk := range t.Chunks {
f, e := os.Open(chunk.Path)
if e != nil {
log.Error("error open chunk", log.Ctx{"error": e})
continue
}
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- record{s.Record(), nil}
}
if s.Err() != nil {
c.ch <- record{nil, s.Err()}
log.Error(
"error scan chunk",
log.Ctx{"error": err, "path": chunk.Path},
)
}
err = f.Close()
if err != nil {
log.Error("error close record file", log.Ctx{"error": err})
}
}
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
err = c.taskFinished(t.Meta.ID)
if err != nil {
log.Error("task finish callback error.", log.Ctx{"error": err})
}
}
}
func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := ""
for curMaster := range addrCh {
// connect to the new address once address changed.
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Error("close old master addr error", log.Ctx{"error": err})
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Error("connect to new master addr error", log.Ctx{"error": err})
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curMaster = lastMaster
}
}
}
lastMaster = curMaster
}
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error {
err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
}
// getTask gets a new task from the master server.
func (c *Client) getTask(passID int) (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", passID, &t)
return t, err
}
// TaskFinished tells the master server a task is finished.
func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFailed", meta, nil)
}
// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
r := <-c.ch
return r.r, r.err
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}
server := rpc.NewServer()
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)
f, err := os.Create(path)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
_, err = w.Write(nil)
if err != nil {
panic(err)
}
// call Close to force RecordIO writing a chunk.
err = w.Close()
if err != nil {
panic(err)
}
}
err = f.Close()
if err != nil {
panic(err)
}
// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
c.conn = connection.New()
addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, cErr := c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}
// getting task before task finishes should return error
_, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
tasks = tasks[1:]
_, cErr = c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
}
for _, task := range tasks {
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil {
t.Fatal(cErr)
}
}
}
for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) {
const (
path = "/tmp/master_client_TestFull"
total = 50
)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil {
panic(err)
}
server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
f, err := os.Create(path)
if err != nil {
panic(err)
}
w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ {
_, err = w.Write([]byte{byte(i)})
if err != nil {
panic(err)
}
}
err = w.Close()
if err != nil {
panic(err)
}
err = f.Close()
if err != nil {
panic(err)
}
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
}
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
received := make(map[byte]bool)
taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
}
}()
}
wg.Wait()
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"context"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/inconshreveable/log15"
)
const (
// DefaultLockPath is the default etcd master lock path.
DefaultLockPath = "/master/lock"
// DefaultStatePath is the default etcd key for master state.
DefaultStatePath = "/master/state"
// DefaultAddrPath is the default etcd key for master address.
DefaultAddrPath = "/master/addr"
)
// EtcdClient is the etcd client that the master uses for fault
// tolerance and service registry.
type EtcdClient struct {
lockPath string
statePath string
client *clientv3.Client
lock *concurrency.Mutex
sess *concurrency.Session
}
// NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: dialTimeout,
})
if err != nil {
return nil, err
}
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
if err != nil {
return nil, err
}
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
err = lock.Lock(context.TODO())
if err != nil {
return nil, err
}
log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Crit("No longer owns the master lock. Exiting.")
panic("No longer owns the master lock. Exiting.")
}
e := &EtcdClient{
lockPath: lockPath,
statePath: statePath,
client: cli,
lock: lock,
sess: sess,
}
return e, nil
}
// Save saves the state into the etcd.
func (e *EtcdClient) Save(state []byte) error {
ctx := context.TODO()
put := clientv3.OpPut(e.statePath, string(state))
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
if err != nil {
return err
}
if !resp.Succeeded {
log.Error("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx)
cancel()
if err != nil {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// management system to kill the master server
// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
// state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry.
log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
panic("Could not acquire the lock at %s: %v. Exiting.")
}
log.Info("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state)
}
return nil
}
// Load loads the state from etcd.
func (e *EtcdClient) Load() ([]byte, error) {
ctx := context.TODO()
get := clientv3.OpGet(e.statePath)
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Error("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background())
if err != nil {
return nil, err
}
return e.Load()
}
kvs := resp.Responses[0].GetResponseRange().Kvs
if len(kvs) == 0 {
// No state exists
return nil, nil
}
state := kvs[0].Value
return state, nil
}
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Error("shutdown error", log.Ctx{"error": newErr})
}
}
return err
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// watchKey watches the specify key and send to valChan if there is some event.
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
valChan <- string(ev.Kv.Value)
}
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import "sync"
// InMemStore is an in memory implementation of Store interface.
//
// It does not tolerate the fault that causes the program to crash.
type InMemStore struct {
mu sync.Mutex
buf []byte
}
// Save saves the state into the in-memory store.
func (m *InMemStore) Save(state []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
m.buf = state
return nil
}
// Load loads the state from the in-memory store.
func (m *InMemStore) Load() ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.buf, nil
}
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"bytes"
"compress/gzip"
"encoding/gob"
"errors"
"math/rand"
"os"
"path/filepath"
"sync"
"time"
log "github.com/inconshreveable/log15"
"github.com/PaddlePaddle/recordio"
)
const (
dialTimeout = 5 * time.Second
)
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state.
type Store interface {
Save([]byte) error
Load() ([]byte, error)
Shutdown() error
}
// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
Path string
Index recordio.Index // chunk index
}
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
Meta TaskMeta
Chunks []Chunk
}
type taskEntry struct {
Task Task
// A task fails if it's timeout or trainer reports it exits unnormally.
NumFailure int
}
type masterState struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []taskEntry
CurPass int
}
// Service is the master server service.
type Service struct {
chunksPerTask int
timeoutDur time.Duration
failureMax int
store Store
ready chan struct{}
initDone bool
mu sync.Mutex
// State to be persisted to snapshot.
state masterState
// The trainer that is currently saving model. This state is
// transient, does not need to be persisted to snapshot.
savingTrainer string
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
// generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 {
chunksPerTask = 1
}
var result []taskEntry
var cur taskEntry
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id
counter++
id = timestamp + randStart + counter
result = append(result, cur)
cur.Task.Chunks = nil
}
cur.Task.Chunks = append(cur.Task.Chunks, c)
}
if len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id
result = append(result, cur)
}
return result
}
// NewService creates a new service.
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
s := &Service{}
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur
s.failureMax = failureMax
s.state = masterState{}
s.state.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{})
s.store = store
recovered, err := s.recover()
if err != nil {
return nil, err
}
if recovered {
// Recovered. Now the state is already initialized,
// and the master is ready.
s.initDone = true
close(s.ready)
log.Info("Master recovered from saved state.")
}
return s, nil
}
// recover recovers service state from etcd.
func (s *Service) recover() (bool, error) {
state, err := s.store.Load()
if err != nil {
return false, err
}
if state == nil {
log.Info("No state exists, not recovered.")
return false, nil
}
log.Info("Loaded snapshot.", log.Ctx{"size": len(state)})
gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil {
return false, err
}
dec := gob.NewDecoder(gr)
var tqs masterState
err = dec.Decode(&tqs)
if err != nil {
return false, err
}
err = gr.Close()
if err != nil {
// Only close failed, recover actually succeed, so
// just log error.
log.Error("error close recover file.", log.Ctx{"error": err})
}
s.state = tqs
log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx())
for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
}
return true, nil
}
// snapshot *must* be called with s.mu being held.
func (s *Service) snapshot() error {
// TODO(helin): etcd request has a size limit, so the snapshot
// size is limited by the max request size. We should either
// divide the snapshot into smaller chunks and save under
// different keys, or configure the request size to be big
// enough:
// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
enc := gob.NewEncoder(gw)
err := enc.Encode(s.state)
if err != nil {
return err
}
err = gw.Close()
if err != nil {
return err
}
state := buf.Bytes()
log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)})
return s.store.Save(state)
}
func readChunks(globPaths []string) ([]Chunk, error) {
var chunks []Chunk
var paths []string
for _, s := range globPaths {
match, err := filepath.Glob(s)
if err != nil {
return nil, err
}
paths = append(paths, match...)
}
if len(paths) == 0 {
return nil, errors.New("no valid dataset specified")
}
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
return nil, err
}
index, err := recordio.LoadIndex(f)
if err != nil {
return nil, err
}
err = f.Close()
if err != nil {
return nil, err
}
count := index.NumChunks()
log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count})
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}
return chunks, nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.initDone {
// Already initialized. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
chunks, err := readChunks(globPaths)
if err != nil {
return err
}
s.state.Todo = partition(chunks, s.chunksPerTask)
err = s.snapshot()
if err != nil {
log.Error("snapshot error", log.Ctx{"error": err})
return err
}
close(s.ready)
s.initDone = true
return nil
}
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check or failed status report.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Error("snapshot error", log.Ctx{"error": err})
}
}()
delete(s.state.Pending, t.Task.Meta.ID)
t.NumFailure++
if t.NumFailure > s.failureMax {
log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Failed = append(s.state.Failed, t)
return
}
log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Todo = append(s.state.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.state.Pending[taskID]
if !ok {
return
}
s.processFailedTask(t, epoch)
}
}
// must be called with lock held.
func (s *Service) logCtx() log.Ctx {
return log.Ctx{
"todoLen": len(s.state.Todo),
"pendingLen": len(s.state.Pending),
"doneLen": len(s.state.Done),
"failedLen": len(s.state.Failed),
"curPass": s.state.CurPass,
}
}
// GetTask gets a new task from the service.
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if passID < s.state.CurPass {
return ErrPassBefore
}
if passID > s.state.CurPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.state.Todo) == 0 {
if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
log.Warn("All tasks failed, may start next pass", s.logCtx())
return ErrAllTaskFailed
}
log.Warn("No more available task.", s.logCtx())
return ErrNoMoreAvailable
}
t := s.state.Todo[0]
t.Task.Meta.Epoch++
s.state.Todo = s.state.Todo[1:]
s.state.Pending[t.Task.Meta.ID] = t
err := s.snapshot()
if err != nil {
return err
}
*task = t.Task
ctx := s.logCtx()
ctx["task meta"] = t.Task.Meta
log.Info("Task dispatched.", ctx)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil
}
// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.state.Pending[taskID]
if !ok {
ctx := s.logCtx()
ctx["task id"] = taskID
log.Warn("Pending task not found.", ctx)
return nil
}
// task finished, reset timeout
t.NumFailure = 0
s.state.Done = append(s.state.Done, t)
delete(s.state.Pending, taskID)
ctx := s.logCtx()
ctx["task id"] = taskID
log.Info("Task finished.", ctx)
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished
s.state.CurPass++
s.state.Todo = append(s.state.Done, s.state.Failed...)
s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{}
ctx := s.logCtx()
ctx["new pass"] = s.state.CurPass
log.Warn("all task finished, add new pass data.", ctx)
}
err := s.snapshot()
if err != nil {
log.Error("snapshot error", log.Ctx{"error": err})
}
return err
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.state.Pending[meta.ID]
if !ok {
log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta})
return nil
}
s.processFailedTask(t, meta.Epoch)
return nil
}
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import "testing"
func TestPartitionCount(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 5)
if len(ts) != 20 {
t.Error(len(ts))
}
cs = make([]Chunk, 101)
ts = partition(cs, 5)
if len(ts) != 21 {
t.Error(len(ts))
}
ts = partition(cs, 1)
if len(ts) != 101 {
t.Error(len(ts))
}
ts = partition(cs, 0)
if len(ts) != 101 {
t.Error(len(ts))
}
}
func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
// test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i)
}
}
}
package master_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}
# Ignore everything in this directory
*
# Except this file
!.gitignore
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m)
# Copy library to the required place.
# See: go/pserver/optimizer.go:
# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
add_custom_command(TARGET paddle_go_optimizer POST_BUILD
COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}"
)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING)
# FIXME: this test requires pserver which is not managed by the test
# we need some kind of e2e testing machanism.
# add_subdirectory(test)
endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
/*
#include <string.h>
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
typedef struct {
char* name;
paddle_element_type element_type;
unsigned char* content;
int content_len;
} paddle_parameter, paddle_gradient;
typedef int paddle_pserver_client;
#define PSERVER_ERROR -1
#define PSERVER_OK 0
*/
import "C"
import (
"strings"
"sync"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/inconshreveable/log15"
)
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client
func add(c *client.Client) C.paddle_pserver_client {
mu.Lock()
defer mu.Unlock()
cli := curHandle
curHandle++
handleMap[cli] = c
return cli
}
func get(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nil {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
type selector bool
func (s selector) Select() (bool, error) {
return bool(s), nil
}
func (s selector) Done() error {
return nil
}
type lister []client.Server
func (l lister) List() []client.Server {
return l
}
//export paddle_new_pserver_client
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]client.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := client.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c)
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
addr := C.GoString(etcdEndpoints)
etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
return add(c)
}
//export paddle_pserver_client_release
func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client)
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1
}
return 0
}
//export paddle_init_param
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int {
et := pserver.ElementType(param.element_type)
name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(paramConfig, int(configLen)),
}
c := get(client)
err := c.InitParam(pc)
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Warn(
"parameter already initialized, treat paddle_init_param as successful.",
log.Ctx{"parameter": name},
)
return C.PSERVER_OK
}
log.Error("error init param", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
//export paddle_finish_init_params
func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
err := c.FinishInitParams()
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK
}
log.Error("error finish init params", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
//export paddle_send_grads
func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient
for i := 0; i < int(total); i++ {
grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
}
c := get(client)
err := c.SendGrads(gs)
if err != nil {
log.Error("error send grads", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
//export paddle_get_params
func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
var ns []string
for i := 0; i < int(total); i++ {
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
ns = append(ns, C.GoString(param.name))
}
c := get(client)
ps, err := c.GetParams(ns)
if err != nil {
log.Error("error get params", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
if len(ps) != len(ns) {
pn := make([]string, len(ps))
for i, p := range ps {
pn[i] = p.Name
}
log.Error(
"pserver returned wrong number of parameters.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR
}
for i := range ps {
if ns[i] != ps[i].Name {
pn := make([]string, len(ps))
for i, p := range ps {
pn[i] = p.Name
}
log.Error(
"pserver returned wrong parameters, or not in requested order.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR
}
}
for i := 0; i < int(total); i++ {
p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nil {
log.Error("must pre-allocate parameter.")
return C.PSERVER_ERROR
}
if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) {
log.Error(
"the pre-allocated content len does not match parameter content len.",
log.Ctx{
"Pre-allocated len": param.content_len,
"Returned len": len(p.Content),
},
)
return C.PSERVER_ERROR
}
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType)
}
return C.PSERVER_OK
}
func main() {} // Required but ignored
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h"
// TODO(helin): Fix: gtest using cmake is not working, using this
// hacky way for now.
#define fail() \
fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
exit(-1);
void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3};
paddle_gradient grad1 = {
"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
paddle_gradient grad2 = {
"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
paddle_gradient *grads[2] = {&grad1, &grad2};
if (paddle_send_grads(c, grads, 2)) {
fail();
}
}
void getParams(paddle_pserver_client c) {
paddle_parameter param_a;
paddle_parameter param_b;
char name_a[] = "param_a";
char name_b[] = "param_b";
// Must pre-allocate the prameter content before calling paddle_get_params.
unsigned char content_a[2000] = {};
unsigned char content_b[3000] = {};
param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_a.name = name_a;
param_a.content = content_a;
param_a.content_len = 2000;
param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_b.name = name_b;
param_b.content = content_b;
param_b.content_len = 3000;
paddle_parameter *params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) {
fail();
}
}
int main() {
char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char *config_proto;
size_t config_proto_len = 0;
ssize_t nread;
FILE *fp = fopen("testdata/optimizer.pb", "r");
if (!fp) {
fail();
}
while ((nread = getline(&config_proto, &config_proto_len, fp)) != -1) {
printf("%s", config_proto);
}
fclose(fp);
retry:
if (paddle_begin_init_params(c)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
unsigned char content_a[2000] = {1};
unsigned char content_b[3000] = {0};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a;
param.content = content_a;
param.content_len = 2000;
int error =
paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) {
goto retry;
}
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_b;
param.content = content_b;
param.content_len = 3000;
error = paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) {
goto retry;
}
error = paddle_finish_init_params(c);
if (error != 0) {
goto retry;
}
}
int i;
for (i = 0; i < 100; i++) {
sendGrads(c);
getParams(c);
}
return 0;
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2 as paddle
import gzip
def softmax_regression(img):
predict = paddle.layer.fc(input=img,
size=10,
act=paddle.activation.Softmax())
return predict
def multilayer_perceptron(img):
# The first fully-connected layer
hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu())
# The second fully-connected layer and the according activation function
hidden2 = paddle.layer.fc(input=hidden1,
size=64,
act=paddle.activation.Relu())
# The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits
predict = paddle.layer.fc(input=hidden2,
size=10,
act=paddle.activation.Softmax())
return predict
def convolutional_neural_network(img):
# first conv layer
conv_pool_1 = paddle.networks.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
num_channel=1,
pool_size=2,
pool_stride=2,
act=paddle.activation.Tanh())
# second conv layer
conv_pool_2 = paddle.networks.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
num_channel=20,
pool_size=2,
pool_stride=2,
act=paddle.activation.Tanh())
# The first fully-connected layer
fc1 = paddle.layer.fc(input=conv_pool_2,
size=128,
act=paddle.activation.Tanh())
# The softmax layer, note that the hidden size should be 10,
# which is the number of unique digits
predict = paddle.layer.fc(input=fc1,
size=10,
act=paddle.activation.Softmax())
return predict
def main():
paddle.init(use_gpu=False, trainer_count=1)
# define network topology
images = paddle.layer.data(
name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
# Here we can build the prediction network in different ways. Please
# choose one by uncomment corresponding line.
predict = softmax_regression(images)
#predict = multilayer_perceptron(images)
#predict = convolutional_neural_network(images)
cost = paddle.layer.classification_cost(input=predict, label=label)
parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1 / 128.0,
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128))
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec="localhost:3000")
lists = []
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1000 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128))
print "Test with Pass %d, Cost %f, %s\n" % (
event.pass_id, result.cost, result.metrics)
lists.append((event.pass_id, result.cost,
result.metrics['classification_error_evaluator']))
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
event_handler=event_handler,
num_passes=100)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
test_creator = paddle.dataset.mnist.test()
test_data = []
for item in test_creator():
test_data.append((item[0], ))
if len(test_data) == 100:
break
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs = paddle.infer(
output_layer=predict, parameters=parameters, input=test_data)
print probs.shape
if __name__ == '__main__':
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoints = "http://" + etcd_ip + ":2379"
print "etcd endpoints: ", etcd_endpoints
def main():
# init
paddle.init(use_gpu=False, trainer_count=1)
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec=etcd_endpoints,
use_etcd=True)
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
if isinstance(event, paddle.event.EndPass):
if (event.pass_id + 1) % 10 == 0:
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding={'x': 0,
'y': 1})
print "Test %d, %.2f" % (event.pass_id, result.cost)
# training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2),
feeding={'x': 0,
'y': 1},
event_handler=event_handler,
num_passes=30)
if __name__ == '__main__':
main()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"errors"
"hash/fnv"
"sort"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/inconshreveable/log15"
)
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface {
// Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
}
// Server is the identification of a parameter Server.
type Server struct {
Index int
Addr string
}
// Lister lists currently available parameter servers.
type Lister interface {
List() []Server
}
// Client is the client to parameter servers.
type Client struct {
sel Selector
pservers []*connection.Conn
}
// NewClient creates a new client.
func NewClient(l Lister, pserverNum int, sel Selector) *Client {
c := &Client{sel: sel}
c.pservers = make([]*connection.Conn, pserverNum)
for i := 0; i < pserverNum; i++ {
c.pservers[i] = connection.New()
}
go c.monitorPservers(l, pserverNum)
return c
}
// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) {
lastServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second)
monitor := func() {
curServers := make([]Server, pserverNum)
list := l.List()
for _, l := range list {
curServers[l.Index] = l
}
for i := range lastServers {
if lastServers[i].Addr == curServers[i].Addr {
continue
}
if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil {
log.Error("error closing connection to pserver", log.Ctx{"error": err})
}
continue
}
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
log.Error("error connecting to pserver", log.Ctx{"error": err})
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = lastServers[i].Addr
}
}
lastServers = curServers
}
monitor()
for range ticker.C {
monitor()
}
}
// BeginInitParams begins to initialize parameters on parameter
// servers.
//
// BeginInitParams will be called from multiple trainers, only one
// trainer will be selected to initialize the parameters on parameter
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select()
}
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error {
for _, p := range c.pservers {
err := p.Call("Service.FinishInitParams", 0, nil)
if err != nil {
return err
}
}
return c.sel.Done()
}
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []pserver.Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g pserver.Gradient) {
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err
}(g)
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(grads) {
break
}
}
return nil
}
type result struct {
idx int
param pserver.Parameter
err error
}
type results []result
func (r results) Len() int {
return len(r)
}
func (r results) Less(i int, j int) bool {
return r[i].idx < r[j].idx
}
func (r results) Swap(i int, j int) {
r[i], r[j] = r[j], r[i]
}
// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
rCh := make(chan result, len(names))
for idx, name := range names {
go func(name string, idx int) {
var parameter pserver.Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx)
}
var rs results
recv := 0
for r := range rCh {
if r.err != nil {
return nil, r.err
}
rs = append(rs, r)
recv++
if recv == len(names) {
break
}
}
sort.Sort(rs)
ps := make([]pserver.Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].param
}
return ps, nil
}
func strHash(s string) uint32 {
h := fnv.New32a()
_, _ = h.Write([]byte(s))
return h.Sum32()
}
// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func (c *Client) partition(key string) int {
return int(strHash(key) % uint32(len(c.pservers)))
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client_test
import (
"context"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
log "github.com/inconshreveable/log15"
)
const (
numPserver = 10
etcdEndpoints = "127.0.0.1:2379"
timeout = 2 * time.Second
)
var pserverClientPorts [numPserver]int
// this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
var ports [numPserver]int
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
ports[i] = p
go func(l net.Listener) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
panic(err)
}
server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
return ports
}
func initNativeClient() {
pserverClientPorts = initClient()
}
func initEtcdClient() {
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{etcdEndpoints},
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
log.Error("error init etcd client", log.Ctx{"error": err})
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err = client.Delete(ctx, pserver.PsDesired)
if err != nil {
panic(err)
}
_, err = client.Delete(ctx, pserver.PsPath)
if err != nil {
panic(err)
}
_, err = client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
if err != nil {
panic(err)
}
ports := initClient()
for i := 0; i < numPserver; i++ {
_, err = client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
if err != nil {
panic(err)
}
}
cancel()
err = client.Close()
if err != nil {
panic(err)
}
}
type selector bool
func (s selector) Select() (bool, error) {
return bool(s), nil
}
func (s selector) Done() error {
return nil
}
type lister []client.Server
func (l lister) List() []client.Server {
return l
}
func testClient(t *testing.T, c *client.Client) {
selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 1000
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ {
wg.Add(1)
go func(i int) {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
if err != nil {
t.Fatal(err)
}
wg.Done()
}(i)
}
wg.Wait()
err = c.FinishInitParams()
if err != nil {
t.Fatal(err)
}
var grads []pserver.Gradient
for i := 0; i < numParameter; i++ {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
g.Content = make([]byte, (i+1)*100)
grads = append(grads, g)
}
const paramPerGroup = 10
const numGroups = numParameter / paramPerGroup
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err := c.SendGrads(gs)
if err != nil {
t.Fatal(err)
}
wg.Done()
}(gs)
}
names := make([]string, numParameter)
for i := 0; i < numParameter; i++ {
names[i] = "p_" + strconv.Itoa(i)
}
for i := 0; i < numGroups; i++ {
var ns []string
if i == numGroups-1 {
ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(ns []string) {
params, err := c.GetParams(ns)
if err != nil {
t.Fatal(err)
}
if len(ns) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
}
for i := range params {
if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
}
}
wg.Done()
}(ns)
}
wg.Wait()
}
func TestNativeClient(t *testing.T) {
initNativeClient()
servers := make([]client.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
testClient(t, c1)
}
// EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient(t *testing.T) {
initEtcdClient()
etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
testClient(t, c2)
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/inconshreveable/log15"
)
const (
defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
)
// Etcd is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers.
type Etcd struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
lock *concurrency.Mutex
}
// Desired read ps desired number from etcd.
func (e *Etcd) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Error(
"Get ps dresire number failed! reconnecting...",
log.Ctx{"error": err},
)
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Info("Waiting for ps desired registered ...")
time.Sleep(e.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Error("atoi failed", log.Ctx{"error": err})
time.Sleep(e.timeout)
continue
}
log.Debug("Got psDesired", log.Ctx{"psDesired": psDesired})
break
}
return psDesired
}
// List return the pserver list read from etcd.
func (e *Etcd) List() []Server {
psDesired := e.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debug("looking for pserver", log.Ctx{"ps key": psKey})
resp, err := e.client.Get(ctx, psKey)
cancel()
if err != nil {
log.Info(
"Get psKey error",
log.Ctx{"ps key": psKey, "error": err},
)
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Info("Waiting for ps addr registered ...")
time.Sleep(e.timeout)
continue
}
psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Info(
"Value under psKey is empty",
log.Ctx{"psKey": psKey},
)
time.Sleep(e.timeout)
continue
}
log.Debug(
"got psAddr given psKey",
log.Ctx{"psAddr": psAddr, "psKey": psKey},
)
servers[i].Index = i
servers[i].Addr = psAddr
}
break
}
return servers
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
for {
cli, err = clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: defaultEtcdTimeout,
})
if err != nil {
log.Error("Init etcd connection failed", log.Ctx{"error": err})
time.Sleep(defaultEtcdTimeout)
continue
}
break
}
log.Info("Connected to etcd endpoint", log.Ctx{"endpoint": endpoints})
client := &Etcd{
client: cli,
timeout: defaultEtcdTimeout,
endpoints: ep,
}
return client
}
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Info("Trying to acquire lock", log.Ctx{"lock path": initLockPath})
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Info("Successfully acquired lock", log.Ctx{"lock path": initLockPath})
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Info("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Info("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Error("error unlocking", log.Ctx{"error": err})
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Error("error unlocking", log.Ctx{"error": err})
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Error("error closing etcd client", log.Ctx{"error": cErr})
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/inconshreveable/log15"
)
const (
// PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired"
// PsPath is the base dir for pserver to store their addr
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
retryTimeout = 5 * time.Second
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
numPservers int
endpoints string
client *clientv3.Client
sess *concurrency.Session
dialTimeout time.Duration
ttlSec int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
}
// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
return &EtcdClient{
dialTimeout: dialtimeout,
ttlSec: ttlSec,
numPservers: numPservers,
endpoints: endpoints,
}
}
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register(port int) (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return 0, err
}
// initialize connection to etcd.
ep := strings.Split(e.endpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: e.dialTimeout,
})
if err != nil {
log.Error("connect to etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout)
continue
}
e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
log.Error("create etcd session error", log.Ctx{"error": err})
time.Sleep(retryTimeout)
continue
}
e.sess = sess
log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints})
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPservers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers})
time.Sleep(retryTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.client.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err})
time.Sleep(retryTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Error(
"psDesired atoi error",
log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)},
)
time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
var pserverIdx int
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel()
if err != nil {
log.Warn("register pserver on etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout)
continue
}
break
}
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i)
ps := c.Get(psKey)
log.Debug(
"register pserver got value",
log.Ctx{"value": ps, "key": psKey},
)
if ps == "" {
// find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr})
idx = i
registered = true
break
}
}
if registered {
return nil
}
return errors.New("not registered, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil {
return 0, err
}
return idx, nil
}
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.client.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
}
v := kvs[0].Value
return v, nil
}
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
var err error
if withLease {
_, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
} else {
_, err = e.client.Put(ctx, key, string(value))
}
cancel()
return err
}
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
var err error
if e.sess != nil {
err = e.sess.Close()
}
if e.client != nil {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
log.Error("shutdown error", log.Ctx{"error": newErr})
} else {
err = newErr
}
}
}
return err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
// #cgo CFLAGS: -I ../../
// #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/legacy/optimizer/optimizer.h"
// #include <stdlib.h>
// #include <string.h>
import "C"
import (
"fmt"
"unsafe"
log "github.com/inconshreveable/log15"
)
type optimizer struct {
opt *C.struct_paddle_optimizer
elementType ElementType
contentLen int
config []byte
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nil {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
o.contentLen = len(paramWithConfigs.Param.Content)
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content))
log.Info("New Optimizer Created with config", log.Ctx{
"ElementType": p.ElementType,
"ParamSize": paramBufferSize,
"ConfigSize": len(c),
"StateSize": len(s),
})
var cbuffer unsafe.Pointer
cbuffer = C.malloc(paramBufferSize)
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize)
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}
var cptr (*C.uchar)
if len(c) > 0 {
cptr = (*C.uchar)(&c[0])
} else {
log.Error("empty config", "param name", paramWithConfigs.Param.Name)
}
o.config = c
o.opt = C.paddle_create_optimizer(
cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
C.int(paramBufferSize),
(*C.char)(cstate),
C.int(len(s)),
)
return o
}
func (o *optimizer) GetWeights() []byte {
var buffer unsafe.Pointer
// we do not own the buffer, no need to free later.
bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer)
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
}
func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
// we owns the state buffer, need to free later.
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
buf := cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
cpy := make([]byte, len(buf))
copy(cpy, buf)
C.free(unsafe.Pointer(cbuffer))
return cpy
}
func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
}
if o.contentLen != len(g.Content) {
return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content))
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r)
}
return nil
}
func (o *optimizer) Cleanup() {
if unsafe.Pointer(o.opt) != nil {
C.paddle_release_optimizer(o.opt)
o.opt = (*C.struct_paddle_optimizer)(nil)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
import (
"encoding/binary"
"io/ioutil"
"math"
"testing"
"github.com/stretchr/testify/assert"
)
func TestOptimizerCreateRelease(t *testing.T) {
p := Parameter{
Name: "a",
ElementType: Int32,
}
p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param, nil)
o.Cleanup()
}
func float32Bytes(float float32) []byte {
bits := math.Float32bits(float)
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, bits)
return bytes
}
func TestOptimizerState(t *testing.T) {
p := Parameter{
Name: "a",
ElementType: Int32,
}
weights := float32Bytes(100)
p.Content = weights
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param, nil)
s := o.GetStates()
// clear param content and check if the state is restored.
param.Param.Content = float32Bytes(300)
o1 := newOptimizer(param, s)
s1 := o1.GetStates()
assert.Equal(t, s, s1)
assert.Equal(t, weights, o.GetWeights())
assert.Equal(t, weights, o1.GetWeights())
o.Cleanup()
o1.Cleanup()
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
import (
"bufio"
"bytes"
"encoding/binary"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/golang/protobuf/proto"
uuid "github.com/satori/go.uuid"
pb "github.com/PaddlePaddle/Paddle/go/proto"
log "github.com/inconshreveable/log15"
)
// ElementType is the type of elements of a Parameter.
type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
// RPC error message.
const (
AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
WrongChecksum = "checkpoint file checksum validation failed"
)
// Supported element types.
const (
Int32 ElementType = iota
UInt32
Int64
UInt64
Float32
Float64
)
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
ElementType ElementType
Content []byte
}
func float32ToString(b []byte) string {
f := make([]float32, len(b)/4)
buf := bytes.NewReader(b)
err := binary.Read(buf, binary.LittleEndian, &f)
if err != nil {
return ""
}
return fmt.Sprintf("%v", f)
}
func float32ByteToString(c []byte) string {
var a []byte
var b []byte
if len(c) <= 80 {
a = c
} else {
a = c[0:40]
b = c[len(c)-40:]
}
var s string
s = float32ToString(a)
if b == nil {
return s
}
s = strings.Replace(s, "]", "", -1) + "..." + strings.Replace(float32ToString(b), "[", "", -1)
return s
}
func (p Parameter) String() string {
if p.ElementType != Float32 {
return fmt.Sprintf("name:%v ElementType:%v",
p.Name, p.ElementType)
}
return float32ByteToString(p.Content)
}
// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
Param Parameter
Config []byte // parameter configuration in Proto Buffer format
}
// checkpointMeta saves checkpoint metadata
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file.
type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter.
type Gradient Parameter
// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int
checkpointInterval time.Duration
checkpointPath string
client KVStore
mu sync.Mutex
optMap map[string]*optimizer
}
// parameterCheckpoint saves parameter checkpoint.
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
type KVStore interface {
GetKey(key string, timeout time.Duration) ([]byte, error)
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}
func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
}
if len(v) == 0 {
err = ErrCheckpointNotFound
return
}
if err = json.Unmarshal(v, &meta); err != nil {
return
}
return
}
// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")
cpMeta, err := loadMeta(e, idx)
if err != nil {
return nil, err
}
content, err := ioutil.ReadFile(cpMeta.Path)
if err != nil {
return nil, err
}
crc32 := crc32.ChecksumIEEE(content)
if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}
dec := gob.NewDecoder(bytes.NewReader(content))
var cp Checkpoint
if err = dec.Decode(&cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
checkpointPath: path,
client: client,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := ParameterWithConfig{
Param: item.Param,
Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
close(s.initialized)
}
return s, nil
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select {
case <-s.initialized:
log.Warn("init param called but parameters already initialized.")
return errors.New(AlreadyInitialized)
default:
}
c := &pb.OptimizerConfig{}
proto.Unmarshal(paramWithConfigs.Config, c)
log.Debug(fmt.Sprintf("OptimizerConfig:%v", c))
s.mu.Lock()
defer s.mu.Unlock()
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
log.Info(
"init parameter",
"name", paramWithConfigs.Param.Name,
"config len", len(paramWithConfigs.Config),
"param len", len(paramWithConfigs.Param.Content),
"type", paramWithConfigs.Param.ElementType,
)
return nil
}
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func (s *Service) FinishInitParams(_ int, _ *int) error {
select {
case <-s.initialized:
log.Warn("finished init param called but parameters already initialized.")
return errors.New(AlreadyInitialized)
default:
}
close(s.initialized)
go func() {
t := time.Tick(s.checkpointInterval)
for range t {
err := s.checkpoint()
if err != nil {
log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
log.Info("init parameter finished.")
return nil
}
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func (s *Service) SendGrad(g Gradient, _ *int) error {
select {
case <-s.initialized:
default:
log.Warn("received gradient before initialization.",
"name", g.Name, "size", len(g.Content), "type", g.ElementType)
return errors.New(Uninitialized)
}
s.mu.Lock()
defer s.mu.Unlock()
o, ok := s.optMap[g.Name]
if !ok {
log.Warn("received gradient but can't find name.",
"name", g.Name, "size", len(g.Content), "type", g.ElementType)
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
log.Debug(Parameter(g).String())
log.Info("received gradient from trainer, updating gradient.",
"name", g.Name, "size", len(g.Content), "type", g.ElementType)
return o.UpdateParameter(g)
}
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
opt, ok := s.optMap[name]
if !ok {
log.Warn("trainer wants to get a parameter that does not exist.", "name", name)
return fmt.Errorf("parameter: %s does not exist", name)
}
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// parameter content.
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
log.Debug(parameter.String())
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
func traceTime(start time.Time, name string) {
elapsed := time.Since(start)
log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
}
// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
log.Info("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock()
cp := make([]parameterCheckpoint, len(s.optMap))
index := 0
// TODO(helin): write checkpoint incrementally to reduce memory
// footprint during checkpoint.
for name, opt := range s.optMap {
var pc parameterCheckpoint
pc.Param.Name = name
pc.Param.ElementType = opt.elementType
pc.Param.Content = opt.GetWeights()
pc.Config = opt.config
pc.State = opt.GetStates()
cp[index] = pc
index++
}
s.mu.Unlock()
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err = encoder.Encode(cp)
if err != nil {
return
}
if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
err = os.MkdirAll(s.checkpointPath, os.ModePerm)
if err != nil {
return
}
}
id := uuid.NewV4().String()
p := path.Join(s.checkpointPath, id)
f, err := os.Create(p)
if err != nil {
return
}
defer func() {
closeErr := f.Close()
if closeErr != nil {
if err != nil {
log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
} else {
// Set closeErr as return value.
err = closeErr
}
}
}()
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
if err != nil {
return
}
err = writer.Flush()
if err != nil {
return
}
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Info("old meta not found, skip removing old meta")
err = nil
} else if err == nil {
log.Info("removing old meta")
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
}
if err != nil {
return
}
crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
CRC32: crc32,
Path: p,
}
json, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
if err != nil {
return
}
return
}
package pserver
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const testDir = "./test_data"
type myKV struct {
m map[string][]byte
}
func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
if m.m == nil {
m.m = make(map[string][]byte)
}
return m.m[key], nil
}
func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
if m.m == nil {
m.m = make(map[string][]byte)
}
m.m[key] = value
return nil
}
func TestCheckpoint(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
err = s.checkpoint()
assert.Nil(t, err)
_, err = LoadCheckpoint(kv, 0)
assert.Nil(t, err)
}
func float32ToByte(f float32) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, f)
if err != nil {
fmt.Println("binary.Write failed:", err)
}
return buf.Bytes()
}
func TestCheckpointWithData(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
var content []byte
for i := 0; i < 50000; i++ {
content = append(content, float32ToByte(float32(i))...)
}
p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
assert.Nil(t, err)
err = s.FinishInitParams(0, nil)
assert.Nil(t, err)
var p2 Parameter
err = s.GetParam(p1.Name, &p2)
assert.Nil(t, err)
assert.Equal(t, p1, p2)
err = s.checkpoint()
assert.Nil(t, err)
cp, err := LoadCheckpoint(kv, 0)
assert.Nil(t, err)
s1, err := NewService(0, time.Hour, testDir, kv, cp)
assert.Nil(t, err)
var p3 Parameter
err = s1.GetParam(p1.Name, &p3)
assert.Nil(t, err)
assert.Equal(t, p1, p3)
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver_test
import (
"fmt"
"io/ioutil"
"reflect"
"sync"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
const (
OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)
func TestServiceFull(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
t.Error(err)
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil {
t.Fatal(err)
}
var p1 pserver.Parameter
p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil {
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.Fatal(err)
}
var param pserver.Parameter
err = s.GetParam("param_b", &param)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(param, p1) {
t.Fatal("not equal:", param, p1)
}
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil)
if err != nil {
t.Fatal(err)
}
err = s.SendGrad(g2, nil)
if err != nil {
t.Fatal(err)
}
var param1 pserver.Parameter
err = s.GetParam("param_a", &param1)
if err != nil {
t.Fatal(err)
}
// don't compare content, since it's already changed by
// gradient update.
param1.Content = nil
p.Content = nil
if !reflect.DeepEqual(param1, p) {
t.Fatal("not equal:", param1, p)
}
}
func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized {
t.Fatal(err)
}
}
func TestUninitialized(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.Fatal(err)
}
}
func TestBlockUntilInitialized(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
t.Error(err)
}
ch := make(chan struct{}, 2)
errCh := make(chan error, 2)
var wg sync.WaitGroup
wg.Add(1)
go func() {
var param pserver.Parameter
err := s.GetParam("param_a", &param)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()
time.Sleep(50 * time.Millisecond)
select {
case <-ch:
// some function returned before initialization is completed.
t.FailNow()
case <-errCh:
t.FailNow()
default:
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil {
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.Fatal(err)
}
wg.Wait()
}
func TestGradientString(t *testing.T) {
g := pserver.Parameter{}
g.ElementType = pserver.Float32
g.Content = []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40}
if g.String() != "[3.3702806e+12 2.142699 3.3702806e+12 2.142699]" {
t.Fatal("get float data error!")
}
g.Content = []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40,
0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40}
if g.String() != "[3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699...3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699]" {
t.Fatal("get float data error!", g.String())
}
fmt.Println(g)
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(network_helper_test)
endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper
import (
"errors"
"net"
)
// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func GetExternalIP() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper
import "testing"
func TestGetIP(t *testing.T) {
_, err := GetExternalIP()
if err != nil {
t.Errorf("GetExternalIP returns error : %v\n", err)
}
}
if (MOBILE_INFERENCE)
file(GLOB proto_filenames . ModelConfig.proto ParameterConfig.proto
TrainerConfig.proto DataConfig.proto)
else()
file(GLOB proto_filenames . *.proto)
endif()
include_directories(${CMAKE_CURRENT_BINARY_DIR})
proto_library(paddle_proto SRCS ${proto_filenames})
set(PROTO_GEN)
set(PROTO_GEN_PY)
foreach(filename ${proto_filenames})
get_filename_component(ABS_FIL ${filename} ABSOLUTE)
get_filename_component(FIL_WE ${filename} NAME_WE)
set(CUR_PROTO_GEN_PY
${PADDLE_BINARY_DIR}/paddle/python/paddle/proto/${FIL_WE}_pb2.py)
set(PROTO_GEN_PY
${CUR_PROTO_GEN_PY}
${PROTO_GEN_PY})
add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY}
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/proto
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS "--python_out=${PADDLE_BINARY_DIR}/python/paddle/proto"
"-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} protoc)
endforeach()
add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY})
if (WITH_GOLANG)
add_custom_target(protoc-gen-go)
add_custom_command(TARGET protoc-gen-go
COMMAND go
ARGS "get" "-u" "github.com/golang/protobuf/protoc-gen-go")
set(PROTO_GEN_GO)
file(GLOB proto_filenames . OptimizerConfig.proto)
foreach(filename ${proto_filenames})
message(STATUS ${filename})
get_filename_component(ABS_FIL ${filename} ABSOLUTE)
get_filename_component(FIL_WE ${filename} NAME_WE)
set(CUR_PROTO_GEN_GO
${PADDLE_SOURCE_DIR}/paddle/go/proto/${FIL_WE}.pb.go)
set(PROTO_GEN_GO
${CUR_PROTO_GEN_GO}
${PROTO_GEN_GO})
add_custom_command(OUTPUT ${CUR_PROTO_GEN_GO}
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS "--go_out=${PADDLE_SOURCE_DIR}/go/proto"
"-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} protoc protoc-gen-go)
endforeach()
add_custom_target(gen_proto_go ALL DEPENDS ${PROTO_GEN_GO})
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
message FileGroupConf {
optional uint32 queue_capacity = 1 [ default = 1 ];
// how many files to load for a load file thread
optional int32 load_file_count = 2 [ default = 1 ];
// how many threads to load files
// Setting to be 5~10 is appropriate when loading files by hadoop vfs
optional int32 load_thread_num = 3 [ default = 1 ];
};
message DataConfig {
required string type = 1;
// name of a text file which contains a list of file names at each line
optional string files = 3;
optional int32 feat_dim = 4; // feature dimension of one frame
repeated int32 slot_dims = 5; // feature slot dims
optional int32 context_len = 6; // max neibour frame numbers
optional uint64 buffer_capacity = 7; // the number of samples
// part of data used in training
// if not -1, part of train data is used in training
optional int64 train_sample_num = 8 [ default = -1 ];
// The number of documents processed once
optional int32 file_load_num = 9 [ default = -1 ];
optional bool async_load_data = 12 [ default = false ];
/// Note the field number 10, 11 and 13 have been deprecated.
optional bool for_test = 14
[ default = false ]; // whether this data is for test
optional FileGroupConf file_group_conf = 15;
repeated int32 float_slot_dims = 16;
/// Note the field number 17, 18 and 19 have been deprecated.
// a list of values which will be used to create additional one dimensional
// float
// values slots. These one dimensional slots can be used as the weight input
// for cost layers.
// Currently this is only supported by ProtoDataProvider.
repeated double constant_slots = 20;
// for PyDataProvider.
// Specify the load data script module name, object name and user args
optional string load_data_module = 21;
optional string load_data_object = 22;
optional string load_data_args = 23;
// for MultiDataProvider
repeated DataConfig sub_data_configs = 24; // sub dataproviders
/*
* the ratio of each sub dataproviders:
* e.g. sub dataprovider A's ratio is 1, B's ratio is 9, batch_size is 100,
* then each mini-batch is combined by 10 instance from A and 90 instances
* from B.
*/
optional int32 data_ratio = 25;
/*
* if one of the sub dataproviders is running out of data, then
* (1) it is "main data", then finish current pass.
* (2) it is not "main data", then reset it, and try getNextBatch again.
*/
optional bool is_main_data = 26 [ default = true ];
// the usage ratio of instances. Setting to 1.0 means the use of all
// instances.
optional double usage_ratio = 27 [ default = 1.0 ];
};
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
/*
If values is not empty and ids is empty, this is a dense vector.
If values is not empty and ids is not empty, this is a sparse vector. The
position of each value
is specified by ids.
If values is empty and ids is not empty, this is a sparse vector whose non-zero
values are 1.
The position of each 1 is specified by ids.
*/
message VectorSlot {
repeated float values = 1 [ packed = true ];
repeated uint32 ids = 2 [ packed = true ];
/* For multidimensional data, for example "image width height depth" */
repeated uint32 dims = 3 [ packed = true ];
repeated string strs = 4;
};
/*
SubseqSlot use to record whether VectorSlot or any other slot in future has
subseq.
If not all VectorSlot have subseq, we only store the one who has subseq, and
use *slot_id* to record it.
One vector_slots has one sequence, and it may have N subseq, thus the number of
*lens* will be N too.
*/
message SubseqSlot {
required uint32 slot_id = 1; // the id of slot who has subseq
repeated uint32 lens = 2; // lengths of sub-sequence in the slot
};
message SlotDef {
enum SlotType {
VECTOR_DENSE = 0;
VECTOR_SPARSE_NON_VALUE = 1;
VECTOR_SPARSE_VALUE = 2;
INDEX = 3; // This can be used as label, or word id, etc.
VAR_MDIM_DENSE = 4;
VAR_MDIM_INDEX = 5;
STRING = 6;
}
required SlotType type = 1;
required uint32 dim =
2; // For INDEX slots, this means the maximal index plus 1.
};
message DataHeader {
// INDEX slot should be always after VECTOR slots.
repeated SlotDef slot_defs = 1;
};
message DataSample {
optional bool is_beginning = 1
[ default = true ]; // is the beginning of a sequence
repeated VectorSlot vector_slots = 2;
repeated uint32 id_slots = 3 [ packed = true ];
/* use ids of VectorSlot */
repeated VectorSlot var_id_slots = 4;
repeated SubseqSlot subseq_slots = 5;
};
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "ParameterConfig.proto";
package paddle;
/**
* Various structs for the configuration of a neural network
*/
message ExternalConfig {
repeated string layer_names = 1;
repeated string input_layer_names = 2;
repeated string output_layer_names = 3;
}
message ActivationConfig {
// identity: f(x) = x
// sigmoid: f(x) = 1 / (1 + exp(-x))
// logistic: f(x) = (1 - exp(-x)) / (1+ exp(-x))
// softmax: y_i = f(x_i) = exp(x_i) / (\sum_i exp(x_i))
// relu: y = max(0, x)
required string type = 1;
};
message ConvConfig {
// filter_size = 5, says that this layer will use
// filters of size 5x5 pixels.
required uint32 filter_size = 1;
// The image data dimensionality.
// This value must be either 1, 2, 3, or a multiple of 4.
required uint32 channels = 2;
// stride = 1, indicates that the distance between
// successive filter applications should be 1 pixel.
required uint32 stride = 3;
// padding = 4, instructs the net to implicitly
// pad the images with a 4-pixel border of zeros.
required uint32 padding = 4;
// If groups = 4 together with the filters = 32 parameter,
// they state that this convolutional layer is to have 4
// groups of 32 filters. Each filter will connect to 8
// input channels.
required uint32 groups = 5;
required uint32 filter_channels = 6;
// The size of output feature map.
required uint32 output_x = 7;
// The size of input feature map.
required uint32 img_size = 8;
// caffe mode for output size coherence
required bool caffe_mode = 9 [ default = true ];
// if filter_size_y is set , this convolutional layer will use
// filters of size filter_size * filter_size_y pixels.
// if filter_size_y is not set, this convolutional layer will use
// filters of size filter_size * filter_size
required uint32 filter_size_y = 10;
required uint32 padding_y = 11;
required uint32 stride_y = 12;
// if not set, use output_x
optional uint32 output_y = 13;
// if not set, use img_size
optional uint32 img_size_y = 14;
optional uint32 dilation = 15 [ default = 1 ];
optional uint32 dilation_y = 16 [ default = 1 ];
optional uint32 filter_size_z = 17 [ default = 1 ];
optional uint32 padding_z = 18 [ default = 1 ];
optional uint32 stride_z = 19 [ default = 1 ];
optional uint32 output_z = 20 [ default = 1 ];
optional uint32 img_size_z = 21 [ default = 1 ];
}
message PoolConfig {
// max or avg pooling
required string pool_type = 1;
required uint32 channels = 2;
// Defines the size of the pooling region in
// the x (equivalently, y) dimension.
required uint32 size_x = 3;
// Tell the net where in the input image to start the pooling.
// start is deprecated now.
optional uint32 start = 4;
// Defines the stride size between successive pooling squares.
required uint32 stride = 5 [ default = 1 ];
// The size of output feature map.
required uint32 output_x = 6;
// The size of input feature map.
required uint32 img_size = 7;
// padding = 4, instructs the net to implicitly
// pad the images with a 4-pixel border of zeros.
optional uint32 padding = 8 [ default = 0 ];
// if not set, use size_x
optional uint32 size_y = 9;
// if not set, use stride
optional uint32 stride_y = 10;
// if not set, use output_x
optional uint32 output_y = 11;
// if not set, use img_size
optional uint32 img_size_y = 12;
// if not set, use padding
optional uint32 padding_y = 13;
optional uint32 size_z = 14 [ default = 1 ];
optional uint32 stride_z = 15 [ default = 1 ];
optional uint32 output_z = 16 [ default = 1 ];
optional uint32 img_size_z = 17 [ default = 1 ];
optional uint32 padding_z = 18 [ default = 1 ];
optional bool exclude_mode = 19;
}
message SppConfig {
required ImageConfig image_conf = 1;
required string pool_type = 2;
required uint32 pyramid_height = 3;
}
message NormConfig {
// rnorm or cmrnorm
required string norm_type = 1;
required uint32 channels = 2;
// rnorm: this defines the size of the local regions
// used for response normalization.
// cmrnorm: The size parameter indicates how many
// nearby maps to use for normalization.
required uint32 size = 3;
// the parameters for normalization
// u = u / (1+scale*sum(u^2 in window))^pow
required double scale = 4;
required double pow = 5;
// The size of output feature map.
required uint32 output_x = 6;
// The size of input feature map.
required uint32 img_size = 7;
// normalize with fixed window or sliding window
// u = u / (1+scale*sum(u^2 in window))^pow
// fixed window: shared a fixed window for each value
// sliding window: have a different window for each value
optional bool blocked = 8;
// if not set, use output_x
optional uint32 output_y = 9;
// if not set, use img_size
optional uint32 img_size_y = 10;
}
message BlockExpandConfig {
required uint32 channels = 1;
required uint32 stride_x = 2;
required uint32 stride_y = 3;
required uint32 padding_x = 4;
required uint32 padding_y = 5;
required uint32 block_x = 6;
required uint32 block_y = 7;
// The size of output feature map.
required uint32 output_x = 8;
required uint32 output_y = 9;
// The size of input feature map.
required uint32 img_size_x = 10;
required uint32 img_size_y = 11;
}
message MaxOutConfig {
required ImageConfig image_conf = 1;
required uint32 groups = 2;
}
message RowConvConfig { required uint32 context_length = 1; }
message SliceConfig {
required uint32 start = 1;
required uint32 end = 2;
}
message ProjectionConfig {
required string type = 1;
required string name = 2;
required uint64 input_size = 3;
required uint64 output_size = 4;
// For ShiftProjection
optional int32 context_start = 5;
optional int32 context_length = 6;
optional bool trainable_padding = 7 [ default = false ];
// For convolution
optional ConvConfig conv_conf = 8;
optional int32 num_filters = 9;
// For IdentityOffsetProjection
optional uint64 offset = 11 [ default = 0 ];
// For pool
optional PoolConfig pool_conf = 12;
// For slice
// Each slice output is the input[start, end)
repeated SliceConfig slices = 13;
}
message OperatorConfig {
required string type = 1;
repeated int32 input_indices = 2;
repeated uint64 input_sizes = 3;
required uint64 output_size = 4;
// For DotMulOperator
optional double dotmul_scale = 5 [ default = 1.0 ];
// For ConvOperator
optional ConvConfig conv_conf = 6;
optional int32 num_filters = 7;
}
message BilinearInterpConfig {
// The size of input feature map.
required ImageConfig image_conf = 1;
// The size of output feature map.
required uint32 out_size_x = 2;
required uint32 out_size_y = 3;
}
message ImageConfig {
// The image data dimensionality.
// This value must be either 1, 2, 3, or a multiple of 4.
required uint32 channels = 2;
// The size of input feature map.
required uint32 img_size = 8;
optional uint32 img_size_y = 9;
optional uint32 img_size_z = 10 [ default = 1 ];
}
message PriorBoxConfig {
repeated uint32 min_size = 1;
repeated uint32 max_size = 2;
repeated float aspect_ratio = 3;
repeated float variance = 4;
}
message PadConfig {
required ImageConfig image_conf = 1;
repeated uint32 pad_c = 2;
repeated uint32 pad_h = 3;
repeated uint32 pad_w = 4;
}
message ReshapeConfig {
repeated uint32 height_axis = 1;
repeated uint32 width_axis = 2;
}
message MultiBoxLossConfig {
required uint32 num_classes = 1;
required float overlap_threshold = 2;
required float neg_pos_ratio = 3;
required float neg_overlap = 4;
required uint32 background_id = 5;
required uint32 input_num = 6;
optional uint32 height = 7 [ default = 1 ];
optional uint32 width = 8 [ default = 1 ];
}
message DetectionOutputConfig {
required uint32 num_classes = 1;
required float nms_threshold = 2;
required uint32 nms_top_k = 3;
required uint32 background_id = 4;
required uint32 input_num = 5;
required uint32 keep_top_k = 6;
required float confidence_threshold = 7;
optional uint32 height = 8 [ default = 1 ];
optional uint32 width = 9 [ default = 1 ];
}
message ClipConfig {
required double min = 1;
required double max = 2;
}
message UpsampleConfig {
required ImageConfig image_conf = 1;
optional uint32 scale = 2 [ default = 2 ];
optional uint32 scale_y = 3 [ default = 2 ];
optional bool pad_out_x = 4 [ default = false ];
optional bool pad_out_y = 5 [ default = false ];
optional uint32 upsample_size = 6;
optional uint32 upsample_size_y = 7;
}
message ROIPoolConfig {
required uint32 pooled_width = 1;
required uint32 pooled_height = 2;
required float spatial_scale = 3;
optional uint32 height = 4 [ default = 1 ];
optional uint32 width = 5 [ default = 1 ];
}
message ScaleSubRegionConfig {
required ImageConfig image_conf = 1;
required float value = 2;
}
message LayerInputConfig {
required string input_layer_name = 1;
optional string input_parameter_name = 2;
optional ConvConfig conv_conf = 3;
optional PoolConfig pool_conf = 4;
optional NormConfig norm_conf = 5;
optional ProjectionConfig proj_conf = 6;
optional BlockExpandConfig block_expand_conf = 7;
optional ImageConfig image_conf = 8;
// If the input layer has multi-output.
// Set the argument name.
optional string input_layer_argument = 9;
optional BilinearInterpConfig bilinear_interp_conf = 10;
optional MaxOutConfig maxout_conf = 11;
optional SppConfig spp_conf = 12;
optional PriorBoxConfig priorbox_conf = 13;
optional PadConfig pad_conf = 14;
optional RowConvConfig row_conv_conf = 15;
optional MultiBoxLossConfig multibox_loss_conf = 16;
optional DetectionOutputConfig detection_output_conf = 17;
optional ClipConfig clip_conf = 18;
optional ScaleSubRegionConfig scale_sub_region_conf = 19;
optional ROIPoolConfig roi_pool_conf = 20;
optional UpsampleConfig upsample_conf = 21;
}
message LayerConfig {
required string name = 1;
required string type = 2;
optional uint64 size = 3;
// optional ActivationConfig activation = 4;
optional string active_type = 4;
repeated LayerInputConfig inputs = 5;
optional string bias_parameter_name = 6;
// This number must be a multiple of 16.
optional uint32 num_filters = 7;
// indicates that the biases of every filter in this layer
// should be shared amongst all applications of that filter
// (which is how convnets are usually trained). Setting this to
// false will untie the biases, yielding a separate bias for
// every location at which the filter is applied.
optional bool shared_biases = 8 [ default = false ];
// Valid values are ones that divide the area of the output
// grid in this convolutional layer. For example if this layer
// produces 32-channel 20x20 output grid, valid values of
// partialSum are ones which divide 20*20 = 400.
// I'll update this comments when confirmed
optional uint32 partial_sum = 9;
// for dropout
optional double drop_rate = 10;
// for HierarchicalSoftmaxLayer and NCELayer
// the number of classes
optional uint32 num_classes = 11;
// the gpu device which the Layer's data in.
// Only used by ParallelNeuralNetork. Ignored otherwise.
optional int32 device = 12 [ default = -1 ];
// for recurrent layer. If true, the recurrence runs from the end to the
// beginning.
optional bool reversed = 13 [ default = false ];
// for lstmemory layer. Different types of nodes have different activation
// type.
optional string active_gate_type = 14;
optional string active_state_type = 15;
// For NCELayer
// The number of random negative labels for each sample
optional int32 num_neg_samples = 16 [ default = 10 ];
// For NCELayer
// The distribution for generating the random negative labels.
// A uniform distribution will be used if not provided
repeated double neg_sampling_dist = 17 [ packed = true ];
// For MaxLayer
// default: output VALUE of MaxLayer. set this flag to true for output INDEX
// INDEX will be put in Argument::value as double values.
optional bool output_max_index = 19 [ default = false ];
/// The filed number 20 have been deprecated.
// For self-normalized estimation
optional double softmax_selfnorm_alpha = 21 [ default = 0.1 ];
/// The filed numbers 22 and 23 have been deprecated.
// for MDLstmLayer
repeated bool directions = 24;
// for CTCLayer
optional bool norm_by_times = 25;
// for CostLayers
optional double coeff = 26 [ default = 1.0 ];
// for AverageLayer
// can be set to: 'average', 'sum' or 'squarerootn'
optional string average_strategy = 27;
// for error clipping
optional double error_clipping_threshold = 28 [ default = 0.0 ];
// for operators used by mixed layer
repeated OperatorConfig operator_confs = 29;
// for lambdaCost
optional int32 NDCG_num = 30;
optional int32 max_sort_size = 31;
// for SlopeInterceptLayer
optional double slope = 32;
optional double intercept = 33;
// for CosSimVecMatLayer and CosSimLayer
optional double cos_scale = 34;
// for DataNormLayer
// can be set to: 'z-score', 'min-max' or 'decimal-scaling'
optional string data_norm_strategy = 36;
// for bos/eos id
optional uint32 bos_id = 37;
optional uint32 eos_id = 38;
// for max id layer
optional uint32 beam_size = 39;
// for seqlastins layer, whether select first instead last
optional bool select_first = 40 [ default = false ];
// for seqlastins layer, AverageLayer, MaxLayer and ExpandLayer
// can be set to: 'non-seq','seq'
optional string trans_type = 41 [ default = 'non-seq' ];
// to indicate whether selective_fc layer
// is used in sequence generation or not
optional bool selective_fc_pass_generation = 42 [ default = false ];
// to indicate whether selective_fc layer take its last input to
// selected several columns and only compute the multiplications
// between the input matrices and the selected columns of
// the parameter matrices of this layer.
// if set false, selective_fc degrades into fc.
optional bool has_selected_colums = 43 [ default = true ];
// this parameter is for speed consideration.
// if number of the selected columns is less than
// sample number * selective_fc output size * selective_fc_mull_mull_ratio
// sparse multiplication is used, otherwise, using full multiplication.
optional double selective_fc_full_mul_ratio = 44 [ default = 0.02 ];
// to indicate how many threads selective_fc use to to accelate
// the plain_mul period
// leave empty or set to 0 to disable multi-thread accleleration
optional uint32 selective_fc_parallel_plain_mul_thread_num = 45
[ default = 0 ];
// for batch normalization layer
// if set use_global_stats true, will use the loaded mean and variance.
optional bool use_global_stats = 46;
// use to compute moving mean and variance.
optional double moving_average_fraction = 47 [ default = 0.9 ];
// bias size
optional uint32 bias_size = 48 [ default = 0 ];
// this parameter can be used as a user-defined parameter when necessary,
// without changing the proto file.
// e.g., when a new layer with a user-defined parameter is implemented,
// it can be used to pass that parameter, without modifying the proto file.
// string type is used for flexibility: different types can be converted
// to string and reinterpreted in the user's own layer implementation.
optional string user_arg = 49;
// to indicate rectangle image data
optional uint64 height = 50;
optional uint64 width = 51;
// blank label used in ctc loss
optional uint32 blank = 52 [ default = 0 ];
// stride parameter for seqlastins layer, AverageLayer, MaxLayer, which
// controls the scope of pooling operation. can be set > 0.
// leave empty or set to -1 to disable this stride pooling.
optional int32 seq_pool_stride = 53 [ default = -1 ];
// for crop layer
optional int32 axis = 54 [ default = 2 ];
repeated uint32 offset = 55;
repeated uint32 shape = 56;
// for HuberRegressionLoss
optional double delta = 57 [ default = 1.0 ];
// for 3D data
optional uint64 depth = 58 [ default = 1 ];
// for switch order layer
optional ReshapeConfig reshape_conf = 59;
// for batch normalization layer
// The small constant added to the variance to improve numeric stability.
optional double epsilon = 60 [ default = 0.00001 ];
// for factorization machine layer
optional uint32 factor_size = 61;
}
message EvaluatorConfig {
required string name = 1;
required string type = 2;
repeated string input_layers = 3;
// Used by ChunkEvaluator
// one of "IOB", "IOE", "IOBES"
optional string chunk_scheme = 4;
// number of chunk types other than "other"
optional int32 num_chunk_types = 5;
// Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator
// For multi binary labels: true if output > classification_threshold
optional double classification_threshold = 6 [ default = 0.5 ];
// The positive label. -1 means average precision and recall
optional int32 positive_label = 7 [ default = -1 ];
// load dict from this file
optional string dict_file = 8;
// dump result in this file
optional string result_file = 9;
// top # results for max id printer
optional int32 num_results = 10 [ default = 1 ];
// whether to delimit the sequence in the seq_text_printer
optional bool delimited = 11 [ default = true ];
// Used by ChunkEvaluator
// chunk of these types are not counted
repeated int32 excluded_chunk_types = 12;
// Used by ClassificationErrorEvaluator
// top # classification error
optional int32 top_k = 13 [ default = 1 ];
// Used by DetectionMAPEvaluator
optional double overlap_threshold = 14 [ default = 0.5 ];
optional int32 background_id = 15 [ default = 0 ];
optional bool evaluate_difficult = 16 [ default = false ];
optional string ap_type = 17 [ default = "11point" ];
}
message LinkConfig {
required string layer_name = 1;
required string link_name = 2;
// If true, this link has sub-sequence
optional bool has_subseq = 3 [ default = false ];
}
message MemoryConfig {
required string layer_name = 1;
required string link_name = 2;
optional string boot_layer_name = 3;
optional string boot_bias_parameter_name = 4;
optional string boot_bias_active_type = 5;
optional uint32 boot_with_const_id = 7;
// memory is a sequence, initailized by a sequence boot layer
optional bool is_sequence = 6 [ default = false ];
}
message GeneratorConfig {
required uint32 max_num_frames = 1;
required string eos_layer_name = 2;
optional int32 num_results_per_sample = 3 [ default = 1 ];
// for beam search
optional int32 beam_size = 4 [ default = 1 ];
optional bool log_prob = 5 [ default = true ];
}
message SubModelConfig {
required string name = 1;
repeated string layer_names = 2; // selected layers in sub model
repeated string input_layer_names = 3;
repeated string output_layer_names = 4;
repeated string evaluator_names = 5;
optional bool is_recurrent_layer_group = 6 [ default = false ];
// If true, the recurrence runs from the end to the beginning.
optional bool reversed = 7 [ default = false ];
// name and link name of memory
repeated MemoryConfig memories = 8;
// if use recurrent layer group, all layers in submodel will postfix by
// "_in_"+submodel.name, so we add a name pair to link between
// root model and layer group,
// note that these in/out layers are not input/output of the network.
repeated LinkConfig in_links = 9;
repeated LinkConfig out_links = 10;
optional GeneratorConfig generator = 11;
// the id of inlink which share info with outlinks, used in recurrent layer
// group
optional int32 target_inlinkid = 12;
}
message ModelConfig {
// type of the model.
// Currently, "nn", "recurrent_nn" and "recursive_nn" are supported
required string type = 1 [ default = "nn" ];
// layers should be ordered in such a way that the forward propagation
// can be correctly executed by going from the first layer to the last layer
repeated LayerConfig layers = 2;
repeated ParameterConfig parameters = 3;
// Input layers should have the same order as the data streams provided
// by the data provider. The type of input layers should be "data"
repeated string input_layer_names = 4;
// For training, the type of a output layer is usually cost layer.
// For prediction, they should be the actual output layers.
repeated string output_layer_names = 5;
repeated EvaluatorConfig evaluators = 6;
repeated SubModelConfig sub_models = 8;
// For External Machine, defining how to split a neural network
// into multiple parts.
optional ExternalConfig external_config = 9;
};
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle;
message SGDConfig {
// SGD
// momentum: float >= 0. Parameter updates momentum.
// decay: float >= 0. Learning rate decay over each update.
// nesterov: boolean. Whether to apply Nesterov momentum.
optional double momentum = 21 [ default = 0.0 ];
optional double decay = 23 [ default = 0.0 ];
optional bool nesterov = 24 [ default = false ];
}
message AdadeltaConfig {
// Adadelta
// It is recommended to leave it at the default value.
// rho: float >= 0.
// epsilon: float >= 0. Fuzz factor.
// decay: float >= 0. Learning rate decay over each update.
// reference : [Adadelta - an adaptive learning rate
// method](http://arxiv.org/abs/1212.5701)
optional double rho = 33 [ default = 0.90 ];
optional double epsilon = 31 [ default = 1e-5 ];
optional double decay = 32 [ default = 0.0 ];
}
message AdagradConfig {
// Adagrad
// epsilon: float >= 0.
// decay: float >= 0. Learning rate decay over each update.
// reference : [Adaptive Subgradient Methods for Online Learning and
// Stochastic
// Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
optional double epsilon = 41 [ default = 1e-5 ];
optional double decay = 42 [ default = 0.0 ];
}
message AdamConfig {
// Adaj
// beta_1: float, 0 < beta < 1. Generally close to 1.
// beta_2: float, 0 < beta < 1. Generally close to 1.
// epsilon: float >= 0. Fuzz factor.
// decay: float >= 0. Learning rate decay over each update.
// reference : [Adam - A Method for Stochastic
// Optimization](http://arxiv.org/abs/1412.6980v8)
optional double beta_1 = 41;
optional double beta_2 = 42;
optional double epsilon = 43;
optional double decay = 44;
}
message ConstLrConfig {
// learninRate Policy
optional double learning_rate = 1 [ default = 1.0 ];
}
message LinearLrConfig {
// learninRate Policy
optional double learning_rate = 1 [ default = 1.0 ];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}
message TensorProto {
enum DataType {
PADDLE_ELEMENT_TYPE_INT32 = 0;
PADDLE_ELEMENT_TYPE_UINT32 = 1;
PADDLE_ELEMENT_TYPE_INT64 = 2;
PADDLE_ELEMENT_TYPE_UINT64 = 3;
PADDLE_ELEMENT_TYPE_FLOAT32 = 4;
PADDLE_ELEMENT_TYPE_FLOAT64 = 5;
}
optional DataType data_type = 1;
repeated bytes content = 2;
}
message LrPolicyState {
// learninRate Policy
optional double learning_rate = 1 [ default = 1.0 ];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}
message SGDOptimizerState {
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto momentums = 2;
}
message AdadeltaOptimizerState {
// learning rate policy
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto accum_gradient = 2;
optional TensorProto accum_delta = 3;
optional TensorProto update_delta = 4;
}
message AdagradOptimizerState {
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto accum_gradient = 2;
}
message AdamOptimizerState {
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto momentums = 2;
optional TensorProto velocitys = 3;
}
message OptimizerConfig {
enum Optimizer {
SGD = 1;
Adadelta = 2;
Adagrad = 3;
Adam = 4;
}
optional Optimizer optimizer = 1;
optional SGDConfig sgd = 3;
optional AdadeltaConfig adadelta = 4;
optional AdagradConfig adagrad = 5;
optional AdamConfig adam = 6;
enum LrPolicy {
Const = 0;
Linear = 1;
}
optional LrPolicy lr_policy = 11;
optional ConstLrConfig const_lr = 12;
optional LinearLrConfig linear_lr = 13;
// common config of optimizer
// gradient clip when L2 exceeding value
optional double clip_norm = 101;
// gradient clip when L1 exceeding value
optional double clip_value = 102;
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
/**
* Configuration structure for parameter
*/
enum ParameterInitStrategy {
PARAMETER_INIT_NORMAL = 0;
PARAMETER_INIT_UNIFORM = 1;
}
message ParameterUpdaterHookConfig {
// hook type such as 'pruning'
required string type = 1;
// this represents the ratio of zero element to be set by the Parameter
optional double sparsity_ratio = 2 [ default = 0.6 ];
}
message ParameterConfig {
required string name = 1;
required uint64 size = 2;
optional double learning_rate = 3 [ default = 1.0 ];
optional double momentum = 4 [ default = 0.0 ];
optional double initial_mean = 5 [ default = 0.0 ];
optional double initial_std = 6 [ default = 0.01 ];
// use L2-regularization if decay_rate set and decay_rate_l1 not set
optional double decay_rate = 7 [ default = 0.0 ];
// use L1-regularization if decay_rate_l1 set
optional double decay_rate_l1 = 8 [ default = 0.0 ];
// dims of Parameter, e.g. dims[0] as height, dims[1] as width..
repeated uint64 dims = 9;
// the gpu device which the parameter in.
// Only used by ParallelNeuralNetork. Ignored otherwise.
optional int32 device = 10 [ default = -1 ];
// how to init the parameter: 0 -> normal, 1 -> uniform
// 0: treat initial_mean as mean, intial_std as standard deviation
// 1: range is (initial_mean - initial_std) to (initial_mean + initial_std)
optional int32 initial_strategy = 11 [ default = 0 ];
// define the variance when init the parameter, by height of the Matrix
optional bool initial_smart = 12 [ default = false ];
// apply regularization every # batches
optional int32 num_batches_regularization = 13 [ default = 1 ];
// if is_sparse is true, para is sparse, else para is dense
optional bool is_sparse = 14 [ default = false ];
// if para is sparse, format should be "csc" or "csr", empty means is not
// sparse
optional string format = 15 [ default = "" ];
// sparse remote update or not
optional bool sparse_remote_update = 16 [ default = false ];
// gradient clipping threshold, no clipping by default
optional double gradient_clipping_threshold = 17 [ default = 0.0 ];
// static parameters are fixed when training
optional bool is_static = 18 [ default = false ];
// para_id should NOT be set by config_parser. It is for
// internal use.
optional uint64 para_id = 19;
repeated ParameterUpdaterHookConfig update_hooks = 20;
// setup load mat -> csr
optional bool need_compact = 21 [ default = false ];
// whether to do sparse update for this parameter
optional bool sparse_update = 22 [ default = false ];
// whether this parameter is shared or not.
optional bool is_shared = 23 [ default = false ];
// parameter block size
optional uint64 parameter_block_size = 24 [ default = 0 ];
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
/**
* Configuration structure for ParameterClient2.
*/
message ParameterClientConfig { required int32 trainer_id = 1; }
/**
* Configuration structure for ParameterServer2.
*/
message ParameterServerConfig {
// Number of ports for sending dense parameter,
// following ports on parameter server will be visited
// for sending dense parameter: [port, port+ports_num-1]
required int32 ports_num = 1 [ default = 1 ];
// Number of ports for sending sparse parameter,
// following ports on parameter server will be visited
// for sending sparse parameter:
// [port+ports_num, port+ports_num+ports_num_for_sparse-1]
required int32 ports_num_for_sparse = 2 [ default = 0 ];
// network device name for pservers
required string nics = 3 [ default = "xgbe0,xgbe1" ];
required string rdma_tcp = 4 [ default = "tcp" ];
// Listening port for pserver
required int32 port = 5 [ default = 20134 ];
// number of gradient servers
required int32 num_gradient_servers = 6 [ default = 1 ];
// number of threads for sync op exec
required int32 pserver_num_threads = 7 [ default = 1 ];
// control config_.async_lagged_grad_discard_ratio() min value
required double async_lagged_ratio_min = 8 [ default = 1.0 ];
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf
// use it as defalut value
required double async_lagged_ratio_default = 9 [ default = 1.5 ];
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "ParameterConfig.proto";
import "TrainerConfig.proto";
package paddle;
/**
* Various structs for communicating with parameter server
*/
enum ParameterUpdateMode {
// Set parameter
PSERVER_UPDATE_MODE_SET_PARAM = 0; // use local param
PSERVER_UPDATE_MODE_SET_PARAM_ZERO = 1; // set zero param
// Update parameter once a gradient is received
PSERVER_UPDATE_MODE_ASYNC_SGD = 2;
// Accumulate gradient
PSERVER_UPDATE_MODE_ADD_GRADIENT = 3;
// Average parameters
PSERVER_UPDATE_MODE_AVERAGE_PARAMETER = 4;
// No update. Only get parameters back.
PSERVER_UPDATE_MODE_GET_PARAM = 5;
PSERVER_UPDATE_MODE_GET_PARAM_SPARSE = 6; // only get sparse rows
};
message ParameterBlock {
// it accurately means parameter id.
required uint64 para_id = 1;
// global sparse row or dense block for each block in parameter
required uint64 block_id = 2;
// offset in (local) storage
required uint64 begin_pos = 3;
// actual size of block, size for last block is [endDim -beginDim],
// others is parameter_block_size in ParameterConfig
required uint64 block_size = 4;
}
enum PServerStatus {
PSERVER_STATUS_NOT_SET = 0;
PSERVER_STATUS_PARAMETER_READY = 1;
};
enum BatchStatus {
BATCH_START = 0;
BATCH_ON = 1;
BATCH_FINISH = 2;
BATCH_START_AND_FINISH = 3;
};
message SendParameterRequest {
required ParameterUpdateMode update_mode = 1;
repeated ParameterBlock blocks = 2;
required bool send_back_parameter = 3;
// number of samples used for calculating this update
optional int64 num_samples = 4;
// cost will be used to calculate global objective value
optional double cost = 5;
required BatchStatus batch_status = 6;
optional int32 trainer_id = 7;
// send back parameter type on pserver, PARAMETER_VALUE by default
optional int32 send_back_parameter_type = 8 [ default = 0 ];
// forwardbackward time in usec
optional uint64 forwardbackward_time = 9;
}
message WaitPassStartRequest {}
message WaitPassStartResponse {}
message WaitPassFinishRequest {}
message WaitPassFinishResponse {}
enum SyncObject {
SYNC_DEFAULT = 0; // wait for the synchronizeBarrier_
SYNC_DATA = 1; // wait for the synchronizeDataBarrier_
}
message SynchronizeRequest {
required SyncObject sync_object_id = 1 [ default = SYNC_DEFAULT ];
optional int32 trainer_id = 2;
}
message SynchronizeResponse {}
message SendParameterResponse { repeated ParameterBlock blocks = 1; }
message SetConfigRequest {
repeated ParameterConfig param_configs = 1;
required OptimizationConfig opt_config = 2;
required string save_dir = 4;
required int32 server_id = 5;
required bool is_sparse_server = 6;
}
message SetConfigResponse {}
message GetStatusRequest {}
message GetStatusResponse { required PServerStatus status = 1; }
message SetStatusRequest { required PServerStatus status = 1; }
message SetStatusResponse {}
// create a column vector. The size is the dimension of parameter
message CreateVectorRequest {}
message CreateVectorResponse {
// error message. Empty if success
optional string return_message = 1;
required int64 handle = 2;
}
message ReleaseVectorRequest { required int64 handle = 1; }
message ReleaseVectorResponse {
// error message. Empty if success
optional string return_message = 1;
}
// Create a column major matrix. The number of rows is the dimension
// of parameter. The number of columns is specifed by num_cols
message CreateMatrixRequest { required int32 num_cols = 1; }
message CreateMatrixResponse {
// error message. Empty if success
optional string return_message = 1;
required int64 handle = 2;
}
message ReleaseMatrixRequest { required int64 handle = 1; }
message ReleaseMatrixResponse {
// error message. Empty if success
optional string return_message = 1;
}
/**
* The operations are defined using the variables commented at Operation
* and OperationResult
*/
enum MatrixVectorOperation {
// r = u^T u
PSERVER_OP_utu = 0;
// r = u^T v
PSERVER_OP_utv = 1;
// u = a u
PSERVER_OP_au = 2;
// v = a u + b v
PSERVER_OP_au_bv = 3;
// u = a A x + b u
PSERVER_OP_aAx_bu = 4;
// Stochastic gradient update
PSERVER_OP_SGD = 5;
// u = a
PSERVER_OP_RESET = 6;
// v = u
PSERVER_OP_COPY = 7;
// w = a u + b v + c w
PSERVER_OP_au_bv_cw = 8;
// owlqn: MakeSteepestDescDir
PSERVER_OP_MAKE_STEEPEST_DESC_DIR = 9;
// owlqn: FixDirSigns
PSERVER_OP_FIX_DIR_SIGNS = 10;
// owlqn: DirDeriv
PSERVER_OP_DIR_DERIV = 11;
// owlqn: FixOmegaSigns
PSERVER_OP_FIX_OMEGA_SIGNS = 12;
// Get overall cost
PSERVER_OP_COST = 13;
// Pass control
PSERVER_OP_START_PASS = 14;
PSERVER_OP_FINISH_PASS = 15;
// randomize value
PSERVER_OP_RANDOMIZE = 16;
// call optimizer apply
PSERVER_OP_APPLY = 17;
}
message ProtoVector {
required int64 dim = 1;
repeated double values = 2 [ packed = true ];
}
message ProtoMatrix {
required int64 num_rows = 1;
required int64 num_cols = 2;
repeated double values = 3 [ packed = true ];
}
message Operation {
required MatrixVectorOperation operation = 1;
// vector handles created on the pserver
repeated int64 pvectors = 2; // u, v, w
// matrix handles created on the pserver
repeated int64 pmatrices = 3; // A, B, C
repeated double scalars = 4; // a, b, c
repeated ProtoVector vectors = 5; // x, y, z
repeated ProtoMatrix matrices = 6; // X, Y, Z
}
message OperationResult {
// error message. Empty if success
optional string return_message = 1;
//
repeated double scalars = 2; // d, e, f
repeated ProtoVector vectors = 3; // p, q, r
repeated ProtoMatrix matrices = 4; // P, Q, R
}
message DoOperationRequest {
repeated Operation operations = 1;
// If true, wait for gradient to be ready before starting the operations
required bool wait_for_gradient = 2;
// If true, send back the parameter to clients after the operations are
// finished
required bool send_back_parameter = 3;
// If true, and if all clients call waitPassFinish,
// signal all clients finish the pass
required bool release_pass = 4;
}
message DoOperationResponse {
// error message. Empty if success
optional string return_message = 1;
repeated OperationResult results = 2;
required bool pass_finish = 3;
}
message LoadValueRequest { required string dir_name = 1; }
message LoadValueResponse {
// error message. Empty if success
optional string return_message = 1;
}
message SaveValueRequest { required string dir_name = 1; }
message SaveValueResponse {
// error message. Empty if success
optional string return_message = 1;
}
enum DataUpdateMode {
// Client send it's own data to pserver
DATA_UPDATE_MODE_SET_OWN = 0;
// Client get all user data from all pservers
DATA_UPDATE_MODE_GET_ALL = 1;
// Client send it's own ref feature to pserver
DATA_UPDATE_MODE_SET_REF = 2;
// Client get all ref featuers from all pservers
DATA_UPDATE_MODE_GET_REF = 3;
// Client send it's own ref label to pserver
DATA_UPDATE_MODE_SET_REF_LABEL = 4;
// Client get all ref labels from all pservers
DATA_UPDATE_MODE_GET_REF_LABEL = 5;
// Client send it's own ref grad to pserver
DATA_UPDATE_MODE_SET_REF_GRAD = 6;
// Client get all ref grad from all pservers
DATA_UPDATE_MODE_GET_REF_GRAD = 7;
}
enum SendDataType {
DATA_REF = 0;
DATA_REFLABEL = 1;
DATA_REFGRAD = 2;
DATA_REDUCE_SUM = 3;
}
enum TransDataType {
TRANS_INT32 = 0;
TRANS_UINT32_T = 1;
TRANS_INT64_T = 2;
TRANS_UINT64_T = 3;
TRANS_FLOAT = 5;
TRANS_DOUBLE = 6;
}
message DataBlock {
// total byte size of this data blcok
required uint64 total_size = 1;
// byte size of one data type
required int32 data_size = 2;
// data_type
optional TransDataType data_type = 3 [ default = TRANS_DOUBLE ];
}
message SendDataRequest {
required SendDataType type = 1;
required DataUpdateMode update_mode = 2;
repeated DataBlock blocks = 3;
required uint64 client_id = 4;
required uint64 server_id = 5;
}
message SendDataResponse {
required SendDataType type = 1;
repeated DataBlock blocks = 2;
required uint64 server_id = 3;
}
## protos in this folder are legacy v2 protos.
## Please refer to paddle/fluid for latest version.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "DataConfig.proto";
import "ModelConfig.proto";
package paddle;
message OptimizationConfig {
optional int32 batch_size = 3 [ default = 1 ];
required string algorithm = 4 [ default = "async_sgd" ];
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];
required double learning_rate = 7;
optional double learning_rate_decay_a = 8 [ default = 0 ];
optional double learning_rate_decay_b = 9 [ default = 0 ];
optional string learning_rate_schedule = 27 [ default = "constant" ];
// learning rate will be scaled according to learning_rate_schedule
// 1), constant:
// lr = learning_rate
// 2), poly:
// lr = learning_rate *
// pow(1 + learning_rate_decay_a * num_samples_processed,
// -learning_rate_decay_b)
// 3), exp:
// lr = learning_rate *
// pow(learning_rate_decay_a,
// num_samples_processed / learning_rate_decay_b)
// 4), discexp:
// lr = learning_rate *
// pow(learning_rate_decay_a,
// floor(num_samples_processed / learning_rate_decay_b))
// 5), linear:
// lr = max(learning_rate - learning_rate_decay_a * num_samples_processed,
// learning_rate_decay_b)
// owlqn related
// L1-regularization
optional double l1weight = 10 [ default = 0.1 ];
// L2-regularization
optional double l2weight = 11 [ default = 0 ];
// "c1" in wolfe condition: if (newobj <= oldobj + c1 * origDirDeriv * step)
// then accept the step
optional double c1 = 12 [ default = 0.0001 ];
// multiply the step with "backoff", when wolfe condition doesn't satisfy
optional double backoff = 13 [ default = 0.5 ];
// how many "s"s and "y"s are kept in owlqn
optional int32 owlqn_steps = 14 [ default = 10 ];
// accept the step if encountered "max_backoff" times of "reduce the step"
optional int32 max_backoff = 15 [ default = 5 ];
// L2-regularization coefficient is reduced linearly from iteration 0 to
// "l2weight_zero_iter", and set to 0 after "l2weight_zero_iter"
// iterations. set "l2weight_zero_iter" to 0 to disable this strategy.
optional int32 l2weight_zero_iter = 17 [ default = 0 ];
// averaged sgd
// About average_window * numBatchProcessed parameter are used
// for average. To be accurate, between average_window * numBatchProcessed
// and 2 * average_window * numBatchProcessed parameters are used for
// average.
optional double average_window = 18 [ default = 0 ];
optional int64 max_average_window = 19 [ default = 0x7fffffffffffffff ];
//////////////////////////
// Options Adaptive SGD //
//////////////////////////
// learning method for sgd/asgd, such as "momentum", "adagrad", "adadelta",
// "rmsprop"
// default learning method("momentum") use global decayed learning rate with
// momentum.
// "adagrad", "adadelta" and "rmsprop" can set momentum too.
optional string learning_method = 23 [ default = "momentum" ];
optional double ada_epsilon = 24 [ default = 1e-6 ];
optional double ada_rou = 26 [ default = 0.95 ];
// Force to do average in cpu in order to save gpu memory usage
optional bool do_average_in_cpu = 25 [ default = false ];
// delta add rate in pserver, used while num_batches_per_send_parameter>1
// will be divided by #machines automatically.
optional double delta_add_rate = 28 [ default = 1.0 ];
// We split a large size into smaller mini-batches, whose sizes are
// determined by mini_batch_size. It only takes effect when there is
// an ExternalMachine.
optional int32 mini_batch_size = 29 [ default = 128 ];
// automatically set if any one of parameters set sparse remote update flag
optional bool use_sparse_remote_updater = 30 [ default = false ];
// how to update center parameter and feedback to local parameter,
// when use local sgd update in cluster training.
// A option is elastic_average, proposed by the paper: Deep learning with
// elastic averaging SGD.
// If use elastic_average method, every trainer node should sample from whole
// data sets.
optional string center_parameter_update_method = 31 [ default = "average" ];
// shrink sparse parameter value
// only works if parameter is remote sparse update and has L1 decay rate
optional double shrink_parameter_value = 32 [ default = 0 ];
////////////////////////////
// Options Adam Optimizer //
////////////////////////////
optional double adam_beta1 = 33 [ default = 0.9 ];
optional double adam_beta2 = 34 [ default = 0.999 ];
optional double adam_epsilon = 35 [ default = 1e-8 ];
// arguments for learning rate scheduler
// Format: num1:rate1,num2:rate2,...,numK:rateK
// For learning_rate_schedule="manual", num is the number of samples,
// For learning_rate_schedule="pass_manual",
// num is the number of passes (starting from 0)
optional string learning_rate_args = 36 [ default = "" ];
// for async sgd gradient commit control.
// when async_lagged_grad_discard_ratio * num_gradient_servers commit passed,
// current async gradient will be discard silently.
optional double async_lagged_grad_discard_ratio = 37 [ default = 1.5 ];
// global threshold for gradient clipping
optional double gradient_clipping_threshold = 38 [ default = 0.0 ];
};
message TrainerConfig {
optional ModelConfig model_config = 1;
optional DataConfig data_config = 2;
required OptimizationConfig opt_config = 3;
optional DataConfig test_data_config = 4;
repeated string config_files = 5;
// the directory to save/load model files for each training path
optional string save_dir = 6 [ default = "./output/model" ];
// Path of the initial model parameters.
// If it was set, start_pass will be ignored.
optional string init_model_path = 7;
// Start training from this pass.
// Will load parameter from the previous pass.
optional int32 start_pass = 8 [ default = 0 ];
// file path to the trainer config file
optional string config_file = 9;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册