提交 8b4cbcfc 编写于 作者: Y Yu Yang

Start doing mnist_train_api

上级 06944ee1
import py_paddle.swig_paddle as api
from paddle.trainer.config_parser import parse_config
import paddle.trainer.config_parser
import numpy as np
def init_parameter(network):
assert isinstance(network, api.GradientMachine)
for each_param in network.getParameters():
assert isinstance(each_param, api.Parameter)
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
assert isinstance(array, np.ndarray)
for i in xrange(len(array)):
array[i] = np.random.uniform(-1.0, 1.0)
def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = parse_config('simple_mnist_network.py', '')
m = api.GradientMachine.createFromConfigProto(config.model_config)
config = paddle.trainer.config_parser.parse_config(
'simple_mnist_network.py', '')
opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
m = api.GradientMachine.createFromConfigProto(
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
assert isinstance(m, api.GradientMachine)
init_parameter(network=m)
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
updater.init(m)
updater.startPass()
if __name__ == '__main__':
......
......@@ -5,6 +5,7 @@ set(API_SOURCES
Matrix.cpp
Parameter.cpp
ParameterOptimizer.cpp
ParameterUpdater.cpp
SequenceGenerator.cpp
Trainer.cpp
Util.cpp
......
......@@ -174,6 +174,7 @@ namespace std {
%newobject Parameter::getConfig;
%newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;
%feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide
......
......@@ -519,6 +519,7 @@ private:
friend class TrainerConfig;
friend class ParameterOptimizer;
friend class ParameterUpdater;
friend class Trainer;
};
......@@ -557,6 +558,7 @@ private:
ParameterPrivate* m;
friend class UpdateCallbackWrapper;
friend class GradientMachine;
friend class ParameterUpdater;
};
struct ModelConfigPrivate;
......@@ -772,6 +774,24 @@ private:
// Not to use c++ 11 init-list, so we use static var as function default arg.
static std::vector<int> defaultParamTypes;
friend class Trainer;
friend class ParameterUpdater;
};
struct ParameterUpdaterPrivate;
class ParameterUpdater {
private:
ParameterUpdater();
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
~ParameterUpdater();
void init(const GradientMachine& gm);
void startPass();
private:
ParameterUpdaterPrivate* m;
};
struct TrainerPrivate;
......
......@@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "PaddleAPI.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h"
#pragma once
#include "paddle/parameter/ParameterUpdaterBase.h"
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
......@@ -65,3 +67,24 @@ struct ArgumentsPrivate {
return *(std::shared_ptr<T>*)(rawPtr);
}
};
struct ParameterUpdaterPrivate {
std::unique_ptr<paddle::ParameterUpdater> updater;
};
struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
// in other situation sharedPtr should
// contains value.
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
......@@ -14,21 +14,7 @@ limitations under the License. */
#include "paddle/parameter/Parameter.h"
#include "PaddleAPI.h"
struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr;
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
#include "PaddleAPIPrivate.h"
Parameter::Parameter() : m(new ParameterPrivate()) {}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/ThreadParameterUpdater.h"
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) {
auto param = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
return param;
}
ParameterUpdater::~ParameterUpdater() { delete m; }
void ParameterUpdater::init(const GradientMachine &gm) {
m->updater->init(gm.m->machine->getParameters());
}
void ParameterUpdater::startPass() { m->updater->startPass(); }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册