提交 55217c96 编写于 作者: H Helin Wang

Implement Pserver RPC, gradient update logic, cgo part

上级 4cf78e6c
package main
import (
"flag"
"net"
"net/http"
"net/rpc"
"strconv"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)
func main() {
port := flag.Int("p", 0, "port of the pserver")
flag.Parse()
s := pserver.NewService()
err := rpc.Register(s)
if err != nil {
panic(err)
}
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
panic(err)
}
err = http.Serve(l, nil)
if err != nil {
panic(err)
}
}
#include <stdlib.h>
#include "optimizer.h"
typedef struct {
double learning_rate;
} SGD_optimizer;
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
o->learning_rate = learning_rate;
return (paddle_optimizer*)o;
}
void paddle_release_optimizer(paddle_optimizer* o) {
free(o);
}
int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes) {
// TODO
return 0;
}
package pserver
/*
#include "optimizer.h"
*/
import "C"
import (
"fmt"
"unsafe"
)
type optimizerType int
const (
sgd optimizerType = iota
)
var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct {
opt *C.paddle_optimizer
}
func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
o := &optimizer{}
o.opt = C.paddle_create_SGD_optimizer(C.double(learning_rate))
return o
}
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
if len(p.Content) != len(g.Content) {
return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content))
}
if p.ElementType != g.ElementType {
return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType)
}
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 {
return fmt.Errorf("optimier returned error code: %d", r)
}
return nil
}
func (o *optimizer) Cleanup() {
if unsafe.Pointer(o.opt) != nullPtr {
C.paddle_release_optimizer(o.opt)
o.opt = (*C.paddle_optimizer)(nullPtr)
}
}
#ifndef PADDLE_PSERVER_OPTIMIZER_H
#define PADDLE_PSERVER_OPTIMIZER_H
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
typedef struct paddle_optimizer paddle_optimizer;
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
void paddle_release_optimizer(paddle_optimizer* o);
int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes);
#endif /* PADDLE_PSERVER_OPTIMIZER_H */
package pserver
import (
"errors"
"fmt"
"sync"
)
// ElementType is the type of elements of a Parameter.
type ElementType int
var ErrUnintialized = errors.New("pserver not initialized")
var ErrAlreadyIntialized = errors.New("pserver already initialized")
// Supported element types
const (
Int32 ElementType = iota
UInt32
Int64
UInt64
Float32
Float64
)
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
ElementType ElementType
Content []byte
}
// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
Param Parameter
Config []byte // parameter configuration in Proto Buffer format
}
// Gradient is the gradient of the parameter.
type Gradient Parameter
type Service struct {
initialized chan struct{}
mu sync.Mutex
opt *optimizer
paramMap map[string]Parameter
}
func NewService() *Service {
s := &Service{}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
return s
}
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
select {
case <-s.initialized:
return ErrAlreadyIntialized
default:
}
s.mu.Lock()
defer s.mu.Unlock()
if s.opt != nil {
s.opt.Cleanup()
}
// TODO(helin): parse learning rate from config
s.opt = newOptimizer(sgd, 0.01)
return nil
}
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select {
case <-s.initialized:
return ErrAlreadyIntialized
default:
}
// TODO(helin): parse parameter config
s.mu.Lock()
defer s.mu.Unlock()
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
return nil
}
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
select {
case <-s.initialized:
return ErrAlreadyIntialized
default:
}
close(s.initialized)
return nil
}
func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
select {
case <-s.initialized:
default:
return ErrUnintialized
}
s.mu.Lock()
s.mu.Unlock()
for _, g := range grads {
if _, ok := s.paramMap[g.Name]; !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
}
var wg sync.WaitGroup
for _, g := range grads {
wg.Add(1)
go func(p Parameter, g Gradient) {
s.opt.UpdateParameter(p, g)
wg.Done()
}(s.paramMap[g.Name], g)
}
wg.Wait()
return nil
}
func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
<-s.initialized
s.mu.Lock()
s.mu.Unlock()
for _, n := range names {
if _, ok := s.paramMap[n]; !ok {
return fmt.Errorf("parameter: %s does not exist", n)
}
}
*parameters = make([]Parameter, len(names))
for i, n := range names {
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
(*parameters)[i] = s.paramMap[n]
}
return nil
}
func (s *Service) SaveModel(path string, dummy *int) error {
<-s.initialized
// TODO
return nil
}
package pserver_test
import (
"reflect"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)
func TestFull(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil {
t.FailNow()
}
var p1 pserver.Parameter
p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err != nil {
t.FailNow()
}
var params []pserver.Parameter
err = s.GetParams([]string{"param_b", "param_a"}, &params)
if err != nil {
t.FailNow()
}
if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
t.FailNow()
}
grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
err = s.SendGrads(grads, &dummy)
if err != nil {
t.FailNow()
}
var params1 []pserver.Parameter
err = s.GetParams([]string{"param_b", "param_a"}, &params1)
if err != nil {
t.FailNow()
}
if len(params) != 2 {
t.FailNow()
}
// we don't care the content, since it's already optimized with gradient
params1[0].Content = nil
params1[0].Content = nil
p.Content = nil
p1.Content = nil
if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
t.FailNow()
}
}
func TestMultipleInit(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
// this is fine, it's possible for client to call init
// multiple times.
err = s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err != pserver.ErrAlreadyIntialized {
t.FailNow()
}
err = s.BeginInitParams(nil, &dummy)
if err != pserver.ErrAlreadyIntialized {
t.FailNow()
}
}
func TestUninitialized(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.SendGrads(nil, &dummy)
if err != pserver.ErrUnintialized {
t.FailNow()
}
}
func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService()
var wg sync.WaitGroup
wg.Add(1)
go func() {
var params []pserver.Parameter
err := s.GetParams(nil, &params)
if err != nil {
t.FailNow()
}
wg.Done()
}()
wg.Add(1)
go func() {
var dummy int
err := s.SaveModel("", &dummy)
if err != nil {
t.FailNow()
}
wg.Done()
}()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err != nil {
t.FailNow()
}
wg.Wait()
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册