// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pserver import ( "bufio" "bytes" "crypto/md5" "encoding/gob" "encoding/hex" "encoding/json" "errors" "fmt" "io/ioutil" "os" "path" "strconv" "sync" "time" uuid "github.com/satori/go.uuid" log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. type ElementType int // ErrCheckpointNotFound indicates that the pserver checkpoint could // not be found. var ErrCheckpointNotFound = errors.New("checkpoint not found") // RPC error message. const ( AlreadyInitialized = "pserver already initialized" Uninitialized = "pserver not fully initialized" WrongChecksum = "checkpoint file checksum validation failed" ) // Supported element types. const ( Int32 ElementType = iota UInt32 Int64 UInt64 Float32 Float64 ) // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string ElementType ElementType Content []byte } // ParameterWithConfig contains the parameter and the configuration. type ParameterWithConfig struct { Param Parameter Config []byte // parameter configuration in Proto Buffer format } // checkpointMeta saves checkpoint metadata type checkpointMeta struct { UUID string `json:"uuid"` Path string `json:"path"` MD5 string `json:"md5"` Timestamp int64 `json:"timestamp"` } // Checkpoint is the pserver shard persist in file. type Checkpoint []parameterCheckpoint // Gradient is the gradient of the parameter. type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} idx int checkpointInterval time.Duration checkpointPath string client *EtcdClient mu sync.Mutex optMap map[string]*optimizer } // parameterCheckpoint saves parameter checkpoint. type parameterCheckpoint struct { ParameterWithConfig State []byte } func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second) if err != nil { return } if len(v) == 0 { err = ErrCheckpointNotFound return } if err = json.Unmarshal(v, &meta); err != nil { return } return } // LoadCheckpoint loads checkpoint from file. func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { cpMeta, err := loadMeta(e, idx) if err != nil { return nil, err } content, err := ioutil.ReadFile(cpMeta.Path) if err != nil { return nil, err } // TODO(helin): change MD5 to CRC since CRC is better for file // checksum in our use case (emphasize speed over security). h := md5.New() md5 := hex.EncodeToString(h.Sum(content)) if md5 != cpMeta.MD5 { return nil, errors.New(WrongChecksum) } dec := gob.NewDecoder(bytes.NewReader(content)) var cp Checkpoint if err = dec.Decode(&cp); err != nil { return nil, err } return cp, nil } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { s := &Service{ idx: idx, checkpointInterval: interval, checkpointPath: path, client: client, } s.optMap = make(map[string]*optimizer) s.initialized = make(chan struct{}) if cp != nil { for _, item := range cp { p := ParameterWithConfig{ Param: item.Param, Config: item.Config, } s.optMap[p.Param.Name] = newOptimizer(p, item.State) } } return s, nil } // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error { select { case <-s.initialized: return errors.New(AlreadyInitialized) default: } // TODO(helin): parse parameter config s.mu.Lock() defer s.mu.Unlock() // TODO(helin): check if paramWithConfigs.Param.Content is // properly memory aligned, if not, make copy to a memory // aligned region. s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) return nil } // FinishInitParams tells the parameter server that the parameter // initialization has finished. func (s *Service) FinishInitParams(_ int, _ *int) error { select { case <-s.initialized: return errors.New(AlreadyInitialized) default: } close(s.initialized) go func() { t := time.Tick(s.checkpointInterval) for range t { err := s.checkpoint() if err != nil { log.Errorln(err) } } }() return nil } // SendGrad sends gradient to parameter servers for parameter // optimization. func (s *Service) SendGrad(g Gradient, _ *int) error { select { case <-s.initialized: default: return errors.New(Uninitialized) } s.mu.Lock() defer s.mu.Unlock() o, ok := s.optMap[g.Name] if !ok { return fmt.Errorf("parameter: %s does not exist", g.Name) } return o.UpdateParameter(g) } // GetParam gets parameters from the parameter server. func (s *Service) GetParam(name string, parameter *Parameter) error { <-s.initialized s.mu.Lock() defer s.mu.Unlock() opt, ok := s.optMap[name] if !ok { return fmt.Errorf("parameter: %s does not exist", name) } // The parameter content (a byte slice) may change // during RPC serialization due to write from other // goroutine, we allow it since mini-batch based deep // learning optimization methods are stochastic in // nature. This race condition is allowed deliberately // to save the program from making a copy of the // parameter content. parameter.Name = name parameter.ElementType = opt.elementType parameter.Content = opt.GetWeights() return nil } func traceTime(start time.Time, name string) { elapsed := time.Since(start) log.Infof("%s took %v", name, elapsed) } // checkpoint saves checkpoint to disk. // // checkpoint should be only called after the parameters are // initialized. func (s *Service) checkpoint() (err error) { log.Infoln("Begin save checkpoint.") defer traceTime(time.Now(), "save checkpoint") s.mu.Lock() cp := make([]parameterCheckpoint, len(s.optMap)) index := 0 // TODO(helin): write checkpoint incrementally to reduce memory // footprint during checkpoint. for name, opt := range s.optMap { var pc parameterCheckpoint pc.Param.Name = name pc.Param.ElementType = opt.elementType pc.Param.Content = opt.GetWeights() pc.Config = opt.config pc.State = opt.GetStates() cp[index] = pc index++ } s.mu.Unlock() var buf bytes.Buffer encoder := gob.NewEncoder(&buf) err = encoder.Encode(cp) if err != nil { return } id := uuid.NewV4().String() p := path.Join(s.checkpointPath, id) f, err := os.Create(p) if err != nil { return } defer func() { closeErr := f.Close() if closeErr != nil { if err != nil { log.Errorln(closeErr) } else { // Set closeErr as return value. err = closeErr } } }() writer := bufio.NewWriter(f) _, err = writer.Write(buf.Bytes()) if err != nil { return } err = writer.Flush() if err != nil { return } oldMeta, err := loadMeta(s.client, s.idx) if err == ErrCheckpointNotFound { log.Infoln("Do not have existing checkpoint.") err = nil } if err != nil { return } h := md5.New() md5 := hex.EncodeToString(h.Sum(buf.Bytes())) cpMeta := checkpointMeta{ UUID: id, Timestamp: time.Now().UnixNano(), MD5: md5, Path: p, } json, err := json.Marshal(cpMeta) if err != nil { return } err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false) if err != nil { return } if oldMeta.Path != "" { rmErr := os.Remove(oldMeta.Path) if rmErr != nil { // log error, but still treat checkpoint as // successful. log.Errorln(rmErr) } } return }