diff --git a/CMakeLists.txt b/CMakeLists.txt index 5349f59805ba35bb03d876e4f7279840c8f8641c..5bedbbefa85a730ff2934a12597988a67e73c1a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,7 @@ include(coveralls) # set code coverage include_directories("${PROJ_ROOT}") include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") -include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") +include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c") include_directories(${Boost_INCLUDE_DIRS}) set(EXTERNAL_LIBS diff --git a/cmake/external/any.cmake b/cmake/external/any.cmake index 62eea42692b4191e53d0bbb0805786fd15ac7944..45e3764e8482a4cfc8ee72fe4d79f04a3c9b74fa 100644 --- a/cmake/external/any.cmake +++ b/cmake/external/any.cmake @@ -2,10 +2,10 @@ INCLUDE(ExternalProject) SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any) -INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/linb_any) +INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any) ExternalProject_Add( - linb_any + extern_lib_any ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/thelink2012/any.git" GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020" @@ -17,5 +17,15 @@ ExternalProject_Add( TEST_COMMAND "" ) +if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c) + file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";") + add_library(lib_any STATIC ${dummyfile}) +else() + add_library(lib_any INTERFACE) +endif() + +add_dependencies(lib_any extern_lib_any) + add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE) -LIST(APPEND external_project_dependencies linb_any) \ No newline at end of file +LIST(APPEND external_project_dependencies lib_any) diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 45f44f617dcb46062355df4e35d537086215a46d..3e6cedbb0d718cfd4454f95dedf7e02a24f2981b 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -2,10 +2,10 @@ INCLUDE(ExternalProject) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) -INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) +INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) ExternalProject_Add( - eigen3 + extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} # for latest version, please get from official website # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" @@ -26,4 +26,14 @@ ExternalProject_Add( TEST_COMMAND "" ) +if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c) + file(WRITE ${dummyfile} "const char * dummy_eigen3 = \"${dummyfile}\";") + add_library(eigen3 STATIC ${dummyfile}) +else() + add_library(eigen3 INTERFACE) +endif() + +add_dependencies(eigen3 extern_eigen3) + LIST(APPEND external_project_dependencies eigen3) diff --git a/go/CMakeLists.txt b/go/CMakeLists.txt index 014697d1555859e4d74c55604f8d65d7abe4cbbf..f00c70a0589a4f41a23164a95d505d4310d9157b 100644 --- a/go/CMakeLists.txt +++ b/go/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License. # -add_subdirectory(pserver/cclient) +add_subdirectory(pserver/client/c) add_subdirectory(cmd/pserver) add_subdirectory(cmd/master) add_subdirectory(master/c) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 8a42d4f8af1713e246f9efaf5dc7ba878c3b271e..31ef450f032f756fb32a0444a7e94a18ec2918a0 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -15,6 +15,7 @@ import ( func main() { port := flag.Int("port", 0, "port of the pserver") + index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") @@ -29,11 +30,16 @@ func main() { } log.SetLevel(level) - timeout := time.Second * time.Duration((*etcdTimeout)) - e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) - idx, err := e.Register() - if err != nil { - panic(err) + var idx int + if *index >= 0 { + idx = *index + } else { + timeout := time.Second * time.Duration((*etcdTimeout)) + e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + idx, err = e.Register() + if err != nil { + panic(err) + } } s, err := pserver.NewService(idx) diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index e27c014792f31ca27fe1a1636d69acccc4206ea3..04c1394e963d1eb541b80b91407fb55b0d1e1f2a 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat lock := concurrency.NewMutex(sess, lockPath) // It's fine for the lock to get stuck, in this case we have // multiple master servers running (only configured to have - // one master running, but split-brain problem may cuase + // one master running, but split-brain problem may cause // multiple master servers running), and the cluster management // software will kill one of them. log.Debugf("Trying to acquire lock at %s.", lockPath) @@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error { // We lost the master lock and can not acquire // it back, it means some other master is // already started. We don't want cluster - // managment system to kill the master server + // management system to kill the master server // who is holding the lock and running // correctly. So the most feasible solution is // to kill current master server. The current diff --git a/go/pserver/cclient/CMakeLists.txt b/go/pserver/client/c/CMakeLists.txt similarity index 67% rename from go/pserver/cclient/CMakeLists.txt rename to go/pserver/client/c/CMakeLists.txt index 7fe74c62f109b186eb43383b78f30478b9be74c1..a3fcaeef190a178c1eed806f3e03a14ced780eef 100644 --- a/go/pserver/cclient/CMakeLists.txt +++ b/go/pserver/client/c/CMakeLists.txt @@ -1,5 +1,5 @@ cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) -go_library(paddle_pserver_cclient STATIC) +go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) if(WITH_TESTING) add_subdirectory(test) endif() diff --git a/go/pserver/cclient/cclient.go b/go/pserver/client/c/cclient.go similarity index 88% rename from go/pserver/cclient/cclient.go rename to go/pserver/client/c/cclient.go index bbaf43d9f1434a278568bc110a709718b9b8c222..7ddaceb7ed33db32e19a191402100a0c0efa241a 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/client/c/cclient.go @@ -30,15 +30,16 @@ import ( "unsafe" "github.com/PaddlePaddle/Paddle/go/pserver" + "github.com/PaddlePaddle/Paddle/go/pserver/client" log "github.com/sirupsen/logrus" ) var nullPtr = unsafe.Pointer(uintptr(0)) var mu sync.Mutex -var handleMap = make(map[C.paddle_pserver_client]*pserver.Client) +var handleMap = make(map[C.paddle_pserver_client]*client.Client) var curHandle C.paddle_pserver_client -func add(c *pserver.Client) C.paddle_pserver_client { +func add(c *client.Client) C.paddle_pserver_client { mu.Lock() defer mu.Unlock() client := curHandle @@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client { return client } -func get(client C.paddle_pserver_client) *pserver.Client { +func get(client C.paddle_pserver_client) *client.Client { mu.Lock() defer mu.Unlock() return handleMap[client] } -func remove(client C.paddle_pserver_client) *pserver.Client { +func remove(client C.paddle_pserver_client) *client.Client { mu.Lock() defer mu.Unlock() h := handleMap[client] @@ -80,9 +81,9 @@ func (s selector) Select() bool { return bool(s) } -type lister []pserver.Server +type lister []client.Server -func (l lister) List() []pserver.Server { +func (l lister) List() []client.Server { return l } @@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server { func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client { a := C.GoString(addrs) as := strings.Split(a, ",") - servers := make([]pserver.Server, len(as)) + servers := make([]client.Server, len(as)) for i := range as { servers[i].Index = i servers[i].Addr = as[i] } - c := pserver.NewClient(lister(servers), len(as), selector(selected != 0)) + c := client.NewClient(lister(servers), len(as), selector(selected != 0)) return add(c) } //export paddle_new_etcd_pserver_client -func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client { - // TODO(helin): fault tolerant pserver client using etcd. - panic("not implemented.") +func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client { + // TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters) + addr := C.GoString(etcd_endpoints) + etcd_client := client.NewEtcd(addr) + c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0)) + return add(c) } //export paddle_pserver_client_release diff --git a/go/pserver/cclient/test/CMakeLists.txt b/go/pserver/client/c/test/CMakeLists.txt similarity index 100% rename from go/pserver/cclient/test/CMakeLists.txt rename to go/pserver/client/c/test/CMakeLists.txt diff --git a/go/pserver/cclient/test/test_cclient.c b/go/pserver/client/c/test/test_cclient.c similarity index 100% rename from go/pserver/cclient/test/test_cclient.c rename to go/pserver/client/c/test/test_cclient.c diff --git a/go/pserver/cclient/test/test_mnist.py b/go/pserver/client/c/test/test_mnist.py similarity index 100% rename from go/pserver/cclient/test/test_mnist.py rename to go/pserver/client/c/test/test_mnist.py diff --git a/go/pserver/cclient/test/test_train.py b/go/pserver/client/c/test/test_train.py similarity index 100% rename from go/pserver/cclient/test/test_train.py rename to go/pserver/client/c/test/test_train.py diff --git a/go/pserver/cclient/test/testdata/optimizer.pb b/go/pserver/client/c/test/testdata/optimizer.pb similarity index 100% rename from go/pserver/cclient/test/testdata/optimizer.pb rename to go/pserver/client/c/test/testdata/optimizer.pb diff --git a/go/pserver/client.go b/go/pserver/client/client.go similarity index 92% rename from go/pserver/client.go rename to go/pserver/client/client.go index 6938b9d5ce6f6d73c05bd6e3154777023965c319..aa8bfe30c26fcc0875ad479ecd562700ccefa5a3 100644 --- a/go/pserver/client.go +++ b/go/pserver/client/client.go @@ -1,4 +1,4 @@ -package pserver +package client import ( "errors" @@ -7,6 +7,7 @@ import ( "time" "github.com/PaddlePaddle/Paddle/go/connection" + "github.com/PaddlePaddle/Paddle/go/pserver" log "github.com/sirupsen/logrus" ) @@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool { } // InitParam initializes the parameter on parameter servers. -func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { +func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error { return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) } @@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error { // SendGrads sends gradients to parameter servers for updating // parameters. -func (c *Client) SendGrads(grads []Gradient) error { +func (c *Client) SendGrads(grads []pserver.Gradient) error { if len(grads) == 0 { return errors.New("no gradient received") } errCh := make(chan error, len(grads)) for _, g := range grads { - go func(g Gradient) { + go func(g pserver.Gradient) { err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) errCh <- err }(g) @@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error { type result struct { idx int - param Parameter + param pserver.Parameter err error } @@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) { } // GetParams gets parameters from parameter servers. -func (c *Client) GetParams(names []string) ([]Parameter, error) { +func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { rCh := make(chan result, len(names)) for idx, name := range names { go func(name string, idx int) { - var parameter Parameter + var parameter pserver.Parameter err := c.pservers[c.partition(name)].Call("Service.GetParam", name, ¶meter) rCh <- result{idx: idx, param: parameter, err: err} }(name, idx) @@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) { } sort.Sort(rs) - ps := make([]Parameter, len(rs)) + ps := make([]pserver.Parameter, len(rs)) for i := range rs { ps[i] = rs[i].param } diff --git a/go/pserver/client_test.go b/go/pserver/client/client_test.go similarity index 54% rename from go/pserver/client_test.go rename to go/pserver/client/client_test.go index b805efa921630098f7ee2fcce8c02722d57d7485..29b400812c9dc3a5f44700eacbf7ba043248f2f2 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client/client_test.go @@ -1,6 +1,7 @@ -package pserver_test +package client_test import ( + "context" "io/ioutil" "net" "net/http" @@ -8,15 +9,25 @@ import ( "strconv" "strings" "testing" + "time" "github.com/PaddlePaddle/Paddle/go/pserver" + "github.com/PaddlePaddle/Paddle/go/pserver/client" + "github.com/coreos/etcd/clientv3" + log "github.com/sirupsen/logrus" ) -const numPserver = 10 +const ( + numPserver = 10 + etcdEndpoints = "127.0.0.1:2379" + timeout = 2 * time.Second +) -var port [numPserver]int +var pserverClientPorts [numPserver]int -func init() { +// this function init pserver client and return their ports in an array. +func initClient() [numPserver]int { + var ports [numPserver]int for i := 0; i < numPserver; i++ { l, err := net.Listen("tcp", ":0") if err != nil { @@ -28,7 +39,7 @@ func init() { if err != nil { panic(err) } - port[i] = p + ports[i] = p go func(l net.Listener) { s, err := pserver.NewService(0) @@ -49,6 +60,31 @@ func init() { } }(l) } + return ports +} + +func initNativeClient() { + pserverClientPorts = initClient() +} + +func initEtcdClient() { + client, err := clientv3.New(clientv3.Config{ + Endpoints: []string{etcdEndpoints}, + DialTimeout: time.Second * time.Duration(1), + }) + if err != nil { + log.Errorf("err %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + client.Delete(ctx, pserver.PsDesired) + client.Delete(ctx, pserver.PsPath) + client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver)) + ports := initClient() + for i := 0; i < numPserver; i++ { + client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i])) + } + cancel() + client.Close() } type selector bool @@ -57,25 +93,20 @@ func (s selector) Select() bool { return bool(s) } -type lister []pserver.Server +type lister []client.Server -func (l lister) List() []pserver.Server { +func (l lister) List() []client.Server { return l } -func TestClientFull(t *testing.T) { - servers := make([]pserver.Server, numPserver) - for i := 0; i < numPserver; i++ { - servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])} - } - c := pserver.NewClient(lister(servers), len(servers), selector(true)) +func ClientTest(t *testing.T, c *client.Client) { selected := c.BeginInitParams() if !selected { t.Fatal("should be selected.") } const numParameter = 100 - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb") if err != nil { t.Fatalf("read optimizer proto failed") } @@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) { } } } + +func TestNativeClient(t *testing.T) { + initNativeClient() + servers := make([]client.Server, numPserver) + for i := 0; i < numPserver; i++ { + servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])} + } + c1 := client.NewClient(lister(servers), len(servers), selector(true)) + ClientTest(t, c1) +} + +// TODO: tmperary disable etcdClient test for dependency of etcd) +func EtcdClient(t *testing.T) { + initEtcdClient() + etcd_client := client.NewEtcd(etcdEndpoints) + c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true)) + ClientTest(t, c2) +} diff --git a/go/pserver/client/etcd_client.go b/go/pserver/client/etcd_client.go new file mode 100644 index 0000000000000000000000000000000000000000..1fd3479aa88ccbbe7c5067da1e9886b65352e847 --- /dev/null +++ b/go/pserver/client/etcd_client.go @@ -0,0 +1,125 @@ +package client + +import ( + "context" + "strconv" + "strings" + "time" + + "github.com/PaddlePaddle/Paddle/go/pserver" + "github.com/coreos/etcd/clientv3" + log "github.com/sirupsen/logrus" +) + +const ( + DefaultEtcdTimeout time.Duration = 5 * time.Second +) + +// EtcdClient is used by pserver client that is a part of trainer process. +// TODO: +// 1. add watcher to watch the change state of pservers) +// 1. add etcd lock) +type EtcdClient struct { + client *clientv3.Client + timeout time.Duration + endpoints []string +} + +// Desired read ps desired number from etcd. +func (p *EtcdClient) Desired() int { + var psDesired int + for { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + resp, err := p.client.Get(ctx, pserver.PsDesired) + cancel() + if err != nil { + log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) + time.Sleep(p.timeout) + continue + } + + kvs := resp.Kvs + if len(kvs) == 0 { + log.Infoln("Waiting for ps desired registered ...") + time.Sleep(p.timeout) + continue + } + + psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("psDesired %s invalid %v", psDesired, err) + time.Sleep(p.timeout) + continue + } + + log.Debugf("Get psDesired number: %d", psDesired) + break + } + return psDesired +} + +// List return the pserver list read from etcd. +func (p *EtcdClient) List() []Server { + psDesired := p.Desired() + + servers := make([]Server, psDesired) + for { + for i := 0; i < psDesired; i++ { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + cancel() + psKey := pserver.PsPath + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + resp, err := p.client.Get(ctx, psKey) + if err != nil { + log.Infof("Get psKey= %s error, %v", psKey, err) + time.Sleep(p.timeout) + continue + } + kvs := resp.Kvs + if len(kvs) == 0 { + log.Infof("Waiting for ps addr registered ...") + time.Sleep(p.timeout) + continue + } + + psAddr := string(resp.Kvs[0].Value) + // TODO(Longfei) check the ps address + if psAddr == "" { + log.Infof("Get psKey = %s, psAddr is empty", psKey) + time.Sleep(p.timeout) + continue + } + log.Infof("got value (%s) for key: %s", psAddr, psKey) + servers[i].Index = i + servers[i].Addr = psAddr + } + break + } + return servers +} + +// NewEtcd create a etcd client to return the state of pserver on etcd. +func NewEtcd(endpoints string) *EtcdClient { + ep := strings.Split(endpoints, ",") + var cli *clientv3.Client + var err error + for { + cli, err = clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: DefaultEtcdTimeout, + }) + if err != nil { + log.Errorf("Init etcd connection failed: %v", err) + time.Sleep(DefaultEtcdTimeout) + continue + } + break + } + log.Infof("Connected to etcd: %s\n", endpoints) + client := &EtcdClient{ + client: cli, + timeout: DefaultEtcdTimeout, + endpoints: ep, + } + return client +} diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 4d88243edd4aa817ddc263ba316a3f6be9e1e67f..37b8d522c1bd07acb41b9515a6d9bc15eae9aa32 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -13,6 +13,13 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // PsDesired is etcd path for store desired pserver count + PsDesired = "/ps_desired" + // PsAddr is the base dir for pserver to store their addr + PsPath = "/ps/" +) + // EtcdClient is the etcd client that the pserver uses for fault // tolerance, service registry and coordination. type EtcdClient struct { @@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) { // it at the same time. for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := e.initDesiredPsercers(ctx, e.numPservers) + _, err := e.initDesiredPservers(ctx, e.numPservers) cancel() if err != nil { log.Warn(err) @@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) { return pserverIdx, nil } -func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { +func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { dsStr := c.Get(PsDesired) if dsStr == "" { @@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { registered := false for i := 0; i < e.desired; i++ { - psKey := "/ps/" + strconv.Itoa(i) + psKey := PsPath + strconv.Itoa(i) log.Debugf("checking %s", psKey) ps := c.Get(psKey) log.Debugf("got value (%s) for key: %s", ps, psKey) diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index b4a040f46bff5c25b193d41e5d36b59762891574..bca3718af32b35416e94606816569dd9e76eccb6 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -2,7 +2,7 @@ package pserver // #cgo CFLAGS: -I ../../ // //FIXME: ldflags contain "build" path -// #cgo LDFLAGS: ../../build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++ +// #cgo LDFLAGS: ../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ // #include "paddle/optimizer/optimizer.h" // #include // #include diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index b99b5a5f0bfed4d780ea19b75ddaa4129be77bd5..0b2f4cfa41a630645c128ac13826de9d8b1d521b 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -11,7 +11,7 @@ func TestOptimizerCreateRelease(t *testing.T) { ElementType: Int32, } p.Content = []byte{1, 3} - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb") if err != nil { t.Fatalf("read optimizer proto failed") } diff --git a/go/pserver/service.go b/go/pserver/service.go index e15a4e5a58a3bb1a154157b1212d141478e96231..7711dc027e173e862f9b33e7a57224097026872c 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -24,9 +24,6 @@ const ( Float64 ) -// PsDesired is etcd path for store desired pserver count -const PsDesired = "/ps_desired" - // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 30e3ac8ae1ccd1d6c9e71ed113cc2543e8c1e224..b6d20d2c8b7ba0ccd7ab46669a597a21dc11c381 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,6 +10,10 @@ import ( "github.com/PaddlePaddle/Paddle/go/pserver" ) +const ( + OptimizerConfig = "./client/c/test/testdata/optimizer.pb" +) + func TestServiceFull(t *testing.T) { s, err := pserver.NewService(0) if err != nil { @@ -19,7 +23,7 @@ func TestServiceFull(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile(OptimizerConfig) if err != nil { t.Fatalf("read optimizer proto failed") } @@ -149,7 +153,7 @@ func TestBlockUntilInitialized(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile(OptimizerConfig) if err != nil { t.Fatalf("read optimizer proto failed") } diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index c7d7b14518ebb8415014a78fc1a3bafa8c386191..cc6b52e9271ff00e91f2ca172815c543eb99261d 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(dynload) + nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) diff --git a/paddle/platform/cuda.h b/paddle/platform/cuda.h index 8fe891f9ce6c3add1df48a8b1f79fd811c7a4362..5ed36c0f02549e29fc680eba097c9d851fa346e5 100644 --- a/paddle/platform/cuda.h +++ b/paddle/platform/cuda.h @@ -34,6 +34,16 @@ int GetDeviceCount(void) { return count; } +int GetCurrentDeviceId(void) { + int device_id; + throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed"); + return device_id; +} + +void SetDeviceId(int device_id) { + throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed"); +} + } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f829b70128655f018c59db32b95d3f1789da5fc --- /dev/null +++ b/paddle/platform/dynload/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h new file mode 100644 index 0000000000000000000000000000000000000000..c9150ac5738f3292f530f053d7c0bb8a1203f13d --- /dev/null +++ b/paddle/platform/dynload/cublas.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/platform/dynamic_loader.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag cublas_dso_flag; +void *cublas_dso_handle = nullptr; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load cublas routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#ifdef PADDLE_USE_DSO +#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cublasStatus_t operator()(Args... args) { \ + typedef cublasStatus_t (*cublasFunc)(Args...); \ + std::call_once(cublas_dso_flag, \ + paddle::platform::dynload::GetCublasDsoHandle, \ + &cublas_dso_handle); \ + void *p_##__name = dlsym(cublas_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; // struct DynLoad__##__name +#else +#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cublasStatus_t operator()(Args... args) { \ + return __name(args...); \ + } \ + } __name; // struct DynLoad__##__name +#endif + +#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name) + +// include all needed cublas functions in HPPL +// clang-format off +#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasSgemv) \ + __macro(cublasDgemv) \ + __macro(cublasSgemm) \ + __macro(cublasDgemm) \ + __macro(cublasSgeam) \ + __macro(cublasDgeam) \ + +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched) +CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) + +#undef DYNAMIC_LOAD_CUBLAS_WRAP +#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP +#undef CUBLAS_BLAS_ROUTINE_EACH + +// clang-format on +#ifndef PADDLE_TYPE_DOUBLE +#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam +#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv +#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm +#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched +#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched +#else +#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam +#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv +#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm +#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched +#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched +#endif +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/cudnn.h b/paddle/platform/dynload/cudnn.h new file mode 100644 index 0000000000000000000000000000000000000000..c03424b375e5dc80610eb4ca8146069b7300e016 --- /dev/null +++ b/paddle/platform/dynload/cudnn.h @@ -0,0 +1,134 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/platform/dynamic_loader.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag cudnn_dso_flag; +void* cudnn_dso_handle = nullptr; + +#ifdef PADDLE_USE_DSO + +#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using cudnn_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(cudnn_dso_flag, \ + paddle::platform::dynload::GetCudnnDsoHandle, \ + &cudnn_dso_handle); \ + void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ + +#else + +#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + return __name(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ + +#endif + +/** + * include all needed cudnn functions in HPPL + * different cudnn version has different interfaces + **/ +// clang-format off +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor) \ + __macro(cudnnSetTensor4dDescriptorEx) \ + __macro(cudnnGetConvolutionNdForwardOutputDim) \ + __macro(cudnnGetConvolutionForwardAlgorithm) \ + __macro(cudnnCreateTensorDescriptor) \ + __macro(cudnnDestroyTensorDescriptor) \ + __macro(cudnnCreateFilterDescriptor) \ + __macro(cudnnSetFilter4dDescriptor) \ + __macro(cudnnSetPooling2dDescriptor) \ + __macro(cudnnDestroyFilterDescriptor) \ + __macro(cudnnCreateConvolutionDescriptor) \ + __macro(cudnnCreatePoolingDescriptor) \ + __macro(cudnnDestroyPoolingDescriptor) \ + __macro(cudnnSetConvolution2dDescriptor) \ + __macro(cudnnDestroyConvolutionDescriptor) \ + __macro(cudnnCreate) \ + __macro(cudnnDestroy) \ + __macro(cudnnSetStream) \ + __macro(cudnnActivationForward) \ + __macro(cudnnConvolutionForward) \ + __macro(cudnnConvolutionBackwardBias) \ + __macro(cudnnGetConvolutionForwardWorkspaceSize) \ + __macro(cudnnTransformTensor) \ + __macro(cudnnPoolingForward) \ + __macro(cudnnPoolingBackward) \ + __macro(cudnnSoftmaxBackward) \ + __macro(cudnnSoftmaxForward) \ + __macro(cudnnGetVersion) \ + __macro(cudnnGetErrorString) +CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP) + +#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ + __macro(cudnnAddTensor) \ + __macro(cudnnConvolutionBackwardData) \ + __macro(cudnnConvolutionBackwardFilter) +CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP) + +// APIs available after R3: +#if CUDNN_VERSION >= 3000 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \ + __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \ + __macro(cudnnGetConvolutionBackwardDataAlgorithm) \ + __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ + __macro(cudnnGetConvolutionBackwardDataWorkspaceSize) +CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 +#endif + + +// APIs available after R4: +#if CUDNN_VERSION >= 4007 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \ + __macro(cudnnBatchNormalizationForwardTraining) \ + __macro(cudnnBatchNormalizationForwardInference) \ + __macro(cudnnBatchNormalizationBackward) +CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4 +#endif + +// APIs in R5 +#if CUDNN_VERSION >= 5000 +#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \ + __macro(cudnnCreateActivationDescriptor) \ + __macro(cudnnSetActivationDescriptor) \ + __macro(cudnnGetActivationDescriptor) \ + __macro(cudnnDestroyActivationDescriptor) +CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_R5 +#endif + +#undef CUDNN_DNN_ROUTINE_EACH +// clang-format on +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/curand.h b/paddle/platform/dynload/curand.h new file mode 100644 index 0000000000000000000000000000000000000000..1ef7a8c833d7488ec18f7c8df55b16006e94cc72 --- /dev/null +++ b/paddle/platform/dynload/curand.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/platform/dynamic_loader.h" + +namespace paddle { +namespace platform { +namespace dynload { +std::once_flag curand_dso_flag; +void *curand_dso_handle = nullptr; +#ifdef PADDLE_USE_DSO +#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + typedef curandStatus_t (*curandFunc)(Args...); \ + std::call_once(curand_dso_flag, \ + paddle::platform::dynload::GetCurandDsoHandle, \ + &curand_dso_handle); \ + void *p_##__name = dlsym(curand_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ +#else +#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + return __name(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ +#endif + +/* include all needed curand functions in HPPL */ +// clang-format off +#define CURAND_RAND_ROUTINE_EACH(__macro) \ + __macro(curandCreateGenerator) \ + __macro(curandSetStream) \ + __macro(curandSetPseudoRandomGeneratorSeed)\ + __macro(curandGenerateUniform) \ + __macro(curandGenerateUniformDouble) \ + __macro(curandDestroyGenerator) +// clang-format on + +CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) + +#undef CURAND_RAND_ROUTINE_EACH +#undef DYNAMIC_LOAD_CURAND_WRAP +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ef67bad8c63cd1dfff967dbf1627ed84645e0d0 --- /dev/null +++ b/paddle/platform/dynload/dynamic_loader.cc @@ -0,0 +1,169 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "dynamic_loader.h" +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(cudnn_dir, "", + "Specify path for loading libcudnn.so. For instance, " + "/usr/local/cudnn/lib. If empty [default], dlopen " + "will search cudnn from LD_LIBRARY_PATH"); + +DEFINE_string(cuda_dir, "", + "Specify path for loading cuda library, such as libcublas, " + "libcurand. For instance, /usr/local/cuda/lib64. If default, " + "dlopen will search cuda from LD_LIBRARY_PATH"); + +DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); + +DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); + +namespace paddle { +namespace platform { +namespace dynload { + +static inline std::string join(const std::string& part1, + const std::string& part2) { + // directory separator + const char sep = '/'; + if (!part2.empty() && part2.front() == sep) { + return part2; + } + std::string ret; + ret.reserve(part1.size() + part2.size() + 1); + ret = part1; + if (!ret.empty() && ret.back() != sep) { + ret += sep; + } + ret += part2; + return ret; +} + +static inline void GetDsoHandleFromDefaultPath(std::string& dso_path, + void** dso_handle, + int dynload_flags) { + VLOG(3) << "Try to find library: " << dso_path + << " from default system path."; + // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH + *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + +// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to +// bring System Integrity Projection (SIP), if dso_handle +// is null, search from default package path in Mac OS. +#if defined(__APPLE__) || defined(__OSX__) + if (nullptr == *dso_handle) { + dso_path = join("/usr/local/cuda/lib/", dso_path); + *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + if (nullptr == *dso_handle) { + if (dso_path == "libcudnn.dylib") { + LOG(FATAL) + << "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n" // NOLINT + << "For instance, sudo tar -xzf " + "cudnn-7.5-osx-x64-v5.0-ga.tgz -C " // NOLINT + << "/usr/local \n sudo chmod a+r " + "/usr/local/cuda/include/cudnn.h " // NOLINT + << "/usr/local/cuda/lib/libcudnn*"; + } + } + } +#endif +} + +static inline void GetDsoHandleFromSearchPath(const std::string& search_root, + const std::string& dso_name, + void** dso_handle) { + int dynload_flags = RTLD_LAZY | RTLD_LOCAL; + *dso_handle = nullptr; + + std::string dlPath = dso_name; + if (search_root.empty()) { + GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); + } else { + // search xxx.so from custom path + dlPath = join(search_root, dso_name); + *dso_handle = dlopen(dlPath.c_str(), dynload_flags); + // if not found, search from default path + if (nullptr == *dso_handle) { + LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " (" + << dlerror() << ")"; + dlPath = dso_name; + GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); + } + } + + CHECK(nullptr != *dso_handle) << "Failed to find dynamic library: " << dlPath + << " (" << dlerror() << ") \n" + << "Please specify its path correctly using " + "following ways: \n" + + << "Method. set environment variable " + "LD_LIBRARY_PATH on Linux or " + << "DYLD_LIBRARY_PATH on Mac OS. \n" + << "For instance, issue command: export " + "LD_LIBRARY_PATH=... \n" + + << "Note: After Mac OS 10.11, using the " + "DYLD_LIBRARY_PATH is impossible " + << "unless System Integrity Protection (SIP) " + "is disabled."; +} + +void GetCublasDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle); +#endif +} + +void GetCudnnDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle); +#endif +} + +void GetCurandDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle); +#endif +} + +void GetWarpCTCDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle); +#endif +} + +void GetLapackDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle); +#endif +} + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/dynamic_loader.h b/paddle/platform/dynload/dynamic_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..a99b05443feb909f10b2c56f4d8bdf3c6fa11e3f --- /dev/null +++ b/paddle/platform/dynload/dynamic_loader.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace platform { +namespace dynload { + +/** + * @brief load the DSO of CUBLAS + * + * @param **dso_handle dso handler + * + */ +void GetCublasDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of CUDNN + * + * @param **dso_handle dso handler + * + */ +void GetCudnnDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of CURAND + * + * @param **dso_handle dso handler + * + */ +void GetCurandDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of warp-ctc + * + * @param **dso_handle dso handler + * + */ +void GetWarpCTCDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of lapack + * + * @param **dso_handle dso handler + * + */ +void GetLapackDsoHandle(void** dso_handle); + +} // namespace dynload +} // namespace platform +} // namespace paddle