提交 bcf9f421 编写于 作者: 武毅 提交者: GitHub

Merge pull request #2774 from typhoonzero/fix_newupdater

Fix new remote updater for go pserver
...@@ -19,7 +19,7 @@ def main(): ...@@ -19,7 +19,7 @@ def main():
# create parameters # create parameters
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig #TODO(zhihong) : replace optimizer with new OptimizerConfig
......
...@@ -41,22 +41,24 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -41,22 +41,24 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State s := State
paramBufferSize := C.size_t(len(p.Content) / C.sizeof_float)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": paramBufferSize,
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s), "StateSize": len(s),
}).Info("New Optimizer Created with config:") }).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content))) cbuffer = C.malloc(paramBufferSize)
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize)
var cstate unsafe.Pointer var cstate unsafe.Pointer
if len(s) != 0 { if len(s) != 0 {
cstate = unsafe.Pointer(&s[0]) cstate = unsafe.Pointer(&s[0])
} }
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s))) C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o return o
} }
...@@ -68,8 +70,8 @@ func (o *optimizer) GetWeights() []byte { ...@@ -68,8 +70,8 @@ func (o *optimizer) GetWeights() []byte {
func (o *optimizer) GetStates() []byte { func (o *optimizer) GetStates() []byte {
var cbuffer *C.char var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer) cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len)) return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
} }
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {
......
...@@ -22,7 +22,8 @@ DECLARE_string(save_dir); ...@@ -22,7 +22,8 @@ DECLARE_string(save_dir);
namespace paddle { namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater( NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec) const OptimizationConfig &config, const std::string pserverSpec)
: parameterClient_(-1), : trainerConfig_(config),
parameterClient_(-1),
newParameters_(nullptr), newParameters_(nullptr),
newGradients_(nullptr), newGradients_(nullptr),
pserverSpec_(pserverSpec) {} pserverSpec_(pserverSpec) {}
...@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init( ...@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params start"; LOG(INFO) << "paddle_begin_init_params start";
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig(); auto paramConfig = parameters_[i]->getConfig();
std::string bytes = paramConfig.SerializeAsString(); LOG(INFO) << "old param config: " << paramConfig.DebugString();
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
OptimizerConfig optimizeConfigV2;
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
sgdConfigV2->set_momentum(paramConfig.momentum());
sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr();
constlr->set_learning_rate(paramConfig.learning_rate());
if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms
} else {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
}
std::string bytes = optimizeConfigV2.SerializeAsString();
const char *array = bytes.data(); const char *array = bytes.data();
int size = (int)bytes.size(); int size = (int)bytes.size();
paddle_init_param( paddle_init_param(
...@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) { ...@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
void NewRemoteParameterUpdater::startPass() {} void NewRemoteParameterUpdater::startPass() {}
bool NewRemoteParameterUpdater::finishPass() { return true; } bool NewRemoteParameterUpdater::finishPass() { return true; }
} } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <thread> #include <thread>
#include "OptimizerConfig.pb.h"
#include "ParameterUpdater.h" #include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h" #include "paddle/pserver/ParameterClient2.h"
...@@ -101,6 +102,7 @@ private: ...@@ -101,6 +102,7 @@ private:
} }
protected: protected:
const OptimizationConfig& trainerConfig_;
/// internal parameter client object for exchanging data with pserver /// internal parameter client object for exchanging data with pserver
paddle_pserver_client parameterClient_; paddle_pserver_client parameterClient_;
/// the parameters for new pserver client /// the parameters for new pserver client
......
...@@ -66,6 +66,8 @@ class Optimizer(object): ...@@ -66,6 +66,8 @@ class Optimizer(object):
if use_sparse_remote_updater: if use_sparse_remote_updater:
gradient_machine.prefetch(in_args) gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote() parameter_updater.getParametersRemote()
:param pserver_spec: pserver location, eg: localhost:3000
:return: parameter_updater :return: parameter_updater
""" """
if is_local: if is_local:
......
...@@ -41,6 +41,7 @@ class SGD(object): ...@@ -41,6 +41,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
:param extra_layers: Some layers in the neural network graph are not :param extra_layers: Some layers in the neural network graph are not
in the path of cost layer. in the path of cost layer.
:param pserver_spec: pserver location, eg: localhost:3000
:type extra_layers: paddle.v2.config_base.Layer :type extra_layers: paddle.v2.config_base.Layer
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册