package pserver import ( "errors" "fmt" "sync" ) // ElementType is the type of elements of a Parameter. type ElementType int const ( AlreadyInitialized = "pserver already initialized" Uninitialized = "pserver not fully initialized" ) // Supported element types const ( Int32 ElementType = iota UInt32 Int64 UInt64 Float32 Float64 ) // PsDesired is etcd path for store desired pserver count const PsDesired = "/ps_desired" // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string ElementType ElementType Content []byte } // ParameterWithConfig contains the parameter and the configuration. type ParameterWithConfig struct { Param Parameter Config []byte // parameter configuration in Proto Buffer format } // Gradient is the gradient of the parameter. type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} idx int mu sync.Mutex opt *optimizer paramMap map[string]Parameter } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. func NewService(idx int) (*Service, error) { s := &Service{ idx: idx, opt: newOptimizer(sgd, 0.005), } s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) return s, nil } // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { select { case <-s.initialized: return errors.New(AlreadyInitialized) default: } // TODO(helin): parse parameter config s.mu.Lock() defer s.mu.Unlock() // TODO(helin): check if paramWithConfigs.Param.Content is // properly memory aligned, if not, make copy to a memory // aligned region. s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param return nil } // FinishInitParams tells the parameter server that the parameter // initialization has finished. func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { select { case <-s.initialized: return errors.New(AlreadyInitialized) default: } close(s.initialized) return nil } // SendGrad sends gradient to parameter servers for parameter // optimization. func (s *Service) SendGrad(g Gradient, dummy *int) error { select { case <-s.initialized: default: return errors.New(Uninitialized) } s.mu.Lock() defer s.mu.Unlock() p, ok := s.paramMap[g.Name] if !ok { return fmt.Errorf("parameter: %s does not exist", g.Name) } return s.opt.UpdateParameter(p, g) } // GetParam gets parameters from the parameter server. func (s *Service) GetParam(name string, parameter *Parameter) error { <-s.initialized s.mu.Lock() defer s.mu.Unlock() p, ok := s.paramMap[name] if !ok { return fmt.Errorf("parameter: %s does not exist", name) } // The parameter content (a byte slice) may change // during RPC serialization due to write from other // goroutine, we allow it since mini-batch based deep // learning optimization methods are stochastic in // nature. This race condition is allowed deliberately // to save the program from making a copy of the // paramter content. *parameter = p return nil } // Save tells the parameter server to save parameters. func (s *Service) Save(path string, dummy *int) error { <-s.initialized // TODO return nil }