diff --git a/CMakeLists.txt b/CMakeLists.txt index 9be75f4a7de4c6ecac461f9dcb62ab10e79cee6b..24a7066adc57c510030b0926c81849daa4caa6ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ if(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING) find_package(Git REQUIRED) find_package(Threads REQUIRED) +find_package(Boost QUIET) include(simd) @@ -110,6 +111,7 @@ include_directories("${PROJ_ROOT}") include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") +include_directories(${Boost_INCLUDE_DIRS}) set(EXTERNAL_LIBS ${GFLAGS_LIBRARIES} diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 6c85b1804bb9c5f3a8bc46bb3f54cc62c56cca70..8a42d4f8af1713e246f9efaf5dc7ba878c3b271e 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -30,7 +30,13 @@ func main() { log.SetLevel(level) timeout := time.Second * time.Duration((*etcdTimeout)) - s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) + e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + idx, err := e.Register() + if err != nil { + panic(err) + } + + s, err := pserver.NewService(idx) if err != nil { panic(err) } diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index b7293a759896f113d630d57d14b4b4ac8963f54a..f7b463857735070241611af98030c102d1907356 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -18,8 +18,8 @@ const ( DefaultAddrPath = "/master/addr" ) -// EtcdClient is the etcd client that master uses for fault tolerance -// and service registry. +// EtcdClient is the etcd client that the master uses for fault +// tolerance and service registry. type EtcdClient struct { lockPath string statePath string diff --git a/go/pserver/client.go b/go/pserver/client.go index dda915977282d4880ddcc8c18ef6fd80ede9e01b..6938b9d5ce6f6d73c05bd6e3154777023965c319 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -1,6 +1,7 @@ package pserver import ( + "errors" "hash/fnv" "sort" "time" @@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { // SendGrads sends gradients to parameter servers for updating // parameters. func (c *Client) SendGrads(grads []Gradient) error { + if len(grads) == 0 { + return errors.New("no gradient received") + } errCh := make(chan error, len(grads)) for _, g := range grads { go func(g Gradient) { diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go index 6ecf1fa08a02ed2ce04fae0903cebd46a7b768a4..5bd16118a7f70b766016abfce55f6bb2adf8cc60 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client_test.go @@ -7,7 +7,6 @@ import ( "strconv" "strings" "testing" - "time" "github.com/PaddlePaddle/Paddle/go/pserver" ) @@ -31,7 +30,7 @@ func init() { port[i] = p go func(l net.Listener) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go new file mode 100644 index 0000000000000000000000000000000000000000..4d88243edd4aa817ddc263ba316a3f6be9e1e67f --- /dev/null +++ b/go/pserver/etcd_client.go @@ -0,0 +1,181 @@ +package pserver + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +// EtcdClient is the etcd client that the pserver uses for fault +// tolerance, service registry and coordination. +type EtcdClient struct { + numPservers int + etcdEndpoints string + etcdClient *clientv3.Client + // etcdTimeout is also used as retry intervals. + etcdTimeout time.Duration + // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. + externalIP string + // desired number of pservers in the job. + // assume desired will not change during one training job. + desired int +} + +// NewEtcdClient creates an EtcdClient +func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { + return &EtcdClient{ + etcdTimeout: timeout, + numPservers: numPservers, + etcdEndpoints: endpoints, + } +} + +// Register registers the pserver on etcd +// +// Register returns the index of the current pserver. +func (e *EtcdClient) Register() (int, error) { + + var err error + e.externalIP, err = networkhelper.GetExternalIP() + if err != nil { + return 0, err + } + + // initialize connection to etcd. + ep := strings.Split(e.etcdEndpoints, ",") + for { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: e.etcdTimeout, + }) + if err != nil { + log.Errorf("connect to etcd error: %v", err) + time.Sleep(e.etcdTimeout) + continue + } + e.etcdClient = cli + log.Debugf("inited client to %s", e.etcdEndpoints) + break + } + // init /ps_desired using transaction, for multiple pservers may want to write + // it at the same time. + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := e.initDesiredPsercers(ctx, e.numPservers) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(e.etcdTimeout) + continue + } + break + } + // TODO: when implementing extending or reducing pservers, /ps_desired is + // changed, then we need to watch /ps_desired node for events. For now, just + // write once when init and read from it. + // wait and set s.desired init value + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := e.etcdClient.Get(ctx, PsDesired) + cancel() + if err != nil { + log.Errorf("getting %s error: %v", PsDesired, err) + time.Sleep(e.etcdTimeout) + continue + } + if len(resp.Kvs) != 0 { + e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("value of %s invalid %v\n", PsDesired, err) + time.Sleep(e.etcdTimeout) + // NOTE: wait util ps_desired value change + continue + } + break + } + } + + var pserverIdx int + // try register pserver node on etcd + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + var err error + pserverIdx, err = e.registerPserverEtcd(ctx) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(e.etcdTimeout) + continue + } + break + } + + return pserverIdx, nil +} + +func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + dsStr := c.Get(PsDesired) + if dsStr == "" { + c.Put(PsDesired, strconv.Itoa(numPservers)) + } + return nil + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) +} + +// registerPserverEtcd registers pserver node on etcd using transaction. +func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { + var idx int + _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + registered := false + for i := 0; i < e.desired; i++ { + psKey := "/ps/" + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + ps := c.Get(psKey) + log.Debugf("got value (%s) for key: %s", ps, psKey) + + if ps == "" { + resp, err := e.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } + // find the first id and write info + c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) + log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) + ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) + if kaerr != nil { + log.Errorf("keepalive etcd node error: %v", kaerr) + return kaerr + } + + // Eat the keep alive message so etcd + // will not expire the lease. + go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { + ka := <-ch + log.Debugf("keepalive: %d\n", ka.TTL) + }(ch) + log.Debug("register finished") + idx = i + registered = true + break + } + } + if registered == true { + return nil + } + return errors.New("not registerd, may due to already have enough pservers") + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) + + if err != nil { + return 0, err + } + + return idx, nil +} diff --git a/go/pserver/service.go b/go/pserver/service.go index f966595fdccbf23e23f94a857503ce05815164ef..f386ebea1eb8659a988de2a807303bb6687fa429 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,18 +1,9 @@ package pserver import ( - "context" "errors" "fmt" - "strconv" - "strings" "sync" - "time" - - "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/clientv3/concurrency" - log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -55,160 +46,25 @@ type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} + idx int mu sync.Mutex opt *optimizer paramMap map[string]Parameter - - etcdEndpoints string - etcdClient *clientv3.Client - // etcdTimeout is also used as retry intervals. - etcdTimeout time.Duration - // desired number of pservers in the job. - // assume desired will not change during one training job. - desired int - // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. - externalIP string } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { - s := &Service{opt: newOptimizer(sgd, 0.005)} +func NewService(idx int) (*Service, error) { + s := &Service{ + idx: idx, + opt: newOptimizer(sgd, 0.005), + } s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) - s.etcdEndpoints = endpoints - s.etcdTimeout = timeout - - var err error - s.externalIP, err = networkhelper.GetExternalIP() - if err != nil { - return nil, err - } - - if endpoints != "" { - // initialize connection to etcd, try - ep := strings.Split(s.etcdEndpoints, ",") - for { - cli, err := clientv3.New(clientv3.Config{ - Endpoints: ep, - DialTimeout: s.etcdTimeout, - }) - if err != nil { - log.Errorf("connect to etcd error: %v", err) - time.Sleep(s.etcdTimeout) - continue - } - s.etcdClient = cli - log.Debugf("inited client to %s", s.etcdEndpoints) - break - } - // init /ps_desired using transaction, for multiple pservers may want to write - // it at the same time. - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := s.initDesiredPsercers(ctx, numPservers) - cancel() - if err != nil { - log.Warn(err) - time.Sleep(s.etcdTimeout) - continue - } - break - } - // TODO: when implementing extending or reducing pservers, /ps_desired is - // changed, then we need to watch /ps_desired node for events. For now, just - // write once when init and read from it. - // wait and set s.desired init value - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := s.etcdClient.Get(ctx, PsDesired) - cancel() - if err != nil { - log.Errorf("getting %s error: %v", PsDesired, err) - time.Sleep(s.etcdTimeout) - continue - } - if len(resp.Kvs) != 0 { - s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) - if err != nil { - log.Errorf("value of %s invalid %v\n", PsDesired, err) - time.Sleep(s.etcdTimeout) - // NOTE: wait util ps_desired value change - continue - } - break - } - } - // try register pserver node on etcd - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := s.registerPserverEtcd(ctx) - cancel() - if err != nil { - log.Warn(err) - time.Sleep(s.etcdTimeout) - continue - } - break - } - } // if endpoints != "" - // Bypass etcd registration if no endpoints specified return s, nil } -func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { - dsStr := c.Get(PsDesired) - if dsStr == "" { - c.Put(PsDesired, strconv.Itoa(numPservers)) - } - return nil - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - -// registerPserverEtcd registers pserver node on etcd using transaction. -func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { - registered := false - for i := 0; i < s.desired; i++ { - psKey := "/ps/" + strconv.Itoa(i) - log.Debugf("checking %s", psKey) - ps := c.Get(psKey) - log.Debugf("got value (%s) for key: %s", ps, psKey) - - if ps == "" { - resp, err := s.etcdClient.Grant(context.TODO(), 5) - if err != nil { - log.Fatal(err) - } - // find the first id and write info - c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) - log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) - ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) - if kaerr != nil { - log.Errorf("keepalive etcd node error: %v", kaerr) - return kaerr - } - - // Eat the keep alive message so etcd - // will not expire the lease. - go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { - ka := <-ch - log.Debugf("keepalive: %d\n", ka.TTL) - }(ch) - log.Debug("register finished") - registered = true - break - } - } - if registered == true { - return nil - } - return errors.New("not registerd, may due to already have enough pservers") - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { select { diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index f317535592165b921491120888badd30c6795c12..d9d887cffd462eed48b972466a7d83bae35d9a1c 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,7 +10,7 @@ import ( ) func TestFull(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } @@ -75,7 +75,7 @@ func TestFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } @@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() @@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) { } func TestBlockUntilInitialized(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 573bd937a351a6f308974e14f3bc92cbe1b541bc..307e99bbe3a833f1fe26057ec38d0b96e04bc0fe 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -9,17 +9,10 @@ add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) add_subdirectory(optimizer) -add_subdirectory(strings) - -# Do not build go directory until go cmake is working smoothly. -# if(CMAKE_Go_COMPILER) -# add_subdirectory(go) -# endif() - -find_package(Boost QUIET) +add_subdirectory(string) if(Boost_FOUND) - include_directories(${Boost_INCLUDE_DIRS}) + add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) endif() diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 8ef5e9d0c116dd088b5c5c318dfb47c245b471fa..018da6c76dc27a74b074ec52c18347beba8164fc 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -601,7 +601,7 @@ void TrainerThread::backward() { void TrainerThread::backwardCallback(Parameter* para) { // CPU parameters are merged in the end - if (!para->useGpu()) return; + if (!para->useGpu() || para->isStatic()) return; int paramId = para->getID(); if (multiMachine_->getNumThreads() == 1) { diff --git a/paddle/memory/.clang-format b/paddle/memory/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..29282dc87e2c499988c17d90d47d44cd5cf7f115 --- /dev/null +++ b/paddle/memory/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3943c3cfad31d13a00645aba6fc153d3d13da987 --- /dev/null +++ b/paddle/memory/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(detail) diff --git a/paddle/memory/README.md b/paddle/memory/README.md index e5f7880e4cad346da5399815f5e76b7b9b99bdea..96a331a486f57d3e030408fee182199bad5b38c2 100644 --- a/paddle/memory/README.md +++ b/paddle/memory/README.md @@ -97,6 +97,7 @@ class BuddyAllocator { struct Block { size_t size; Block* left, right; + size_t index; // allocator id }; ... }; diff --git a/paddle/memory/detail/CMakeLists.txt b/paddle/memory/detail/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..72d3749ad789eca9a4b10944131171c0cf8dfe5a --- /dev/null +++ b/paddle/memory/detail/CMakeLists.txt @@ -0,0 +1,7 @@ +if(${WITH_GPU}) + nv_library(system_allocator SRCS system_allocator.cc DEPS gflags) + nv_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags) +else(${WITH_GPU}) + cc_library(system_allocator SRCS system_allocator.cc DEPS gflags) + cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags) +endif(${WITH_GPU}) diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebe680f5eea4948339fb8c5584a5b9f5d71c752e --- /dev/null +++ b/paddle/memory/detail/buddy_allocator.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/memory/detail/buddy_allocator.h" + +namespace paddle { +namespace memory { +namespace detail { + +BuddyAllocator::BuddyAllocator(size_t pool_size, size_t max_pools, + SystemAllocator* system_allocator) + : pool_size_(pool_size), + max_pools_(max_pools), + system_allocator_(system_allocator) { + PADDLE_ASSERT(pool_size > 0); + PADDLE_ASSERT(max_pools > 0); + PADDLE_ASSERT(system_allocator != nullptr); +} + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/buddy_allocator.h b/paddle/memory/detail/buddy_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..82e6aaedc719966b4074449ce1ef7193c73dc265 --- /dev/null +++ b/paddle/memory/detail/buddy_allocator.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/memory/detail/system_allocator.h" + +#include +#include + +namespace paddle { +namespace memory { +namespace detail { + +class BuddyAllocator { + public: + BuddyAllocator(size_t pool_size, size_t max_pools, + SystemAllocator* system_allocator); + ~BuddyAllocator(); + + void* Alloc(size_t size); + void Free(void*); + size_t Used(); + + private: + struct Block { + size_t size_; + Block* left_; // left buddy + Block* right_; // right buddy + }; + + // Initially, there is only one pool. If a Alloc founds not enough + // memory from that pool, and there has not been max_num_pools_, + // create a new pool by calling system_allocator_.Alloc(pool_size_). + std::vector pools_; + + size_t pool_size_; // the size of each pool; + size_t max_num_pools_; // the size of all pools; + + SystemAllocator* system_allocator_; + + std::mutex mutex_; + + // Disable copy and assignment. + BuddyAllocator(const BuddyAllocator&) = delete; + BuddyAllocator& operator=(const BuddyAllocator&) = delete; +}; + +BuddyAllocator* GetCPUBuddyAllocator() { + static BuddyAllocator* a = nullptr; + if (a == nullptr) { + a = new BuddyAllocator(); + } + return a; +} + +#ifndef PADDLE_ONLY_CPU // The following code are for CUDA. + +BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { + static BuddyAllocator** as = NULL; + if (as == NULL) { + int gpu_num = platform::GetDeviceCount(); + as = new BuddyAllocator*[gpu_num]; + for (int gpu = 0; gpu < gpu_num; gpu++) { + as[gpu] = new BuddyAllocator(); + } + } + return as[gpu_id]; +} + +#endif // PADDLE_ONLY_CPU + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..50bec926f83dee8a4343d0b16aeb088f9d2a4871 --- /dev/null +++ b/paddle/memory/detail/system_allocator.cc @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/memory/detail/system_allocator.h" + +#include // for malloc and free +#include // for mlock and munlock + +#include "gflags/gflags.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cuda.h" + +// If use_pinned_memory is true, CPUAllocator calls mlock, which +// returns pinned and locked memory as staging areas for data exchange +// between host and device. Allocates too much would reduce the amount +// of memory available to the system for paging. So, by default, we +// should set false to use_pinned_memory. +DEFINE_bool(use_pinned_memory, false, + "If set, allocate cpu/gpu pinned memory."); + +namespace paddle { +namespace memory { +namespace detail { + +void* CPUAllocator::Alloc(size_t size) { + // According to http://www.cplusplus.com/reference/cstdlib/malloc/, + // malloc might not return nullptr if size is zero, but the returned + // pointer shall not be dereferenced -- so we make it nullptr. + if (size <= 0) return nullptr; + + void* p = malloc(size); + if (p != nullptr && FLAGS_use_pinned_memory) { + mlock(p, size); + } + return p; +} + +void CPUAllocator::Free(void* p, size_t size) { + if (p != nullptr && FLAGS_use_pinned_memory) { + munlock(p, size); + } + free(p); +} + +#ifndef PADDLE_ONLY_CPU + +void* GPUAllocator::Alloc(size_t size) { + // CUDA documentation doesn't explain if cudaMalloc returns nullptr + // if size is 0. We just make sure it does. + if (size <= 0) { + return nullptr; + } + + void* p = 0; + cudaError_t result = + FLAGS_use_pinned_memory ? cudaMallocHost(&p, size) : cudaMalloc(&p, size); + if (result != cudaSuccess) { + cudaGetLastError(); // clear error if there is any. + } + return result == cudaSuccess ? p : nullptr; +} + +void GPUAllocator::Free(void* p, size_t size) { + // Purposefully allow cudaErrorCudartUnloading, because + // that is returned if you ever call cudaFree after the + // driver has already shutdown. This happens only if the + // process is terminating, in which case we don't care if + // cudaFree succeeds. + cudaError_t err = FLAGS_use_pinned_memory ? cudaFreeHost(p) : cudaFree(p); + if (err != cudaErrorCudartUnloading) { + platform::throw_on_error(err, "cudaFree{Host} failed"); + } +} + +#endif // PADDLE_ONLY_CPU + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/system_allocator.h b/paddle/memory/detail/system_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..184b383f7f78244fa6632a3bffb1a0a78b3aa664 --- /dev/null +++ b/paddle/memory/detail/system_allocator.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include // for size_t + +namespace paddle { +namespace memory { +namespace detail { + +// SystemAllocator is the parent class of CPUAllocator and +// GPUAllocator. A BuddyAllocator object uses a SystemAllocator* +// pointing to the underlying system allocator. An alternative to +// this class hierarchy is to pass a system allocator class to +// BuddyAllocator as a template parameter. This approach makes +// BuddyAllocator a class template, and it's very complicated +// algorithm would make the buddy_allocator.h messy. +class SystemAllocator { + public: + virtual ~SystemAllocator() {} + virtual void* Alloc(size_t size) = 0; + virtual void Free(void* p, size_t size) = 0; +}; + +class CPUAllocator : public SystemAllocator { + public: + virtual void* Alloc(size_t size); + virtual void Free(void* p, size_t size); +}; + +#ifndef PADDLE_ONLY_CPU +class GPUAllocator : public SystemAllocator { + public: + virtual void* Alloc(size_t size); + virtual void Free(void* p, size_t size); +}; +#endif // PADDLE_ONLY_CPU + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/system_allocator_test.cc b/paddle/memory/detail/system_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bd5706a4e4d1546a8c879ebbac0f3349c9d59f6 --- /dev/null +++ b/paddle/memory/detail/system_allocator_test.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/memory/detail/system_allocator.h" + +#include +#include + +#include "gflags/gflags.h" +#include "gtest/gtest.h" + +DECLARE_bool(use_pinned_memory); + +void TestAllocator(paddle::memory::detail::SystemAllocator& a, size_t size) { + bool freed = false; + { + void* p = a.Alloc(size); + if (size > 0) { + EXPECT_NE(p, nullptr); + } else { + EXPECT_EQ(p, nullptr); + } + + int* i = static_cast(p); + std::shared_ptr ptr(i, [&](void* p) { + freed = true; + a.Free(p, size); + }); + } + EXPECT_TRUE(freed); +} + +TEST(CPUAllocator, NoLockMem) { + FLAGS_use_pinned_memory = false; + paddle::memory::detail::CPUAllocator a; + TestAllocator(a, 2048); + TestAllocator(a, 0); +} + +TEST(CPUAllocator, LockMem) { + FLAGS_use_pinned_memory = true; + paddle::memory::detail::CPUAllocator a; + TestAllocator(a, 2048); + TestAllocator(a, 0); +} + +#ifndef PADDLE_ONLY_CPU +TEST(GPUAllocator, NoStaging) { + FLAGS_use_pinned_memory = false; + paddle::memory::detail::GPUAllocator a; + TestAllocator(a, 2048); + TestAllocator(a, 0); +} +TEST(GPUAllocator, Staging) { + FLAGS_use_pinned_memory = true; + paddle::memory::detail::GPUAllocator a; + TestAllocator(a, 2048); + TestAllocator(a, 0); +} +#endif // PADDLE_ONLY_CPU diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d123d99e234a378ee64850eebacece223e2b121 --- /dev/null +++ b/paddle/memory/memory.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/memory/memory.h" +#include "paddle/memory/detail/buddy_allocator.h" +#include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/assert.h" + +#include + +namespace paddle { +namespace memory { + +void* Alloc(platform::Place pl, size_t size) { +#ifndef PADDLE_ONLY_CPU + if (paddle::platform::is_gpu_place(pl)) { + size_t gpu_id = boost::get(pl).device; + return detail::GetGPUBuddyAllocator(gpu_id)->Alloc(size); + } +#endif // PADDLE_ONLY_CPU + PADDLE_ASSERT(paddle::platform::is_cpu_place(pl)); + return detail::GetCPUBuddyAllocator()->Alloc(size); +} + +void Free(paddle::platform::Place pl, void* p) { +#ifndef PADDLE_ONLY_CPU + if (paddle::platform::is_gpu_place(pl)) { + size_t gpu_id = boost::get(pl).device; + detail::GetGPUBuddyAllocator(gpu_id)->Free(p); + } +#endif // PADDLE_ONLY_CPU + PADDLE_ASSERT(paddle::platform::is_cpu_place(pl)); + detail::GetCPUBuddyAllocator()->Free(p); +} + +size_t Used(paddle::platform::Place pl) { +#ifndef PADDLE_ONLY_CPU + if (paddle::platform::is_gpu_place(pl)) { + size_t gpu_id = boost::get(pl).device; + return detail::GetGPUBuddyAllocator(gpu_id)->Used(); + } +#endif // PADDLE_ONLY_CPU + PADDLE_ASSERT(paddle::platform::is_cpu_place(pl)); + return detail::GetCPUBuddyAllocator()->Used(); +} + +} // namespace memory +} // namespace paddle diff --git a/paddle/platform/must_check.h b/paddle/memory/memory.h similarity index 53% rename from paddle/platform/must_check.h rename to paddle/memory/memory.h index 4fcc62afc05b14949fc43266f0d05be1f1b7891a..a33092bade65e6df0faee226a8967c9fc9caa032 100644 --- a/paddle/platform/must_check.h +++ b/paddle/memory/memory.h @@ -1,8 +1,11 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -10,17 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -/** - * __must_check macro. It make the function's return value must be used, - * otherwise it will raise a compile warning. And also Paddle treat all compile - * warnings as errors. - */ -#ifdef __GNUC__ -#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400 -#define __must_check __attribute__((warn_unused_result)) -#else -#define __must_check -#endif -#else -#define __must_check -#endif + +#include "paddle/platform/place.h" + +namespace paddle { +namespace memory { + +void* Alloc(paddle::platform::Place, size_t); +void Free(paddle::platform::Place, void*); +size_t Used(paddle::platform::Place); + +} // namespace memory +} // namespace paddle diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 8435410564b2770b89fa28dbf96c9421335ca889..bc72e62be41e127b80e7597788a73f97bfef0e2f 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -2,5 +2,4 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -cc_test(must_check_test SRCS must_check_test.cc) cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/platform/cuda.h b/paddle/platform/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..8fe891f9ce6c3add1df48a8b1f79fd811c7a4362 --- /dev/null +++ b/paddle/platform/cuda.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifndef PADDLE_ONLY_CPU + +#include +#include + +namespace paddle { +namespace platform { + +inline void throw_on_error(cudaError_t e, const char* message) { + if (e) { + throw thrust::system_error(e, thrust::cuda_category(), message); + } +} + +int GetDeviceCount(void) { + int count; + throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); + return count; +} + +} // namespace platform +} // namespace paddle + +#endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/must_check_test.cc b/paddle/platform/must_check_test.cc deleted file mode 100644 index 6ee3ea49acdc4384b5d5df353bfa1290856e982c..0000000000000000000000000000000000000000 --- a/paddle/platform/must_check_test.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include - -int __must_check SomeFunctionMustCheck() { return 0; } - -TEST(MustCheck, all) { - // This line should not be compiled, because the - // return value of SomeFunctionMustCheck marked as __must_check - // SomeFunctionMustCheck(); -} \ No newline at end of file diff --git a/paddle/platform/place.cc b/paddle/platform/place.cc index 1afd03c01169d395b086c1da458ce25c66a12a51..0704820aa05079401eb56814d689d6e280311edb 100644 --- a/paddle/platform/place.cc +++ b/paddle/platform/place.cc @@ -8,8 +8,8 @@ namespace detail { class PlacePrinter : public boost::static_visitor<> { public: PlacePrinter(std::ostream &os) : os_(os) {} - void operator()(const CpuPlace &) { os_ << "CpuPlace"; } - void operator()(const GpuPlace &p) { os_ << "GpuPlace(" << p.device << ")"; } + void operator()(const CPUPlace &) { os_ << "CPUPlace"; } + void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; } private: std::ostream &os_; @@ -22,14 +22,14 @@ static Place the_default_place; void set_place(const Place &place) { the_default_place = place; } const Place &get_place() { return the_default_place; } -const GpuPlace default_gpu() { return GpuPlace(0); } -const CpuPlace default_cpu() { return CpuPlace(); } +const GPUPlace default_gpu() { return GPUPlace(0); } +const CPUPlace default_cpu() { return CPUPlace(); } bool is_gpu_place(const Place &p) { - return boost::apply_visitor(IsGpuPlace(), p); + return boost::apply_visitor(IsGPUPlace(), p); } bool is_cpu_place(const Place &p) { - return !boost::apply_visitor(IsGpuPlace(), p); + return !boost::apply_visitor(IsGPUPlace(), p); } bool places_are_same_class(const Place &p1, const Place &p2) { diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 489572c526e162500c8f747f0ec8df10da9d86a2..7cead183884bc9379355cd931921b40d6c11ce90 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -1,43 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once + #include #include namespace paddle { namespace platform { -struct CpuPlace { +struct CPUPlace { // WORKAROUND: for some reason, omitting this constructor // causes errors with boost 1.59 and OSX - CpuPlace() {} + CPUPlace() {} // needed for variant equality comparison - inline bool operator==(const CpuPlace &) const { return true; } - inline bool operator!=(const CpuPlace &) const { return false; } + inline bool operator==(const CPUPlace &) const { return true; } + inline bool operator!=(const CPUPlace &) const { return false; } }; -struct GpuPlace { - GpuPlace() : GpuPlace(0) {} - GpuPlace(int d) : device(d) {} +struct GPUPlace { + GPUPlace() : GPUPlace(0) {} + GPUPlace(int d) : device(d) {} // needed for variant equality comparison - inline bool operator==(const GpuPlace &o) const { return device == o.device; } - inline bool operator!=(const GpuPlace &o) const { return !(*this == o); } + inline bool operator==(const GPUPlace &o) const { return device == o.device; } + inline bool operator!=(const GPUPlace &o) const { return !(*this == o); } int device; }; -struct IsGpuPlace : public boost::static_visitor { - bool operator()(const CpuPlace &) const { return false; } - bool operator()(const GpuPlace &gpu) const { return true; } +struct IsGPUPlace : public boost::static_visitor { + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const GPUPlace &gpu) const { return true; } }; -typedef boost::variant Place; +typedef boost::variant Place; void set_place(const Place &); const Place &get_place(); -const GpuPlace default_gpu(); -const CpuPlace default_cpu(); +const GPUPlace default_gpu(); +const CPUPlace default_cpu(); bool is_gpu_place(const Place &); bool is_cpu_place(const Place &); diff --git a/paddle/platform/place_test.cc b/paddle/platform/place_test.cc index 73fccceedf6918148a26100f64cf322305c3ac20..33e2e5a439ce6801c02daba4bcbd462a74d7a614 100644 --- a/paddle/platform/place_test.cc +++ b/paddle/platform/place_test.cc @@ -3,8 +3,8 @@ #include "gtest/gtest.h" TEST(Place, Equality) { - paddle::platform::CpuPlace cpu; - paddle::platform::GpuPlace g0(0), g1(1), gg0(0); + paddle::platform::CPUPlace cpu; + paddle::platform::GPUPlace g0(0), g1(1), gg0(0); EXPECT_EQ(cpu, cpu); EXPECT_EQ(g0, g0); @@ -22,19 +22,19 @@ TEST(Place, Default) { EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); - paddle::platform::set_place(paddle::platform::CpuPlace()); + paddle::platform::set_place(paddle::platform::CPUPlace()); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); } TEST(Place, Print) { { std::stringstream ss; - ss << paddle::platform::GpuPlace(1); - EXPECT_EQ("GpuPlace(1)", ss.str()); + ss << paddle::platform::GPUPlace(1); + EXPECT_EQ("GPUPlace(1)", ss.str()); } { std::stringstream ss; - ss << paddle::platform::CpuPlace(); - EXPECT_EQ("CpuPlace", ss.str()); + ss << paddle::platform::CPUPlace(); + EXPECT_EQ("CPUPlace", ss.str()); } } diff --git a/paddle/string/CMakeLists.txt b/paddle/string/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5becf62672d0c606c98ea1a1a4383df97088ab05 --- /dev/null +++ b/paddle/string/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library(stringpiece SRCS piece.cc) +cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) + +cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags) diff --git a/paddle/string/piece.cc b/paddle/string/piece.cc new file mode 100644 index 0000000000000000000000000000000000000000..b80afdec82d642fd3a8245b96ce1bb2bea17cbae --- /dev/null +++ b/paddle/string/piece.cc @@ -0,0 +1,138 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include "paddle/string/piece.h" + +#include + +#include +#include +#include + +namespace paddle { +namespace string { + +Piece::Piece() : data_(NULL), size_(0) {} + +Piece::Piece(const char* d, size_t n) : data_(d), size_(n) { + if (d == NULL && n != 0) + throw std::invalid_argument("Piece requires len to be 0 for NULL data"); +} + +Piece::Piece(const char* s) : data_(s) { size_ = (s == NULL) ? 0 : strlen(s); } + +Piece::Piece(const std::string& s) : data_(s.data()), size_(s.size()) {} + +char Piece::operator[](size_t n) const { + if (n >= len()) throw std::invalid_argument("index out of Piece length"); + return data_[n]; +} + +int Compare(Piece a, Piece b) { + const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); + int r = memcmp(a.data(), b.data(), min_len); + if (r == 0) { + if (a.len() < b.len()) + return -1; + else if (a.len() > b.len()) + return 1; + } + return r; +} + +bool operator==(Piece x, Piece y) { + return ((x.len() == y.len()) && + (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); +} + +bool operator!=(Piece x, Piece y) { return !(x == y); } + +bool operator<(Piece x, Piece y) { return Compare(x, y) < 0; } +bool operator>(Piece x, Piece y) { return Compare(x, y) > 0; } + +bool operator<=(Piece x, Piece y) { return Compare(x, y) <= 0; } +bool operator>=(Piece x, Piece y) { return Compare(x, y) >= 0; } + +bool HasPrefix(Piece s, Piece x) { + return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); +} + +bool HasSuffix(Piece s, Piece x) { + return ((s.len() >= x.len()) && + (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); +} + +Piece SkipPrefix(Piece s, size_t n) { + if (n > s.len()) + throw std::invalid_argument("Skip distance larger than Piece length"); + return Piece(s.data() + n, s.len() - n); +} + +Piece SkipSuffix(Piece s, size_t n) { + if (n > s.len()) + throw std::invalid_argument("Skip distance larger than Piece length"); + return Piece(s.data(), s.len() - n); +} + +Piece TrimPrefix(Piece s, Piece x) { + return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; +} + +Piece TrimSuffix(Piece s, Piece x) { + return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; +} + +bool Contains(Piece s, Piece sub) { + return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); +} + +size_t Index(Piece s, Piece sub) { + auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); + return e != s.end() ? e - s.data() : Piece::npos; +} + +size_t Find(Piece s, char c, size_t pos) { + if (pos >= s.len()) { + return Piece::npos; + } + const char* result = + reinterpret_cast(memchr(s.data() + pos, c, s.len() - pos)); + return result != nullptr ? result - s.data() : Piece::npos; +} + +size_t RFind(Piece s, char c, size_t pos) { + if (s.len() == 0) return Piece::npos; + for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); + p--) { + if (*p == c) { + return p - s.data(); + } + } + return Piece::npos; +} + +Piece SubStr(Piece s, size_t pos, size_t n) { + if (pos > s.len()) pos = s.len(); + if (n > s.len() - pos) n = s.len() - pos; + return Piece(s.data() + pos, n); +} + +std::ostream& operator<<(std::ostream& o, Piece piece) { + return o << piece.ToString(); +} + +} // namespace string +} // namespace paddle diff --git a/paddle/strings/stringpiece.h b/paddle/string/piece.h similarity index 57% rename from paddle/strings/stringpiece.h rename to paddle/string/piece.h index adff713e86f49349b8f189c1d24584bfc1bb8aa7..db7c3e69804a6a8f0510ba376432fe560ae74442 100644 --- a/paddle/strings/stringpiece.h +++ b/paddle/string/piece.h @@ -20,33 +20,34 @@ #include namespace paddle { +namespace string { -// StringPiece points into a std::string object but doesn't own the +// Piece points into a std::string object but doesn't own the // string. It is for efficient access to strings. Like Go's string -// type. Not that StringPiece doesn't mutate the underlying string, +// type. Not that Piece doesn't mutate the underlying string, // so it is thread-safe given that the underlying string doesn't -// change. Because StringPiece contains a little data members, and +// change. Because Piece contains a little data members, and // its syntax is simple as it doesn't own/manage the string, it is -// cheap to construct StringPieces and pass them around. -class StringPiece { +// cheap to construct Pieces and pass them around. +class Piece { public: static const size_t npos = static_cast(-1); // We provide non-explicit singleton constructors so users can - // pass in a "const char*" or a "string" wherever a "StringPiece" + // pass in a "const char*" or a "string" wherever a "Piece" // is expected. These contructors ensure that if data_ is NULL, // size_ is 0. - StringPiece(); - StringPiece(const char* d, size_t n); - StringPiece(const char* d); - StringPiece(const std::string& s); + Piece(); + Piece(const char* d, size_t n); + Piece(const char* d); + Piece(const std::string& s); const char* data() const { return data_; } size_t len() const { return size_; } char operator[](size_t n) const; - // StringPiece doesn't own the string, so both iterator and const + // Piece doesn't own the string, so both iterator and const // iterator are const char* indeed. typedef const char* const_iterator; typedef const char* iterator; @@ -63,43 +64,44 @@ private: // Intentionally copyable }; -int Compare(StringPiece a, StringPiece b); +int Compare(Piece a, Piece b); -bool operator==(StringPiece x, StringPiece y); -bool operator!=(StringPiece x, StringPiece y); -bool operator<(StringPiece x, StringPiece y); -bool operator>(StringPiece x, StringPiece y); -bool operator<=(StringPiece x, StringPiece y); -bool operator>=(StringPiece x, StringPiece y); +bool operator==(Piece x, Piece y); +bool operator!=(Piece x, Piece y); +bool operator<(Piece x, Piece y); +bool operator>(Piece x, Piece y); +bool operator<=(Piece x, Piece y); +bool operator>=(Piece x, Piece y); -bool HasPrefix(StringPiece s, StringPiece prefix); -bool HasSuffix(StringPiece s, StringPiece suffix); +bool HasPrefix(Piece s, Piece prefix); +bool HasSuffix(Piece s, Piece suffix); -StringPiece SkipPrefix(StringPiece s, size_t n); -StringPiece SkipSuffix(StringPiece s, size_t n); +Piece SkipPrefix(Piece s, size_t n); +Piece SkipSuffix(Piece s, size_t n); // Skip the prefix (or suffix) if it matches with the string. -StringPiece TrimPrefix(StringPiece s, StringPiece prefix); -StringPiece TrimSuffix(StringPiece s, StringPiece suffix); +Piece TrimPrefix(Piece s, Piece prefix); +Piece TrimSuffix(Piece s, Piece suffix); // Returns if s contains sub. Any s except for empty s contains an // empty sub. -bool Contains(StringPiece s, StringPiece sub); +bool Contains(Piece s, Piece sub); // Return the first occurrence of sub in s, or npos. If both s and // sub is empty, it returns npos; otherwise, if only sub is empty, it // returns 0. -size_t Index(StringPiece s, StringPiece sub); +size_t Index(Piece s, Piece sub); // Return the first occurrence of c in s[pos:end], or npos. -size_t Find(StringPiece s, char c, size_t pos); +size_t Find(Piece s, char c, size_t pos); // Search range is [0..pos] inclusive. If pos == npos, search everything. -size_t RFind(StringPiece s, char c, size_t pos); +size_t RFind(Piece s, char c, size_t pos); -StringPiece SubStr(StringPiece s, size_t pos, size_t n); +Piece SubStr(Piece s, size_t pos, size_t n); -// allow StringPiece to be logged -std::ostream& operator<<(std::ostream& o, StringPiece piece); +// allow Piece to be logged +std::ostream& operator<<(std::ostream& o, Piece piece); +} // namespace string } // namespace paddle diff --git a/paddle/strings/stringpiece_test.cc b/paddle/string/piece_test.cc similarity index 77% rename from paddle/strings/stringpiece_test.cc rename to paddle/string/piece_test.cc index 2ba66a04f641c3457efa713383484491a213668f..cf5152ff5a3cb0a2afae0c90b787abf291122fa3 100644 --- a/paddle/strings/stringpiece_test.cc +++ b/paddle/string/piece_test.cc @@ -14,7 +14,7 @@ limitations under the License. */ -#include "paddle/strings/stringpiece.h" +#include "paddle/string/piece.h" #include @@ -22,42 +22,44 @@ TEST(StringPiece, Construct) { { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(NULL, s.data()); EXPECT_EQ(0U, s.len()); } - { EXPECT_THROW(paddle::StringPiece s(NULL, 10000U), std::invalid_argument); } { - paddle::StringPiece s(NULL); + EXPECT_THROW(paddle::string::Piece s(NULL, 10000U), std::invalid_argument); + } + { + paddle::string::Piece s(NULL); EXPECT_EQ(0U, s.len()); } { std::string a; EXPECT_EQ(0U, a.size()); - paddle::StringPiece s(a); + paddle::string::Piece s(a); EXPECT_EQ(0U, s.len()); } } TEST(StringPiece, CopyAndAssign) { - paddle::StringPiece empty; + paddle::string::Piece empty; EXPECT_EQ(0U, empty.len()); - paddle::StringPiece a("hello"); - paddle::StringPiece b = a; + paddle::string::Piece a("hello"); + paddle::string::Piece b = a; EXPECT_EQ(b.len(), strlen("hello")); EXPECT_EQ(a, b); std::string storage("hello"); - paddle::StringPiece c(storage); + paddle::string::Piece c(storage); EXPECT_EQ(a, c); EXPECT_NE(a.data(), c.data()); } TEST(StringPiece, Compare) { { - paddle::StringPiece a("hello"); - paddle::StringPiece b("world"); + paddle::string::Piece a("hello"); + paddle::string::Piece b("world"); EXPECT_TRUE(a != b); EXPECT_FALSE(a == b); EXPECT_TRUE(a < b); @@ -68,7 +70,7 @@ TEST(StringPiece, Compare) { EXPECT_GT(Compare(b, a), 0); } { - paddle::StringPiece a, b; + paddle::string::Piece a, b; EXPECT_TRUE(a == b); EXPECT_FALSE(a != b); EXPECT_FALSE(a < b); @@ -82,31 +84,31 @@ TEST(StringPiece, Compare) { TEST(StringPiece, ToString) { { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(std::string(""), s.ToString()); } { - paddle::StringPiece s(NULL); + paddle::string::Piece s(NULL); EXPECT_EQ(std::string(""), s.ToString()); } { - paddle::StringPiece s("hello"); + paddle::string::Piece s("hello"); EXPECT_EQ(std::string("hello"), s.ToString()); } } TEST(StringPiece, HasPrefixSuffix) { - using paddle::HasPrefix; - using paddle::HasSuffix; + using paddle::string::HasPrefix; + using paddle::string::HasSuffix; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_FALSE(HasPrefix(s, "something")); EXPECT_TRUE(HasPrefix(s, "")); EXPECT_FALSE(HasSuffix(s, "something")); EXPECT_TRUE(HasSuffix(s, "")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, "a")); EXPECT_TRUE(HasPrefix(s, "ap")); @@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) { } TEST(StringPiece, SkipPrefixSuffix) { - using paddle::SkipPrefix; - using paddle::SkipSuffix; + using paddle::string::SkipPrefix; + using paddle::string::SkipSuffix; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ("", SkipPrefix(s, 0)); EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument); @@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) { EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ("app", SkipPrefix(s, 0)); EXPECT_EQ("pp", SkipPrefix(s, 1)); EXPECT_EQ("p", SkipPrefix(s, 2)); @@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) { } TEST(StringPiece, TrimPrefixSuffix) { - using paddle::TrimPrefix; - using paddle::TrimSuffix; + using paddle::string::TrimPrefix; + using paddle::string::TrimSuffix; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ("", TrimPrefix(s, "")); EXPECT_EQ("", TrimPrefix(s, "something")); @@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) { EXPECT_EQ("", TrimSuffix(s, "something")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ("app", TrimPrefix(s, "")); EXPECT_EQ("pp", TrimPrefix(s, "a")); EXPECT_EQ("p", TrimPrefix(s, "ap")); @@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) { } TEST(StringPiece, Contains) { - using paddle::Contains; + using paddle::string::Contains; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_FALSE(Contains(s, "")); EXPECT_FALSE(Contains(s, "something")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_TRUE(Contains(s, "")); EXPECT_TRUE(Contains(s, "a")); EXPECT_TRUE(Contains(s, "p")); @@ -193,15 +195,15 @@ TEST(StringPiece, Contains) { } TEST(StringPiece, Index) { - using paddle::Index; - auto npos = paddle::StringPiece::npos; + using paddle::string::Index; + auto npos = paddle::string::Piece::npos; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(npos, Index(s, "")); EXPECT_EQ(npos, Index(s, "something")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ(0U, Index(s, "")); EXPECT_EQ(0U, Index(s, "a")); EXPECT_EQ(1U, Index(s, "p")); @@ -213,14 +215,14 @@ TEST(StringPiece, Index) { } TEST(StringPiece, Find) { - using paddle::Find; - auto npos = paddle::StringPiece::npos; + using paddle::string::Find; + auto npos = paddle::string::Piece::npos; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(npos, Find(s, 'a', 0U)); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ(0U, Find(s, 'a', 0U)); EXPECT_EQ(1U, Find(s, 'p', 0U)); EXPECT_EQ(1U, Find(s, 'p', 1U)); @@ -230,14 +232,14 @@ TEST(StringPiece, Find) { } TEST(StringPiece, RFind) { - using paddle::RFind; - auto npos = paddle::StringPiece::npos; + using paddle::string::RFind; + auto npos = paddle::string::Piece::npos; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(npos, RFind(s, 'a', 0U)); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ(2U, RFind(s, 'p', 2U)); EXPECT_EQ(0U, RFind(s, 'a', 2U)); EXPECT_EQ(1U, RFind(s, 'p', 1U)); @@ -247,15 +249,15 @@ TEST(StringPiece, RFind) { } TEST(StringPiece, SubStr) { - using paddle::SubStr; + using paddle::string::SubStr; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 1)); EXPECT_EQ("", SubStr(s, 1, 0)); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 2, 0)); @@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) { } TEST(StringPiece, StreamOutput) { - using paddle::StringPiece; + using paddle::string::Piece; std::stringstream o; - o << StringPiece(); + o << paddle::string::Piece(); EXPECT_EQ("", o.str()); - o << StringPiece("hello"); + o << paddle::string::Piece("hello"); EXPECT_EQ("hello", o.str()); - o << StringPiece(); + o << paddle::string::Piece(); EXPECT_EQ("hello", o.str()); } diff --git a/paddle/string/printf.h b/paddle/string/printf.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5ce63a8e8dfe77962ff1e7415911d381a28aac --- /dev/null +++ b/paddle/string/printf.h @@ -0,0 +1,99 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Compared with std::stringstream, there are primary purpose of +// string::Printf: +// +// 1. Type-safe printing, with why and how explained in +// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999. +// Implementation includes +// +// https://github.com/c42f/tinyformat +// boost::format +// std::stringstream +// +// std::stringstream is not convenient enough in many cases. For example: +// +// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n"; +// +// boost::format is the most convenient one. We can have +// +// std::cout << format("%2% %1%") % 36 % 77; +// +// or +// +// format fmter("%2% %1%"); +// fmter % 36; fmter % 77; +// std::cout << fmter.c_str(); +// +// But the overloading of % might be overkilling and it would be +// more efficient if it can write to std::cout directly. +// +// tinyformat has an interface compatible with the C-printf style, +// and it can writes to a stream or returns a std::string: +// +// std::cout << tfm::printf( +// "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// or +// +// tfm::format(std::cout, +// "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// 2. High-performance -- most printed strings are not too long and +// doens't need dynamic memory allocation. Many StringPrintf +// implementations doesn't enforce type-safe, but are +// high-performance, including +// +// https://developers.google.com/optimization/reference/base/stringprintf/ +// https://github.com/adobe/chromium/blob/master/base/stringprintf.h +// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h +// +// According to +// https://github.com/c42f/tinyformat#compile-time-and-code-bloat, +// boost::format runs too slow and results in large executable binary +// files. So here we port tinyformat. + +#pragma once + +#include +#include +#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat + +namespace paddle { +namespace string { + +template +void Fprintf(std::ostream& out, const char* fmt, const Args&... args) { + tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...)); +} + +template +std::string Sprintf(const char* fmt, const Args&... args) { + std::ostringstream oss; + Fprintf(oss, fmt, args...); + return oss.str(); +} + +template +void Printf(const char* fmt, const Args&... args) { + Fprintf(std::cout, fmt, args...); +} + +} // namespace string +} // namespace paddle diff --git a/paddle/string/printf_test.cc b/paddle/string/printf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8f2454165d741b3937f908dcfd87501940750d5 --- /dev/null +++ b/paddle/string/printf_test.cc @@ -0,0 +1,16 @@ +#include "paddle/string/printf.h" + +#include + +#include "gtest/gtest.h" + +TEST(StringPrintf, StringPrintf) { + std::string weekday = "Wednesday"; + const char* month = "July"; + size_t day = 27; + long hour = 14; + int min = 44; + EXPECT_EQ(std::string("Wednesday, July 27, 14:44"), + paddle::string::Sprintf( + "%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min)); +} diff --git a/paddle/string/tinyformat/tinyformat.h b/paddle/string/tinyformat/tinyformat.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e5e0160fb018b813c1dade727da2861a295147 --- /dev/null +++ b/paddle/string/tinyformat/tinyformat.h @@ -0,0 +1,902 @@ +// tinyformat.h +// Copyright (C) 2011, Chris Foster [chris42f (at) gmail (d0t) com] +// +// Boost Software License - Version 1.0 +// +// Permission is hereby granted, free of charge, to any person or organization +// obtaining a copy of the software and accompanying documentation covered by +// this license (the "Software") to use, reproduce, display, distribute, +// execute, and transmit the Software, and to prepare derivative works of the +// Software, and to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: +// +// The copyright notices in the Software and this entire statement, including +// the above license grant, this restriction and the following disclaimer, +// must be included in all copies of the Software, in whole or in part, and +// all derivative works of the Software, unless such copies or derivative +// works are solely in the form of machine-executable object code generated by +// a source language processor. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//------------------------------------------------------------------------------ +// Tinyformat: A minimal type safe printf replacement +// +// tinyformat.h is a type safe printf replacement library in a single C++ +// header file. Design goals include: +// +// * Type safety and extensibility for user defined types. +// * C99 printf() compatibility, to the extent possible using std::ostream +// * Simplicity and minimalism. A single header file to include and distribute +// with your projects. +// * Augment rather than replace the standard stream formatting mechanism +// * C++98 support, with optional C++11 niceties +// +// +// Main interface example usage +// ---------------------------- +// +// To print a date to std::cout: +// +// std::string weekday = "Wednesday"; +// const char* month = "July"; +// size_t day = 27; +// long hour = 14; +// int min = 44; +// +// tfm::printf("%s, %s %d, %.2d:%.2d\n", weekday, month, day, hour, min); +// +// The strange types here emphasize the type safety of the interface; it is +// possible to print a std::string using the "%s" conversion, and a +// size_t using the "%d" conversion. A similar result could be achieved +// using either of the tfm::format() functions. One prints on a user provided +// stream: +// +// tfm::format(std::cerr, "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// The other returns a std::string: +// +// std::string date = tfm::format("%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// std::cout << date; +// +// These are the three primary interface functions. There is also a +// convenience function printfln() which appends a newline to the usual result +// of printf() for super simple logging. +// +// +// User defined format functions +// ----------------------------- +// +// Simulating variadic templates in C++98 is pretty painful since it requires +// writing out the same function for each desired number of arguments. To make +// this bearable tinyformat comes with a set of macros which are used +// internally to generate the API, but which may also be used in user code. +// +// The three macros TINYFORMAT_ARGTYPES(n), TINYFORMAT_VARARGS(n) and +// TINYFORMAT_PASSARGS(n) will generate a list of n argument types, +// type/name pairs and argument names respectively when called with an integer +// n between 1 and 16. We can use these to define a macro which generates the +// desired user defined function with n arguments. To generate all 16 user +// defined function bodies, use the macro TINYFORMAT_FOREACH_ARGNUM. For an +// example, see the implementation of printf() at the end of the source file. +// +// Sometimes it's useful to be able to pass a list of format arguments through +// to a non-template function. The FormatList class is provided as a way to do +// this by storing the argument list in a type-opaque way. Continuing the +// example from above, we construct a FormatList using makeFormatList(): +// +// FormatListRef formatList = tfm::makeFormatList(weekday, month, day, hour, +// min); +// +// The format list can now be passed into any non-template function and used +// via a call to the vformat() function: +// +// tfm::vformat(std::cout, "%s, %s %d, %.2d:%.2d\n", formatList); +// +// +// Additional API information +// -------------------------- +// +// Error handling: Define TINYFORMAT_ERROR to customize the error handling for +// format strings which are unsupported or have the wrong number of format +// specifiers (calls assert() by default). +// +// User defined types: Uses operator<< for user defined types by default. +// Overload formatValue() for more control. + +#pragma once + +#include +#include +#include +#include + +namespace paddle { +namespace string { +namespace tinyformat { + +#ifndef TINYFORMAT_ERROR +#define TINYFORMAT_ERROR(reason) assert(0 && reason) +#endif + +//------------------------------------------------------------------------------ +namespace detail { + +// Test whether type T1 is convertible to type T2 +template +struct is_convertible { +private: + // two types of different size + struct fail { + char dummy[2]; + }; + struct succeed { + char dummy; + }; + // Try to convert a T1 to a T2 by plugging into tryConvert + static fail tryConvert(...); + static succeed tryConvert(const T2 &); + static const T1 &makeT1(); + +public: + // Standard trick: the (...) version of tryConvert will be chosen from + // the overload set only if the version taking a T2 doesn't match. + // Then we compare the sizes of the return types to check which + // function matched. Very neat, in a disgusting kind of way :) + static const bool value = sizeof(tryConvert(makeT1())) == sizeof(succeed); +}; + +// Format the value by casting to type fmtT. This default implementation +// should never be called. +template ::value> +struct formatValueAsType { + static void invoke(std::ostream & /*out*/, const T & /*value*/) { assert(0); } +}; +// Specialized version for types that can actually be converted to fmtT, as +// indicated by the "convertible" template parameter. +template +struct formatValueAsType { + static void invoke(std::ostream &out, const T &value) { + out << static_cast(value); + } +}; + +// Convert an arbitrary type to integer. The version with convertible=false +// throws an error. +template ::value> +struct convertToInt { + static int invoke(const T & /*value*/) { + TINYFORMAT_ERROR( + "tinyformat: Cannot convert from argument type to " + "integer for use as variable width or precision"); + return 0; + } +}; +// Specialization for convertToInt when conversion is possible +template +struct convertToInt { + static int invoke(const T &value) { return static_cast(value); } +}; + +// Format at most ntrunc characters to the given stream. +template +inline void formatTruncated(std::ostream &out, const T &value, int ntrunc) { + std::ostringstream tmp; + tmp << value; + std::string result = tmp.str(); + out.write(result.c_str(), + (std::min)(ntrunc, static_cast(result.size()))); +} +#define TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(type) \ + inline void formatTruncated(std::ostream &out, type *value, int ntrunc) { \ + std::streamsize len = 0; \ + while (len < ntrunc && value[len] != 0) ++len; \ + out.write(value, len); \ + } +// Overload for const char* and char*. Could overload for signed & unsigned +// char too, but these are technically unneeded for printf compatibility. +TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(const char) +TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(char) +#undef TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR + +} // namespace detail + +//------------------------------------------------------------------------------ +// Variable formatting functions. May be overridden for user-defined types if +// desired. + +/// Format a value into a stream, delegating to operator<< by default. +/// +/// Users may override this for their own types. When this function is called, +/// the stream flags will have been modified according to the format string. +/// The format specification is provided in the range [fmtBegin, fmtEnd). For +/// truncating conversions, ntrunc is set to the desired maximum number of +/// characters, for example "%.7s" calls formatValue with ntrunc = 7. +/// +/// By default, formatValue() uses the usual stream insertion operator +/// operator<< to format the type T, with special cases for the %c and %p +/// conversions. +template +inline void formatValue(std::ostream &out, + const char * /*fmtBegin*/, + const char *fmtEnd, + int ntrunc, + const T &value) { + // The mess here is to support the %c and %p conversions: if these + // conversions are active we try to convert the type to a char or const + // void* respectively and format that instead of the value itself. For the + // %p conversion it's important to avoid dereferencing the pointer, which + // could otherwise lead to a crash when printing a dangling (const char*). + const bool canConvertToChar = detail::is_convertible::value; + const bool canConvertToVoidPtr = + detail::is_convertible::value; + if (canConvertToChar && *(fmtEnd - 1) == 'c') + detail::formatValueAsType::invoke(out, value); + else if (canConvertToVoidPtr && *(fmtEnd - 1) == 'p') + detail::formatValueAsType::invoke(out, value); + else if (ntrunc >= 0) { + // Take care not to overread C strings in truncating conversions like + // "%.4s" where at most 4 characters may be read. + detail::formatTruncated(out, value, ntrunc); + } else + out << value; +} + +// Overloaded version for char types to support printing as an integer +#define TINYFORMAT_DEFINE_FORMATVALUE_CHAR(charType) \ + inline void formatValue(std::ostream &out, \ + const char * /*fmtBegin*/, \ + const char *fmtEnd, \ + int /**/, \ + charType value) { \ + switch (*(fmtEnd - 1)) { \ + case 'u': \ + case 'd': \ + case 'i': \ + case 'o': \ + case 'X': \ + case 'x': \ + out << static_cast(value); \ + break; \ + default: \ + out << value; \ + break; \ + } \ + } +// per 3.9.1: char, signed char and unsigned char are all distinct types +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(char) +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(signed char) +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(unsigned char) +#undef TINYFORMAT_DEFINE_FORMATVALUE_CHAR + +//------------------------------------------------------------------------------ +// Tools for emulating variadic templates in C++98. The basic idea here is +// stolen from the boost preprocessor metaprogramming library and cut down to +// be just general enough for what we need. + +#define TINYFORMAT_ARGTYPES(n) TINYFORMAT_ARGTYPES_##n +#define TINYFORMAT_VARARGS(n) TINYFORMAT_VARARGS_##n +#define TINYFORMAT_PASSARGS(n) TINYFORMAT_PASSARGS_##n +#define TINYFORMAT_PASSARGS_TAIL(n) TINYFORMAT_PASSARGS_TAIL_##n + +// To keep it as transparent as possible, the macros below have been generated +// using python via the excellent cog.py code generation script. This avoids +// the need for a bunch of complex (but more general) preprocessor tricks as +// used in boost.preprocessor. +// +// To rerun the code generation in place, use `cog.py -r tinyformat.h` +// (see http://nedbatchelder.com/code/cog). Alternatively you can just create +// extra versions by hand. + +/*[[[cog +maxParams = 16 + +def makeCommaSepLists(lineTemplate, elemTemplate, startInd=1): + for j in range(startInd,maxParams+1): + list = ', '.join([elemTemplate % {'i':i} for i in range(startInd,j+1)]) + cog.outl(lineTemplate % {'j':j, 'list':list}) + +makeCommaSepLists('#define TINYFORMAT_ARGTYPES_%(j)d %(list)s', + 'class T%(i)d') + +cog.outl() +makeCommaSepLists('#define TINYFORMAT_VARARGS_%(j)d %(list)s', + 'const T%(i)d& v%(i)d') + +cog.outl() +makeCommaSepLists('#define TINYFORMAT_PASSARGS_%(j)d %(list)s', 'v%(i)d') + +cog.outl() +cog.outl('#define TINYFORMAT_PASSARGS_TAIL_1') +makeCommaSepLists('#define TINYFORMAT_PASSARGS_TAIL_%(j)d , %(list)s', + 'v%(i)d', startInd = 2) + +cog.outl() +cog.outl('#define TINYFORMAT_FOREACH_ARGNUM(m) \\\n ' + + ' '.join(['m(%d)' % (j,) for j in range(1,maxParams+1)])) +]]]*/ +#define TINYFORMAT_ARGTYPES_1 class T1 +#define TINYFORMAT_ARGTYPES_2 class T1, class T2 +#define TINYFORMAT_ARGTYPES_3 class T1, class T2, class T3 +#define TINYFORMAT_ARGTYPES_4 class T1, class T2, class T3, class T4 +#define TINYFORMAT_ARGTYPES_5 class T1, class T2, class T3, class T4, class T5 +#define TINYFORMAT_ARGTYPES_6 \ + class T1, class T2, class T3, class T4, class T5, class T6 +#define TINYFORMAT_ARGTYPES_7 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7 +#define TINYFORMAT_ARGTYPES_8 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, class T8 +#define TINYFORMAT_ARGTYPES_9 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9 +#define TINYFORMAT_ARGTYPES_10 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10 +#define TINYFORMAT_ARGTYPES_11 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11 +#define TINYFORMAT_ARGTYPES_12 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12 +#define TINYFORMAT_ARGTYPES_13 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13 +#define TINYFORMAT_ARGTYPES_14 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14 +#define TINYFORMAT_ARGTYPES_15 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14, class T15 +#define TINYFORMAT_ARGTYPES_16 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14, class T15, class T16 + +#define TINYFORMAT_VARARGS_1 const T1 &v1 +#define TINYFORMAT_VARARGS_2 const T1 &v1, const T2 &v2 +#define TINYFORMAT_VARARGS_3 const T1 &v1, const T2 &v2, const T3 &v3 +#define TINYFORMAT_VARARGS_4 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4 +#define TINYFORMAT_VARARGS_5 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5 +#define TINYFORMAT_VARARGS_6 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6 +#define TINYFORMAT_VARARGS_7 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7 +#define TINYFORMAT_VARARGS_8 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8 +#define TINYFORMAT_VARARGS_9 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9 +#define TINYFORMAT_VARARGS_10 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10 +#define TINYFORMAT_VARARGS_11 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11 +#define TINYFORMAT_VARARGS_12 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12 +#define TINYFORMAT_VARARGS_13 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13 +#define TINYFORMAT_VARARGS_14 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14 +#define TINYFORMAT_VARARGS_15 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \ + const T15 &v15 +#define TINYFORMAT_VARARGS_16 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \ + const T15 &v15, const T16 &v16 + +#define TINYFORMAT_PASSARGS_1 v1 +#define TINYFORMAT_PASSARGS_2 v1, v2 +#define TINYFORMAT_PASSARGS_3 v1, v2, v3 +#define TINYFORMAT_PASSARGS_4 v1, v2, v3, v4 +#define TINYFORMAT_PASSARGS_5 v1, v2, v3, v4, v5 +#define TINYFORMAT_PASSARGS_6 v1, v2, v3, v4, v5, v6 +#define TINYFORMAT_PASSARGS_7 v1, v2, v3, v4, v5, v6, v7 +#define TINYFORMAT_PASSARGS_8 v1, v2, v3, v4, v5, v6, v7, v8 +#define TINYFORMAT_PASSARGS_9 v1, v2, v3, v4, v5, v6, v7, v8, v9 +#define TINYFORMAT_PASSARGS_10 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define TINYFORMAT_PASSARGS_11 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define TINYFORMAT_PASSARGS_12 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define TINYFORMAT_PASSARGS_13 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13 +#define TINYFORMAT_PASSARGS_14 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 +#define TINYFORMAT_PASSARGS_15 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15 +#define TINYFORMAT_PASSARGS_16 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16 + +#define TINYFORMAT_PASSARGS_TAIL_1 +#define TINYFORMAT_PASSARGS_TAIL_2 , v2 +#define TINYFORMAT_PASSARGS_TAIL_3 , v2, v3 +#define TINYFORMAT_PASSARGS_TAIL_4 , v2, v3, v4 +#define TINYFORMAT_PASSARGS_TAIL_5 , v2, v3, v4, v5 +#define TINYFORMAT_PASSARGS_TAIL_6 , v2, v3, v4, v5, v6 +#define TINYFORMAT_PASSARGS_TAIL_7 , v2, v3, v4, v5, v6, v7 +#define TINYFORMAT_PASSARGS_TAIL_8 , v2, v3, v4, v5, v6, v7, v8 +#define TINYFORMAT_PASSARGS_TAIL_9 , v2, v3, v4, v5, v6, v7, v8, v9 +#define TINYFORMAT_PASSARGS_TAIL_10 , v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define TINYFORMAT_PASSARGS_TAIL_11 , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define TINYFORMAT_PASSARGS_TAIL_12 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define TINYFORMAT_PASSARGS_TAIL_13 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13 +#define TINYFORMAT_PASSARGS_TAIL_14 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 +#define TINYFORMAT_PASSARGS_TAIL_15 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15 +#define TINYFORMAT_PASSARGS_TAIL_16 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16 + +#define TINYFORMAT_FOREACH_ARGNUM(m) \ + m(1) m(2) m(3) m(4) m(5) m(6) m(7) m(8) m(9) m(10) m(11) m(12) m(13) m(14) \ + m(15) m(16) +//[[[end]]] + +namespace detail { + +// Type-opaque holder for an argument to format(), with associated actions on +// the type held as explicit function pointers. This allows FormatArg's for +// each argument to be allocated as a homogenous array inside FormatList +// whereas a naive implementation based on inheritance does not. +class FormatArg { +public: + FormatArg() {} + + template + FormatArg(const T &value) + : m_value(static_cast(&value)), + m_formatImpl(&formatImpl), + m_toIntImpl(&toIntImpl) {} + + void format(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc) const { + m_formatImpl(out, fmtBegin, fmtEnd, ntrunc, m_value); + } + + int toInt() const { return m_toIntImpl(m_value); } + +private: + template + static void formatImpl(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc, + const void *value) { + formatValue(out, fmtBegin, fmtEnd, ntrunc, *static_cast(value)); + } + + template + static int toIntImpl(const void *value) { + return convertToInt::invoke(*static_cast(value)); + } + + const void *m_value; + void (*m_formatImpl)(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc, + const void *value); + int (*m_toIntImpl)(const void *value); +}; + +// Parse and return an integer from the string c, as atoi() +// On return, c is set to one past the end of the integer. +inline int parseIntAndAdvance(const char *&c) { + int i = 0; + for (; *c >= '0' && *c <= '9'; ++c) i = 10 * i + (*c - '0'); + return i; +} + +// Print literal part of format string and return next format spec +// position. +// +// Skips over any occurrences of '%%', printing a literal '%' to the +// output. The position of the first % character of the next +// nontrivial format spec is returned, or the end of string. +inline const char *printFormatStringLiteral(std::ostream &out, + const char *fmt) { + const char *c = fmt; + for (;; ++c) { + switch (*c) { + case '\0': + out.write(fmt, c - fmt); + return c; + case '%': + out.write(fmt, c - fmt); + if (*(c + 1) != '%') return c; + // for "%%", tack trailing % onto next literal section. + fmt = ++c; + break; + default: + break; + } + } +} + +// Parse a format string and set the stream state accordingly. +// +// The format mini-language recognized here is meant to be the one from C99, +// with the form "%[flags][width][.precision][length]type". +// +// Formatting options which can't be natively represented using the ostream +// state are returned in spacePadPositive (for space padded positive numbers) +// and ntrunc (for truncating conversions). argIndex is incremented if +// necessary to pull out variable width and precision . The function returns a +// pointer to the character after the end of the current format spec. +inline const char *streamStateFromFormat(std::ostream &out, + bool &spacePadPositive, + int &ntrunc, + const char *fmtStart, + const detail::FormatArg *formatters, + int &argIndex, + int numFormatters) { + if (*fmtStart != '%') { + TINYFORMAT_ERROR( + "tinyformat: Not enough conversion specifiers in format string"); + return fmtStart; + } + // Reset stream state to defaults. + out.width(0); + out.precision(6); + out.fill(' '); + // Reset most flags; ignore irrelevant unitbuf & skipws. + out.unsetf(std::ios::adjustfield | std::ios::basefield | + std::ios::floatfield | std::ios::showbase | std::ios::boolalpha | + std::ios::showpoint | std::ios::showpos | std::ios::uppercase); + bool precisionSet = false; + bool widthSet = false; + int widthExtra = 0; + const char *c = fmtStart + 1; + // 1) Parse flags + for (;; ++c) { + switch (*c) { + case '#': + out.setf(std::ios::showpoint | std::ios::showbase); + continue; + case '0': + // overridden by left alignment ('-' flag) + if (!(out.flags() & std::ios::left)) { + // Use internal padding so that numeric values are + // formatted correctly, eg -00010 rather than 000-10 + out.fill('0'); + out.setf(std::ios::internal, std::ios::adjustfield); + } + continue; + case '-': + out.fill(' '); + out.setf(std::ios::left, std::ios::adjustfield); + continue; + case ' ': + // overridden by show positive sign, '+' flag. + if (!(out.flags() & std::ios::showpos)) spacePadPositive = true; + continue; + case '+': + out.setf(std::ios::showpos); + spacePadPositive = false; + widthExtra = 1; + continue; + default: + break; + } + break; + } + // 2) Parse width + if (*c >= '0' && *c <= '9') { + widthSet = true; + out.width(parseIntAndAdvance(c)); + } + if (*c == '*') { + widthSet = true; + int width = 0; + if (argIndex < numFormatters) + width = formatters[argIndex++].toInt(); + else + TINYFORMAT_ERROR( + "tinyformat: Not enough arguments to read variable width"); + if (width < 0) { + // negative widths correspond to '-' flag set + out.fill(' '); + out.setf(std::ios::left, std::ios::adjustfield); + width = -width; + } + out.width(width); + ++c; + } + // 3) Parse precision + if (*c == '.') { + ++c; + int precision = 0; + if (*c == '*') { + ++c; + if (argIndex < numFormatters) + precision = formatters[argIndex++].toInt(); + else + TINYFORMAT_ERROR( + "tinyformat: Not enough arguments to read variable precision"); + } else { + if (*c >= '0' && *c <= '9') + precision = parseIntAndAdvance(c); + else if (*c == '-') // negative precisions ignored, treated as zero. + parseIntAndAdvance(++c); + } + out.precision(precision); + precisionSet = true; + } + // 4) Ignore any C99 length modifier + while (*c == 'l' || *c == 'h' || *c == 'L' || *c == 'j' || *c == 'z' || + *c == 't') + ++c; + // 5) We're up to the conversion specifier character. + // Set stream flags based on conversion specifier (thanks to the + // boost::format class for forging the way here). + bool intConversion = false; + switch (*c) { + case 'u': + case 'd': + case 'i': + out.setf(std::ios::dec, std::ios::basefield); + intConversion = true; + break; + case 'o': + out.setf(std::ios::oct, std::ios::basefield); + intConversion = true; + break; + case 'X': + out.setf(std::ios::uppercase); + case 'x': + case 'p': + out.setf(std::ios::hex, std::ios::basefield); + intConversion = true; + break; + case 'E': + out.setf(std::ios::uppercase); + case 'e': + out.setf(std::ios::scientific, std::ios::floatfield); + out.setf(std::ios::dec, std::ios::basefield); + break; + case 'F': + out.setf(std::ios::uppercase); + case 'f': + out.setf(std::ios::fixed, std::ios::floatfield); + break; + case 'G': + out.setf(std::ios::uppercase); + case 'g': + out.setf(std::ios::dec, std::ios::basefield); + // As in boost::format, let stream decide float format. + out.flags(out.flags() & ~std::ios::floatfield); + break; + case 'a': + case 'A': + TINYFORMAT_ERROR( + "tinyformat: the %a and %A conversion specs " + "are not supported"); + break; + case 'c': + // Handled as special case inside formatValue() + break; + case 's': + if (precisionSet) ntrunc = static_cast(out.precision()); + // Make %s print booleans as "true" and "false" + out.setf(std::ios::boolalpha); + break; + case 'n': + // Not supported - will cause problems! + TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported"); + break; + case '\0': + TINYFORMAT_ERROR( + "tinyformat: Conversion spec incorrectly " + "terminated by end of string"); + return c; + default: + break; + } + if (intConversion && precisionSet && !widthSet) { + // "precision" for integers gives the minimum number of digits (to be + // padded with zeros on the left). This isn't really supported by the + // iostreams, but we can approximately simulate it with the width if + // the width isn't otherwise used. + out.width(out.precision() + widthExtra); + out.setf(std::ios::internal, std::ios::adjustfield); + out.fill('0'); + } + return c + 1; +} + +//------------------------------------------------------------------------------ +inline void formatImpl(std::ostream &out, + const char *fmt, + const detail::FormatArg *formatters, + int numFormatters) { + // Saved stream state + std::streamsize origWidth = out.width(); + std::streamsize origPrecision = out.precision(); + std::ios::fmtflags origFlags = out.flags(); + char origFill = out.fill(); + + for (int argIndex = 0; argIndex < numFormatters; ++argIndex) { + // Parse the format string + fmt = printFormatStringLiteral(out, fmt); + bool spacePadPositive = false; + int ntrunc = -1; + const char *fmtEnd = streamStateFromFormat(out, + spacePadPositive, + ntrunc, + fmt, + formatters, + argIndex, + numFormatters); + if (argIndex >= numFormatters) { + // Check args remain after reading any variable width/precision + TINYFORMAT_ERROR("tinyformat: Not enough format arguments"); + return; + } + const FormatArg &arg = formatters[argIndex]; + // Format the arg into the stream. + if (!spacePadPositive) + arg.format(out, fmt, fmtEnd, ntrunc); + else { + // The following is a special case with no direct correspondence + // between stream formatting and the printf() behaviour. Simulate + // it crudely by formatting into a temporary string stream and + // munging the resulting string. + std::ostringstream tmpStream; + tmpStream.copyfmt(out); + tmpStream.setf(std::ios::showpos); + arg.format(tmpStream, fmt, fmtEnd, ntrunc); + std::string result = tmpStream.str(); // allocates... yuck. + for (size_t i = 0, iend = result.size(); i < iend; ++i) + if (result[i] == '+') result[i] = ' '; + out << result; + } + fmt = fmtEnd; + } + + // Print remaining part of format string. + fmt = printFormatStringLiteral(out, fmt); + if (*fmt != '\0') + TINYFORMAT_ERROR( + "tinyformat: Too many conversion specifiers in format string"); + + // Restore stream state + out.width(origWidth); + out.precision(origPrecision); + out.flags(origFlags); + out.fill(origFill); +} + +} // namespace detail + +/// List of template arguments format(), held in a type-opaque way. +/// +/// A const reference to FormatList (typedef'd as FormatListRef) may be +/// conveniently used to pass arguments to non-template functions: All type +/// information has been stripped from the arguments, leaving just enough of a +/// common interface to perform formatting as required. +class FormatList { +public: + FormatList(detail::FormatArg *formatters, int N) + : m_formatters(formatters), m_N(N) {} + + friend void vformat(std::ostream &out, + const char *fmt, + const FormatList &list); + +private: + const detail::FormatArg *m_formatters; + int m_N; +}; + +/// Reference to type-opaque format list for passing to vformat() +typedef const FormatList &FormatListRef; + +namespace detail { + +// Format list subclass with fixed storage to avoid dynamic allocation +template +class FormatListN : public FormatList { +public: + template + FormatListN(const Args &... args) + : FormatList(&m_formatterStore[0], N), + m_formatterStore{FormatArg(args)...} { + static_assert(sizeof...(args) == N, "Number of args must be N"); + } + +private: + FormatArg m_formatterStore[N]; +}; + +// Special 0-arg version - MSVC says zero-sized C array in struct is nonstandard +template <> +class FormatListN<0> : public FormatList { +public: + FormatListN() : FormatList(0, 0) {} +}; + +} // namespace detail + +//------------------------------------------------------------------------------ +// Primary API functions + +/// Make type-agnostic format list from list of template arguments. +/// +/// The exact return type of this function is an implementation detail and +/// shouldn't be relied upon. Instead it should be stored as a FormatListRef: +/// +/// FormatListRef formatList = makeFormatList( /*...*/ ); +template +detail::FormatListN makeFormatList(const Args &... args) { + return detail::FormatListN(args...); +} + +/// Format list of arguments to the stream according to the given format string. +/// +/// The name vformat() is chosen for the semantic similarity to vprintf(): the +/// list of format arguments is held in a single function argument. +inline void vformat(std::ostream &out, const char *fmt, FormatListRef list) { + detail::formatImpl(out, fmt, list.m_formatters, list.m_N); +} + +/// Format list of arguments to the stream according to given format string. +template +void format(std::ostream &out, const char *fmt, const Args &... args) { + vformat(out, fmt, makeFormatList(args...)); +} + +/// Format list of arguments according to the given format string and return +/// the result as a string. +template +std::string format(const char *fmt, const Args &... args) { + std::ostringstream oss; + format(oss, fmt, args...); + return oss.str(); +} + +/// Format list of arguments to std::cout, according to the given format string +template +void printf(const char *fmt, const Args &... args) { + format(std::cout, fmt, args...); +} + +template +void printfln(const char *fmt, const Args &... args) { + format(std::cout, fmt, args...); + std::cout << '\n'; +} + +} // namespace tinyformat +} // namespace string +} // namespace paddle diff --git a/paddle/strings/CMakeLists.txt b/paddle/strings/CMakeLists.txt deleted file mode 100644 index 4e55eecd484c0e218ecd51bbd19b3eb4f6f92a25..0000000000000000000000000000000000000000 --- a/paddle/strings/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -cc_library(stringpiece SRCS stringpiece.cc) -cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags) diff --git a/paddle/strings/stringpiece.cc b/paddle/strings/stringpiece.cc deleted file mode 100644 index 415b3558d5dfffde26275bcb16ea3922424ca9f3..0000000000000000000000000000000000000000 --- a/paddle/strings/stringpiece.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "paddle/strings/stringpiece.h" - -#include - -#include -#include -#include - -namespace paddle { - -StringPiece::StringPiece() : data_(NULL), size_(0) {} - -StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) { - if (d == NULL && n != 0) - throw std::invalid_argument( - "StringPiece requires len to be 0 for NULL data"); -} - -StringPiece::StringPiece(const char* s) : data_(s) { - size_ = (s == NULL) ? 0 : strlen(s); -} - -StringPiece::StringPiece(const std::string& s) - : data_(s.data()), size_(s.size()) {} - -char StringPiece::operator[](size_t n) const { - if (n >= len()) - throw std::invalid_argument("index out of StringPiece length"); - return data_[n]; -} - -int Compare(StringPiece a, StringPiece b) { - const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); - int r = memcmp(a.data(), b.data(), min_len); - if (r == 0) { - if (a.len() < b.len()) - return -1; - else if (a.len() > b.len()) - return 1; - } - return r; -} - -bool operator==(StringPiece x, StringPiece y) { - return ((x.len() == y.len()) && - (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); -} - -bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } - -bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; } -bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; } - -bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; } -bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; } - -bool HasPrefix(StringPiece s, StringPiece x) { - return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); -} - -bool HasSuffix(StringPiece s, StringPiece x) { - return ((s.len() >= x.len()) && - (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); -} - -StringPiece SkipPrefix(StringPiece s, size_t n) { - if (n > s.len()) - throw std::invalid_argument("Skip distance larger than StringPiece length"); - return StringPiece(s.data() + n, s.len() - n); -} - -StringPiece SkipSuffix(StringPiece s, size_t n) { - if (n > s.len()) - throw std::invalid_argument("Skip distance larger than StringPiece length"); - return StringPiece(s.data(), s.len() - n); -} - -StringPiece TrimPrefix(StringPiece s, StringPiece x) { - return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; -} - -StringPiece TrimSuffix(StringPiece s, StringPiece x) { - return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; -} - -bool Contains(StringPiece s, StringPiece sub) { - return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); -} - -size_t Index(StringPiece s, StringPiece sub) { - auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); - return e != s.end() ? e - s.data() : StringPiece::npos; -} - -size_t Find(StringPiece s, char c, size_t pos) { - if (pos >= s.len()) { - return StringPiece::npos; - } - const char* result = - reinterpret_cast(memchr(s.data() + pos, c, s.len() - pos)); - return result != nullptr ? result - s.data() : StringPiece::npos; -} - -size_t RFind(StringPiece s, char c, size_t pos) { - if (s.len() == 0) return StringPiece::npos; - for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); - p--) { - if (*p == c) { - return p - s.data(); - } - } - return StringPiece::npos; -} - -StringPiece SubStr(StringPiece s, size_t pos, size_t n) { - if (pos > s.len()) pos = s.len(); - if (n > s.len() - pos) n = s.len() - pos; - return StringPiece(s.data() + pos, n); -} - -std::ostream& operator<<(std::ostream& o, StringPiece piece) { - return o << piece.ToString(); -} - -} // namespace paddle diff --git a/paddle/utils/Error.h b/paddle/utils/Error.h index f3d535c69c53fa350612459560dd9ac7c279aa72..27ddaab3f003110a2684a871a2de17afb473d660 100644 --- a/paddle/utils/Error.h +++ b/paddle/utils/Error.h @@ -19,7 +19,21 @@ limitations under the License. */ #include #include #include -#include "paddle/platform/must_check.h" + +/** + * __must_check macro. It make the function's return value must be used, + * otherwise it will raise a compile warning. And also Paddle treat all compile + * warnings as errors. + */ +#ifdef __GNUC__ +#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400 +#define __must_check __attribute__((warn_unused_result)) +#else +#define __must_check +#endif +#else +#define __must_check +#endif namespace paddle {