提交 91c6a792 编写于 作者: S Superjom

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into network

...@@ -113,7 +113,7 @@ include(coveralls) # set code coverage ...@@ -113,7 +113,7 @@ include(coveralls) # set code coverage
include_directories("${PROJ_ROOT}") include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS}) include_directories(${Boost_INCLUDE_DIRS})
set(EXTERNAL_LIBS set(EXTERNAL_LIBS
......
...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject) ...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject)
SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any) SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any)
INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/linb_any) INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add( ExternalProject_Add(
linb_any extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/thelink2012/any.git" GIT_REPOSITORY "https://github.com/thelink2012/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020" GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
...@@ -17,5 +17,15 @@ ExternalProject_Add( ...@@ -17,5 +17,15 @@ ExternalProject_Add(
TEST_COMMAND "" TEST_COMMAND ""
) )
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(lib_any STATIC ${dummyfile})
else()
add_library(lib_any INTERFACE)
endif()
add_dependencies(lib_any extern_lib_any)
add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE) add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE)
LIST(APPEND external_project_dependencies linb_any) LIST(APPEND external_project_dependencies lib_any)
\ No newline at end of file
...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject) ...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add( ExternalProject_Add(
eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website # for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
...@@ -26,4 +26,14 @@ ExternalProject_Add( ...@@ -26,4 +26,14 @@ ExternalProject_Add(
TEST_COMMAND "" TEST_COMMAND ""
) )
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_eigen3 = \"${dummyfile}\";")
add_library(eigen3 STATIC ${dummyfile})
else()
add_library(eigen3 INTERFACE)
endif()
add_dependencies(eigen3 extern_eigen3)
LIST(APPEND external_project_dependencies eigen3) LIST(APPEND external_project_dependencies eigen3)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# #
add_subdirectory(pserver/cclient) add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver) add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master) add_subdirectory(cmd/master)
add_subdirectory(master/c) add_subdirectory(master/c)
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
...@@ -29,12 +30,17 @@ func main() { ...@@ -29,12 +30,17 @@ func main() {
} }
log.SetLevel(level) log.SetLevel(level)
var idx int
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register() idx, err = e.Register()
if err != nil { if err != nil {
panic(err) panic(err)
} }
}
s, err := pserver.NewService(idx) s, err := pserver.NewService(idx)
if err != nil { if err != nil {
......
...@@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
lock := concurrency.NewMutex(sess, lockPath) lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have // It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have // multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath) log.Debugf("Trying to acquire lock at %s.", lockPath)
...@@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error { ...@@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error {
// We lost the master lock and can not acquire // We lost the master lock and can not acquire
// it back, it means some other master is // it back, it means some other master is
// already started. We don't want cluster // already started. We don't want cluster
// managment system to kill the master server // management system to kill the master server
// who is holding the lock and running // who is holding the lock and running
// correctly. So the most feasible solution is // correctly. So the most feasible solution is
// to kill current master server. The current // to kill current master server. The current
......
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
go_library(paddle_pserver_cclient STATIC) go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(test) add_subdirectory(test)
endif() endif()
...@@ -30,15 +30,16 @@ import ( ...@@ -30,15 +30,16 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.paddle_pserver_client { func add(c *client.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle client := curHandle
...@@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client { ...@@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client {
return client return client
} }
func get(client C.paddle_pserver_client) *pserver.Client { func get(client C.paddle_pserver_client) *client.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return handleMap[client] return handleMap[client]
} }
func remove(client C.paddle_pserver_client) *pserver.Client { func remove(client C.paddle_pserver_client) *client.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
h := handleMap[client] h := handleMap[client]
...@@ -80,9 +81,9 @@ func (s selector) Select() bool { ...@@ -80,9 +81,9 @@ func (s selector) Select() bool {
return bool(s) return bool(s)
} }
type lister []pserver.Server type lister []client.Server
func (l lister) List() []pserver.Server { func (l lister) List() []client.Server {
return l return l
} }
...@@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server { ...@@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server {
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client { func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs) a := C.GoString(addrs)
as := strings.Split(a, ",") as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as)) servers := make([]client.Server, len(as))
for i := range as { for i := range as {
servers[i].Index = i servers[i].Index = i
servers[i].Addr = as[i] servers[i].Addr = as[i]
} }
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0)) c := client.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c) return add(c)
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd. // TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
panic("not implemented.") addr := C.GoString(etcd_endpoints)
etcd_client := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
return add(c)
} }
//export paddle_pserver_client_release //export paddle_pserver_client_release
......
package pserver package client
import ( import (
"errors" "errors"
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool { ...@@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool {
} }
// InitParam initializes the parameter on parameter servers. // InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
} }
...@@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error { ...@@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []pserver.Gradient) error {
if len(grads) == 0 { if len(grads) == 0 {
return errors.New("no gradient received") return errors.New("no gradient received")
} }
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g pserver.Gradient) {
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err errCh <- err
}(g) }(g)
...@@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error { ...@@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
type result struct { type result struct {
idx int idx int
param Parameter param pserver.Parameter
err error err error
} }
...@@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) { ...@@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) {
} }
// GetParams gets parameters from parameter servers. // GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) { func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
rCh := make(chan result, len(names)) rCh := make(chan result, len(names))
for idx, name := range names { for idx, name := range names {
go func(name string, idx int) { go func(name string, idx int) {
var parameter Parameter var parameter pserver.Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter) err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err} rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx) }(name, idx)
...@@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) { ...@@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) {
} }
sort.Sort(rs) sort.Sort(rs)
ps := make([]Parameter, len(rs)) ps := make([]pserver.Parameter, len(rs))
for i := range rs { for i := range rs {
ps[i] = rs[i].param ps[i] = rs[i].param
} }
......
package pserver_test package client_test
import ( import (
"context"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
...@@ -8,15 +9,25 @@ import ( ...@@ -8,15 +9,25 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
) )
const numPserver = 10 const (
numPserver = 10
etcdEndpoints = "127.0.0.1:2379"
timeout = 2 * time.Second
)
var port [numPserver]int var pserverClientPorts [numPserver]int
func init() { // this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
var ports [numPserver]int
for i := 0; i < numPserver; i++ { for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
...@@ -28,7 +39,7 @@ func init() { ...@@ -28,7 +39,7 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
port[i] = p ports[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService(0) s, err := pserver.NewService(0)
...@@ -49,6 +60,31 @@ func init() { ...@@ -49,6 +60,31 @@ func init() {
} }
}(l) }(l)
} }
return ports
}
func initNativeClient() {
pserverClientPorts = initClient()
}
func initEtcdClient() {
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{etcdEndpoints},
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
log.Errorf("err %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath)
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
ports := initClient()
for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
}
cancel()
client.Close()
} }
type selector bool type selector bool
...@@ -57,25 +93,20 @@ func (s selector) Select() bool { ...@@ -57,25 +93,20 @@ func (s selector) Select() bool {
return bool(s) return bool(s)
} }
type lister []pserver.Server type lister []client.Server
func (l lister) List() []pserver.Server { func (l lister) List() []client.Server {
return l return l
} }
func TestClientFull(t *testing.T) { func ClientTest(t *testing.T, c *client.Client) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
selected := c.BeginInitParams() selected := c.BeginInitParams()
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
const numParameter = 100 const numParameter = 100
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
...@@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) { ...@@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) {
} }
} }
} }
func TestNativeClient(t *testing.T) {
initNativeClient()
servers := make([]client.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
ClientTest(t, c2)
}
package client
import (
"context"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
const (
DefaultEtcdTimeout time.Duration = 5 * time.Second
)
// EtcdClient is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
}
// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err)
time.Sleep(p.timeout)
continue
}
log.Debugf("Get psDesired number: %d", psDesired)
break
}
return psDesired
}
// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
psDesired := p.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout)
continue
}
psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout)
continue
}
log.Infof("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i
servers[i].Addr = psAddr
}
break
}
return servers
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
for {
cli, err = clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: DefaultEtcdTimeout,
})
if err != nil {
log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout)
continue
}
break
}
log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{
client: cli,
timeout: DefaultEtcdTimeout,
endpoints: ep,
}
return client
}
...@@ -13,6 +13,13 @@ import ( ...@@ -13,6 +13,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
// PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
)
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination. // tolerance, service registry and coordination.
type EtcdClient struct { type EtcdClient struct {
...@@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) {
// it at the same time. // it at the same time.
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers) _, err := e.initDesiredPservers(ctx, e.numPservers)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
...@@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) {
return pserverIdx, nil return pserverIdx, nil
} }
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired) dsStr := c.Get(PsDesired)
if dsStr == "" { if dsStr == "" {
...@@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
ps := c.Get(psKey) ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debugf("got value (%s) for key: %s", ps, psKey)
......
...@@ -2,7 +2,7 @@ package pserver ...@@ -2,7 +2,7 @@ package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path // //FIXME: ldflags contain "build" path
// #cgo LDFLAGS: ../../build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++ // #cgo LDFLAGS: ../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++
// #include "paddle/optimizer/optimizer.h" // #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>
......
...@@ -11,7 +11,7 @@ func TestOptimizerCreateRelease(t *testing.T) { ...@@ -11,7 +11,7 @@ func TestOptimizerCreateRelease(t *testing.T) {
ElementType: Int32, ElementType: Int32,
} }
p.Content = []byte{1, 3} p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
......
...@@ -24,9 +24,6 @@ const ( ...@@ -24,9 +24,6 @@ const (
Float64 Float64
) )
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server. // Parameter is a piece of data to sync with the parameter server.
type Parameter struct { type Parameter struct {
Name string Name string
......
...@@ -10,6 +10,10 @@ import ( ...@@ -10,6 +10,10 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
const (
OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) s, err := pserver.NewService(0)
if err != nil { if err != nil {
...@@ -19,7 +23,7 @@ func TestServiceFull(t *testing.T) { ...@@ -19,7 +23,7 @@ func TestServiceFull(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
...@@ -149,7 +153,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -149,7 +153,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
......
add_subdirectory(dynload)
nv_test(cuda_test SRCS cuda_test.cu) nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
......
...@@ -34,6 +34,16 @@ int GetDeviceCount(void) { ...@@ -34,6 +34,16 @@ int GetDeviceCount(void) {
return count; return count;
} }
int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cublas_v2.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load cublas routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, \
paddle::platform::dynload::GetCublasDsoHandle, \
&cublas_dso_handle); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
} \
} __name; // struct DynLoad__##__name
#else
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; // struct DynLoad__##__name
#endif
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
// include all needed cublas functions in HPPL
// clang-format off
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv) \
__macro(cublasDgemv) \
__macro(cublasSgemm) \
__macro(cublasDgemm) \
__macro(cublasSgeam) \
__macro(cublasDgeam) \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_WRAP
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH
// clang-format on
#ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
#else
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cudnn.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(cudnn_dso_flag, \
paddle::platform::dynload::GetCudnnDsoHandle, \
&cudnn_dso_handle); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#endif
/**
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
// clang-format off
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor) \
__macro(cudnnSetTensor4dDescriptorEx) \
__macro(cudnnGetConvolutionNdForwardOutputDim) \
__macro(cudnnGetConvolutionForwardAlgorithm) \
__macro(cudnnCreateTensorDescriptor) \
__macro(cudnnDestroyTensorDescriptor) \
__macro(cudnnCreateFilterDescriptor) \
__macro(cudnnSetFilter4dDescriptor) \
__macro(cudnnSetPooling2dDescriptor) \
__macro(cudnnDestroyFilterDescriptor) \
__macro(cudnnCreateConvolutionDescriptor) \
__macro(cudnnCreatePoolingDescriptor) \
__macro(cudnnDestroyPoolingDescriptor) \
__macro(cudnnSetConvolution2dDescriptor) \
__macro(cudnnDestroyConvolutionDescriptor) \
__macro(cudnnCreate) \
__macro(cudnnDestroy) \
__macro(cudnnSetStream) \
__macro(cudnnActivationForward) \
__macro(cudnnConvolutionForward) \
__macro(cudnnConvolutionBackwardBias) \
__macro(cudnnGetConvolutionForwardWorkspaceSize) \
__macro(cudnnTransformTensor) \
__macro(cudnnPoolingForward) \
__macro(cudnnPoolingBackward) \
__macro(cudnnSoftmaxBackward) \
__macro(cudnnSoftmaxForward) \
__macro(cudnnGetVersion) \
__macro(cudnnGetErrorString)
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
__macro(cudnnAddTensor) \
__macro(cudnnConvolutionBackwardData) \
__macro(cudnnConvolutionBackwardFilter)
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)
// APIs available after R3:
#if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif
// APIs available after R4:
#if CUDNN_VERSION >= 4007
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining) \
__macro(cudnnBatchNormalizationForwardInference) \
__macro(cudnnBatchNormalizationBackward)
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#endif
// APIs in R5
#if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
__macro(cudnnCreateActivationDescriptor) \
__macro(cudnnSetActivationDescriptor) \
__macro(cudnnGetActivationDescriptor) \
__macro(cudnnDestroyActivationDescriptor)
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#endif
#undef CUDNN_DNN_ROUTINE_EACH
// clang-format on
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <curand.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
typedef curandStatus_t (*curandFunc)(Args...); \
std::call_once(curand_dso_flag, \
paddle::platform::dynload::GetCurandDsoHandle, \
&curand_dso_handle); \
void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#endif
/* include all needed curand functions in HPPL */
// clang-format off
#define CURAND_RAND_ROUTINE_EACH(__macro) \
__macro(curandCreateGenerator) \
__macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)
#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "dynamic_loader.h"
#include <dlfcn.h>
#include <memory>
#include <mutex>
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, "
"/usr/local/cudnn/lib. If empty [default], dlopen "
"will search cudnn from LD_LIBRARY_PATH");
DEFINE_string(cuda_dir, "",
"Specify path for loading cuda library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH");
DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
namespace paddle {
namespace platform {
namespace dynload {
static inline std::string join(const std::string& part1,
const std::string& part2) {
// directory separator
const char sep = '/';
if (!part2.empty() && part2.front() == sep) {
return part2;
}
std::string ret;
ret.reserve(part1.size() + part2.size() + 1);
ret = part1;
if (!ret.empty() && ret.back() != sep) {
ret += sep;
}
ret += part2;
return ret;
}
static inline void GetDsoHandleFromDefaultPath(std::string& dso_path,
void** dso_handle,
int dynload_flags) {
VLOG(3) << "Try to find library: " << dso_path
<< " from default system path.";
// default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH
*dso_handle = dlopen(dso_path.c_str(), dynload_flags);
// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to
// bring System Integrity Projection (SIP), if dso_handle
// is null, search from default package path in Mac OS.
#if defined(__APPLE__) || defined(__OSX__)
if (nullptr == *dso_handle) {
dso_path = join("/usr/local/cuda/lib/", dso_path);
*dso_handle = dlopen(dso_path.c_str(), dynload_flags);
if (nullptr == *dso_handle) {
if (dso_path == "libcudnn.dylib") {
LOG(FATAL)
<< "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n" // NOLINT
<< "For instance, sudo tar -xzf "
"cudnn-7.5-osx-x64-v5.0-ga.tgz -C " // NOLINT
<< "/usr/local \n sudo chmod a+r "
"/usr/local/cuda/include/cudnn.h " // NOLINT
<< "/usr/local/cuda/lib/libcudnn*";
}
}
}
#endif
}
static inline void GetDsoHandleFromSearchPath(const std::string& search_root,
const std::string& dso_name,
void** dso_handle) {
int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
*dso_handle = nullptr;
std::string dlPath = dso_name;
if (search_root.empty()) {
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
} else {
// search xxx.so from custom path
dlPath = join(search_root, dso_name);
*dso_handle = dlopen(dlPath.c_str(), dynload_flags);
// if not found, search from default path
if (nullptr == *dso_handle) {
LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " ("
<< dlerror() << ")";
dlPath = dso_name;
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
}
}
CHECK(nullptr != *dso_handle) << "Failed to find dynamic library: " << dlPath
<< " (" << dlerror() << ") \n"
<< "Please specify its path correctly using "
"following ways: \n"
<< "Method. set environment variable "
"LD_LIBRARY_PATH on Linux or "
<< "DYLD_LIBRARY_PATH on Mac OS. \n"
<< "For instance, issue command: export "
"LD_LIBRARY_PATH=... \n"
<< "Note: After Mac OS 10.11, using the "
"DYLD_LIBRARY_PATH is impossible "
<< "unless System Integrity Protection (SIP) "
"is disabled.";
}
void GetCublasDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle);
#endif
}
void GetCudnnDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle);
#endif
}
void GetCurandDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle);
#endif
}
void GetWarpCTCDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle);
#endif
}
void GetLapackDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle);
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace platform {
namespace dynload {
/**
* @brief load the DSO of CUBLAS
*
* @param **dso_handle dso handler
*
*/
void GetCublasDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CUDNN
*
* @param **dso_handle dso handler
*
*/
void GetCudnnDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CURAND
*
* @param **dso_handle dso handler
*
*/
void GetCurandDsoHandle(void** dso_handle);
/**
* @brief load the DSO of warp-ctc
*
* @param **dso_handle dso handler
*
*/
void GetWarpCTCDsoHandle(void** dso_handle);
/**
* @brief load the DSO of lapack
*
* @param **dso_handle dso handler
*
*/
void GetLapackDsoHandle(void** dso_handle);
} // namespace dynload
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册