提交 26d95a6b 编写于 作者: W wuyi05

fix new remote updater for go pserver

上级 7b810553
...@@ -19,7 +19,7 @@ def main(): ...@@ -19,7 +19,7 @@ def main():
# create parameters # create parameters
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig #TODO(zhihong) : replace optimizer with new OptimizerConfig
......
...@@ -42,12 +42,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { ...@@ -42,12 +42,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
c := paramWithConfigs.Config c := paramWithConfigs.Config
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": len(p.Content) / C.sizeof_float,
"ConfigSize": len(c), "ConfigSize": len(c),
}).Info("New Optimizer Created with config:") }).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content))) cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)/C.sizeof_float))
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0) (*C.char)(nullPtr), 0)
......
...@@ -22,7 +22,8 @@ DECLARE_string(save_dir); ...@@ -22,7 +22,8 @@ DECLARE_string(save_dir);
namespace paddle { namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater( NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec) const OptimizationConfig &config, const std::string pserverSpec)
: parameterClient_(-1), : trainerConfig_(config),
parameterClient_(-1),
newParameters_(nullptr), newParameters_(nullptr),
newGradients_(nullptr), newGradients_(nullptr),
pserverSpec_(pserverSpec) {} pserverSpec_(pserverSpec) {}
...@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init( ...@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params start"; LOG(INFO) << "paddle_begin_init_params start";
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig(); auto paramConfig = parameters_[i]->getConfig();
std::string bytes = paramConfig.SerializeAsString(); LOG(INFO) << "old param config: " << paramConfig.DebugString();
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
OptimizerConfig optimizeConfigV2;
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
sgdConfigV2->set_momentum(paramConfig.momentum());
sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr();
constlr->set_learning_rate(paramConfig.learning_rate());
if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms
} else {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
}
std::string bytes = optimizeConfigV2.SerializeAsString();
const char *array = bytes.data(); const char *array = bytes.data();
int size = (int)bytes.size(); int size = (int)bytes.size();
paddle_init_param( paddle_init_param(
...@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) { ...@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
void NewRemoteParameterUpdater::startPass() {} void NewRemoteParameterUpdater::startPass() {}
bool NewRemoteParameterUpdater::finishPass() { return true; } bool NewRemoteParameterUpdater::finishPass() { return true; }
} } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <thread> #include <thread>
#include "OptimizerConfig.pb.h"
#include "ParameterUpdater.h" #include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h" #include "paddle/pserver/ParameterClient2.h"
...@@ -101,6 +102,7 @@ private: ...@@ -101,6 +102,7 @@ private:
} }
protected: protected:
const OptimizationConfig& trainerConfig_;
/// internal parameter client object for exchanging data with pserver /// internal parameter client object for exchanging data with pserver
paddle_pserver_client parameterClient_; paddle_pserver_client parameterClient_;
/// the parameters for new pserver client /// the parameters for new pserver client
......
...@@ -66,6 +66,8 @@ class Optimizer(object): ...@@ -66,6 +66,8 @@ class Optimizer(object):
if use_sparse_remote_updater: if use_sparse_remote_updater:
gradient_machine.prefetch(in_args) gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote() parameter_updater.getParametersRemote()
:param pserver_spec: pserver location, eg: localhost:3000
:return: parameter_updater :return: parameter_updater
""" """
if is_local: if is_local:
......
...@@ -41,6 +41,7 @@ class SGD(object): ...@@ -41,6 +41,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
:param extra_layers: Some layers in the neural network graph are not :param extra_layers: Some layers in the neural network graph are not
in the path of cost layer. in the path of cost layer.
:param pserver_spec: pserver location, eg: localhost:3000
:type extra_layers: paddle.v2.config_base.Layer :type extra_layers: paddle.v2.config_base.Layer
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册