提交 49fd49f7 编写于 作者: L liaogang

Fix conflicts

...@@ -93,6 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) ...@@ -93,6 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl")
endif(NOT APPLE) endif(NOT APPLE)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
...@@ -103,6 +104,7 @@ function(merge_static_libs TARGET_NAME) ...@@ -103,6 +104,7 @@ function(merge_static_libs TARGET_NAME)
foreach(lib ${libs}) foreach(lib ${libs})
list(APPEND libs_deps ${${lib}_LIB_DEPENDS}) list(APPEND libs_deps ${${lib}_LIB_DEPENDS})
endforeach() endforeach()
list(REMOVE_DUPLICATES libs_deps)
if(APPLE) # Use OSX's libtool to merge archives if(APPLE) # Use OSX's libtool to merge archives
# To produce a library we need at least one source file. # To produce a library we need at least one source file.
......
...@@ -20,6 +20,8 @@ func main() { ...@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
...@@ -31,18 +33,20 @@ func main() { ...@@ -31,18 +33,20 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register() idx, err = e.Register()
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
s, err := pserver.NewService(idx) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -18,6 +18,8 @@ const ( ...@@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
...@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err := e.etcdClient.Put(ctx, key, string(value))
cancel()
if err != nil {
return err
}
return nil
}
...@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
} }
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{} o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": len(p.Content),
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:") }).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content))) cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
(*C.char)(nullPtr), 0)
return o return o
} }
...@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte { ...@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float) return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
} }
func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
}
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType { if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
......
...@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) { ...@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p, Param: p,
Config: config, Config: config,
} }
o := newOptimizer(param) o := newOptimizer(param, nil)
o.Cleanup() o.Cleanup()
} }
package pserver package pserver
import ( import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"os"
"path/filepath"
"strconv"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -39,26 +51,55 @@ type ParameterWithConfig struct { ...@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// ParameterCheckpoint is Parameter and State checkpoint
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
checkpointInterval time.Duration
mu sync.Mutex checkpointPath string
optMap map[string]*optimizer client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(idx int) (*Service, error) { func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
} }
s.optMap = make(map[string]*optimizer) s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil return s, nil
} }
...@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is // TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil return nil
} }
...@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil return nil
} }
// Save tells the parameter server to save parameters. // pserver save checkpoint
func (s *Service) Save(path string, dummy *int) error { func (s *Service) doCheckpoint() error {
<-s.initialized <-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(cp)
if err != nil {
return err
}
cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
// TODO cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
return err
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
return err
}
return nil return nil
} }
...@@ -15,7 +15,8 @@ const ( ...@@ -15,7 +15,8 @@ const (
) )
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) { ...@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) { ...@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
...@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) { ...@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{} ch <- struct{}{}
}() }()
wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
select { select {
...@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}
...@@ -16,3 +16,6 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. ...@@ -16,3 +16,6 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto)
#include "paddle/framework/net.h"
namespace paddle {
namespace framework {
PlainNet::PlainNet(const NetDesc& def) {}
void PlainNet::InferShape(Scope* scope) {
for (auto& op : ops_) {
op.InferShape();
}
}
void PlainNet::Run(std::shared_ptr<Scope> scope, DeviceContext* ctx) {
for (auto& op : ops_) {
op.Run(ctx);
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/net_proto.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
using namespace paddle::platform;
// operator's index stored in a network.
typedef int OpIndex;
/**
* NOTE following codes are some definitions of unimplemented concepts.
* We write some basic implementation to make Net compilable. These APIs will
* keep updating if the concepts related are implemented.
*/
struct OpDesc;
struct OpAttrs {};
class Operator {
public:
Operator(const OpDesc &def) {}
void InferShape() {}
void Run(DeviceContext *ctx) {}
};
/**
* @brief Network that manage the operators it has.
*
* Network is the container and controller of a set of operators, user can build
* a real network from a NetDesc which is a protobuf message and use
* Network.Run() * to run all the operators in the network.
* A network object knows all Operators belonging to this network. Variables,
* which are inputs and outputs of these operators, are created and managed by a
* hierarchy of Scope objects.
*
* This is the base class of network, all the networks should implement the apis
* it defines.
*/
class Net {
public:
/**
* @brief Infer shapes of all inputs and outputs of operators.
*/
virtual void InferShape(Scope *scope) = 0;
/**
* @brief Run the network.
*
* Run all the operators and return success(true) or not, with all the
* variables are located in `scope`. `context` describes the detail execution
* environment for ops. `begin` and `end` specify the scope of `ops_` to run,
* If no positive indexes are provided, all operators in `ops_` will run.
*/
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) = 0;
/**
* @brief Add an Operator according to `def`.
*/
virtual OpIndex AddOp(const OpProto &def) = 0;
/**
* @brief Add optimizer operators acctording to `attrs`.
*/
virtual void AddOptimizerOps(const OpAttrs &attrs) = 0;
/**
* @brief Add backward operators.
*/
virtual void AddBackwardOps() = 0;
/**
* @brief Create a network.
*/
static std::unique_ptr<Net> Create(const NetDesc &def = NetDesc());
virtual ~Net() {}
};
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
public:
/**
* @brief Initialize a PlainNet.
*
* Initialize from a network describe by `def`. NetDesc is the definition of
* a network.
*/
PlainNet(const NetDesc &def);
/**
* Infer all the operators' input and output varialbes' shapes, will be called
* before every mini-batch
*/
virtual void InferShape(Scope *scope) override;
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) override;
/**
* @brief Add an operator to this network.
*/
virtual OpIndex AddOp(const OpProto &def) override;
/**
* @brief Add all optimizer operators related into the network.
*/
virtual void AddOptimizerOps(const OpAttrs &attrs) override;
/**
* @brief Add all backward operators related into the network.
*/
virtual void AddBackwardOps() override;
virtual ~PlainNet() override {}
protected:
/**
* @brief Build the network.
*
* Create operators accordding to `def`, will be called by the constructor.
*/
void BuildNet(const NetDesc &def);
/**
* @brief Add an operator into this network.
*
* Add a operator which is identified as `type` and has attributes described
* in `attrs`, the `inputs` are the keys of readonly input variables,
* `outputs` are keys of mutable output variables. An `OpIndex` will be
* returned to indicate the offset of the new operator in `ops_`.
*/
OpIndex AddOp(const std::string &type, const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs,
const OpAttrs &attrs = OpAttrs());
private:
// the operators owned by `Network`.
std::vector<Operator> ops_;
};
} // namespace framework
} // namespace paddle
syntax="proto2";
package paddle.framework;
import "op_proto.proto";
message NetDesc {
// network identification
optional string name = 1;
// operator contains in network
repeated OpProto operators = 2;
// network type to run with. e.g "plainNet", "DAG"
optional string net_type = 3;
// num worker always
optional int32 num_workers = 4;
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
class FakeFC : public Operator {}
} // namespace framework
} // namespace paddle
...@@ -119,4 +119,4 @@ TEST(OpRegistry, CustomChecker) { ...@@ -119,4 +119,4 @@ TEST(OpRegistry, CustomChecker) {
for (size_t i = 0; i < debug_str.length(); ++i) { for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]); ASSERT_EQ(debug_str[i], str[i]);
} }
} }
\ No newline at end of file
...@@ -151,7 +151,13 @@ void BuddyAllocator::Free(void* p) { ...@@ -151,7 +151,13 @@ void BuddyAllocator::Free(void* p) {
pool_.insert( pool_.insert(
IndexSizeAddress(block->index(cache_), block->total_size(cache_), block)); IndexSizeAddress(block->index(cache_), block->total_size(cache_), block));
// TODO(gangliao): Clean up if existing too much free memory // Clean up if existing too much free memory
// Prefer freeing fallback allocation first
CleanIdleFallBackAlloc();
// Free normal allocation
CleanIdleNormalAlloc();
} }
size_t BuddyAllocator::Used() { return total_used_; } size_t BuddyAllocator::Used() { return total_used_; }
...@@ -249,6 +255,11 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it, ...@@ -249,6 +255,11 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
return block; return block;
} }
void BuddyAllocator::CleanIdleFallBackAlloc() {
}
} // namespace detail } // namespace detail
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -43,10 +43,11 @@ class BuddyAllocator { ...@@ -43,10 +43,11 @@ class BuddyAllocator {
size_t Used(); size_t Used();
public: public:
// Disable copy and assignment. // Disable copy and assignment
BuddyAllocator(const BuddyAllocator&) = delete; BuddyAllocator(const BuddyAllocator&) = delete;
BuddyAllocator& operator=(const BuddyAllocator&) = delete; BuddyAllocator& operator=(const BuddyAllocator&) = delete;
private:
// Tuple (allocator index, memory size, memory address) // Tuple (allocator index, memory size, memory address)
using IndexSizeAddress = std::tuple<size_t, size_t, void*>; using IndexSizeAddress = std::tuple<size_t, size_t, void*>;
// Each element in PoolSet is a free allocation // Each element in PoolSet is a free allocation
...@@ -59,16 +60,25 @@ class BuddyAllocator { ...@@ -59,16 +60,25 @@ class BuddyAllocator {
PoolSet::iterator RefillPool(); PoolSet::iterator RefillPool();
/** /**
* \brief Find the suitable chunk from existing pool * \brief Find the suitable chunk from existing pool and split
* it to left and right buddies
*
* \param it the iterator of pool list
* \param size the size of allocation
* *
* \param it pool iterator which contains suitable block. * \return the left buddy address
* \param size the size of allocation.
*/ */
void* SplitToAlloc(PoolSet::iterator it, size_t size); void* SplitToAlloc(PoolSet::iterator it, size_t size);
/*! \brief Find the existing chunk which used to allocation */ /*! \brief Find the existing chunk which used to allocation */
PoolSet::iterator FindExistChunk(size_t size); PoolSet::iterator FindExistChunk(size_t size);
/*! \brief Clean idle fallback allocation */
void CleanIdleFallBackAlloc();
/*! \brief Clean idle normal allocation */
void CleanIdleNormalAlloc();
private: private:
size_t total_used_ = 0; // the total size of used memory size_t total_used_ = 0; // the total size of used memory
size_t total_free_ = 0; // the total size of free memory size_t total_free_ = 0; // the total size of free memory
......
...@@ -30,6 +30,21 @@ TEST(BuddyAllocator, CPUAllocation) { ...@@ -30,6 +30,21 @@ TEST(BuddyAllocator, CPUAllocation) {
paddle::memory::Free(cpu, p); paddle::memory::Free(cpu, p);
} }
TEST(BuddyAllocator, CPUMultAlloc) {
paddle::platform::CPUPlace cpu;
std::vector<void*> ps;
ps.reserve(8);
for (auto size : {256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) {
ps.emplace_back(paddle::memory::Alloc(cpu, size));
}
for (auto p : ps) {
paddle::memory::Free(cpu, p);
}
}
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
TEST(BuddyAllocator, GPUAllocation) { TEST(BuddyAllocator, GPUAllocation) {
......
...@@ -7,3 +7,5 @@ cc_library(place SRCS place.cc) ...@@ -7,3 +7,5 @@ cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_library(dynamic_loader SRCS dynload/dynamic_loader.cc DEPS gflags glog) cc_library(dynamic_loader SRCS dynload/dynamic_loader.cc DEPS gflags glog)
nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/enforce.h"
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
class DeviceContext {
public:
virtual ~DeviceContext() {}
};
class CPUDeviceContext : public DeviceContext {};
#ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard {
public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
if (previous_ != new_place) {
paddle::platform::SetDeviceId(new_place.device);
}
}
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
private:
GPUPlace previous_;
};
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
}
void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice eigen_device() { return *eigen_device_; }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
}
return blas_handle_;
}
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
}
return dnn_handle_;
}
curandGenerator_t curand_generator() {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed");
}
return rand_generator_;
}
~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed");
}
delete eigen_stream_;
delete eigen_device_;
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed");
}
private:
GPUPlace gpu_place_;
cudaStream_t stream_;
Eigen::CudaStreamDevice* eigen_stream_;
Eigen::GpuDevice* eigen_device_;
cublasHandle_t blas_handle_{nullptr};
cudnnHandle_t dnn_handle_{nullptr};
int random_seed_;
curandGenerator_t rand_generator_{nullptr};
};
#endif
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"
TEST(CUDADeviceContext, Init) {
int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) {
paddle::platform::CUDADeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device.stream());
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
delete device_context;
}
}
...@@ -1395,7 +1395,7 @@ def inputs(layers, *args): ...@@ -1395,7 +1395,7 @@ def inputs(layers, *args):
if len(args) != 0: if len(args) != 0:
layers.extend(args) layers.extend(args)
Inputs(* [l.name for l in layers]) Inputs(*[l.name for l in layers])
def outputs(layers, *args): def outputs(layers, *args):
...@@ -1438,7 +1438,7 @@ def outputs(layers, *args): ...@@ -1438,7 +1438,7 @@ def outputs(layers, *args):
assert len(layers) > 0 assert len(layers) > 0
if HasInputsSet(): # input already set if HasInputsSet(): # input already set
Outputs(* [l.name for l in layers]) Outputs(*[l.name for l in layers])
return # just return outputs. return # just return outputs.
if len(layers) != 1: if len(layers) != 1:
......
...@@ -32,9 +32,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' ...@@ -32,9 +32,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later. # this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# this is the pretrained model, whose bleu = 26.92 # BLEU of this trained model is 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4' MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册