提交 91c6a792 编写于 作者: S Superjom

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into network

......@@ -113,7 +113,7 @@ include(coveralls) # set code coverage
include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS})
set(EXTERNAL_LIBS
......
......@@ -2,10 +2,10 @@ INCLUDE(ExternalProject)
SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any)
INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/linb_any)
INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add(
linb_any
extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/thelink2012/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
......@@ -17,5 +17,15 @@ ExternalProject_Add(
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(lib_any STATIC ${dummyfile})
else()
add_library(lib_any INTERFACE)
endif()
add_dependencies(lib_any extern_lib_any)
add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE)
LIST(APPEND external_project_dependencies linb_any)
\ No newline at end of file
LIST(APPEND external_project_dependencies lib_any)
......@@ -2,10 +2,10 @@ INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add(
eigen3
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
......@@ -26,4 +26,14 @@ ExternalProject_Add(
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_eigen3 = \"${dummyfile}\";")
add_library(eigen3 STATIC ${dummyfile})
else()
add_library(eigen3 INTERFACE)
endif()
add_dependencies(eigen3 extern_eigen3)
LIST(APPEND external_project_dependencies eigen3)
......@@ -13,7 +13,7 @@
# limitations under the License.
#
add_subdirectory(pserver/cclient)
add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
......@@ -15,6 +15,7 @@ import (
func main() {
port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
......@@ -29,11 +30,16 @@ func main() {
}
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
if err != nil {
panic(err)
var idx int
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register()
if err != nil {
panic(err)
}
}
s, err := pserver.NewService(idx)
......
......@@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath)
......@@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// managment system to kill the master server
// management system to kill the master server
// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
......
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
go_library(paddle_pserver_cclient STATIC)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING)
add_subdirectory(test)
endif()
......@@ -30,15 +30,16 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus"
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.paddle_pserver_client {
func add(c *client.Client) C.paddle_pserver_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
......@@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client {
return client
}
func get(client C.paddle_pserver_client) *pserver.Client {
func get(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_pserver_client) *pserver.Client {
func remove(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
......@@ -80,9 +81,9 @@ func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
type lister []client.Server
func (l lister) List() []pserver.Server {
func (l lister) List() []client.Server {
return l
}
......@@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server {
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
servers := make([]client.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
c := client.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c)
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcd_endpoints)
etcd_client := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
return add(c)
}
//export paddle_pserver_client_release
......
package pserver
package client
import (
"errors"
......@@ -7,6 +7,7 @@ import (
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
)
......@@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool {
}
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
}
......@@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
func (c *Client) SendGrads(grads []pserver.Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
go func(g pserver.Gradient) {
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err
}(g)
......@@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
type result struct {
idx int
param Parameter
param pserver.Parameter
err error
}
......@@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) {
}
// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) {
func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
rCh := make(chan result, len(names))
for idx, name := range names {
go func(name string, idx int) {
var parameter Parameter
var parameter pserver.Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx)
......@@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) {
}
sort.Sort(rs)
ps := make([]Parameter, len(rs))
ps := make([]pserver.Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].param
}
......
package pserver_test
package client_test
import (
"context"
"io/ioutil"
"net"
"net/http"
......@@ -8,15 +9,25 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
const numPserver = 10
const (
numPserver = 10
etcdEndpoints = "127.0.0.1:2379"
timeout = 2 * time.Second
)
var port [numPserver]int
var pserverClientPorts [numPserver]int
func init() {
// this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
var ports [numPserver]int
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
......@@ -28,7 +39,7 @@ func init() {
if err != nil {
panic(err)
}
port[i] = p
ports[i] = p
go func(l net.Listener) {
s, err := pserver.NewService(0)
......@@ -49,6 +60,31 @@ func init() {
}
}(l)
}
return ports
}
func initNativeClient() {
pserverClientPorts = initClient()
}
func initEtcdClient() {
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{etcdEndpoints},
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
log.Errorf("err %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath)
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
ports := initClient()
for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
}
cancel()
client.Close()
}
type selector bool
......@@ -57,25 +93,20 @@ func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
type lister []client.Server
func (l lister) List() []pserver.Server {
func (l lister) List() []client.Server {
return l
}
func TestClientFull(t *testing.T) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
func ClientTest(t *testing.T, c *client.Client) {
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......@@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) {
}
}
}
func TestNativeClient(t *testing.T) {
initNativeClient()
servers := make([]client.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
ClientTest(t, c2)
}
package client
import (
"context"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
const (
DefaultEtcdTimeout time.Duration = 5 * time.Second
)
// EtcdClient is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
}
// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err)
time.Sleep(p.timeout)
continue
}
log.Debugf("Get psDesired number: %d", psDesired)
break
}
return psDesired
}
// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
psDesired := p.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout)
continue
}
psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout)
continue
}
log.Infof("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i
servers[i].Addr = psAddr
}
break
}
return servers
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
for {
cli, err = clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: DefaultEtcdTimeout,
})
if err != nil {
log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout)
continue
}
break
}
log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{
client: cli,
timeout: DefaultEtcdTimeout,
endpoints: ep,
}
return client
}
......@@ -13,6 +13,13 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
......@@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) {
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers)
_, err := e.initDesiredPservers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn(err)
......@@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) {
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
......@@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
......
......@@ -2,7 +2,7 @@ package pserver
// #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path
// #cgo LDFLAGS: ../../build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++
// #cgo LDFLAGS: ../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++
// #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h>
// #include <string.h>
......
......@@ -11,7 +11,7 @@ func TestOptimizerCreateRelease(t *testing.T) {
ElementType: Int32,
}
p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......
......@@ -24,9 +24,6 @@ const (
Float64
)
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
......
......@@ -10,6 +10,10 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver"
)
const (
OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)
func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0)
if err != nil {
......@@ -19,7 +23,7 @@ func TestServiceFull(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......@@ -149,7 +153,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......
add_subdirectory(dynload)
nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc)
......
......@@ -34,6 +34,16 @@ int GetDeviceCount(void) {
return count;
}
int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
} // namespace platform
} // namespace paddle
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cublas_v2.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load cublas routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, \
paddle::platform::dynload::GetCublasDsoHandle, \
&cublas_dso_handle); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
} \
} __name; // struct DynLoad__##__name
#else
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; // struct DynLoad__##__name
#endif
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
// include all needed cublas functions in HPPL
// clang-format off
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv) \
__macro(cublasDgemv) \
__macro(cublasSgemm) \
__macro(cublasDgemm) \
__macro(cublasSgeam) \
__macro(cublasDgeam) \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_WRAP
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH
// clang-format on
#ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
#else
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cudnn.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(cudnn_dso_flag, \
paddle::platform::dynload::GetCudnnDsoHandle, \
&cudnn_dso_handle); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#endif
/**
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
// clang-format off
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor) \
__macro(cudnnSetTensor4dDescriptorEx) \
__macro(cudnnGetConvolutionNdForwardOutputDim) \
__macro(cudnnGetConvolutionForwardAlgorithm) \
__macro(cudnnCreateTensorDescriptor) \
__macro(cudnnDestroyTensorDescriptor) \
__macro(cudnnCreateFilterDescriptor) \
__macro(cudnnSetFilter4dDescriptor) \
__macro(cudnnSetPooling2dDescriptor) \
__macro(cudnnDestroyFilterDescriptor) \
__macro(cudnnCreateConvolutionDescriptor) \
__macro(cudnnCreatePoolingDescriptor) \
__macro(cudnnDestroyPoolingDescriptor) \
__macro(cudnnSetConvolution2dDescriptor) \
__macro(cudnnDestroyConvolutionDescriptor) \
__macro(cudnnCreate) \
__macro(cudnnDestroy) \
__macro(cudnnSetStream) \
__macro(cudnnActivationForward) \
__macro(cudnnConvolutionForward) \
__macro(cudnnConvolutionBackwardBias) \
__macro(cudnnGetConvolutionForwardWorkspaceSize) \
__macro(cudnnTransformTensor) \
__macro(cudnnPoolingForward) \
__macro(cudnnPoolingBackward) \
__macro(cudnnSoftmaxBackward) \
__macro(cudnnSoftmaxForward) \
__macro(cudnnGetVersion) \
__macro(cudnnGetErrorString)
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
__macro(cudnnAddTensor) \
__macro(cudnnConvolutionBackwardData) \
__macro(cudnnConvolutionBackwardFilter)
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)
// APIs available after R3:
#if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif
// APIs available after R4:
#if CUDNN_VERSION >= 4007
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining) \
__macro(cudnnBatchNormalizationForwardInference) \
__macro(cudnnBatchNormalizationBackward)
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#endif
// APIs in R5
#if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
__macro(cudnnCreateActivationDescriptor) \
__macro(cudnnSetActivationDescriptor) \
__macro(cudnnGetActivationDescriptor) \
__macro(cudnnDestroyActivationDescriptor)
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#endif
#undef CUDNN_DNN_ROUTINE_EACH
// clang-format on
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <curand.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
typedef curandStatus_t (*curandFunc)(Args...); \
std::call_once(curand_dso_flag, \
paddle::platform::dynload::GetCurandDsoHandle, \
&curand_dso_handle); \
void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#endif
/* include all needed curand functions in HPPL */
// clang-format off
#define CURAND_RAND_ROUTINE_EACH(__macro) \
__macro(curandCreateGenerator) \
__macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)
#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "dynamic_loader.h"
#include <dlfcn.h>
#include <memory>
#include <mutex>
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, "
"/usr/local/cudnn/lib. If empty [default], dlopen "
"will search cudnn from LD_LIBRARY_PATH");
DEFINE_string(cuda_dir, "",
"Specify path for loading cuda library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH");
DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
namespace paddle {
namespace platform {
namespace dynload {
static inline std::string join(const std::string& part1,
const std::string& part2) {
// directory separator
const char sep = '/';
if (!part2.empty() && part2.front() == sep) {
return part2;
}
std::string ret;
ret.reserve(part1.size() + part2.size() + 1);
ret = part1;
if (!ret.empty() && ret.back() != sep) {
ret += sep;
}
ret += part2;
return ret;
}
static inline void GetDsoHandleFromDefaultPath(std::string& dso_path,
void** dso_handle,
int dynload_flags) {
VLOG(3) << "Try to find library: " << dso_path
<< " from default system path.";
// default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH
*dso_handle = dlopen(dso_path.c_str(), dynload_flags);
// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to
// bring System Integrity Projection (SIP), if dso_handle
// is null, search from default package path in Mac OS.
#if defined(__APPLE__) || defined(__OSX__)
if (nullptr == *dso_handle) {
dso_path = join("/usr/local/cuda/lib/", dso_path);
*dso_handle = dlopen(dso_path.c_str(), dynload_flags);
if (nullptr == *dso_handle) {
if (dso_path == "libcudnn.dylib") {
LOG(FATAL)
<< "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n" // NOLINT
<< "For instance, sudo tar -xzf "
"cudnn-7.5-osx-x64-v5.0-ga.tgz -C " // NOLINT
<< "/usr/local \n sudo chmod a+r "
"/usr/local/cuda/include/cudnn.h " // NOLINT
<< "/usr/local/cuda/lib/libcudnn*";
}
}
}
#endif
}
static inline void GetDsoHandleFromSearchPath(const std::string& search_root,
const std::string& dso_name,
void** dso_handle) {
int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
*dso_handle = nullptr;
std::string dlPath = dso_name;
if (search_root.empty()) {
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
} else {
// search xxx.so from custom path
dlPath = join(search_root, dso_name);
*dso_handle = dlopen(dlPath.c_str(), dynload_flags);
// if not found, search from default path
if (nullptr == *dso_handle) {
LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " ("
<< dlerror() << ")";
dlPath = dso_name;
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
}
}
CHECK(nullptr != *dso_handle) << "Failed to find dynamic library: " << dlPath
<< " (" << dlerror() << ") \n"
<< "Please specify its path correctly using "
"following ways: \n"
<< "Method. set environment variable "
"LD_LIBRARY_PATH on Linux or "
<< "DYLD_LIBRARY_PATH on Mac OS. \n"
<< "For instance, issue command: export "
"LD_LIBRARY_PATH=... \n"
<< "Note: After Mac OS 10.11, using the "
"DYLD_LIBRARY_PATH is impossible "
<< "unless System Integrity Protection (SIP) "
"is disabled.";
}
void GetCublasDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle);
#endif
}
void GetCudnnDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle);
#endif
}
void GetCurandDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle);
#endif
}
void GetWarpCTCDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle);
#endif
}
void GetLapackDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle);
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace platform {
namespace dynload {
/**
* @brief load the DSO of CUBLAS
*
* @param **dso_handle dso handler
*
*/
void GetCublasDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CUDNN
*
* @param **dso_handle dso handler
*
*/
void GetCudnnDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CURAND
*
* @param **dso_handle dso handler
*
*/
void GetCurandDsoHandle(void** dso_handle);
/**
* @brief load the DSO of warp-ctc
*
* @param **dso_handle dso handler
*
*/
void GetWarpCTCDsoHandle(void** dso_handle);
/**
* @brief load the DSO of lapack
*
* @param **dso_handle dso handler
*
*/
void GetLapackDsoHandle(void** dso_handle);
} // namespace dynload
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册