diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 6abb5d4e562ee9d24a1f914757f3b8e4a3e5cb12..5d4ef90f10d3d4faeb43f61a8c20862c2f8dbbd1 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -1,11 +1,36 @@ import py_paddle.swig_paddle as api -from paddle.trainer.config_parser import parse_config +import paddle.trainer.config_parser +import numpy as np + + +def init_parameter(network): + assert isinstance(network, api.GradientMachine) + for each_param in network.getParameters(): + assert isinstance(each_param, api.Parameter) + array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace() + assert isinstance(array, np.ndarray) + for i in xrange(len(array)): + array[i] = np.random.uniform(-1.0, 1.0) def main(): api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores - config = parse_config('simple_mnist_network.py', '') - m = api.GradientMachine.createFromConfigProto(config.model_config) + config = paddle.trainer.config_parser.parse_config( + 'simple_mnist_network.py', '') + + opt_config = api.OptimizationConfig.createFromProto(config.opt_config) + _temp_optimizer_ = api.ParameterOptimizer.create(opt_config) + enable_types = _temp_optimizer_.getParameterTypes() + + m = api.GradientMachine.createFromConfigProto( + config.model_config, api.CREATE_MODE_NORMAL, enable_types) + assert isinstance(m, api.GradientMachine) + init_parameter(network=m) + + updater = api.ParameterUpdater.createLocalUpdater(opt_config) + assert isinstance(updater, api.ParameterUpdater) + updater.init(m) + updater.startPass() if __name__ == '__main__': diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6ad1d79e59b11b2c1f7aacf22d13347b3fd8e0e2..39fe43556595ccfccb6190314ff43d323dc5c2e9 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -5,6 +5,7 @@ set(API_SOURCES Matrix.cpp Parameter.cpp ParameterOptimizer.cpp + ParameterUpdater.cpp SequenceGenerator.cpp Trainer.cpp Util.cpp diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 9194a6371be9e00c037967464ee2b63c1e4f6192..b0fa8beb166b3438d2b6cbf7afd46791979f41bb 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -174,6 +174,7 @@ namespace std { %newobject Parameter::getConfig; %newobject ParameterOptimizer::create; %newobject ParameterOptimizer::needSpecialTraversal; +%newobject ParameterUpdater::createLocalUpdater; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide @@ -193,4 +194,4 @@ namespace std { %ignore OptimizationConfigPrivate; %ignore ParameterTraverseCallbackPrivate; %include "utils/GlobalConstants.h" -%include "api/PaddleAPI.h" \ No newline at end of file +%include "api/PaddleAPI.h" diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 84a66719c33678fc4aeb038bb81a6b7c5d0c93fb..bd413eb1e9d9a945965bdf6767da82b4d631bbb5 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -519,6 +519,7 @@ private: friend class TrainerConfig; friend class ParameterOptimizer; + friend class ParameterUpdater; friend class Trainer; }; @@ -557,6 +558,7 @@ private: ParameterPrivate* m; friend class UpdateCallbackWrapper; friend class GradientMachine; + friend class ParameterUpdater; }; struct ModelConfigPrivate; @@ -772,6 +774,24 @@ private: // Not to use c++ 11 init-list, so we use static var as function default arg. static std::vector defaultParamTypes; friend class Trainer; + friend class ParameterUpdater; +}; + +struct ParameterUpdaterPrivate; +class ParameterUpdater { +private: + ParameterUpdater(); + +public: + static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); + ~ParameterUpdater(); + + void init(const GradientMachine& gm); + + void startPass(); + +private: + ParameterUpdaterPrivate* m; }; struct TrainerPrivate; diff --git a/paddle/api/PaddleAPIPrivate.h b/paddle/api/PaddleAPIPrivate.h index d2b56fc41c8aadb136ad6812f848e764e031073c..905668a62f24fbd8db4a8833d92df2fe43b6d0c1 100644 --- a/paddle/api/PaddleAPIPrivate.h +++ b/paddle/api/PaddleAPIPrivate.h @@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - +#pragma once +#include +#include "PaddleAPI.h" #include "paddle/gserver/gradientmachines/GradientMachine.h" #include "paddle/trainer/TrainerConfigHelper.h" -#pragma once +#include "paddle/parameter/ParameterUpdaterBase.h" struct GradientMachinePrivate { std::shared_ptr machine; @@ -65,3 +67,24 @@ struct ArgumentsPrivate { return *(std::shared_ptr*)(rawPtr); } }; + +struct ParameterUpdaterPrivate { + std::unique_ptr updater; +}; + +struct ParameterPrivate { + std::shared_ptr sharedPtr; + paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater, + // in other situation sharedPtr should + // contains value. + + ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {} + + paddle::Parameter* getPtr() { + if (sharedPtr) { + return sharedPtr.get(); + } else { + return rawPtr; + } + } +}; diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index 4eed00a84a695f2c48ff93b33419ae2b3dd03768..41cf50043cc2b076dad49b9e772252b9243f39d6 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -14,21 +14,7 @@ limitations under the License. */ #include "paddle/parameter/Parameter.h" #include "PaddleAPI.h" - -struct ParameterPrivate { - std::shared_ptr sharedPtr; - paddle::Parameter* rawPtr; - - ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {} - - paddle::Parameter* getPtr() { - if (sharedPtr) { - return sharedPtr.get(); - } else { - return rawPtr; - } - } -}; +#include "PaddleAPIPrivate.h" Parameter::Parameter() : m(new ParameterPrivate()) {} diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp new file mode 100644 index 0000000000000000000000000000000000000000..af5b746a7cd0825dcb6839b64e464228713efbd5 --- /dev/null +++ b/paddle/api/ParameterUpdater.cpp @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "PaddleAPI.h" + +#include "PaddleAPIPrivate.h" +#include "paddle/trainer/ThreadParameterUpdater.h" + +ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} + +ParameterUpdater *ParameterUpdater::createLocalUpdater( + OptimizationConfig *config) { + auto param = new ParameterUpdater(); + param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); + return param; +} + +ParameterUpdater::~ParameterUpdater() { delete m; } + +void ParameterUpdater::init(const GradientMachine &gm) { + m->updater->init(gm.m->machine->getParameters()); +} + +void ParameterUpdater::startPass() { m->updater->startPass(); }