提交 7c9b53c3 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into from_tar

...@@ -27,6 +27,7 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -27,6 +27,7 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET)
include(simd) include(simd)
...@@ -110,6 +111,7 @@ include_directories("${PROJ_ROOT}") ...@@ -110,6 +111,7 @@ include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient")
include_directories(${Boost_INCLUDE_DIRS})
set(EXTERNAL_LIBS set(EXTERNAL_LIBS
${GFLAGS_LIBRARIES} ${GFLAGS_LIBRARIES}
......
...@@ -30,7 +30,13 @@ func main() { ...@@ -30,7 +30,13 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
if err != nil {
panic(err)
}
s, err := pserver.NewService(idx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -18,8 +18,8 @@ const ( ...@@ -18,8 +18,8 @@ const (
DefaultAddrPath = "/master/addr" DefaultAddrPath = "/master/addr"
) )
// EtcdClient is the etcd client that master uses for fault tolerance // EtcdClient is the etcd client that the master uses for fault
// and service registry. // tolerance and service registry.
type EtcdClient struct { type EtcdClient struct {
lockPath string lockPath string
statePath string statePath string
......
package pserver package pserver
import ( import (
"errors"
"hash/fnv" "hash/fnv"
"sort" "sort"
"time" "time"
...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { ...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g Gradient) {
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
...@@ -31,7 +30,7 @@ func init() { ...@@ -31,7 +30,7 @@ func init() {
port[i] = p port[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package pserver
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
numPservers int
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
}
// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
return &EtcdClient{
etcdTimeout: timeout,
numPservers: numPservers,
etcdEndpoints: endpoints,
}
}
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return 0, err
}
// initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: e.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout)
continue
}
e.etcdClient = cli
log.Debugf("inited client to %s", e.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
var pserverIdx int
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
idx = i
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil {
return 0, err
}
return idx, nil
}
package pserver package pserver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -55,160 +46,25 @@ type Gradient Parameter ...@@ -55,160 +46,25 @@ type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int
mu sync.Mutex mu sync.Mutex
opt *optimizer opt *optimizer
paramMap map[string]Parameter paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { func NewService(idx int) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{
idx: idx,
opt: newOptimizer(sgd, 0.005),
}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil return s, nil
} }
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -75,7 +75,7 @@ func TestFull(t *testing.T) { ...@@ -75,7 +75,7 @@ func TestFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
...@@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) { ...@@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) {
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
...@@ -9,17 +9,10 @@ add_subdirectory(pserver) ...@@ -9,17 +9,10 @@ add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer) add_subdirectory(optimizer)
add_subdirectory(strings) add_subdirectory(string)
# Do not build go directory until go cmake is working smoothly.
# if(CMAKE_Go_COMPILER)
# add_subdirectory(go)
# endif()
find_package(Boost QUIET)
if(Boost_FOUND) if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS}) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
endif() endif()
......
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
add_subdirectory(detail)
...@@ -97,6 +97,7 @@ class BuddyAllocator { ...@@ -97,6 +97,7 @@ class BuddyAllocator {
struct Block { struct Block {
size_t size; size_t size;
Block* left, right; Block* left, right;
size_t index; // allocator id
}; };
... ...
}; };
......
if(${WITH_GPU})
nv_library(system_allocator SRCS system_allocator.cc DEPS gflags)
nv_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags)
else(${WITH_GPU})
cc_library(system_allocator SRCS system_allocator.cc DEPS gflags)
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags)
endif(${WITH_GPU})
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/detail/buddy_allocator.h"
namespace paddle {
namespace memory {
namespace detail {
BuddyAllocator::BuddyAllocator(size_t pool_size, size_t max_pools,
SystemAllocator* system_allocator)
: pool_size_(pool_size),
max_pools_(max_pools),
system_allocator_(system_allocator) {
PADDLE_ASSERT(pool_size > 0);
PADDLE_ASSERT(max_pools > 0);
PADDLE_ASSERT(system_allocator != nullptr);
}
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/detail/system_allocator.h"
#include <mutex>
#include <vector>
namespace paddle {
namespace memory {
namespace detail {
class BuddyAllocator {
public:
BuddyAllocator(size_t pool_size, size_t max_pools,
SystemAllocator* system_allocator);
~BuddyAllocator();
void* Alloc(size_t size);
void Free(void*);
size_t Used();
private:
struct Block {
size_t size_;
Block* left_; // left buddy
Block* right_; // right buddy
};
// Initially, there is only one pool. If a Alloc founds not enough
// memory from that pool, and there has not been max_num_pools_,
// create a new pool by calling system_allocator_.Alloc(pool_size_).
std::vector<void*> pools_;
size_t pool_size_; // the size of each pool;
size_t max_num_pools_; // the size of all pools;
SystemAllocator* system_allocator_;
std::mutex mutex_;
// Disable copy and assignment.
BuddyAllocator(const BuddyAllocator&) = delete;
BuddyAllocator& operator=(const BuddyAllocator&) = delete;
};
BuddyAllocator<CPUAllocator>* GetCPUBuddyAllocator() {
static BuddyAllocator<CPUAllocator>* a = nullptr;
if (a == nullptr) {
a = new BuddyAllocator<CPUAllocator>();
}
return a;
}
#ifndef PADDLE_ONLY_CPU // The following code are for CUDA.
BuddyAllocator<GPUAllocator>* GetGPUBuddyAllocator(int gpu_id) {
static BuddyAllocator<GPUAllocator>** as = NULL;
if (as == NULL) {
int gpu_num = platform::GetDeviceCount();
as = new BuddyAllocator<GPUAllocator>*[gpu_num];
for (int gpu = 0; gpu < gpu_num; gpu++) {
as[gpu] = new BuddyAllocator<GPUAllocator>();
}
}
return as[gpu_id];
}
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/detail/system_allocator.h"
#include <stdlib.h> // for malloc and free
#include <sys/mman.h> // for mlock and munlock
#include "gflags/gflags.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda.h"
// If use_pinned_memory is true, CPUAllocator calls mlock, which
// returns pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, false,
"If set, allocate cpu/gpu pinned memory.");
namespace paddle {
namespace memory {
namespace detail {
void* CPUAllocator::Alloc(size_t size) {
// According to http://www.cplusplus.com/reference/cstdlib/malloc/,
// malloc might not return nullptr if size is zero, but the returned
// pointer shall not be dereferenced -- so we make it nullptr.
if (size <= 0) return nullptr;
void* p = malloc(size);
if (p != nullptr && FLAGS_use_pinned_memory) {
mlock(p, size);
}
return p;
}
void CPUAllocator::Free(void* p, size_t size) {
if (p != nullptr && FLAGS_use_pinned_memory) {
munlock(p, size);
}
free(p);
}
#ifndef PADDLE_ONLY_CPU
void* GPUAllocator::Alloc(size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr
// if size is 0. We just make sure it does.
if (size <= 0) {
return nullptr;
}
void* p = 0;
cudaError_t result =
FLAGS_use_pinned_memory ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
if (result != cudaSuccess) {
cudaGetLastError(); // clear error if there is any.
}
return result == cudaSuccess ? p : nullptr;
}
void GPUAllocator::Free(void* p, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = FLAGS_use_pinned_memory ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, "cudaFree{Host} failed");
}
}
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
namespace paddle {
namespace memory {
namespace detail {
// SystemAllocator is the parent class of CPUAllocator and
// GPUAllocator. A BuddyAllocator object uses a SystemAllocator*
// pointing to the underlying system allocator. An alternative to
// this class hierarchy is to pass a system allocator class to
// BuddyAllocator as a template parameter. This approach makes
// BuddyAllocator a class template, and it's very complicated
// algorithm would make the buddy_allocator.h messy.
class SystemAllocator {
public:
virtual ~SystemAllocator() {}
virtual void* Alloc(size_t size) = 0;
virtual void Free(void* p, size_t size) = 0;
};
class CPUAllocator : public SystemAllocator {
public:
virtual void* Alloc(size_t size);
virtual void Free(void* p, size_t size);
};
#ifndef PADDLE_ONLY_CPU
class GPUAllocator : public SystemAllocator {
public:
virtual void* Alloc(size_t size);
virtual void Free(void* p, size_t size);
};
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/detail/system_allocator.h"
#include <memory>
#include <vector>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
DECLARE_bool(use_pinned_memory);
void TestAllocator(paddle::memory::detail::SystemAllocator& a, size_t size) {
bool freed = false;
{
void* p = a.Alloc(size);
if (size > 0) {
EXPECT_NE(p, nullptr);
} else {
EXPECT_EQ(p, nullptr);
}
int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, [&](void* p) {
freed = true;
a.Free(p, size);
});
}
EXPECT_TRUE(freed);
}
TEST(CPUAllocator, NoLockMem) {
FLAGS_use_pinned_memory = false;
paddle::memory::detail::CPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
TEST(CPUAllocator, LockMem) {
FLAGS_use_pinned_memory = true;
paddle::memory::detail::CPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#ifndef PADDLE_ONLY_CPU
TEST(GPUAllocator, NoStaging) {
FLAGS_use_pinned_memory = false;
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
TEST(GPUAllocator, Staging) {
FLAGS_use_pinned_memory = true;
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#endif // PADDLE_ONLY_CPU
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/memory.h"
#include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/assert.h"
#include <boost/variant.hpp>
namespace paddle {
namespace memory {
void* Alloc(platform::Place pl, size_t size) {
#ifndef PADDLE_ONLY_CPU
if (paddle::platform::is_gpu_place(pl)) {
size_t gpu_id = boost::get<platform::GPUPlace>(pl).device;
return detail::GetGPUBuddyAllocator(gpu_id)->Alloc(size);
}
#endif // PADDLE_ONLY_CPU
PADDLE_ASSERT(paddle::platform::is_cpu_place(pl));
return detail::GetCPUBuddyAllocator()->Alloc(size);
}
void Free(paddle::platform::Place pl, void* p) {
#ifndef PADDLE_ONLY_CPU
if (paddle::platform::is_gpu_place(pl)) {
size_t gpu_id = boost::get<platform::GPUPlace>(pl).device;
detail::GetGPUBuddyAllocator(gpu_id)->Free(p);
}
#endif // PADDLE_ONLY_CPU
PADDLE_ASSERT(paddle::platform::is_cpu_place(pl));
detail::GetCPUBuddyAllocator()->Free(p);
}
size_t Used(paddle::platform::Place pl) {
#ifndef PADDLE_ONLY_CPU
if (paddle::platform::is_gpu_place(pl)) {
size_t gpu_id = boost::get<platform::GPUPlace>(pl).device;
return detail::GetGPUBuddyAllocator(gpu_id)->Used();
}
#endif // PADDLE_ONLY_CPU
PADDLE_ASSERT(paddle::platform::is_cpu_place(pl));
return detail::GetCPUBuddyAllocator()->Used();
}
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/place.h"
namespace paddle {
namespace memory {
void* Alloc(paddle::platform::Place, size_t);
void Free(paddle::platform::Place, void*);
size_t Used(paddle::platform::Place);
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef PADDLE_ONLY_CPU
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
namespace paddle {
namespace platform {
inline void throw_on_error(cudaError_t e, const char* message) {
if (e) {
throw thrust::system_error(e, thrust::cuda_category(), message);
}
}
int GetDeviceCount(void) {
int count;
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
return count;
}
} // namespace platform
} // namespace paddle
#endif // PADDLE_ONLY_CPU
...@@ -8,8 +8,8 @@ namespace detail { ...@@ -8,8 +8,8 @@ namespace detail {
class PlacePrinter : public boost::static_visitor<> { class PlacePrinter : public boost::static_visitor<> {
public: public:
PlacePrinter(std::ostream &os) : os_(os) {} PlacePrinter(std::ostream &os) : os_(os) {}
void operator()(const CpuPlace &) { os_ << "CpuPlace"; } void operator()(const CPUPlace &) { os_ << "CPUPlace"; }
void operator()(const GpuPlace &p) { os_ << "GpuPlace(" << p.device << ")"; } void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; }
private: private:
std::ostream &os_; std::ostream &os_;
...@@ -22,14 +22,14 @@ static Place the_default_place; ...@@ -22,14 +22,14 @@ static Place the_default_place;
void set_place(const Place &place) { the_default_place = place; } void set_place(const Place &place) { the_default_place = place; }
const Place &get_place() { return the_default_place; } const Place &get_place() { return the_default_place; }
const GpuPlace default_gpu() { return GpuPlace(0); } const GPUPlace default_gpu() { return GPUPlace(0); }
const CpuPlace default_cpu() { return CpuPlace(); } const CPUPlace default_cpu() { return CPUPlace(); }
bool is_gpu_place(const Place &p) { bool is_gpu_place(const Place &p) {
return boost::apply_visitor(IsGpuPlace(), p); return boost::apply_visitor(IsGPUPlace(), p);
} }
bool is_cpu_place(const Place &p) { bool is_cpu_place(const Place &p) {
return !boost::apply_visitor(IsGpuPlace(), p); return !boost::apply_visitor(IsGPUPlace(), p);
} }
bool places_are_same_class(const Place &p1, const Place &p2) { bool places_are_same_class(const Place &p1, const Place &p2) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <iostream> #include <iostream>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct CpuPlace { struct CPUPlace {
// WORKAROUND: for some reason, omitting this constructor // WORKAROUND: for some reason, omitting this constructor
// causes errors with boost 1.59 and OSX // causes errors with boost 1.59 and OSX
CpuPlace() {} CPUPlace() {}
// needed for variant equality comparison // needed for variant equality comparison
inline bool operator==(const CpuPlace &) const { return true; } inline bool operator==(const CPUPlace &) const { return true; }
inline bool operator!=(const CpuPlace &) const { return false; } inline bool operator!=(const CPUPlace &) const { return false; }
}; };
struct GpuPlace { struct GPUPlace {
GpuPlace() : GpuPlace(0) {} GPUPlace() : GPUPlace(0) {}
GpuPlace(int d) : device(d) {} GPUPlace(int d) : device(d) {}
// needed for variant equality comparison // needed for variant equality comparison
inline bool operator==(const GpuPlace &o) const { return device == o.device; } inline bool operator==(const GPUPlace &o) const { return device == o.device; }
inline bool operator!=(const GpuPlace &o) const { return !(*this == o); } inline bool operator!=(const GPUPlace &o) const { return !(*this == o); }
int device; int device;
}; };
struct IsGpuPlace : public boost::static_visitor<bool> { struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CpuPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GpuPlace &gpu) const { return true; } bool operator()(const GPUPlace &gpu) const { return true; }
}; };
typedef boost::variant<GpuPlace, CpuPlace> Place; typedef boost::variant<GPUPlace, CPUPlace> Place;
void set_place(const Place &); void set_place(const Place &);
const Place &get_place(); const Place &get_place();
const GpuPlace default_gpu(); const GPUPlace default_gpu();
const CpuPlace default_cpu(); const CPUPlace default_cpu();
bool is_gpu_place(const Place &); bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &); bool is_cpu_place(const Place &);
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
TEST(Place, Equality) { TEST(Place, Equality) {
paddle::platform::CpuPlace cpu; paddle::platform::CPUPlace cpu;
paddle::platform::GpuPlace g0(0), g1(1), gg0(0); paddle::platform::GPUPlace g0(0), g1(1), gg0(0);
EXPECT_EQ(cpu, cpu); EXPECT_EQ(cpu, cpu);
EXPECT_EQ(g0, g0); EXPECT_EQ(g0, g0);
...@@ -22,19 +22,19 @@ TEST(Place, Default) { ...@@ -22,19 +22,19 @@ TEST(Place, Default) {
EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu()));
EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu()));
paddle::platform::set_place(paddle::platform::CpuPlace()); paddle::platform::set_place(paddle::platform::CPUPlace());
EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place()));
} }
TEST(Place, Print) { TEST(Place, Print) {
{ {
std::stringstream ss; std::stringstream ss;
ss << paddle::platform::GpuPlace(1); ss << paddle::platform::GPUPlace(1);
EXPECT_EQ("GpuPlace(1)", ss.str()); EXPECT_EQ("GPUPlace(1)", ss.str());
} }
{ {
std::stringstream ss; std::stringstream ss;
ss << paddle::platform::CpuPlace(); ss << paddle::platform::CPUPlace();
EXPECT_EQ("CpuPlace", ss.str()); EXPECT_EQ("CPUPlace", ss.str());
} }
} }
cc_library(stringpiece SRCS piece.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/strings/stringpiece.h" #include "paddle/string/piece.h"
#include <string.h> #include <string.h>
...@@ -23,29 +23,25 @@ ...@@ -23,29 +23,25 @@
#include <stdexcept> #include <stdexcept>
namespace paddle { namespace paddle {
namespace string {
StringPiece::StringPiece() : data_(NULL), size_(0) {} Piece::Piece() : data_(NULL), size_(0) {}
StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) { Piece::Piece(const char* d, size_t n) : data_(d), size_(n) {
if (d == NULL && n != 0) if (d == NULL && n != 0)
throw std::invalid_argument( throw std::invalid_argument("Piece requires len to be 0 for NULL data");
"StringPiece requires len to be 0 for NULL data");
} }
StringPiece::StringPiece(const char* s) : data_(s) { Piece::Piece(const char* s) : data_(s) { size_ = (s == NULL) ? 0 : strlen(s); }
size_ = (s == NULL) ? 0 : strlen(s);
}
StringPiece::StringPiece(const std::string& s) Piece::Piece(const std::string& s) : data_(s.data()), size_(s.size()) {}
: data_(s.data()), size_(s.size()) {}
char StringPiece::operator[](size_t n) const { char Piece::operator[](size_t n) const {
if (n >= len()) if (n >= len()) throw std::invalid_argument("index out of Piece length");
throw std::invalid_argument("index out of StringPiece length");
return data_[n]; return data_[n];
} }
int Compare(StringPiece a, StringPiece b) { int Compare(Piece a, Piece b) {
const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); const size_t min_len = (a.len() < b.len()) ? a.len() : b.len();
int r = memcmp(a.data(), b.data(), min_len); int r = memcmp(a.data(), b.data(), min_len);
if (r == 0) { if (r == 0) {
...@@ -57,85 +53,86 @@ int Compare(StringPiece a, StringPiece b) { ...@@ -57,85 +53,86 @@ int Compare(StringPiece a, StringPiece b) {
return r; return r;
} }
bool operator==(StringPiece x, StringPiece y) { bool operator==(Piece x, Piece y) {
return ((x.len() == y.len()) && return ((x.len() == y.len()) &&
(x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0));
} }
bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } bool operator!=(Piece x, Piece y) { return !(x == y); }
bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; } bool operator<(Piece x, Piece y) { return Compare(x, y) < 0; }
bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; } bool operator>(Piece x, Piece y) { return Compare(x, y) > 0; }
bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; } bool operator<=(Piece x, Piece y) { return Compare(x, y) <= 0; }
bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; } bool operator>=(Piece x, Piece y) { return Compare(x, y) >= 0; }
bool HasPrefix(StringPiece s, StringPiece x) { bool HasPrefix(Piece s, Piece x) {
return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0));
} }
bool HasSuffix(StringPiece s, StringPiece x) { bool HasSuffix(Piece s, Piece x) {
return ((s.len() >= x.len()) && return ((s.len() >= x.len()) &&
(memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0));
} }
StringPiece SkipPrefix(StringPiece s, size_t n) { Piece SkipPrefix(Piece s, size_t n) {
if (n > s.len()) if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length"); throw std::invalid_argument("Skip distance larger than Piece length");
return StringPiece(s.data() + n, s.len() - n); return Piece(s.data() + n, s.len() - n);
} }
StringPiece SkipSuffix(StringPiece s, size_t n) { Piece SkipSuffix(Piece s, size_t n) {
if (n > s.len()) if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length"); throw std::invalid_argument("Skip distance larger than Piece length");
return StringPiece(s.data(), s.len() - n); return Piece(s.data(), s.len() - n);
} }
StringPiece TrimPrefix(StringPiece s, StringPiece x) { Piece TrimPrefix(Piece s, Piece x) {
return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s;
} }
StringPiece TrimSuffix(StringPiece s, StringPiece x) { Piece TrimSuffix(Piece s, Piece x) {
return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s;
} }
bool Contains(StringPiece s, StringPiece sub) { bool Contains(Piece s, Piece sub) {
return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end();
} }
size_t Index(StringPiece s, StringPiece sub) { size_t Index(Piece s, Piece sub) {
auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end());
return e != s.end() ? e - s.data() : StringPiece::npos; return e != s.end() ? e - s.data() : Piece::npos;
} }
size_t Find(StringPiece s, char c, size_t pos) { size_t Find(Piece s, char c, size_t pos) {
if (pos >= s.len()) { if (pos >= s.len()) {
return StringPiece::npos; return Piece::npos;
} }
const char* result = const char* result =
reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos)); reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos));
return result != nullptr ? result - s.data() : StringPiece::npos; return result != nullptr ? result - s.data() : Piece::npos;
} }
size_t RFind(StringPiece s, char c, size_t pos) { size_t RFind(Piece s, char c, size_t pos) {
if (s.len() == 0) return StringPiece::npos; if (s.len() == 0) return Piece::npos;
for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data();
p--) { p--) {
if (*p == c) { if (*p == c) {
return p - s.data(); return p - s.data();
} }
} }
return StringPiece::npos; return Piece::npos;
} }
StringPiece SubStr(StringPiece s, size_t pos, size_t n) { Piece SubStr(Piece s, size_t pos, size_t n) {
if (pos > s.len()) pos = s.len(); if (pos > s.len()) pos = s.len();
if (n > s.len() - pos) n = s.len() - pos; if (n > s.len() - pos) n = s.len() - pos;
return StringPiece(s.data() + pos, n); return Piece(s.data() + pos, n);
} }
std::ostream& operator<<(std::ostream& o, StringPiece piece) { std::ostream& operator<<(std::ostream& o, Piece piece) {
return o << piece.ToString(); return o << piece.ToString();
} }
} // namespace string
} // namespace paddle } // namespace paddle
...@@ -20,33 +20,34 @@ ...@@ -20,33 +20,34 @@
#include <string> #include <string>
namespace paddle { namespace paddle {
namespace string {
// StringPiece points into a std::string object but doesn't own the // Piece points into a std::string object but doesn't own the
// string. It is for efficient access to strings. Like Go's string // string. It is for efficient access to strings. Like Go's string
// type. Not that StringPiece doesn't mutate the underlying string, // type. Not that Piece doesn't mutate the underlying string,
// so it is thread-safe given that the underlying string doesn't // so it is thread-safe given that the underlying string doesn't
// change. Because StringPiece contains a little data members, and // change. Because Piece contains a little data members, and
// its syntax is simple as it doesn't own/manage the string, it is // its syntax is simple as it doesn't own/manage the string, it is
// cheap to construct StringPieces and pass them around. // cheap to construct Pieces and pass them around.
class StringPiece { class Piece {
public: public:
static const size_t npos = static_cast<size_t>(-1); static const size_t npos = static_cast<size_t>(-1);
// We provide non-explicit singleton constructors so users can // We provide non-explicit singleton constructors so users can
// pass in a "const char*" or a "string" wherever a "StringPiece" // pass in a "const char*" or a "string" wherever a "Piece"
// is expected. These contructors ensure that if data_ is NULL, // is expected. These contructors ensure that if data_ is NULL,
// size_ is 0. // size_ is 0.
StringPiece(); Piece();
StringPiece(const char* d, size_t n); Piece(const char* d, size_t n);
StringPiece(const char* d); Piece(const char* d);
StringPiece(const std::string& s); Piece(const std::string& s);
const char* data() const { return data_; } const char* data() const { return data_; }
size_t len() const { return size_; } size_t len() const { return size_; }
char operator[](size_t n) const; char operator[](size_t n) const;
// StringPiece doesn't own the string, so both iterator and const // Piece doesn't own the string, so both iterator and const
// iterator are const char* indeed. // iterator are const char* indeed.
typedef const char* const_iterator; typedef const char* const_iterator;
typedef const char* iterator; typedef const char* iterator;
...@@ -63,43 +64,44 @@ private: ...@@ -63,43 +64,44 @@ private:
// Intentionally copyable // Intentionally copyable
}; };
int Compare(StringPiece a, StringPiece b); int Compare(Piece a, Piece b);
bool operator==(StringPiece x, StringPiece y); bool operator==(Piece x, Piece y);
bool operator!=(StringPiece x, StringPiece y); bool operator!=(Piece x, Piece y);
bool operator<(StringPiece x, StringPiece y); bool operator<(Piece x, Piece y);
bool operator>(StringPiece x, StringPiece y); bool operator>(Piece x, Piece y);
bool operator<=(StringPiece x, StringPiece y); bool operator<=(Piece x, Piece y);
bool operator>=(StringPiece x, StringPiece y); bool operator>=(Piece x, Piece y);
bool HasPrefix(StringPiece s, StringPiece prefix); bool HasPrefix(Piece s, Piece prefix);
bool HasSuffix(StringPiece s, StringPiece suffix); bool HasSuffix(Piece s, Piece suffix);
StringPiece SkipPrefix(StringPiece s, size_t n); Piece SkipPrefix(Piece s, size_t n);
StringPiece SkipSuffix(StringPiece s, size_t n); Piece SkipSuffix(Piece s, size_t n);
// Skip the prefix (or suffix) if it matches with the string. // Skip the prefix (or suffix) if it matches with the string.
StringPiece TrimPrefix(StringPiece s, StringPiece prefix); Piece TrimPrefix(Piece s, Piece prefix);
StringPiece TrimSuffix(StringPiece s, StringPiece suffix); Piece TrimSuffix(Piece s, Piece suffix);
// Returns if s contains sub. Any s except for empty s contains an // Returns if s contains sub. Any s except for empty s contains an
// empty sub. // empty sub.
bool Contains(StringPiece s, StringPiece sub); bool Contains(Piece s, Piece sub);
// Return the first occurrence of sub in s, or npos. If both s and // Return the first occurrence of sub in s, or npos. If both s and
// sub is empty, it returns npos; otherwise, if only sub is empty, it // sub is empty, it returns npos; otherwise, if only sub is empty, it
// returns 0. // returns 0.
size_t Index(StringPiece s, StringPiece sub); size_t Index(Piece s, Piece sub);
// Return the first occurrence of c in s[pos:end], or npos. // Return the first occurrence of c in s[pos:end], or npos.
size_t Find(StringPiece s, char c, size_t pos); size_t Find(Piece s, char c, size_t pos);
// Search range is [0..pos] inclusive. If pos == npos, search everything. // Search range is [0..pos] inclusive. If pos == npos, search everything.
size_t RFind(StringPiece s, char c, size_t pos); size_t RFind(Piece s, char c, size_t pos);
StringPiece SubStr(StringPiece s, size_t pos, size_t n); Piece SubStr(Piece s, size_t pos, size_t n);
// allow StringPiece to be logged // allow Piece to be logged
std::ostream& operator<<(std::ostream& o, StringPiece piece); std::ostream& operator<<(std::ostream& o, Piece piece);
} // namespace string
} // namespace paddle } // namespace paddle
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/strings/stringpiece.h" #include "paddle/string/piece.h"
#include <sstream> #include <sstream>
...@@ -22,42 +22,44 @@ ...@@ -22,42 +22,44 @@
TEST(StringPiece, Construct) { TEST(StringPiece, Construct) {
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(NULL, s.data()); EXPECT_EQ(NULL, s.data());
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
{ EXPECT_THROW(paddle::StringPiece s(NULL, 10000U), std::invalid_argument); }
{ {
paddle::StringPiece s(NULL); EXPECT_THROW(paddle::string::Piece s(NULL, 10000U), std::invalid_argument);
}
{
paddle::string::Piece s(NULL);
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
{ {
std::string a; std::string a;
EXPECT_EQ(0U, a.size()); EXPECT_EQ(0U, a.size());
paddle::StringPiece s(a); paddle::string::Piece s(a);
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
} }
TEST(StringPiece, CopyAndAssign) { TEST(StringPiece, CopyAndAssign) {
paddle::StringPiece empty; paddle::string::Piece empty;
EXPECT_EQ(0U, empty.len()); EXPECT_EQ(0U, empty.len());
paddle::StringPiece a("hello"); paddle::string::Piece a("hello");
paddle::StringPiece b = a; paddle::string::Piece b = a;
EXPECT_EQ(b.len(), strlen("hello")); EXPECT_EQ(b.len(), strlen("hello"));
EXPECT_EQ(a, b); EXPECT_EQ(a, b);
std::string storage("hello"); std::string storage("hello");
paddle::StringPiece c(storage); paddle::string::Piece c(storage);
EXPECT_EQ(a, c); EXPECT_EQ(a, c);
EXPECT_NE(a.data(), c.data()); EXPECT_NE(a.data(), c.data());
} }
TEST(StringPiece, Compare) { TEST(StringPiece, Compare) {
{ {
paddle::StringPiece a("hello"); paddle::string::Piece a("hello");
paddle::StringPiece b("world"); paddle::string::Piece b("world");
EXPECT_TRUE(a != b); EXPECT_TRUE(a != b);
EXPECT_FALSE(a == b); EXPECT_FALSE(a == b);
EXPECT_TRUE(a < b); EXPECT_TRUE(a < b);
...@@ -68,7 +70,7 @@ TEST(StringPiece, Compare) { ...@@ -68,7 +70,7 @@ TEST(StringPiece, Compare) {
EXPECT_GT(Compare(b, a), 0); EXPECT_GT(Compare(b, a), 0);
} }
{ {
paddle::StringPiece a, b; paddle::string::Piece a, b;
EXPECT_TRUE(a == b); EXPECT_TRUE(a == b);
EXPECT_FALSE(a != b); EXPECT_FALSE(a != b);
EXPECT_FALSE(a < b); EXPECT_FALSE(a < b);
...@@ -82,31 +84,31 @@ TEST(StringPiece, Compare) { ...@@ -82,31 +84,31 @@ TEST(StringPiece, Compare) {
TEST(StringPiece, ToString) { TEST(StringPiece, ToString) {
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(std::string(""), s.ToString()); EXPECT_EQ(std::string(""), s.ToString());
} }
{ {
paddle::StringPiece s(NULL); paddle::string::Piece s(NULL);
EXPECT_EQ(std::string(""), s.ToString()); EXPECT_EQ(std::string(""), s.ToString());
} }
{ {
paddle::StringPiece s("hello"); paddle::string::Piece s("hello");
EXPECT_EQ(std::string("hello"), s.ToString()); EXPECT_EQ(std::string("hello"), s.ToString());
} }
} }
TEST(StringPiece, HasPrefixSuffix) { TEST(StringPiece, HasPrefixSuffix) {
using paddle::HasPrefix; using paddle::string::HasPrefix;
using paddle::HasSuffix; using paddle::string::HasSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_FALSE(HasPrefix(s, "something")); EXPECT_FALSE(HasPrefix(s, "something"));
EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_FALSE(HasSuffix(s, "something")); EXPECT_FALSE(HasSuffix(s, "something"));
EXPECT_TRUE(HasSuffix(s, "")); EXPECT_TRUE(HasSuffix(s, ""));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_TRUE(HasPrefix(s, "a")); EXPECT_TRUE(HasPrefix(s, "a"));
EXPECT_TRUE(HasPrefix(s, "ap")); EXPECT_TRUE(HasPrefix(s, "ap"));
...@@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) { ...@@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) {
} }
TEST(StringPiece, SkipPrefixSuffix) { TEST(StringPiece, SkipPrefixSuffix) {
using paddle::SkipPrefix; using paddle::string::SkipPrefix;
using paddle::SkipSuffix; using paddle::string::SkipSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", SkipPrefix(s, 0)); EXPECT_EQ("", SkipPrefix(s, 0));
EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument); EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument);
...@@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) { ...@@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) {
EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument); EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument);
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("app", SkipPrefix(s, 0)); EXPECT_EQ("app", SkipPrefix(s, 0));
EXPECT_EQ("pp", SkipPrefix(s, 1)); EXPECT_EQ("pp", SkipPrefix(s, 1));
EXPECT_EQ("p", SkipPrefix(s, 2)); EXPECT_EQ("p", SkipPrefix(s, 2));
...@@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) { ...@@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) {
} }
TEST(StringPiece, TrimPrefixSuffix) { TEST(StringPiece, TrimPrefixSuffix) {
using paddle::TrimPrefix; using paddle::string::TrimPrefix;
using paddle::TrimSuffix; using paddle::string::TrimSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", TrimPrefix(s, "")); EXPECT_EQ("", TrimPrefix(s, ""));
EXPECT_EQ("", TrimPrefix(s, "something")); EXPECT_EQ("", TrimPrefix(s, "something"));
...@@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) { ...@@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) {
EXPECT_EQ("", TrimSuffix(s, "something")); EXPECT_EQ("", TrimSuffix(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("app", TrimPrefix(s, "")); EXPECT_EQ("app", TrimPrefix(s, ""));
EXPECT_EQ("pp", TrimPrefix(s, "a")); EXPECT_EQ("pp", TrimPrefix(s, "a"));
EXPECT_EQ("p", TrimPrefix(s, "ap")); EXPECT_EQ("p", TrimPrefix(s, "ap"));
...@@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) { ...@@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) {
} }
TEST(StringPiece, Contains) { TEST(StringPiece, Contains) {
using paddle::Contains; using paddle::string::Contains;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_FALSE(Contains(s, "")); EXPECT_FALSE(Contains(s, ""));
EXPECT_FALSE(Contains(s, "something")); EXPECT_FALSE(Contains(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_TRUE(Contains(s, "")); EXPECT_TRUE(Contains(s, ""));
EXPECT_TRUE(Contains(s, "a")); EXPECT_TRUE(Contains(s, "a"));
EXPECT_TRUE(Contains(s, "p")); EXPECT_TRUE(Contains(s, "p"));
...@@ -193,15 +195,15 @@ TEST(StringPiece, Contains) { ...@@ -193,15 +195,15 @@ TEST(StringPiece, Contains) {
} }
TEST(StringPiece, Index) { TEST(StringPiece, Index) {
using paddle::Index; using paddle::string::Index;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, Index(s, "")); EXPECT_EQ(npos, Index(s, ""));
EXPECT_EQ(npos, Index(s, "something")); EXPECT_EQ(npos, Index(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(0U, Index(s, "")); EXPECT_EQ(0U, Index(s, ""));
EXPECT_EQ(0U, Index(s, "a")); EXPECT_EQ(0U, Index(s, "a"));
EXPECT_EQ(1U, Index(s, "p")); EXPECT_EQ(1U, Index(s, "p"));
...@@ -213,14 +215,14 @@ TEST(StringPiece, Index) { ...@@ -213,14 +215,14 @@ TEST(StringPiece, Index) {
} }
TEST(StringPiece, Find) { TEST(StringPiece, Find) {
using paddle::Find; using paddle::string::Find;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, Find(s, 'a', 0U)); EXPECT_EQ(npos, Find(s, 'a', 0U));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(0U, Find(s, 'a', 0U)); EXPECT_EQ(0U, Find(s, 'a', 0U));
EXPECT_EQ(1U, Find(s, 'p', 0U)); EXPECT_EQ(1U, Find(s, 'p', 0U));
EXPECT_EQ(1U, Find(s, 'p', 1U)); EXPECT_EQ(1U, Find(s, 'p', 1U));
...@@ -230,14 +232,14 @@ TEST(StringPiece, Find) { ...@@ -230,14 +232,14 @@ TEST(StringPiece, Find) {
} }
TEST(StringPiece, RFind) { TEST(StringPiece, RFind) {
using paddle::RFind; using paddle::string::RFind;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, RFind(s, 'a', 0U)); EXPECT_EQ(npos, RFind(s, 'a', 0U));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(2U, RFind(s, 'p', 2U)); EXPECT_EQ(2U, RFind(s, 'p', 2U));
EXPECT_EQ(0U, RFind(s, 'a', 2U)); EXPECT_EQ(0U, RFind(s, 'a', 2U));
EXPECT_EQ(1U, RFind(s, 'p', 1U)); EXPECT_EQ(1U, RFind(s, 'p', 1U));
...@@ -247,15 +249,15 @@ TEST(StringPiece, RFind) { ...@@ -247,15 +249,15 @@ TEST(StringPiece, RFind) {
} }
TEST(StringPiece, SubStr) { TEST(StringPiece, SubStr) {
using paddle::SubStr; using paddle::string::SubStr;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 0, 1)); EXPECT_EQ("", SubStr(s, 0, 1));
EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 1, 0));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 1, 0));
EXPECT_EQ("", SubStr(s, 2, 0)); EXPECT_EQ("", SubStr(s, 2, 0));
...@@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) { ...@@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) {
} }
TEST(StringPiece, StreamOutput) { TEST(StringPiece, StreamOutput) {
using paddle::StringPiece; using paddle::string::Piece;
std::stringstream o; std::stringstream o;
o << StringPiece(); o << paddle::string::Piece();
EXPECT_EQ("", o.str()); EXPECT_EQ("", o.str());
o << StringPiece("hello"); o << paddle::string::Piece("hello");
EXPECT_EQ("hello", o.str()); EXPECT_EQ("hello", o.str());
o << StringPiece(); o << paddle::string::Piece();
EXPECT_EQ("hello", o.str()); EXPECT_EQ("hello", o.str());
} }
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Compared with std::stringstream, there are primary purpose of
// string::Printf:
//
// 1. Type-safe printing, with why and how explained in
// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999.
// Implementation includes
//
// https://github.com/c42f/tinyformat
// boost::format
// std::stringstream
//
// std::stringstream is not convenient enough in many cases. For example:
//
// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n";
//
// boost::format is the most convenient one. We can have
//
// std::cout << format("%2% %1%") % 36 % 77;
//
// or
//
// format fmter("%2% %1%");
// fmter % 36; fmter % 77;
// std::cout << fmter.c_str();
//
// But the overloading of % might be overkilling and it would be
// more efficient if it can write to std::cout directly.
//
// tinyformat has an interface compatible with the C-printf style,
// and it can writes to a stream or returns a std::string:
//
// std::cout << tfm::printf(
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// or
//
// tfm::format(std::cout,
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// 2. High-performance -- most printed strings are not too long and
// doens't need dynamic memory allocation. Many StringPrintf
// implementations doesn't enforce type-safe, but are
// high-performance, including
//
// https://developers.google.com/optimization/reference/base/stringprintf/
// https://github.com/adobe/chromium/blob/master/base/stringprintf.h
// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h
//
// According to
// https://github.com/c42f/tinyformat#compile-time-and-code-bloat,
// boost::format runs too slow and results in large executable binary
// files. So here we port tinyformat.
#pragma once
#include <iostream>
#include <sstream>
#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat
namespace paddle {
namespace string {
template <typename... Args>
void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
}
template <typename... Args>
std::string Sprintf(const char* fmt, const Args&... args) {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
}
template <typename... Args>
void Printf(const char* fmt, const Args&... args) {
Fprintf(std::cout, fmt, args...);
}
} // namespace string
} // namespace paddle
#include "paddle/string/printf.h"
#include <string>
#include "gtest/gtest.h"
TEST(StringPrintf, StringPrintf) {
std::string weekday = "Wednesday";
const char* month = "July";
size_t day = 27;
long hour = 14;
int min = 44;
EXPECT_EQ(std::string("Wednesday, July 27, 14:44"),
paddle::string::Sprintf(
"%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min));
}
此差异已折叠。
cc_library(stringpiece SRCS stringpiece.cc)
cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册