提交 f448edf1 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #2610 from dzhwinter/go_optimizer

Go optimizer: integrate Go with optimizer library
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
go_library(paddle_pserver_cclient STATIC) go_library(paddle_pserver_cclient STATIC)
if(WITH_TESTING)
add_subdirectory(test) add_subdirectory(test)
endif()
cc_binary(main SRCS main.c DEPS paddle_pserver_cclient)
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
add_style_check_target(test_cclient test_cclient.c)
#include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h"
// TODO(helin): Fix: gtest using cmake is not working, using this
// hacky way for now.
#define fail() \
fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
exit(-1);
void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3};
paddle_gradient grad1 = {
"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
paddle_gradient grad2 = {
"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
paddle_gradient* grads[2] = {&grad1, &grad2};
if (paddle_send_grads(c, grads, 2)) {
fail();
}
}
void getParams(paddle_pserver_client c) {
paddle_parameter param_a;
paddle_parameter param_b;
char name_a[] = "param_a";
char name_b[] = "param_b";
// Must pre-allocate the prameter content before calling paddle_get_params.
unsigned char content_a[2000] = {};
unsigned char content_b[3000] = {};
param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_a.name = name_a;
param_a.content = content_a;
param_a.content_len = 2000;
param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_b.name = name_b;
param_b.content = content_b;
param_b.content_len = 3000;
paddle_parameter* params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) {
fail();
}
}
int main() {
char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
retry:
if (paddle_begin_init_params(c)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
unsigned char content_a[2000] = {1};
unsigned char content_b[3000] = {0};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a;
param.content = content_a;
param.content_len = 2000;
int error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry;
}
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_b;
param.content = content_b;
param.content_len = 3000;
error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry;
}
error = paddle_finish_init_params(c);
if (error != 0) {
goto retry;
}
}
int i;
for (i = 0; i < 100; i++) {
sendGrads(c);
getParams(c);
}
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0;
}
...@@ -3,113 +3,101 @@ ...@@ -3,113 +3,101 @@
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
typedef float real; // TODO(helin): Fix: gtest using cmake is not working, using this
// hacky way for now.
void fail() { #define fail() \
// TODO(helin): fix: gtest using cmake is not working, using this fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
// hacky way for now.
printf("test failed.\n");
exit(-1); exit(-1);
void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3};
paddle_gradient grad1 = {
"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
paddle_gradient grad2 = {
"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
paddle_gradient *grads[2] = {&grad1, &grad2};
if (paddle_send_grads(c, grads, 2)) {
fail();
}
} }
void print_parameter(paddle_gradient* param) { void getParams(paddle_pserver_client c) {
if (param == NULL) { paddle_parameter param_a;
printf("param is NULL!!\n"); paddle_parameter param_b;
} else { char name_a[] = "param_a";
printf("==== parameter ====\n"); char name_b[] = "param_b";
printf("name: %s\n", param->name); // Must pre-allocate the prameter content before calling paddle_get_params.
printf("content_len: %d\n", param->content_len); unsigned char content_a[2000] = {};
printf("content_type: %d\n", param->element_type); unsigned char content_b[3000] = {};
int i; param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
for (i = 0; i < param->content_len / (int)sizeof(real); ++i) { param_a.name = name_a;
printf("%f ", ((float*)param->content)[i]); param_a.content = content_a;
} param_a.content_len = 2000;
printf("\n\n"); param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_b.name = name_b;
param_b.content = content_b;
param_b.content_len = 3000;
paddle_parameter *params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) {
fail();
} }
} }
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1); paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char *config_proto;
char* names[] = {"param_a", "param_b"}; size_t config_proto_len = 0;
ssize_t nread;
FILE *fp = fopen("testdata/optimizer.pb.txt", "r");
if (!fp) {
fail();
}
while ((nread = getline(&config_proto, &config_proto_len, fp)) != -1) {
printf("%s", config_proto);
}
fclose(fp);
retry: retry:
printf("init parameter to pserver:\n");
real param_content1[] = {0.1, 0.2, 0.3};
real param_content2[] = {0.4, 0.5, 0.6};
paddle_parameter** params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * 2);
params[0] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[0]->name = names[0];
params[0]->content = (unsigned char*)param_content1;
params[0]->content_len = 3 * sizeof(real);
params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
params[1] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[1]->name = names[1];
params[1]->content = (unsigned char*)param_content2;
params[1]->content_len = 3 * sizeof(real);
params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
if (paddle_begin_init_params(c)) { if (paddle_begin_init_params(c)) {
if (paddle_init_param(c, *params[0], NULL, 0) != 0) { paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
unsigned char content_a[2000] = {1};
unsigned char content_b[3000] = {0};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a;
param.content = content_a;
param.content_len = 2000;
int error =
paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) {
goto retry; goto retry;
} }
if (paddle_init_param(c, *params[1], NULL, 0) != 0) {
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_b;
param.content = content_b;
param.content_len = 3000;
error = paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) {
goto retry; goto retry;
} }
if (paddle_finish_init_params(c) != 0) {
error = paddle_finish_init_params(c);
if (error != 0) {
goto retry; goto retry;
} }
} else {
fail();
}
printf("get inited parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, params, 2) != 0) {
fail();
}
print_parameter(params[0]);
print_parameter(params[1]);
printf("send gradient to pserver:\n");
real gradient_content1[] = {0.01, 0.02, 0.03};
real gradinet_content2[] = {0.04, 0.05, 0.06};
paddle_gradient** grads =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[0]->name = names[0];
grads[0]->content = (unsigned char*)gradient_content1;
grads[0]->content_len = 3 * sizeof(real);
grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[1]->name = names[1];
grads[1]->content = (unsigned char*)gradinet_content2;
grads[1]->content_len = 3 * sizeof(real);
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
printf("print gradient sent to pserver:\n");
print_parameter(grads[0]);
print_parameter(grads[1]);
if (paddle_send_grads(c, grads, 2) != 0) {
fail();
} }
printf("get updated parameters from pserver:\n"); int i;
// get parameters again by reusing the allocated parameter buffers. for (i = 0; i < 100; i++) {
if (paddle_get_params(c, params, 2) != 0) { sendGrads(c);
fail(); getParams(c);
} }
print_parameter(params[0]);
print_parameter(params[1]);
if (paddle_save_model(c, "/tmp/") != 0) { if (paddle_save_model(c, "/tmp/")) {
fail(); fail();
} }
......
...@@ -22,6 +22,8 @@ def main(): ...@@ -22,6 +22,8 @@ def main():
# create optimizer # create optimizer
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
......
package pserver_test package pserver_test
import ( import (
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
...@@ -74,18 +75,22 @@ func TestClientFull(t *testing.T) { ...@@ -74,18 +75,22 @@ func TestClientFull(t *testing.T) {
} }
const numParameter = 100 const numParameter = 100
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
var p pserver.Parameter var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i) p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32 p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100) p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p}) err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
err := c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
#include <stdlib.h>
#include "optimizer.h"
typedef int (*update_func)(void*, void*, paddle_element_type, const void*, int);
typedef void (*release_func)(void*);
typedef struct paddle_optimizer {
update_func update;
release_func release;
void* optimizer;
} paddle_optimizer;
void paddle_release_optimizer(paddle_optimizer* o) {
o->release(o->optimizer);
free(o);
}
int paddle_update_parameter(paddle_optimizer* o,
void* buffer,
paddle_element_type element_type,
const void* gradient,
int num_bytes) {
return o->update(o->optimizer, buffer, element_type, gradient, num_bytes);
}
typedef struct { double learning_rate; } SGD_optimizer;
int update_SGD(void* optimizer,
void* buffer,
paddle_element_type element_type,
const void* gradient,
int num_bytes) {
SGD_optimizer* o = (SGD_optimizer*)optimizer;
float* parameter = (float*)buffer;
float* grad = (float*)gradient;
int i;
for (i = 0; i < num_bytes / sizeof(float); ++i) {
parameter[i] -= o->learning_rate * grad[i];
}
return 0;
}
void release_SGD(void* optimizer) {
SGD_optimizer* o = (SGD_optimizer*)optimizer;
// nothing allocated on heap
}
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
SGD_optimizer* impl = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
impl->learning_rate = learning_rate;
paddle_optimizer* opt = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
opt->update = update_SGD;
opt->release = release_SGD;
opt->optimizer = impl;
return opt;
}
package pserver package pserver
/* // #cgo CFLAGS: -I ../../
#include "optimizer.h" // //FIXME: ldflags contain "build" path
*/ // #cgo LDFLAGS: ../../build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++
// #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h>
// #include <string.h>
import "C" import "C"
import ( import (
"fmt" "fmt"
"unsafe" "unsafe"
)
type optimizerType int
const ( log "github.com/sirupsen/logrus"
sgd optimizerType = iota
) )
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct { type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType
} }
func newOptimizer(t optimizerType, learning_rate float64) *optimizer { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
o := &optimizer{} o := &optimizer{}
o.opt = C.paddle_create_SGD_optimizer(C.double(learning_rate)) o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ConfigSize": len(c),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0)
return o return o
} }
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { func (o *optimizer) GetWeights() []byte {
if len(p.Content) != len(g.Content) { var buffer unsafe.Pointer
return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, len(p.Content), len(g.Content)) buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer)
} return cArrayToSlice(buffer, int(buffer_len)*C.sizeof_float)
}
if p.ElementType != g.ElementType { func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType) if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float)
if r != 0 { if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r) return fmt.Errorf("optimizer update returned error code: %d", r)
} }
......
#ifndef PADDLE_PSERVER_OPTIMIZER_H
#define PADDLE_PSERVER_OPTIMIZER_H
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
struct paddle_optimizer;
struct paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
void paddle_release_optimizer(struct paddle_optimizer* o);
int paddle_update_parameter(struct paddle_optimizer* o,
void* buffer,
paddle_element_type element_type,
const void* gradient,
int num_bytes);
#endif /* PADDLE_PSERVER_OPTIMIZER_H */
package pserver package pserver
import "testing" import (
"io/ioutil"
"testing"
)
func TestSGDCreateRelease(t *testing.T) { func TestOptimizerCreateRelease(t *testing.T) {
o := newOptimizer(sgd, 1) p := Parameter{
Name: "a",
ElementType: Int32,
}
p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param)
o.Cleanup() o.Cleanup()
} }
...@@ -48,9 +48,8 @@ type Service struct { ...@@ -48,9 +48,8 @@ type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
mu sync.Mutex mu sync.Mutex
opt *optimizer optMap map[string]*optimizer
paramMap map[string]Parameter
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
...@@ -58,9 +57,8 @@ type Service struct { ...@@ -58,9 +57,8 @@ type Service struct {
func NewService(idx int) (*Service, error) { func NewService(idx int) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
opt: newOptimizer(sgd, 0.005),
} }
s.paramMap = make(map[string]Parameter) s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s, nil return s, nil
} }
...@@ -81,7 +79,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -81,7 +79,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is // TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
return nil return nil
} }
...@@ -110,12 +108,12 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error { ...@@ -110,12 +108,12 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
p, ok := s.paramMap[g.Name] o, ok := s.optMap[g.Name]
if !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name) return fmt.Errorf("parameter: %s does not exist", g.Name)
} }
return s.opt.UpdateParameter(p, g) return o.UpdateParameter(g)
} }
// GetParam gets parameters from the parameter server. // GetParam gets parameters from the parameter server.
...@@ -124,7 +122,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -124,7 +122,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
p, ok := s.paramMap[name] opt, ok := s.optMap[name]
if !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", name) return fmt.Errorf("parameter: %s does not exist", name)
} }
...@@ -136,7 +134,9 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -136,7 +134,9 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// nature. This race condition is allowed deliberately // nature. This race condition is allowed deliberately
// to save the program from making a copy of the // to save the program from making a copy of the
// paramter content. // paramter content.
*parameter = p parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
return nil return nil
} }
......
package pserver_test package pserver_test
import ( import (
"io/ioutil"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
...@@ -9,7 +10,7 @@ import ( ...@@ -9,7 +10,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
func TestFull(t *testing.T) { func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -18,7 +19,12 @@ func TestFull(t *testing.T) { ...@@ -18,7 +19,12 @@ func TestFull(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -27,7 +33,7 @@ func TestFull(t *testing.T) { ...@@ -27,7 +33,7 @@ func TestFull(t *testing.T) {
p1.Name = "param_b" p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -48,6 +54,7 @@ func TestFull(t *testing.T) { ...@@ -48,6 +54,7 @@ func TestFull(t *testing.T) {
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
...@@ -142,7 +149,12 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -142,7 +149,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
...@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS ...@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies}) add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_simple_unittest(serialization_test) add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test) add_simple_unittest(parameter_optimizer_test)
......
...@@ -5,6 +5,8 @@ import paddle.trainer_config_helpers.optimizers as v1_optimizers ...@@ -5,6 +5,8 @@ import paddle.trainer_config_helpers.optimizers as v1_optimizers
""" """
Optimizers(update equation) for SGD method. Optimizers(update equation) for SGD method.
TODO(zhihong) : create new optimizer with proto config, add new optimizer here
TODO(yuyang18): Complete comments. TODO(yuyang18): Complete comments.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册