From 37594eae1737ad7c95016f48a385521ceb0de529 Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Tue, 13 Jun 2017 08:15:12 +0800 Subject: [PATCH] add paramConfig for each parameter --- go/pserver/cclient/test/main.c | 2 -- go/pserver/cclient/test/test_cclient.c | 3 +-- go/pserver/optimizer.c | 1 - go/pserver/service.go | 6 +----- paddle/trainer/NewRemoteParameterUpdater.cpp | 8 ++++++-- paddle/trainer/NewRemoteParameterUpdater.h | 1 - 6 files changed, 8 insertions(+), 13 deletions(-) diff --git a/go/pserver/cclient/test/main.c b/go/pserver/cclient/test/main.c index 72ec3590768..6adc3c9b533 100644 --- a/go/pserver/cclient/test/main.c +++ b/go/pserver/cclient/test/main.c @@ -76,7 +76,5 @@ retry: fail(); } - printf("test success!\n"); - return 0; } diff --git a/go/pserver/cclient/test/test_cclient.c b/go/pserver/cclient/test/test_cclient.c index 50ba2d5597a..9083064eeeb 100644 --- a/go/pserver/cclient/test/test_cclient.c +++ b/go/pserver/cclient/test/test_cclient.c @@ -21,7 +21,7 @@ void print_parameter(paddle_gradient* param) { printf("content_len: %d\n", param->content_len); printf("content_type: %d\n", param->element_type); int i; - for (i = 0; i < param->content_len / sizeof(real); ++i) { + for (i = 0; i < param->content_len / (int)sizeof(real); ++i) { printf("%f ", ((float*)param->content)[i]); } printf("\n\n"); @@ -110,6 +110,5 @@ retry: fail(); } - printf("test success!\n"); return 0; } diff --git a/go/pserver/optimizer.c b/go/pserver/optimizer.c index 48bbceb343b..f16ba2cbf8e 100644 --- a/go/pserver/optimizer.c +++ b/go/pserver/optimizer.c @@ -32,7 +32,6 @@ int update_SGD(void* optimizer, const void* gradient, int num_bytes) { SGD_optimizer* o = (SGD_optimizer*)optimizer; - // TODO(a simple SGD implement) float* parameter = (float*)buffer; float* grad = (float*)gradient; diff --git a/go/pserver/service.go b/go/pserver/service.go index ab814662b6b..7d2a1fea865 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -29,10 +29,6 @@ type Parameter struct { Content []byte } -func (p *Parameter) toString() { - fmt.Println(p.Name, p.ElementType, p.Content) -} - // ParameterWithConfig contains the parameter and the configuration. type ParameterWithConfig struct { Param Parameter @@ -53,7 +49,7 @@ type Service struct { // NewService creates a new service. func NewService() *Service { - s := &Service{opt: newOptimizer(sgd, 0.01)} + s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) return s diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp index d554e09759c..b3655d9d025 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.cpp +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -31,7 +31,6 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater( void NewRemoteParameterUpdater::init( const std::vector<ParameterPtr> ¶meters) { ParameterUpdater::init(parameters); - LOG(INFO) << "NewRemoteParameterUpdater init in"; for (auto ¶ : parameters_) { para->getBuf(PARAMETER_VALUE)->zeroMem(); @@ -58,7 +57,12 @@ void NewRemoteParameterUpdater::init( if (paddle_begin_init_params(parameterClient_)) { LOG(INFO) << "paddle_begin_init_params start"; for (int i = 0; i < parameterSize(); ++i) { - paddle_init_param(parameterClient_, *newParameters_[i], NULL, 0); + auto paramConfig = parameters_[i]->getConfig(); + std::string bytes = paramConfig.SerializeAsString(); + const char *array = bytes.data(); + int size = (int)bytes.size(); + paddle_init_param( + parameterClient_, *newParameters_[i], (void *)array, size); } paddle_finish_init_params(parameterClient_); LOG(INFO) << "paddle_begin_init_params done"; diff --git a/paddle/trainer/NewRemoteParameterUpdater.h b/paddle/trainer/NewRemoteParameterUpdater.h index b7c0425982b..1f22c15cef5 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.h +++ b/paddle/trainer/NewRemoteParameterUpdater.h @@ -84,7 +84,6 @@ private: for (int i = 0; i < parameterSize(); ++i) { ParameterPtr param = parameters_[i]; - new_params[i]->content_len = 10; new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; new_params[i]->name = (char*)param->getName().c_str(); new_params[i]->content = -- GitLab