提交 8b4cbcfc 编写于 作者: Y Yu Yang

Start doing mnist_train_api

上级 06944ee1
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.trainer.config_parser import parse_config import paddle.trainer.config_parser
import numpy as np
def init_parameter(network):
assert isinstance(network, api.GradientMachine)
for each_param in network.getParameters():
assert isinstance(each_param, api.Parameter)
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
assert isinstance(array, np.ndarray)
for i in xrange(len(array)):
array[i] = np.random.uniform(-1.0, 1.0)
def main(): def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = parse_config('simple_mnist_network.py', '') config = paddle.trainer.config_parser.parse_config(
m = api.GradientMachine.createFromConfigProto(config.model_config) 'simple_mnist_network.py', '')
opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
m = api.GradientMachine.createFromConfigProto(
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
assert isinstance(m, api.GradientMachine)
init_parameter(network=m)
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
updater.init(m)
updater.startPass()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,6 +5,7 @@ set(API_SOURCES ...@@ -5,6 +5,7 @@ set(API_SOURCES
Matrix.cpp Matrix.cpp
Parameter.cpp Parameter.cpp
ParameterOptimizer.cpp ParameterOptimizer.cpp
ParameterUpdater.cpp
SequenceGenerator.cpp SequenceGenerator.cpp
Trainer.cpp Trainer.cpp
Util.cpp Util.cpp
......
...@@ -174,6 +174,7 @@ namespace std { ...@@ -174,6 +174,7 @@ namespace std {
%newobject Parameter::getConfig; %newobject Parameter::getConfig;
%newobject ParameterOptimizer::create; %newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;
%feature("director") UpdateCallback; %feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide %feature("autodoc", 1); // To generate method stub, for code hint in ide
...@@ -193,4 +194,4 @@ namespace std { ...@@ -193,4 +194,4 @@ namespace std {
%ignore OptimizationConfigPrivate; %ignore OptimizationConfigPrivate;
%ignore ParameterTraverseCallbackPrivate; %ignore ParameterTraverseCallbackPrivate;
%include "utils/GlobalConstants.h" %include "utils/GlobalConstants.h"
%include "api/PaddleAPI.h" %include "api/PaddleAPI.h"
\ No newline at end of file
...@@ -519,6 +519,7 @@ private: ...@@ -519,6 +519,7 @@ private:
friend class TrainerConfig; friend class TrainerConfig;
friend class ParameterOptimizer; friend class ParameterOptimizer;
friend class ParameterUpdater;
friend class Trainer; friend class Trainer;
}; };
...@@ -557,6 +558,7 @@ private: ...@@ -557,6 +558,7 @@ private:
ParameterPrivate* m; ParameterPrivate* m;
friend class UpdateCallbackWrapper; friend class UpdateCallbackWrapper;
friend class GradientMachine; friend class GradientMachine;
friend class ParameterUpdater;
}; };
struct ModelConfigPrivate; struct ModelConfigPrivate;
...@@ -772,6 +774,24 @@ private: ...@@ -772,6 +774,24 @@ private:
// Not to use c++ 11 init-list, so we use static var as function default arg. // Not to use c++ 11 init-list, so we use static var as function default arg.
static std::vector<int> defaultParamTypes; static std::vector<int> defaultParamTypes;
friend class Trainer; friend class Trainer;
friend class ParameterUpdater;
};
struct ParameterUpdaterPrivate;
class ParameterUpdater {
private:
ParameterUpdater();
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
~ParameterUpdater();
void init(const GradientMachine& gm);
void startPass();
private:
ParameterUpdaterPrivate* m;
}; };
struct TrainerPrivate; struct TrainerPrivate;
......
...@@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <memory>
#include "PaddleAPI.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h" #include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h" #include "paddle/trainer/TrainerConfigHelper.h"
#pragma once #include "paddle/parameter/ParameterUpdaterBase.h"
struct GradientMachinePrivate { struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine; std::shared_ptr<paddle::GradientMachine> machine;
...@@ -65,3 +67,24 @@ struct ArgumentsPrivate { ...@@ -65,3 +67,24 @@ struct ArgumentsPrivate {
return *(std::shared_ptr<T>*)(rawPtr); return *(std::shared_ptr<T>*)(rawPtr);
} }
}; };
struct ParameterUpdaterPrivate {
std::unique_ptr<paddle::ParameterUpdater> updater;
};
struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
// in other situation sharedPtr should
// contains value.
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
...@@ -14,21 +14,7 @@ limitations under the License. */ ...@@ -14,21 +14,7 @@ limitations under the License. */
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr;
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
Parameter::Parameter() : m(new ParameterPrivate()) {} Parameter::Parameter() : m(new ParameterPrivate()) {}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/ThreadParameterUpdater.h"
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) {
auto param = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
return param;
}
ParameterUpdater::~ParameterUpdater() { delete m; }
void ParameterUpdater::init(const GradientMachine &gm) {
m->updater->init(gm.m->machine->getParameters());
}
void ParameterUpdater::startPass() { m->updater->startPass(); }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册