提交 bcf9f421 编写于 作者: 武毅 提交者: GitHub

Merge pull request #2774 from typhoonzero/fix_newupdater

Fix new remote updater for go pserver
......@@ -19,7 +19,7 @@ def main():
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
......
......@@ -41,22 +41,24 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content) / C.sizeof_float)
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ParamSize": paramBufferSize,
"ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
cbuffer = C.malloc(paramBufferSize)
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize)
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o
}
......@@ -68,8 +70,8 @@ func (o *optimizer) GetWeights() []byte {
func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
}
func (o *optimizer) UpdateParameter(g Gradient) error {
......
......@@ -22,7 +22,8 @@ DECLARE_string(save_dir);
namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec)
: parameterClient_(-1),
: trainerConfig_(config),
parameterClient_(-1),
newParameters_(nullptr),
newGradients_(nullptr),
pserverSpec_(pserverSpec) {}
......@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params start";
for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig();
std::string bytes = paramConfig.SerializeAsString();
LOG(INFO) << "old param config: " << paramConfig.DebugString();
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
OptimizerConfig optimizeConfigV2;
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
sgdConfigV2->set_momentum(paramConfig.momentum());
sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr();
constlr->set_learning_rate(paramConfig.learning_rate());
if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms
} else {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
}
std::string bytes = optimizeConfigV2.SerializeAsString();
const char *array = bytes.data();
int size = (int)bytes.size();
paddle_init_param(
......@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
void NewRemoteParameterUpdater::startPass() {}
bool NewRemoteParameterUpdater::finishPass() { return true; }
}
} // namespace paddle
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <functional>
#include <thread>
#include "OptimizerConfig.pb.h"
#include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h"
......@@ -101,6 +102,7 @@ private:
}
protected:
const OptimizationConfig& trainerConfig_;
/// internal parameter client object for exchanging data with pserver
paddle_pserver_client parameterClient_;
/// the parameters for new pserver client
......
......@@ -66,6 +66,8 @@ class Optimizer(object):
if use_sparse_remote_updater:
gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote()
:param pserver_spec: pserver location, eg: localhost:3000
:return: parameter_updater
"""
if is_local:
......
......@@ -41,6 +41,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters
:param extra_layers: Some layers in the neural network graph are not
in the path of cost layer.
:param pserver_spec: pserver location, eg: localhost:3000
:type extra_layers: paddle.v2.config_base.Layer
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册